diff options
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/handlers/authkeycommand_test.go | 8 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 32 | ||||
| -rw-r--r-- | internal/server/server.go | 21 |
3 files changed, 43 insertions, 18 deletions
diff --git a/internal/server/handlers/authkeycommand_test.go b/internal/server/handlers/authkeycommand_test.go index f510038..a454e94 100644 --- a/internal/server/handlers/authkeycommand_test.go +++ b/internal/server/handlers/authkeycommand_test.go @@ -33,11 +33,10 @@ func TestHandleAuthKeyCommandSuccess(t *testing.T) { if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY OK\n" { t.Fatalf("Unexpected response: %q", message) } - if !sshserver.AuthKeys().Has(handler.user.Name, key) { + if !handler.authKeyStore.Has(handler.user.Name, key) { t.Fatalf("Expected key to be stored for user") } - - sshserver.AuthKeys().Remove(handler.user.Name, key) + handler.authKeyStore.Remove(handler.user.Name, key) } func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) { @@ -51,7 +50,7 @@ func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) { if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR feature disabled\n" { t.Fatalf("Unexpected response: %q", message) } - if sshserver.AuthKeys().Has(handler.user.Name, key) { + if handler.authKeyStore.Has(handler.user.Name, key) { t.Fatalf("Expected no key to be stored while feature is disabled") } } @@ -84,6 +83,7 @@ func newAuthKeyTestHandler(userName string, authKeyEnabled bool) *ServerHandler serverCfg: &config.ServerConfig{ AuthKeyEnabled: authKeyEnabled, }, + authKeyStore: sshserver.NewAuthKeyStore(time.Hour, 5), } } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 078fd27..732cc06 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -23,11 +23,12 @@ import ( // This handler implements the handler of the SSH server. type ServerHandler struct { baseHandler - catLimiter chan struct{} - tailLimiter chan struct{} - serverCfg *config.ServerConfig - regex string - commands map[string]commandHandler + catLimiter chan struct{} + tailLimiter chan struct{} + serverCfg *config.ServerConfig + authKeyStore *sshserver.AuthKeyStore + regex string + commands map[string]commandHandler // Track pending files waiting for limiter slots pendingFiles int32 } @@ -38,7 +39,8 @@ var _ Handler = (*ServerHandler)(nil) // NewServerHandler returns the server handler. func NewServerHandler(user *user.User, catLimiter, - tailLimiter chan struct{}, serverCfg *config.ServerConfig) *ServerHandler { + tailLimiter chan struct{}, serverCfg *config.ServerConfig, + authKeyStore *sshserver.AuthKeyStore) *ServerHandler { dlog.Server.Debug(user, "Creating new server handler") if serverCfg == nil { @@ -55,10 +57,14 @@ func NewServerHandler(user *user.User, catLimiter, user: user, codec: newProtocolCodec(user), }, - catLimiter: catLimiter, - tailLimiter: tailLimiter, - serverCfg: serverCfg, - regex: ".", + catLimiter: catLimiter, + tailLimiter: tailLimiter, + serverCfg: serverCfg, + authKeyStore: authKeyStore, + regex: ".", + } + if h.authKeyStore == nil { + h.authKeyStore = sshserver.AuthKeys() } h.handleCommandCb = h.handleUserCommand h.commands = h.newCommandRegistry() @@ -180,6 +186,10 @@ func (h *ServerHandler) handleAuthKeyCommand(_ context.Context, _ lcontext.LCont return } - sshserver.AuthKeys().Add(h.user.Name, pubKey) + if h.authKeyStore == nil { + h.sendln(h.serverMessages, "AUTHKEY ERR internal key store unavailable") + return + } + h.authKeyStore.Add(h.user.Name, pubKey) h.sendln(h.serverMessages, "AUTHKEY OK") } diff --git a/internal/server/server.go b/internal/server/server.go index 943defa..72094ef 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -37,6 +37,8 @@ type Server struct { cont *continuous // Authentication strategies keyed by SSH username. authStrategies map[string]authStrategy + // In-memory auth key cache for fast reconnect. + authKeyStore *server.AuthKeyStore } type authStrategy func(*user.User, string, string) bool @@ -48,7 +50,6 @@ func New(cfg config.RuntimeConfig) *Server { } dlog.Server.Info("Starting server", version.String()) - server.ConfigureAuthKeyStore(cfg.Server.AuthKeyTTLSeconds, cfg.Server.AuthKeyMaxPerUser) s := Server{ cfg: cfg, @@ -64,11 +65,19 @@ func New(cfg config.RuntimeConfig) *Server { tailLimiter: make(chan struct{}, cfg.Server.MaxConcurrentTails), sched: newScheduler(cfg), cont: newContinuous(cfg), + authKeyStore: server.NewAuthKeyStore( + time.Duration(cfg.Server.AuthKeyTTLSeconds)*time.Second, + cfg.Server.AuthKeyMaxPerUser, + ), } s.authStrategies = s.newAuthStrategies() s.sshServerConfig.PasswordCallback = s.Callback - s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback + s.sshServerConfig.PublicKeyCallback = server.NewPublicKeyCallback( + cfg.Server.AuthKeyEnabled, + cfg.Common.CacheDir, + s.authKeyStore, + ) private, err := gossh.ParsePrivateKey(server.PrivateHostKey()) if err != nil { @@ -222,7 +231,13 @@ func (s *Server) handleShellRequest(ctx context.Context, sshConn gossh.Conn, case config.HealthUser: handler = handlers.NewHealthHandler(user) default: - handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.cfg.Server) + handler = handlers.NewServerHandler( + user, + s.catLimiter, + s.tailLimiter, + s.cfg.Server, + s.authKeyStore, + ) } terminate := func() { |
