summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2021-09-12 19:04:42 +0300
committerPaul Buetow <paul@buetow.org>2021-10-02 12:26:29 +0300
commit2ebe7e9d63ba62c6f19749c39fe0a577d86ca775 (patch)
tree2ae6d11a3cbc82152085a9d7755adef436b3ce46 /internal/server
parent842fd5800000bb68d6306a9ecad80a98ed762a2f (diff)
bugfix: dmap skipped the last couple of mapreduce lines
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/readcommand.go15
-rw-r--r--internal/server/handlers/serverhandler.go86
2 files changed, 40 insertions, 61 deletions
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index 5bab26f..69dd4a5 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/io/fs"
+ "github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/regex"
@@ -113,16 +114,20 @@ func (r *readCommand) readFile(ctx context.Context, path, globID string, re rege
}
lines := r.server.lines
-
- // Plug in mappreduce engine
- if r.server.aggregate != nil {
- lines = r.server.aggregate.Lines
- }
+ aggregate := r.server.aggregate
for {
+ if aggregate != nil {
+ lines = make(chan line.Line, 100)
+ aggregate.NextLinesCh <- lines
+ }
if err := reader.Start(ctx, lines, re); err != nil {
logger.Error(r.server.user, path, globID, err)
}
+ if aggregate != nil {
+ // Also makes aggregate to Flush
+ close(lines)
+ }
select {
case <-ctx.Done():
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index e74e686..ed19412 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -32,36 +32,35 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
- quiet bool
- readBuf bytes.Buffer
- writeBuf bytes.Buffer
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
}
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ user: user,
}
fqdn, err := os.Hostname()
@@ -108,7 +107,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
h.readBuf.WriteByte(protocol.MessageDelimiter)
n = copy(p, h.readBuf.Bytes())
- case message := <-h.aggregatedMessages:
+ case message := <-h.maprMessages:
// Send mapreduce-aggregated data as a message.
h.readBuf.WriteString("AGGREGATE")
h.readBuf.WriteString(protocol.FieldDelimiter)
@@ -260,14 +259,6 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
h.shutdown()
}
}
- readerFinished := func() {
- if h.decrementActiveReaders() == 0 {
- if h.aggregate == nil {
- return
- }
- h.aggregate.Shutdown()
- }
- }
splitted := strings.Split(args[0], ":")
commandName := splitted[0]
@@ -289,18 +280,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
- h.incrementActiveReaders()
command.Start(ctx, argc, args, 1)
- readerFinished()
commandFinished()
}()
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
- h.incrementActiveReaders()
command.Start(ctx, argc, args, 10)
- readerFinished()
commandFinished()
}()
@@ -315,7 +302,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
h.aggregate = aggregate
go func() {
- command.Start(ctx, h.aggregatedMessages)
+ command.Start(ctx, h.maprMessages)
commandFinished()
}()
@@ -361,15 +348,11 @@ func (h *ServerHandler) serverMessageC() chan<- string {
return h.serverMessages
}
-func (h *ServerHandler) flush() {
- logger.Debug(h.user, "flush()")
-
- if h.aggregate != nil {
- h.aggregate.Flush()
- }
+func (h *ServerHandler) flushMessages() {
+ logger.Debug(h.user, "flushMessages()")
unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
}
for i := 0; i < 3; i++ {
if unsentMessages() == 0 {
@@ -385,7 +368,7 @@ func (h *ServerHandler) flush() {
func (h *ServerHandler) shutdown() {
logger.Debug(h.user, "shutdown()")
- h.flush()
+ h.flushMessages()
go func() {
select {
@@ -413,15 +396,6 @@ func (h *ServerHandler) decrementActiveCommands() int32 {
return atomic.LoadInt32(&h.activeCommands)
}
-func (h *ServerHandler) incrementActiveReaders() {
- atomic.AddInt32(&h.activeReaders, 1)
-}
-
-func (h *ServerHandler) decrementActiveReaders() int32 {
- atomic.AddInt32(&h.activeReaders, -1)
- return atomic.LoadInt32(&h.activeReaders)
-}
-
func readOptions(opts []string) (map[string]string, error) {
options := make(map[string]string, len(opts))