diff options
| author | Paul Buetow <paul@buetow.org> | 2026-04-10 18:03:29 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-04-10 18:03:29 +0300 |
| commit | 28f6319b77d35c6da6b99ad7e35d0d5602dc2ee6 (patch) | |
| tree | 687b2c38755a087694cacacb73cd73b8ef244ce7 /internal | |
| parent | 13b21feb07c86f65760f7338f284f3b492364cd9 (diff) | |
Fix known-hosts trust deadlock, host key stat, and optional nozstd build
- stdout logger: release mutex while waiting on pause resume so prompt
callbacks can log (fixes hang after trusting new hosts; known_hosts
was written but Resume never ran).
- known hosts callback: stop borrowing the SSH dial throttle channel
(could block or interact badly with parallel handshakes).
- host key path: use errors.Is(..., fs.ErrNotExist) for RootedPath.Stat
wrapped errors; stat errors now fail fast instead of mis-read.
- public key path: same ErrNotExist check for authorized_keys miss.
- Build: optional DTAIL_NO_ZSTD=yes / nozstd tag for CGO-free builds;
split zstd readers into tagged files.
- Docs/examples: firewalld note for port 2222, log prune timer+script,
SSHBindAddress note, dserver unit disabled-by-default comment;
firewalld helper script example.
- Regression test for stdout pause/mutex behavior.
Made-with: Cursor
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 2 | ||||
| -rw-r--r-- | internal/io/dlog/loggers/stdout.go | 9 | ||||
| -rw-r--r-- | internal/io/dlog/loggers/stdout_test.go | 36 | ||||
| -rw-r--r-- | internal/io/fs/readfile.go | 7 | ||||
| -rw-r--r-- | internal/io/fs/readfile_nozstd.go | 16 | ||||
| -rw-r--r-- | internal/io/fs/readfile_zstd.go | 20 | ||||
| -rw-r--r-- | internal/ssh/client/authmethods.go | 8 | ||||
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 9 | ||||
| -rw-r--r-- | internal/ssh/client/knownhostscallback_test.go | 5 | ||||
| -rw-r--r-- | internal/ssh/server/hostkey.go | 28 | ||||
| -rw-r--r-- | internal/ssh/server/publickeycallback.go | 4 |
11 files changed, 105 insertions, 39 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 71b8d02..bc7c2f1 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -73,7 +73,7 @@ func (c *baseClient) init() { } c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods( c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, - c.throttleCh, c.Args.SSHPrivateKeyFilePath, c.Args.SSHAgentKeyIndex) + c.Args.SSHPrivateKeyFilePath, c.Args.SSHAgentKeyIndex) } func (c *baseClient) makeConnections(maker maker) { diff --git a/internal/io/dlog/loggers/stdout.go b/internal/io/dlog/loggers/stdout.go index a2575c8..b4e695a 100644 --- a/internal/io/dlog/loggers/stdout.go +++ b/internal/io/dlog/loggers/stdout.go @@ -44,14 +44,17 @@ func (s *stdout) RawWithColors(now time.Time, message, coloredMessage string) { func (s *stdout) log(message string, nl bool) { s.mutex.Lock() - defer s.mutex.Unlock() - select { case <-s.pauseCh: - // Pause until resumed. + // Wait for Resume without holding the mutex: the prompt path calls + // dlog after the user answers while Pause is still active; holding the + // mutex here would deadlock (Info blocks on Lock, Resume never runs). + s.mutex.Unlock() <-s.resumeCh + s.mutex.Lock() default: } + defer s.mutex.Unlock() if nl { fmt.Println(message) diff --git a/internal/io/dlog/loggers/stdout_test.go b/internal/io/dlog/loggers/stdout_test.go new file mode 100644 index 0000000..4f70efc --- /dev/null +++ b/internal/io/dlog/loggers/stdout_test.go @@ -0,0 +1,36 @@ +package loggers + +import ( + "testing" + "time" +) + +// Regression: during an interactive prompt, dlog.Common.Pause() unblocks when some +// goroutine hits stdout.log(); that goroutine must not hold the stdout mutex while +// waiting on resume, or dlog.Client.Info from the prompt callback deadlocks forever. +func TestStdoutSecondLogDuringPauseWaitDoesNotDeadlock(t *testing.T) { + s := newStdout() + + go s.Pause() + time.Sleep(50 * time.Millisecond) + + go func() { + s.Log(time.Now(), "first log consumes pause and waits on resume") + }() + time.Sleep(50 * time.Millisecond) + + secondDone := make(chan struct{}) + go func() { + s.Log(time.Now(), "second log must acquire mutex while first waits for Resume") + close(secondDone) + }() + + select { + case <-secondDone: + case <-time.After(2 * time.Second): + t.Fatal("deadlock: second Log blocked on mutex while first waits for Resume") + } + + s.Resume() + time.Sleep(50 * time.Millisecond) +} diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go index d305c4d..5241556 100644 --- a/internal/io/fs/readfile.go +++ b/internal/io/fs/readfile.go @@ -19,8 +19,6 @@ import ( "github.com/mimecast/dtail/internal/io/pool" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/regex" - - "github.com/DataDog/zstd" ) type readStatus int @@ -193,10 +191,7 @@ func (f *readFile) makeCompressedFileReader(fd *os.File) (reader *bufio.Reader, decompressor = gzipReader reader = bufio.NewReader(gzipReader) case strings.HasSuffix(f.FilePath(), ".zst"): - dlog.Common.Info(f.FilePath(), "Detected zstd compression format") - zstdReader := zstd.NewReader(fd) - decompressor = zstdReader - reader = bufio.NewReader(zstdReader) + return f.makeZstdReader(fd) default: reader = bufio.NewReader(fd) } diff --git a/internal/io/fs/readfile_nozstd.go b/internal/io/fs/readfile_nozstd.go new file mode 100644 index 0000000..afd4523 --- /dev/null +++ b/internal/io/fs/readfile_nozstd.go @@ -0,0 +1,16 @@ +//go:build nozstd + +package fs + +import ( + "bufio" + "fmt" + "io" + "os" +) + +func (f *readFile) makeZstdReader(fd *os.File) (reader *bufio.Reader, decompressor io.Closer, err error) { + _ = fd + err = fmt.Errorf("%s: zstd is not supported in this build (built with -tags nozstd)", f.FilePath()) + return +} diff --git a/internal/io/fs/readfile_zstd.go b/internal/io/fs/readfile_zstd.go new file mode 100644 index 0000000..a7e479b --- /dev/null +++ b/internal/io/fs/readfile_zstd.go @@ -0,0 +1,20 @@ +//go:build !nozstd + +package fs + +import ( + "bufio" + "io" + "os" + + "github.com/DataDog/zstd" + "github.com/mimecast/dtail/internal/io/dlog" +) + +func (f *readFile) makeZstdReader(fd *os.File) (reader *bufio.Reader, decompressor io.Closer, err error) { + dlog.Common.Info(f.FilePath(), "Detected zstd compression format") + zstdReader := zstd.NewReader(fd) + decompressor = zstdReader + reader = bufio.NewReader(zstdReader) + return +} diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 7ac4d0c..3cd1bb3 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -18,7 +18,7 @@ var ( // InitSSHAuthMethods initialises all known SSH auth methods on the client side. func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, - hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, + hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { if len(sshAuthMethods) > 0 { @@ -28,10 +28,10 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, } return sshAuthMethods, simpleCallback } - return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex) + return initKnownHostsAuthMethods(trustAllHosts, privateKeyPath, agentKeyIndex) } -func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, +func initKnownHostsAuthMethods(trustAllHosts bool, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME")) @@ -40,7 +40,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, knownHostsFile = "./known_hosts" } - knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts, throttleCh) + knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts) if err != nil { dlog.Client.FatalPanic(knownHostsFile, err) } diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index 174f6aa..da1b29b 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -41,7 +41,6 @@ type KnownHostsCallback struct { knownHostsPath string knownHostsFile fs.RootedPath unknownCh chan unknownHost - throttleCh chan struct{} trustAllHostsCh chan struct{} untrustedHosts map[string]bool mutex *sync.Mutex @@ -50,8 +49,7 @@ type KnownHostsCallback struct { var _ HostKeyCallback = (*KnownHostsCallback)(nil) // NewKnownHostsCallback returns a new wrapper. -func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, - throttleCh chan struct{}) (HostKeyCallback, error) { +func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool) (HostKeyCallback, error) { knownHostsFile, err := fs.NewRootedPath(knownHostsPath) if err != nil { @@ -65,7 +63,6 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, knownHostsFile: knownHostsFile, unknownCh: make(chan unknownHost), trustAllHostsCh: make(chan struct{}), - throttleCh: throttleCh, untrustedHosts: untrustedHosts, mutex: &sync.Mutex{}, } @@ -103,10 +100,6 @@ func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback { // OK return nil } - // Make sure that interactive user callback does not interfere with - // SSH connection throttler. - <-c.throttleCh - defer func() { c.throttleCh <- struct{}{} }() unknown := unknownHost{ server: server, diff --git a/internal/ssh/client/knownhostscallback_test.go b/internal/ssh/client/knownhostscallback_test.go index 596aea8..1765598 100644 --- a/internal/ssh/client/knownhostscallback_test.go +++ b/internal/ssh/client/knownhostscallback_test.go @@ -112,10 +112,7 @@ func TestTrustHostsRejectsEscapingKnownHostsSymlink(t *testing.T) { func testKnownHostsCallback(t *testing.T, knownHostsPath string) *KnownHostsCallback { t.Helper() - throttleCh := make(chan struct{}, 1) - throttleCh <- struct{}{} - - callback, err := NewKnownHostsCallback(knownHostsPath, false, throttleCh) + callback, err := NewKnownHostsCallback(knownHostsPath, false) if err != nil { t.Fatalf("NewKnownHostsCallback failed: %v", err) } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 1df2287..1315351 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -1,7 +1,8 @@ package server import ( - "os" + "errors" + iofs "io/fs" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" @@ -31,18 +32,21 @@ func PrivateHostKey(hostKeyFile string, hostKeyBits int) []byte { } _, err = hostKeyPath.Stat() - - if os.IsNotExist(err) { - dlog.Server.Info("Generating private server RSA host key") - pem, err := generatePrivateHostKey(hostKeyBits) - if err != nil { - dlog.Server.FatalPanic("Failed to generate private server RSA host key", err) - } - if err := storePrivateHostKey(hostKeyPath, pem); err != nil { - dlog.Server.Error("Unable to write private server RSA host key to file", - hostKeyFile, err) + if err != nil { + // os.IsNotExist does not unwrap fmt.Errorf chains from RootedPath.Stat; use errors.Is. + if errors.Is(err, iofs.ErrNotExist) { + dlog.Server.Info("Generating private server RSA host key") + pem, genErr := generatePrivateHostKey(hostKeyBits) + if genErr != nil { + dlog.Server.FatalPanic("Failed to generate private server RSA host key", genErr) + } + if storeErr := storePrivateHostKey(hostKeyPath, pem); storeErr != nil { + dlog.Server.Error("Unable to write private server RSA host key to file", + hostKeyFile, storeErr) + } + return pem } - return pem + dlog.Server.FatalPanic("Cannot stat private server RSA host key path", hostKeyFile, err) } dlog.Server.Info("Reading private server RSA host key from file", hostKeyFile) diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index 3afbfba..df83bf6 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -1,7 +1,9 @@ package server import ( + "errors" "fmt" + iofs "io/fs" "os" goUser "os/user" "path/filepath" @@ -142,7 +144,7 @@ func findAuthorizedKeysPath(user *user.User, cacheDir, cwd string, if _, err = rootedAuthorizedKeysPath.Stat(); err == nil { return rootedAuthorizedKeysPath, nil } - if !os.IsNotExist(err) { + if !errors.Is(err, iofs.ErrNotExist) { return fs.RootedPath{}, err } |
