summaryrefslogtreecommitdiff
path: root/internal/config
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2024-09-03 10:58:34 +0300
committerPaul Buetow <paul@buetow.org>2024-09-03 10:58:34 +0300
commit1ade563588a8de2dfa9356e2c0e2ce33a3861227 (patch)
treee6c6eb61576fb07d62404c47e73a39f417fc401a /internal/config
parent3e33b52642b2b4d5a0c7a8d3a7a6bb269553be1f (diff)
initial generic handling of environment variables
Diffstat (limited to 'internal/config')
-rw-r--r--internal/config/client/client.go12
-rw-r--r--internal/config/config.go80
-rw-r--r--internal/config/config_test.go38
-rw-r--r--internal/config/enver.go93
-rw-r--r--internal/config/server/server.go20
5 files changed, 128 insertions, 115 deletions
diff --git a/internal/config/client/client.go b/internal/config/client/client.go
index d6f2939..ada1a76 100644
--- a/internal/config/client/client.go
+++ b/internal/config/client/client.go
@@ -26,16 +26,16 @@ func New(configFile string) (ClientConfig, error) {
log.Println("Skipping config file:", err)
}
- conf.Servers = config.EnvToStrSlice("GOS_SERVERS", conf.Servers)
- conf.APIKey = config.EnvToStr("GOS_API_KEY", conf.APIKey)
- conf.Editor = config.EnvToStr("GOS_EDITOR", "EDITOR", conf.Editor, "vi")
+ conf.Servers = config.Env[config.ToStringSlice]("GOS_SERVERS", conf.Servers)
+ conf.APIKey = config.Env[config.ToString]("GOS_API_KEY", conf.APIKey)
+ conf.Editor = config.Env[config.ToString]("GOS_EDITOR", "EDITOR", conf.Editor, "vi")
defaultDataDir := fmt.Sprintf("%s/.gos/data", os.Getenv("HOME"))
- conf.DataDir = config.EnvToStr("GOS_DATA_DIR", conf.DataDir, defaultDataDir)
- conf.ComposeFile = config.EnvToStr("GOS_COMPOSE_FILE", conf.ComposeFile, "compose.txt")
+ conf.DataDir = config.Env[config.ToString]("GOS_DATA_DIR", conf.DataDir, defaultDataDir)
+ conf.ComposeFile = config.Env[config.ToString]("GOS_COMPOSE_FILE", conf.ComposeFile, "compose.txt")
defaultLogFile := fmt.Sprintf("%s/.gos/gos.log", os.Getenv("HOME"))
- conf.LogFile = config.EnvToStr("GOS_LOG_FILE", conf.LogFile, defaultLogFile)
+ conf.LogFile = config.Env[config.ToString]("GOS_LOG_FILE", conf.LogFile, defaultLogFile)
return conf, nil
}
diff --git a/internal/config/config.go b/internal/config/config.go
index f71911b..6f697e7 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -4,8 +4,6 @@ import (
"encoding/json"
"io"
"os"
- "strconv"
- "strings"
"unicode"
)
@@ -27,84 +25,6 @@ func FromFile[T any](configFile string) (T, error) {
return conf, err
}
-// Set config from environment variable if present, e.g. hansWurst from GOS_HANS_WURST
-func EnvToStr(keys ...any) string {
- for _, key := range keys {
- switch key := key.(type) {
- case string:
- if key == "" {
- continue
- }
- if !isAllUpperCase(key) {
- return key
- }
- if value := os.Getenv(key); value != "" {
- return value
- }
- case func() string:
- return key()
- }
- }
-
- return ""
-}
-
-func EnvToStrSlice(keys ...any) []string {
- result := strings.Split(EnvToStr(keys...), ",")
- if len(result) == 1 && result[0] == "" {
- return []string{}
- }
- return result
-}
-
-func EnvToInt(keys ...any) int {
- for _, key := range keys {
- switch key := key.(type) {
- case string:
- if key == "" || !isAllUpperCase(key) {
- continue
- }
- strValue := os.Getenv(key)
- if strValue == "" {
- continue
- }
- if intValue, err := strconv.Atoi(strValue); err == nil {
- return intValue
- }
- case int:
- return key
- case func() int:
- return key()
- }
- }
-
- return 0
-}
-
-func EnvToBool(keys ...any) bool {
- for _, key := range keys {
- switch key := key.(type) {
- case string:
- if key == "" || !isAllUpperCase(key) {
- continue
- }
- strValue := os.Getenv(key)
- if strValue == "" {
- continue
- }
- if boolValue, err := strconv.ParseBool(strValue); err == nil {
- return boolValue
- }
- case bool:
- return key
- case func() bool:
- return key()
- }
- }
-
- return false
-}
-
func isAllUpperCase(s string) bool {
for _, r := range s {
if unicode.IsLetter(r) && !unicode.IsUpper(r) {
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index f56d44f..ea96ed9 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -14,7 +14,7 @@ func TestEnvToStr(t *testing.T) {
var (
expected = "foobarbaz"
- got = EnvToStr("GOS_TEST_FROM_ENV")
+ got = Env[ToString]("GOS_TEST_FROM_ENV")
)
if got != expected {
@@ -23,19 +23,19 @@ func TestEnvToStr(t *testing.T) {
t.Logf("got '%s' as expected", expected)
expected = "default value"
- got = EnvToStr("NON_EXISTENT_ENV", expected)
+ got = Env[ToString]("NON_EXISTENT_ENV", expected)
if got != expected {
t.Errorf("got '%s' but expected '%s'", got, expected)
}
t.Logf("got '%s' as expected", expected)
- if got = EnvToStr("NON_EXISTENT_ENV"); got != "" {
+ if got = Env[ToString]("NON_EXISTENT_ENV"); got != "" {
t.Errorf("got '%s' but expected empty string", got)
}
t.Logf("got empty string as expected")
expected = "casio g-shock"
- got = EnvToStr("GOS_WATCH", "", "", "", expected, "")
+ got = Env[ToString]("GOS_WATCH", "", "", "", expected, "")
if got != expected {
t.Errorf("got '%s' but expected '%s'", got, expected)
}
@@ -49,7 +49,7 @@ func TestEnvToStrSlice(t *testing.T) {
var (
expected = []string{"foo", "bar", "baz"}
- got = EnvToStrSlice("GOS_TEST_SLICE_FROM_ENV")
+ got = Env[ToStringSlice]("GOS_TEST_SLICE_FROM_ENV")
)
if !slices.Equal(got, expected) {
t.Errorf("got '%v' but expected '%v'", got, expected)
@@ -57,20 +57,20 @@ func TestEnvToStrSlice(t *testing.T) {
t.Logf("got '%v' as expected", expected)
expected = []string{"default value"}
- got = EnvToStrSlice("NON_EXISTENT_ENV_SLICE", "default value")
+ got = Env[ToStringSlice]("NON_EXISTENT_ENV_SLICE", "default value")
if !slices.Equal(got, expected) {
t.Errorf("got '%v' but expected '%v'", got, expected)
}
t.Logf("got '%v' as expected", expected)
os.Unsetenv("NON_EXISTENT_ENV")
- if got = EnvToStrSlice("NON_EXISTENT_ENV"); len(got) > 0 {
+ if got = Env[ToStringSlice]("NON_EXISTENT_ENV"); len(got) > 0 {
t.Errorf("got '%s' of len '%d' but expected empty slice", got, len(got))
}
t.Logf("got empty slice as expected")
expected = []string{"casio", "g-shock"}
- got = EnvToStrSlice("NON_EXISTENT_ENV", "", "", "", "casio,g-shock", "")
+ got = Env[ToStringSlice]("NON_EXISTENT_ENV", "", "", "", "casio,g-shock", "")
if !slices.Equal(got, expected) {
t.Errorf("got '%v' but expected '%v'", got, expected)
}
@@ -85,7 +85,7 @@ func TestEnvToInt(t *testing.T) {
var (
expected = 1
- got = EnvToInt(t, "GOS_TEST_INT_FROM_ENV")
+ got = Env[ToInteger](t, "GOS_TEST_INT_FROM_ENV")
)
if got != expected {
@@ -94,19 +94,19 @@ func TestEnvToInt(t *testing.T) {
t.Logf("got '%d' as expected", expected)
expected = 999
- got = EnvToInt("NON_EXISTENT_ENV", expected)
+ got = Env[ToInteger]("NON_EXISTENT_ENV", expected)
if got != expected {
t.Errorf("got '%d' but expected '%d'", got, expected)
}
t.Logf("got '%d' as expected", expected)
- if got = EnvToInt("NON_EXISTENT_ENV"); got != 0 {
+ if got = Env[ToInteger]("NON_EXISTENT_ENV"); got != 0 {
t.Errorf("got '%d' but expected zero", got)
}
t.Logf("got zero as expected")
expected = 1234
- got = EnvToInt("GOS_WATCH", "", "", "", expected, "")
+ got = Env[ToInteger]("GOS_WATCH", "", "", "", expected, "")
if got != expected {
t.Errorf("got '%d' but expected '%d'", got, expected)
}
@@ -121,7 +121,7 @@ func TestEnvToBool(t *testing.T) {
var (
expected = true
- got = EnvToBool(t, "GOS_TEST_BOOL_FROM_ENV")
+ got = Env[ToBool](t, "GOS_TEST_BOOL_FROM_ENV")
)
if got != expected {
@@ -130,19 +130,19 @@ func TestEnvToBool(t *testing.T) {
t.Logf("got '%t' as expected", expected)
expected = false
- got = EnvToBool("NON_EXISTENT_ENV", expected)
+ got = Env[ToBool]("NON_EXISTENT_ENV", expected)
if got != expected {
t.Errorf("got '%t' but expected '%t'", got, expected)
}
t.Logf("got '%t' as expected", expected)
- if got = EnvToBool("NON_EXISTENT_ENV"); got {
+ if got = Env[ToBool]("NON_EXISTENT_ENV"); got {
t.Errorf("got '%t' but expected false", got)
}
t.Logf("got 'false' as expected")
expected = true
- got = EnvToBool("NON_EXISTENT_ENV", "", "", "", expected, "")
+ got = Env[ToBool]("NON_EXISTENT_ENV", "", "", "", expected, "")
if got != expected {
t.Errorf("got '%t' but expected '%t'", got, expected)
}
@@ -157,7 +157,7 @@ func TestSecondENV(t *testing.T) {
var (
expected = "hx"
- got = EnvToStr("GOS_NONEXISTANT", "EDITOR", "notepad.exe")
+ got = Env[ToString]("GOS_NONEXISTANT", "EDITOR", "notepad.exe")
)
if expected != got {
@@ -183,7 +183,7 @@ func TestDefaultStrCB(t *testing.T) {
var (
expected = "hello"
- got = EnvToStr("GOS_NONEXISTANT", func() string {
+ got = Env[ToString]("GOS_NONEXISTANT", func() string {
return "hello"
})
)
@@ -199,7 +199,7 @@ func TestDefaultIntCB(t *testing.T) {
var (
expected = 666
- got = EnvToInt("GOS_NONEXISTANT", func() int {
+ got = Env[ToInteger]("GOS_NONEXISTANT", func() int {
return 666
})
)
diff --git a/internal/config/enver.go b/internal/config/enver.go
new file mode 100644
index 0000000..c9b3360
--- /dev/null
+++ b/internal/config/enver.go
@@ -0,0 +1,93 @@
+package config
+
+import (
+ "os"
+ "strconv"
+ "strings"
+)
+
+type enverConstraint interface {
+ ~int | ~bool | ~string | []string
+}
+
+type enver[T enverConstraint] interface {
+ // Return T value from input string
+ fromStr(value string) T
+ // Return T's zero value
+ zero() T
+}
+
+func Env[U enver[T], T enverConstraint](keys ...any) T {
+ var enver U
+
+ for _, key := range keys {
+ switch key := key.(type) {
+ case string:
+ if key == "" {
+ continue
+ }
+ if !isAllUpperCase(key) {
+ return enver.fromStr(key)
+ }
+ if value := os.Getenv(key); value != "" {
+ return enver.fromStr(value)
+ }
+ case func() T:
+ return key()
+ }
+ }
+
+ return enver.zero()
+}
+
+type ToString struct{}
+
+func (ToString) fromStr(str string) string {
+ return str
+}
+
+func (ToString) zero() string {
+ return ""
+}
+
+type ToStringSlice struct{}
+
+func (s ToStringSlice) fromStr(str string) []string {
+ result := strings.Split(str, ",")
+ if len(result) == 1 && result[0] == "" {
+ return s.zero()
+ }
+ return result
+}
+
+func (ToStringSlice) zero() []string {
+ return []string{}
+}
+
+type ToInteger struct{}
+
+// TODO: Return an error if can't convert to int
+func (s ToInteger) fromStr(str string) int {
+ if result, err := strconv.Atoi(str); err == nil {
+ return result
+ }
+ return s.zero()
+}
+
+func (ToInteger) zero() int {
+ return 0
+}
+
+type ToBool struct{}
+
+// TODO: Return an error if can't convert to bool
+func (s ToBool) fromStr(str string) bool {
+ if result, err := strconv.ParseBool(str); err == nil {
+ return result
+ }
+ return s.zero()
+}
+
+func (ToBool) zero() bool {
+ return false
+}
diff --git a/internal/config/server/server.go b/internal/config/server/server.go
index f0a0be3..48dd632 100644
--- a/internal/config/server/server.go
+++ b/internal/config/server/server.go
@@ -35,14 +35,14 @@ func New(configFile, secretsFile string) (ServerConfig, error) {
return conf, err
}
- conf.ListenAddr = config.EnvToStr("GOS_LISTEN_ADDR", conf.ListenAddr, "localhost:8080")
- conf.Partners = config.EnvToStrSlice("GOS_PARTNERS", conf.Partners)
- conf.APIKey = config.EnvToStr("GOS_API_KEY", conf.APIKey)
- conf.DataDir = config.EnvToStr("GOS_DATA_DIR", conf.DataDir, "data")
- conf.EmailTo = config.EnvToStr("GOS_EMAIL_TO", conf.EmailTo)
- conf.EmailFrom = config.EnvToStr("GOS_EMAIL_FROM", conf.EmailFrom)
-
- conf.SMTPServer = config.EnvToStr("GOS_SMTP_SERVER", conf.SMTPServer, func() string {
+ conf.ListenAddr = config.Env[config.ToString]("GOS_LISTEN_ADDR", conf.ListenAddr, "localhost:8080")
+ conf.Partners = config.Env[config.ToStringSlice]("GOS_PARTNERS", conf.Partners)
+ conf.APIKey = config.Env[config.ToString]("GOS_API_KEY", conf.APIKey)
+ conf.DataDir = config.Env[config.ToString]("GOS_DATA_DIR", conf.DataDir, "data")
+ conf.EmailTo = config.Env[config.ToString]("GOS_EMAIL_TO", conf.EmailTo)
+ conf.EmailFrom = config.Env[config.ToString]("GOS_EMAIL_FROM", conf.EmailFrom)
+
+ conf.SMTPServer = config.Env[config.ToString]("GOS_SMTP_SERVER", conf.SMTPServer, func() string {
hostname, err := os.Hostname()
if err != nil {
log.Fatal(err)
@@ -51,8 +51,8 @@ func New(configFile, secretsFile string) (ServerConfig, error) {
})
const oneHour = 3600
- conf.MergeIntervalS = config.EnvToInt("GOS_MERGE_INTERVAL", oneHour)
- conf.ScheduleIntervalS = config.EnvToInt("GOS_SCHEDULER_INTERVAL", oneHour*6)
+ conf.MergeIntervalS = config.Env[config.ToInteger]("GOS_MERGE_INTERVAL", oneHour)
+ conf.ScheduleIntervalS = config.Env[config.ToInteger]("GOS_SCHEDULER_INTERVAL", oneHour*6)
return conf, nil
}