diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-13 22:52:46 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-02-13 22:52:46 +0200 |
| commit | cd5a3614baab756a41d764b79308afeea93f12dd (patch) | |
| tree | efc8c31e8b162ca2121ba92c841322119e6d3b04 /internal/config | |
| parent | bf7c6ade292a6444877797c8d699d147aceb57cc (diff) | |
Remove Perl version and build files; add .gitignore for .serena/
Amp-Thread-ID: https://ampcode.com/threads/T-019c58b3-06fb-733d-8fc1-f268fe7f70d5
Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'internal/config')
| -rw-r--r-- | internal/config/config.go | 263 | ||||
| -rw-r--r-- | internal/config/config_test.go | 123 |
2 files changed, 386 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..7551a10 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,263 @@ +package config + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/loadbars/loadbars/internal/constants" +) + +// Config holds all loadbars configuration (file + CLI). +// Defaults match the Perl Shared.pm %C. +type Config struct { + Hosts []string // Each entry is "host" or "host:user" + Title string + BarWidth int + CPUAverage int + Extended bool + HasAgent bool + Height int + MaxWidth int + NetAverage int + NetInt string + NetLink string + ShowCores bool + ShowMem bool + ShowNet bool + SSHOpts string + Cluster string +} + +// Default returns a Config with default values. +func Default() Config { + return Config{ + BarWidth: 20, + CPUAverage: 10, + Extended: false, + HasAgent: false, + Height: 150, + MaxWidth: 1900, + NetAverage: 15, + NetLink: "gbit", + ShowCores: false, + ShowMem: false, + ShowNet: false, + } +} + +// ConfFilePath returns the full path to the config file (~/.loadbarsrc). +func ConfFilePath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("home dir: %w", err) + } + return filepath.Join(home, constants.ConfFile), nil +} + +// Load reads config from the config file and merges into c. Unknown keys are ignored. +func (c *Config) Load() error { + path, err := ConfFilePath() + if err != nil { + return err + } + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("open config: %w", err) + } + defer f.Close() + return c.parseReader(f) +} + +func (c *Config) parseReader(f *os.File) error { + validKeys := map[string]bool{ + "title": true, "barwidth": true, "cpuaverage": true, "extended": true, + "hasagent": true, "height": true, "maxwidth": true, "netaverage": true, + "netint": true, "netlink": true, "showcores": true, "showmem": true, + "shownet": true, "sshopts": true, "cluster": true, + } + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if idx := strings.Index(line, "#"); idx >= 0 { + line = strings.TrimSpace(line[:idx]) + } + if line == "" { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + if !validKeys[key] { + continue + } + c.set(key, val) + } + return scanner.Err() +} + +func (c *Config) set(key, val string) { + switch key { + case "title": + c.Title = val + case "barwidth": + if n, err := strconv.Atoi(val); err == nil { + c.BarWidth = n + } + case "cpuaverage": + if n, err := strconv.Atoi(val); err == nil { + c.CPUAverage = n + } + case "extended": + c.Extended = parseBool(val) + case "hasagent": + c.HasAgent = parseBool(val) + case "height": + if n, err := strconv.Atoi(val); err == nil { + c.Height = n + } + case "maxwidth": + if n, err := strconv.Atoi(val); err == nil { + c.MaxWidth = n + } + case "netaverage": + if n, err := strconv.Atoi(val); err == nil { + c.NetAverage = n + } + case "netint": + c.NetInt = val + case "netlink": + c.NetLink = val + case "showcores": + c.ShowCores = parseBool(val) + case "showmem": + c.ShowMem = parseBool(val) + case "shownet": + c.ShowNet = parseBool(val) + case "sshopts": + c.SSHOpts = val + case "cluster": + c.Cluster = val + } +} + +func parseBool(s string) bool { + s = strings.TrimSpace(strings.ToLower(s)) + return s == "1" || s == "true" || s == "yes" +} + +// Write saves the current config to the config file (excluding title). +func (c *Config) Write() error { + path, err := ConfFilePath() + if err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("create config: %w", err) + } + defer f.Close() + return c.writeTo(f) +} + +func (c *Config) writeTo(f *os.File) error { + w := bufio.NewWriter(f) + writeInt := func(key string, v int) { fmt.Fprintf(w, "%s=%d\n", key, v) } + writeStr := func(key, v string) { fmt.Fprintf(w, "%s=%s\n", key, v) } + writeBool := func(key string, v bool) { + val := "0" + if v { + val = "1" + } + fmt.Fprintf(w, "%s=%s\n", key, val) + } + writeInt("barwidth", c.BarWidth) + writeInt("cpuaverage", c.CPUAverage) + writeBool("extended", c.Extended) + writeBool("hasagent", c.HasAgent) + writeInt("height", c.Height) + writeInt("maxwidth", c.MaxWidth) + writeInt("netaverage", c.NetAverage) + writeStr("netint", c.NetInt) + writeStr("netlink", c.NetLink) + writeBool("showcores", c.ShowCores) + writeBool("showmem", c.ShowMem) + writeBool("shownet", c.ShowNet) + writeStr("sshopts", c.SSHOpts) + writeStr("cluster", c.Cluster) + return w.Flush() +} + +// GetClusterHosts resolves a cluster name from /etc/clusters into a list of hosts. +func GetClusterHosts(cluster string) ([]string, error) { + return GetClusterHostsFromFile(cluster, constants.CSSHConfFile) +} + +// GetClusterHostsFromFile resolves a cluster from a clusters file (for testing or custom path). +// Supports recursive cluster references with cycle detection. +func GetClusterHostsFromFile(cluster, path string) ([]string, error) { + return getClusterHostsRec(cluster, path, 1, nil) +} + +func getClusterHostsRec(cluster, path string, depth int, seen map[string]bool) ([]string, error) { + if depth > constants.CSSHMaxRecursion { + return nil, fmt.Errorf("cluster recursion limit reached in %s (possible cycle)", path) + } + if seen == nil { + seen = make(map[string]bool) + } + if seen[cluster] { + return nil, fmt.Errorf("cluster cycle detected: %s", cluster) + } + + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open %s: %w", path, err) + } + defer f.Close() + + var line string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + ln := strings.TrimSpace(scanner.Text()) + if ln == "" || strings.HasPrefix(ln, "#") { + continue + } + fields := strings.Fields(ln) + if len(fields) >= 1 && fields[0] == cluster { + if len(fields) > 1 { + line = strings.Join(fields[1:], " ") + } + break + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + + if line == "" { + return []string{cluster}, nil + } + + seen[cluster] = true + defer delete(seen, cluster) + + var out []string + for _, part := range strings.Fields(line) { + hosts, err := getClusterHostsRec(part, path, depth+1, seen) + if err != nil { + return nil, err + } + out = append(out, hosts...) + } + return out, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..d51feb7 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,123 @@ +package config + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestConfig_parseReader(t *testing.T) { + tests := []struct { + name string + input string + wantBar int + wantExt bool + }{ + {"empty", "", 20, false}, + {"barwidth", "barwidth=42\n", 42, false}, + {"extended_1", "extended=1\n", 20, true}, + {"extended_true", "extended=true\n", 20, true}, + {"comments", "# foo\nbarwidth=10\n# bar\n", 10, false}, + {"unknown_key", "barwidth=5\nunknown=ignored\n", 5, false}, + {"multiple", "barwidth=30\nextended=1\nshowcores=1\n", 30, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Default() + f, _ := os.Open(os.DevNull) + defer f.Close() + // Use a temp file with the content since parseReader takes *os.File + dir := t.TempDir() + path := filepath.Join(dir, "rc") + if err := os.WriteFile(path, []byte(tt.input), 0600); err != nil { + t.Fatal(err) + } + f2, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + if err := c.parseReader(f2); err != nil { + t.Fatal(err) + } + if c.BarWidth != tt.wantBar { + t.Errorf("BarWidth = %d, want %d", c.BarWidth, tt.wantBar) + } + if c.Extended != tt.wantExt { + t.Errorf("Extended = %v, want %v", c.Extended, tt.wantExt) + } + }) + } +} + +func TestConfig_writeTo(t *testing.T) { + c := Default() + c.BarWidth = 25 + c.ShowCores = true + dir := t.TempDir() + path := filepath.Join(dir, "out") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + err = c.writeTo(f) + f.Close() + if err != nil { + t.Fatal(err) + } + data, _ := os.ReadFile(path) + if len(data) == 0 { + t.Error("writeTo wrote nothing") + } + if !bytes.Contains(data, []byte("barwidth=25")) { + t.Errorf("expected barwidth=25 in %s", data) + } + if !bytes.Contains(data, []byte("showcores=1")) { + t.Errorf("expected showcores=1 in %s", data) + } +} + +func TestGetClusterHostsFromFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "clusters") + + tests := []struct { + name string + content string + cluster string + wantHosts []string + wantErr bool + }{ + {"single_host", "foo host1\n", "foo", []string{"host1"}, false}, + {"two_hosts", "bar host1 host2\n", "bar", []string{"host1", "host2"}, false}, + {"missing_returns_cluster", "x y\n", "missing", []string{"missing"}, false}, + {"recursive", "a b\nb c\nc d\n", "a", []string{"d"}, false}, + {"cycle", "a b\nb a\n", "a", nil, true}, + {"comment_ignored", "# comment\na h1\n", "a", []string{"h1"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := os.WriteFile(path, []byte(tt.content), 0600); err != nil { + t.Fatal(err) + } + got, err := GetClusterHostsFromFile(tt.cluster, path) + if (err != nil) != tt.wantErr { + t.Errorf("GetClusterHostsFromFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if len(got) != len(tt.wantHosts) { + t.Errorf("got %v, want %v", got, tt.wantHosts) + return + } + for i := range got { + if got[i] != tt.wantHosts[i] { + t.Errorf("got[%d] = %s, want %s", i, got[i], tt.wantHosts[i]) + } + } + }) + } +} |
