diff options
| author | Paul Buetow <paul@buetow.org> | 2024-09-03 10:58:34 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2024-09-03 10:58:34 +0300 |
| commit | 1ade563588a8de2dfa9356e2c0e2ce33a3861227 (patch) | |
| tree | e6c6eb61576fb07d62404c47e73a39f417fc401a /internal/config | |
| parent | 3e33b52642b2b4d5a0c7a8d3a7a6bb269553be1f (diff) | |
initial generic handling of environment variables
Diffstat (limited to 'internal/config')
| -rw-r--r-- | internal/config/client/client.go | 12 | ||||
| -rw-r--r-- | internal/config/config.go | 80 | ||||
| -rw-r--r-- | internal/config/config_test.go | 38 | ||||
| -rw-r--r-- | internal/config/enver.go | 93 | ||||
| -rw-r--r-- | internal/config/server/server.go | 20 |
5 files changed, 128 insertions, 115 deletions
diff --git a/internal/config/client/client.go b/internal/config/client/client.go index d6f2939..ada1a76 100644 --- a/internal/config/client/client.go +++ b/internal/config/client/client.go @@ -26,16 +26,16 @@ func New(configFile string) (ClientConfig, error) { log.Println("Skipping config file:", err) } - conf.Servers = config.EnvToStrSlice("GOS_SERVERS", conf.Servers) - conf.APIKey = config.EnvToStr("GOS_API_KEY", conf.APIKey) - conf.Editor = config.EnvToStr("GOS_EDITOR", "EDITOR", conf.Editor, "vi") + conf.Servers = config.Env[config.ToStringSlice]("GOS_SERVERS", conf.Servers) + conf.APIKey = config.Env[config.ToString]("GOS_API_KEY", conf.APIKey) + conf.Editor = config.Env[config.ToString]("GOS_EDITOR", "EDITOR", conf.Editor, "vi") defaultDataDir := fmt.Sprintf("%s/.gos/data", os.Getenv("HOME")) - conf.DataDir = config.EnvToStr("GOS_DATA_DIR", conf.DataDir, defaultDataDir) - conf.ComposeFile = config.EnvToStr("GOS_COMPOSE_FILE", conf.ComposeFile, "compose.txt") + conf.DataDir = config.Env[config.ToString]("GOS_DATA_DIR", conf.DataDir, defaultDataDir) + conf.ComposeFile = config.Env[config.ToString]("GOS_COMPOSE_FILE", conf.ComposeFile, "compose.txt") defaultLogFile := fmt.Sprintf("%s/.gos/gos.log", os.Getenv("HOME")) - conf.LogFile = config.EnvToStr("GOS_LOG_FILE", conf.LogFile, defaultLogFile) + conf.LogFile = config.Env[config.ToString]("GOS_LOG_FILE", conf.LogFile, defaultLogFile) return conf, nil } diff --git a/internal/config/config.go b/internal/config/config.go index f71911b..6f697e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,8 +4,6 @@ import ( "encoding/json" "io" "os" - "strconv" - "strings" "unicode" ) @@ -27,84 +25,6 @@ func FromFile[T any](configFile string) (T, error) { return conf, err } -// Set config from environment variable if present, e.g. hansWurst from GOS_HANS_WURST -func EnvToStr(keys ...any) string { - for _, key := range keys { - switch key := key.(type) { - case string: - if key == "" { - continue - } - if !isAllUpperCase(key) { - return key - } - if value := os.Getenv(key); value != "" { - return value - } - case func() string: - return key() - } - } - - return "" -} - -func EnvToStrSlice(keys ...any) []string { - result := strings.Split(EnvToStr(keys...), ",") - if len(result) == 1 && result[0] == "" { - return []string{} - } - return result -} - -func EnvToInt(keys ...any) int { - for _, key := range keys { - switch key := key.(type) { - case string: - if key == "" || !isAllUpperCase(key) { - continue - } - strValue := os.Getenv(key) - if strValue == "" { - continue - } - if intValue, err := strconv.Atoi(strValue); err == nil { - return intValue - } - case int: - return key - case func() int: - return key() - } - } - - return 0 -} - -func EnvToBool(keys ...any) bool { - for _, key := range keys { - switch key := key.(type) { - case string: - if key == "" || !isAllUpperCase(key) { - continue - } - strValue := os.Getenv(key) - if strValue == "" { - continue - } - if boolValue, err := strconv.ParseBool(strValue); err == nil { - return boolValue - } - case bool: - return key - case func() bool: - return key() - } - } - - return false -} - func isAllUpperCase(s string) bool { for _, r := range s { if unicode.IsLetter(r) && !unicode.IsUpper(r) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f56d44f..ea96ed9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -14,7 +14,7 @@ func TestEnvToStr(t *testing.T) { var ( expected = "foobarbaz" - got = EnvToStr("GOS_TEST_FROM_ENV") + got = Env[ToString]("GOS_TEST_FROM_ENV") ) if got != expected { @@ -23,19 +23,19 @@ func TestEnvToStr(t *testing.T) { t.Logf("got '%s' as expected", expected) expected = "default value" - got = EnvToStr("NON_EXISTENT_ENV", expected) + got = Env[ToString]("NON_EXISTENT_ENV", expected) if got != expected { t.Errorf("got '%s' but expected '%s'", got, expected) } t.Logf("got '%s' as expected", expected) - if got = EnvToStr("NON_EXISTENT_ENV"); got != "" { + if got = Env[ToString]("NON_EXISTENT_ENV"); got != "" { t.Errorf("got '%s' but expected empty string", got) } t.Logf("got empty string as expected") expected = "casio g-shock" - got = EnvToStr("GOS_WATCH", "", "", "", expected, "") + got = Env[ToString]("GOS_WATCH", "", "", "", expected, "") if got != expected { t.Errorf("got '%s' but expected '%s'", got, expected) } @@ -49,7 +49,7 @@ func TestEnvToStrSlice(t *testing.T) { var ( expected = []string{"foo", "bar", "baz"} - got = EnvToStrSlice("GOS_TEST_SLICE_FROM_ENV") + got = Env[ToStringSlice]("GOS_TEST_SLICE_FROM_ENV") ) if !slices.Equal(got, expected) { t.Errorf("got '%v' but expected '%v'", got, expected) @@ -57,20 +57,20 @@ func TestEnvToStrSlice(t *testing.T) { t.Logf("got '%v' as expected", expected) expected = []string{"default value"} - got = EnvToStrSlice("NON_EXISTENT_ENV_SLICE", "default value") + got = Env[ToStringSlice]("NON_EXISTENT_ENV_SLICE", "default value") if !slices.Equal(got, expected) { t.Errorf("got '%v' but expected '%v'", got, expected) } t.Logf("got '%v' as expected", expected) os.Unsetenv("NON_EXISTENT_ENV") - if got = EnvToStrSlice("NON_EXISTENT_ENV"); len(got) > 0 { + if got = Env[ToStringSlice]("NON_EXISTENT_ENV"); len(got) > 0 { t.Errorf("got '%s' of len '%d' but expected empty slice", got, len(got)) } t.Logf("got empty slice as expected") expected = []string{"casio", "g-shock"} - got = EnvToStrSlice("NON_EXISTENT_ENV", "", "", "", "casio,g-shock", "") + got = Env[ToStringSlice]("NON_EXISTENT_ENV", "", "", "", "casio,g-shock", "") if !slices.Equal(got, expected) { t.Errorf("got '%v' but expected '%v'", got, expected) } @@ -85,7 +85,7 @@ func TestEnvToInt(t *testing.T) { var ( expected = 1 - got = EnvToInt(t, "GOS_TEST_INT_FROM_ENV") + got = Env[ToInteger](t, "GOS_TEST_INT_FROM_ENV") ) if got != expected { @@ -94,19 +94,19 @@ func TestEnvToInt(t *testing.T) { t.Logf("got '%d' as expected", expected) expected = 999 - got = EnvToInt("NON_EXISTENT_ENV", expected) + got = Env[ToInteger]("NON_EXISTENT_ENV", expected) if got != expected { t.Errorf("got '%d' but expected '%d'", got, expected) } t.Logf("got '%d' as expected", expected) - if got = EnvToInt("NON_EXISTENT_ENV"); got != 0 { + if got = Env[ToInteger]("NON_EXISTENT_ENV"); got != 0 { t.Errorf("got '%d' but expected zero", got) } t.Logf("got zero as expected") expected = 1234 - got = EnvToInt("GOS_WATCH", "", "", "", expected, "") + got = Env[ToInteger]("GOS_WATCH", "", "", "", expected, "") if got != expected { t.Errorf("got '%d' but expected '%d'", got, expected) } @@ -121,7 +121,7 @@ func TestEnvToBool(t *testing.T) { var ( expected = true - got = EnvToBool(t, "GOS_TEST_BOOL_FROM_ENV") + got = Env[ToBool](t, "GOS_TEST_BOOL_FROM_ENV") ) if got != expected { @@ -130,19 +130,19 @@ func TestEnvToBool(t *testing.T) { t.Logf("got '%t' as expected", expected) expected = false - got = EnvToBool("NON_EXISTENT_ENV", expected) + got = Env[ToBool]("NON_EXISTENT_ENV", expected) if got != expected { t.Errorf("got '%t' but expected '%t'", got, expected) } t.Logf("got '%t' as expected", expected) - if got = EnvToBool("NON_EXISTENT_ENV"); got { + if got = Env[ToBool]("NON_EXISTENT_ENV"); got { t.Errorf("got '%t' but expected false", got) } t.Logf("got 'false' as expected") expected = true - got = EnvToBool("NON_EXISTENT_ENV", "", "", "", expected, "") + got = Env[ToBool]("NON_EXISTENT_ENV", "", "", "", expected, "") if got != expected { t.Errorf("got '%t' but expected '%t'", got, expected) } @@ -157,7 +157,7 @@ func TestSecondENV(t *testing.T) { var ( expected = "hx" - got = EnvToStr("GOS_NONEXISTANT", "EDITOR", "notepad.exe") + got = Env[ToString]("GOS_NONEXISTANT", "EDITOR", "notepad.exe") ) if expected != got { @@ -183,7 +183,7 @@ func TestDefaultStrCB(t *testing.T) { var ( expected = "hello" - got = EnvToStr("GOS_NONEXISTANT", func() string { + got = Env[ToString]("GOS_NONEXISTANT", func() string { return "hello" }) ) @@ -199,7 +199,7 @@ func TestDefaultIntCB(t *testing.T) { var ( expected = 666 - got = EnvToInt("GOS_NONEXISTANT", func() int { + got = Env[ToInteger]("GOS_NONEXISTANT", func() int { return 666 }) ) diff --git a/internal/config/enver.go b/internal/config/enver.go new file mode 100644 index 0000000..c9b3360 --- /dev/null +++ b/internal/config/enver.go @@ -0,0 +1,93 @@ +package config + +import ( + "os" + "strconv" + "strings" +) + +type enverConstraint interface { + ~int | ~bool | ~string | []string +} + +type enver[T enverConstraint] interface { + // Return T value from input string + fromStr(value string) T + // Return T's zero value + zero() T +} + +func Env[U enver[T], T enverConstraint](keys ...any) T { + var enver U + + for _, key := range keys { + switch key := key.(type) { + case string: + if key == "" { + continue + } + if !isAllUpperCase(key) { + return enver.fromStr(key) + } + if value := os.Getenv(key); value != "" { + return enver.fromStr(value) + } + case func() T: + return key() + } + } + + return enver.zero() +} + +type ToString struct{} + +func (ToString) fromStr(str string) string { + return str +} + +func (ToString) zero() string { + return "" +} + +type ToStringSlice struct{} + +func (s ToStringSlice) fromStr(str string) []string { + result := strings.Split(str, ",") + if len(result) == 1 && result[0] == "" { + return s.zero() + } + return result +} + +func (ToStringSlice) zero() []string { + return []string{} +} + +type ToInteger struct{} + +// TODO: Return an error if can't convert to int +func (s ToInteger) fromStr(str string) int { + if result, err := strconv.Atoi(str); err == nil { + return result + } + return s.zero() +} + +func (ToInteger) zero() int { + return 0 +} + +type ToBool struct{} + +// TODO: Return an error if can't convert to bool +func (s ToBool) fromStr(str string) bool { + if result, err := strconv.ParseBool(str); err == nil { + return result + } + return s.zero() +} + +func (ToBool) zero() bool { + return false +} diff --git a/internal/config/server/server.go b/internal/config/server/server.go index f0a0be3..48dd632 100644 --- a/internal/config/server/server.go +++ b/internal/config/server/server.go @@ -35,14 +35,14 @@ func New(configFile, secretsFile string) (ServerConfig, error) { return conf, err } - conf.ListenAddr = config.EnvToStr("GOS_LISTEN_ADDR", conf.ListenAddr, "localhost:8080") - conf.Partners = config.EnvToStrSlice("GOS_PARTNERS", conf.Partners) - conf.APIKey = config.EnvToStr("GOS_API_KEY", conf.APIKey) - conf.DataDir = config.EnvToStr("GOS_DATA_DIR", conf.DataDir, "data") - conf.EmailTo = config.EnvToStr("GOS_EMAIL_TO", conf.EmailTo) - conf.EmailFrom = config.EnvToStr("GOS_EMAIL_FROM", conf.EmailFrom) - - conf.SMTPServer = config.EnvToStr("GOS_SMTP_SERVER", conf.SMTPServer, func() string { + conf.ListenAddr = config.Env[config.ToString]("GOS_LISTEN_ADDR", conf.ListenAddr, "localhost:8080") + conf.Partners = config.Env[config.ToStringSlice]("GOS_PARTNERS", conf.Partners) + conf.APIKey = config.Env[config.ToString]("GOS_API_KEY", conf.APIKey) + conf.DataDir = config.Env[config.ToString]("GOS_DATA_DIR", conf.DataDir, "data") + conf.EmailTo = config.Env[config.ToString]("GOS_EMAIL_TO", conf.EmailTo) + conf.EmailFrom = config.Env[config.ToString]("GOS_EMAIL_FROM", conf.EmailFrom) + + conf.SMTPServer = config.Env[config.ToString]("GOS_SMTP_SERVER", conf.SMTPServer, func() string { hostname, err := os.Hostname() if err != nil { log.Fatal(err) @@ -51,8 +51,8 @@ func New(configFile, secretsFile string) (ServerConfig, error) { }) const oneHour = 3600 - conf.MergeIntervalS = config.EnvToInt("GOS_MERGE_INTERVAL", oneHour) - conf.ScheduleIntervalS = config.EnvToInt("GOS_SCHEDULER_INTERVAL", oneHour*6) + conf.MergeIntervalS = config.Env[config.ToInteger]("GOS_MERGE_INTERVAL", oneHour) + conf.ScheduleIntervalS = config.Env[config.ToInteger]("GOS_SCHEDULER_INTERVAL", oneHour*6) return conf, nil } |
