summaryrefslogtreecommitdiff
path: root/internal/config
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-02-13 22:52:46 +0200
committerPaul Buetow <paul@buetow.org>2026-02-13 22:52:46 +0200
commitcd5a3614baab756a41d764b79308afeea93f12dd (patch)
treeefc8c31e8b162ca2121ba92c841322119e6d3b04 /internal/config
parentbf7c6ade292a6444877797c8d699d147aceb57cc (diff)
Remove Perl version and build files; add .gitignore for .serena/
Amp-Thread-ID: https://ampcode.com/threads/T-019c58b3-06fb-733d-8fc1-f268fe7f70d5 Co-authored-by: Amp <amp@ampcode.com>
Diffstat (limited to 'internal/config')
-rw-r--r--internal/config/config.go263
-rw-r--r--internal/config/config_test.go123
2 files changed, 386 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 0000000..7551a10
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,263 @@
+package config
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "github.com/loadbars/loadbars/internal/constants"
+)
+
+// Config holds all loadbars configuration (file + CLI).
+// Defaults match the Perl Shared.pm %C.
+type Config struct {
+ Hosts []string // Each entry is "host" or "host:user"
+ Title string
+ BarWidth int
+ CPUAverage int
+ Extended bool
+ HasAgent bool
+ Height int
+ MaxWidth int
+ NetAverage int
+ NetInt string
+ NetLink string
+ ShowCores bool
+ ShowMem bool
+ ShowNet bool
+ SSHOpts string
+ Cluster string
+}
+
+// Default returns a Config with default values.
+func Default() Config {
+ return Config{
+ BarWidth: 20,
+ CPUAverage: 10,
+ Extended: false,
+ HasAgent: false,
+ Height: 150,
+ MaxWidth: 1900,
+ NetAverage: 15,
+ NetLink: "gbit",
+ ShowCores: false,
+ ShowMem: false,
+ ShowNet: false,
+ }
+}
+
+// ConfFilePath returns the full path to the config file (~/.loadbarsrc).
+func ConfFilePath() (string, error) {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", fmt.Errorf("home dir: %w", err)
+ }
+ return filepath.Join(home, constants.ConfFile), nil
+}
+
+// Load reads config from the config file and merges into c. Unknown keys are ignored.
+func (c *Config) Load() error {
+ path, err := ConfFilePath()
+ if err != nil {
+ return err
+ }
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("open config: %w", err)
+ }
+ defer f.Close()
+ return c.parseReader(f)
+}
+
+func (c *Config) parseReader(f *os.File) error {
+ validKeys := map[string]bool{
+ "title": true, "barwidth": true, "cpuaverage": true, "extended": true,
+ "hasagent": true, "height": true, "maxwidth": true, "netaverage": true,
+ "netint": true, "netlink": true, "showcores": true, "showmem": true,
+ "shownet": true, "sshopts": true, "cluster": true,
+ }
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if idx := strings.Index(line, "#"); idx >= 0 {
+ line = strings.TrimSpace(line[:idx])
+ }
+ if line == "" {
+ continue
+ }
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ continue
+ }
+ key := strings.TrimSpace(parts[0])
+ val := strings.TrimSpace(parts[1])
+ if !validKeys[key] {
+ continue
+ }
+ c.set(key, val)
+ }
+ return scanner.Err()
+}
+
+func (c *Config) set(key, val string) {
+ switch key {
+ case "title":
+ c.Title = val
+ case "barwidth":
+ if n, err := strconv.Atoi(val); err == nil {
+ c.BarWidth = n
+ }
+ case "cpuaverage":
+ if n, err := strconv.Atoi(val); err == nil {
+ c.CPUAverage = n
+ }
+ case "extended":
+ c.Extended = parseBool(val)
+ case "hasagent":
+ c.HasAgent = parseBool(val)
+ case "height":
+ if n, err := strconv.Atoi(val); err == nil {
+ c.Height = n
+ }
+ case "maxwidth":
+ if n, err := strconv.Atoi(val); err == nil {
+ c.MaxWidth = n
+ }
+ case "netaverage":
+ if n, err := strconv.Atoi(val); err == nil {
+ c.NetAverage = n
+ }
+ case "netint":
+ c.NetInt = val
+ case "netlink":
+ c.NetLink = val
+ case "showcores":
+ c.ShowCores = parseBool(val)
+ case "showmem":
+ c.ShowMem = parseBool(val)
+ case "shownet":
+ c.ShowNet = parseBool(val)
+ case "sshopts":
+ c.SSHOpts = val
+ case "cluster":
+ c.Cluster = val
+ }
+}
+
+func parseBool(s string) bool {
+ s = strings.TrimSpace(strings.ToLower(s))
+ return s == "1" || s == "true" || s == "yes"
+}
+
+// Write saves the current config to the config file (excluding title).
+func (c *Config) Write() error {
+ path, err := ConfFilePath()
+ if err != nil {
+ return err
+ }
+ f, err := os.Create(path)
+ if err != nil {
+ return fmt.Errorf("create config: %w", err)
+ }
+ defer f.Close()
+ return c.writeTo(f)
+}
+
+func (c *Config) writeTo(f *os.File) error {
+ w := bufio.NewWriter(f)
+ writeInt := func(key string, v int) { fmt.Fprintf(w, "%s=%d\n", key, v) }
+ writeStr := func(key, v string) { fmt.Fprintf(w, "%s=%s\n", key, v) }
+ writeBool := func(key string, v bool) {
+ val := "0"
+ if v {
+ val = "1"
+ }
+ fmt.Fprintf(w, "%s=%s\n", key, val)
+ }
+ writeInt("barwidth", c.BarWidth)
+ writeInt("cpuaverage", c.CPUAverage)
+ writeBool("extended", c.Extended)
+ writeBool("hasagent", c.HasAgent)
+ writeInt("height", c.Height)
+ writeInt("maxwidth", c.MaxWidth)
+ writeInt("netaverage", c.NetAverage)
+ writeStr("netint", c.NetInt)
+ writeStr("netlink", c.NetLink)
+ writeBool("showcores", c.ShowCores)
+ writeBool("showmem", c.ShowMem)
+ writeBool("shownet", c.ShowNet)
+ writeStr("sshopts", c.SSHOpts)
+ writeStr("cluster", c.Cluster)
+ return w.Flush()
+}
+
+// GetClusterHosts resolves a cluster name from /etc/clusters into a list of hosts.
+func GetClusterHosts(cluster string) ([]string, error) {
+ return GetClusterHostsFromFile(cluster, constants.CSSHConfFile)
+}
+
+// GetClusterHostsFromFile resolves a cluster from a clusters file (for testing or custom path).
+// Supports recursive cluster references with cycle detection.
+func GetClusterHostsFromFile(cluster, path string) ([]string, error) {
+ return getClusterHostsRec(cluster, path, 1, nil)
+}
+
+func getClusterHostsRec(cluster, path string, depth int, seen map[string]bool) ([]string, error) {
+ if depth > constants.CSSHMaxRecursion {
+ return nil, fmt.Errorf("cluster recursion limit reached in %s (possible cycle)", path)
+ }
+ if seen == nil {
+ seen = make(map[string]bool)
+ }
+ if seen[cluster] {
+ return nil, fmt.Errorf("cluster cycle detected: %s", cluster)
+ }
+
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("open %s: %w", path, err)
+ }
+ defer f.Close()
+
+ var line string
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ ln := strings.TrimSpace(scanner.Text())
+ if ln == "" || strings.HasPrefix(ln, "#") {
+ continue
+ }
+ fields := strings.Fields(ln)
+ if len(fields) >= 1 && fields[0] == cluster {
+ if len(fields) > 1 {
+ line = strings.Join(fields[1:], " ")
+ }
+ break
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+
+ if line == "" {
+ return []string{cluster}, nil
+ }
+
+ seen[cluster] = true
+ defer delete(seen, cluster)
+
+ var out []string
+ for _, part := range strings.Fields(line) {
+ hosts, err := getClusterHostsRec(part, path, depth+1, seen)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, hosts...)
+ }
+ return out, nil
+}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..d51feb7
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,123 @@
+package config
+
+import (
+ "bytes"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestConfig_parseReader(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantBar int
+ wantExt bool
+ }{
+ {"empty", "", 20, false},
+ {"barwidth", "barwidth=42\n", 42, false},
+ {"extended_1", "extended=1\n", 20, true},
+ {"extended_true", "extended=true\n", 20, true},
+ {"comments", "# foo\nbarwidth=10\n# bar\n", 10, false},
+ {"unknown_key", "barwidth=5\nunknown=ignored\n", 5, false},
+ {"multiple", "barwidth=30\nextended=1\nshowcores=1\n", 30, true},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := Default()
+ f, _ := os.Open(os.DevNull)
+ defer f.Close()
+ // Use a temp file with the content since parseReader takes *os.File
+ dir := t.TempDir()
+ path := filepath.Join(dir, "rc")
+ if err := os.WriteFile(path, []byte(tt.input), 0600); err != nil {
+ t.Fatal(err)
+ }
+ f2, err := os.Open(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer f2.Close()
+ if err := c.parseReader(f2); err != nil {
+ t.Fatal(err)
+ }
+ if c.BarWidth != tt.wantBar {
+ t.Errorf("BarWidth = %d, want %d", c.BarWidth, tt.wantBar)
+ }
+ if c.Extended != tt.wantExt {
+ t.Errorf("Extended = %v, want %v", c.Extended, tt.wantExt)
+ }
+ })
+ }
+}
+
+func TestConfig_writeTo(t *testing.T) {
+ c := Default()
+ c.BarWidth = 25
+ c.ShowCores = true
+ dir := t.TempDir()
+ path := filepath.Join(dir, "out")
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = c.writeTo(f)
+ f.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(path)
+ if len(data) == 0 {
+ t.Error("writeTo wrote nothing")
+ }
+ if !bytes.Contains(data, []byte("barwidth=25")) {
+ t.Errorf("expected barwidth=25 in %s", data)
+ }
+ if !bytes.Contains(data, []byte("showcores=1")) {
+ t.Errorf("expected showcores=1 in %s", data)
+ }
+}
+
+func TestGetClusterHostsFromFile(t *testing.T) {
+ dir := t.TempDir()
+ path := filepath.Join(dir, "clusters")
+
+ tests := []struct {
+ name string
+ content string
+ cluster string
+ wantHosts []string
+ wantErr bool
+ }{
+ {"single_host", "foo host1\n", "foo", []string{"host1"}, false},
+ {"two_hosts", "bar host1 host2\n", "bar", []string{"host1", "host2"}, false},
+ {"missing_returns_cluster", "x y\n", "missing", []string{"missing"}, false},
+ {"recursive", "a b\nb c\nc d\n", "a", []string{"d"}, false},
+ {"cycle", "a b\nb a\n", "a", nil, true},
+ {"comment_ignored", "# comment\na h1\n", "a", []string{"h1"}, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := os.WriteFile(path, []byte(tt.content), 0600); err != nil {
+ t.Fatal(err)
+ }
+ got, err := GetClusterHostsFromFile(tt.cluster, path)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("GetClusterHostsFromFile() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+ if len(got) != len(tt.wantHosts) {
+ t.Errorf("got %v, want %v", got, tt.wantHosts)
+ return
+ }
+ for i := range got {
+ if got[i] != tt.wantHosts[i] {
+ t.Errorf("got[%d] = %s, want %s", i, got[i], tt.wantHosts[i])
+ }
+ }
+ })
+ }
+}