diff options
Diffstat (limited to 'internal/ssh/server')
| -rw-r--r-- | internal/ssh/server/hostkey.go | 28 | ||||
| -rw-r--r-- | internal/ssh/server/publickeycallback.go | 4 |
2 files changed, 19 insertions, 13 deletions
diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 1df2287..1315351 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -1,7 +1,8 @@ package server import ( - "os" + "errors" + iofs "io/fs" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" @@ -31,18 +32,21 @@ func PrivateHostKey(hostKeyFile string, hostKeyBits int) []byte { } _, err = hostKeyPath.Stat() - - if os.IsNotExist(err) { - dlog.Server.Info("Generating private server RSA host key") - pem, err := generatePrivateHostKey(hostKeyBits) - if err != nil { - dlog.Server.FatalPanic("Failed to generate private server RSA host key", err) - } - if err := storePrivateHostKey(hostKeyPath, pem); err != nil { - dlog.Server.Error("Unable to write private server RSA host key to file", - hostKeyFile, err) + if err != nil { + // os.IsNotExist does not unwrap fmt.Errorf chains from RootedPath.Stat; use errors.Is. + if errors.Is(err, iofs.ErrNotExist) { + dlog.Server.Info("Generating private server RSA host key") + pem, genErr := generatePrivateHostKey(hostKeyBits) + if genErr != nil { + dlog.Server.FatalPanic("Failed to generate private server RSA host key", genErr) + } + if storeErr := storePrivateHostKey(hostKeyPath, pem); storeErr != nil { + dlog.Server.Error("Unable to write private server RSA host key to file", + hostKeyFile, storeErr) + } + return pem } - return pem + dlog.Server.FatalPanic("Cannot stat private server RSA host key path", hostKeyFile, err) } dlog.Server.Info("Reading private server RSA host key from file", hostKeyFile) diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index 3afbfba..df83bf6 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -1,7 +1,9 @@ package server import ( + "errors" "fmt" + iofs "io/fs" "os" goUser "os/user" "path/filepath" @@ -142,7 +144,7 @@ func findAuthorizedKeysPath(user *user.User, cacheDir, cwd string, if _, err = rootedAuthorizedKeysPath.Stat(); err == nil { return rootedAuthorizedKeysPath, nil } - if !os.IsNotExist(err) { + if !errors.Is(err, iofs.ErrNotExist) { return fs.RootedPath{}, err } |
