diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 20:52:54 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 20:52:54 +0200 |
| commit | 1b34e1f2501b8def0a0fb4eae28bf6c19a8adde2 (patch) | |
| tree | 4898ab4ff4a7dd4ea102726a845e3935c39ee320 /internal/clients | |
| parent | 07d654f76e1002b6ac18a43aab3c64797dcd2a32 (diff) | |
Fix serverless output draining regressions
Diffstat (limited to 'internal/clients')
| -rw-r--r-- | internal/clients/connectors/serverless.go | 26 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 4 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler_test.go | 32 | ||||
| -rw-r--r-- | internal/clients/session_spec_test.go | 14 |
4 files changed, 75 insertions, 1 deletions
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index 72e3fda..4e4d57e 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -3,6 +3,7 @@ package connectors import ( "context" "io" + "sync" "time" "github.com/mimecast/dtail/internal/clients/handlers" @@ -76,13 +77,16 @@ func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { dlog.Client.Debug("Starting serverless connector") + done := make(chan struct{}) go func() { + defer close(done) defer cancel() if err := s.handle(ctx, cancel); err != nil { dlog.Client.Warn(err) } }() <-ctx.Done() + <-done } func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error { @@ -111,9 +115,12 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Error tracking errChan := make(chan error, 4) + var ioWg sync.WaitGroup // Read from client handler + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(toServer) buf := make([]byte, 32*1024) for { @@ -137,7 +144,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro }() // Write to server handler + ioWg.Add(1) go func() { + defer ioWg.Done() for data := range toServer { if _, err := serverHandler.Write(data); err != nil { errChan <- err @@ -147,7 +156,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro }() // Read from server handler + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(fromServer) buf := make([]byte, 64*1024) // Larger buffer for server responses for { @@ -172,7 +183,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Write to client handler serverDone := make(chan struct{}) + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(serverDone) for data := range fromServer { if _, err := s.handler.Write(data); err != nil { @@ -192,6 +205,18 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro select { case <-s.handler.Done(): dlog.Client.Trace("<-s.handler.Done()") + // The client handler marks itself done as soon as it receives the + // hidden close message. Keep the in-process server alive long enough + // for the remaining output and close ACK to drain instead of canceling + // the whole session immediately. + select { + case <-serverDone: + dlog.Client.Trace("Server transfer done after client close") + case <-ctx.Done(): + dlog.Client.Trace("<-ctx.Done() while waiting for server transfer") + case <-time.After(6 * time.Second): + dlog.Client.Debug("Timed out waiting for server transfer after client close") + } case <-serverDone: dlog.Client.Trace("Server transfer done") case <-ctx.Done(): @@ -201,6 +226,7 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Wait for completion <-ctx.Done() + ioWg.Wait() // Check for errors select { diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 8da4556..2979091 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -192,7 +192,9 @@ func (h *baseHandler) handleHiddenMessage(message string) { strings.HasPrefix(message, protocol.HiddenSessionErrorPrefix): h.handleSessionAckMessage(message) case strings.HasPrefix(message, ".syn close connection"): - go h.SendMessage(".ack close connection") + if err := h.SendMessage(".ack close connection"); err != nil { + dlog.Client.Debug(h.server, "Unable to acknowledge close connection", err) + } h.Shutdown() } } diff --git a/internal/clients/handlers/basehandler_test.go b/internal/clients/handlers/basehandler_test.go index 7db2bb8..3e8aaa1 100644 --- a/internal/clients/handlers/basehandler_test.go +++ b/internal/clients/handlers/basehandler_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/protocol" ) @@ -172,3 +173,34 @@ func TestHandleSessionAckMessage(t *testing.T) { t.Fatalf("unexpected session ack: %#v", ack) } } + +func TestHandleCloseConnectionAcknowledgesBeforeShutdown(t *testing.T) { + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) + + handler := baseHandler{ + done: internal.NewDone(), + server: "server-under-test", + commands: make(chan string, 1), + } + + handler.handleHiddenMessage(".syn close connection") + + select { + case command := <-handler.commands: + if command == "" { + t.Fatal("expected close acknowledgement command") + } + case <-time.After(10 * time.Millisecond): + t.Fatal("expected close acknowledgement command to be queued") + } + + select { + case <-handler.Done(): + default: + t.Fatal("expected handler to be shut down after close acknowledgement") + } +} diff --git a/internal/clients/session_spec_test.go b/internal/clients/session_spec_test.go index aa3c45d..8133bc9 100644 --- a/internal/clients/session_spec_test.go +++ b/internal/clients/session_spec_test.go @@ -131,3 +131,17 @@ func TestNewSessionSpecSplitsFiles(t *testing.T) { t.Fatalf("unexpected timeout: %d", spec.Timeout) } } + +func TestNewSessionSpecUsesPipeSentinelForServerlessStdin(t *testing.T) { + t.Parallel() + + spec := NewSessionSpec(config.Args{ + Mode: omode.GrepClient, + Serverless: true, + RegexStr: "ERROR", + }) + + if len(spec.Files) != 1 || spec.Files[0] != "-" { + t.Fatalf("unexpected files for serverless stdin: %#v", spec.Files) + } +} |
