summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/authkeycommand_test.go8
-rw-r--r--internal/server/handlers/serverhandler.go32
-rw-r--r--internal/server/server.go21
3 files changed, 43 insertions, 18 deletions
diff --git a/internal/server/handlers/authkeycommand_test.go b/internal/server/handlers/authkeycommand_test.go
index f510038..a454e94 100644
--- a/internal/server/handlers/authkeycommand_test.go
+++ b/internal/server/handlers/authkeycommand_test.go
@@ -33,11 +33,10 @@ func TestHandleAuthKeyCommandSuccess(t *testing.T) {
if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY OK\n" {
t.Fatalf("Unexpected response: %q", message)
}
- if !sshserver.AuthKeys().Has(handler.user.Name, key) {
+ if !handler.authKeyStore.Has(handler.user.Name, key) {
t.Fatalf("Expected key to be stored for user")
}
-
- sshserver.AuthKeys().Remove(handler.user.Name, key)
+ handler.authKeyStore.Remove(handler.user.Name, key)
}
func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) {
@@ -51,7 +50,7 @@ func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) {
if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR feature disabled\n" {
t.Fatalf("Unexpected response: %q", message)
}
- if sshserver.AuthKeys().Has(handler.user.Name, key) {
+ if handler.authKeyStore.Has(handler.user.Name, key) {
t.Fatalf("Expected no key to be stored while feature is disabled")
}
}
@@ -84,6 +83,7 @@ func newAuthKeyTestHandler(userName string, authKeyEnabled bool) *ServerHandler
serverCfg: &config.ServerConfig{
AuthKeyEnabled: authKeyEnabled,
},
+ authKeyStore: sshserver.NewAuthKeyStore(time.Hour, 5),
}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 078fd27..732cc06 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -23,11 +23,12 @@ import (
// This handler implements the handler of the SSH server.
type ServerHandler struct {
baseHandler
- catLimiter chan struct{}
- tailLimiter chan struct{}
- serverCfg *config.ServerConfig
- regex string
- commands map[string]commandHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ serverCfg *config.ServerConfig
+ authKeyStore *sshserver.AuthKeyStore
+ regex string
+ commands map[string]commandHandler
// Track pending files waiting for limiter slots
pendingFiles int32
}
@@ -38,7 +39,8 @@ var _ Handler = (*ServerHandler)(nil)
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter,
- tailLimiter chan struct{}, serverCfg *config.ServerConfig) *ServerHandler {
+ tailLimiter chan struct{}, serverCfg *config.ServerConfig,
+ authKeyStore *sshserver.AuthKeyStore) *ServerHandler {
dlog.Server.Debug(user, "Creating new server handler")
if serverCfg == nil {
@@ -55,10 +57,14 @@ func NewServerHandler(user *user.User, catLimiter,
user: user,
codec: newProtocolCodec(user),
},
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- serverCfg: serverCfg,
- regex: ".",
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ serverCfg: serverCfg,
+ authKeyStore: authKeyStore,
+ regex: ".",
+ }
+ if h.authKeyStore == nil {
+ h.authKeyStore = sshserver.AuthKeys()
}
h.handleCommandCb = h.handleUserCommand
h.commands = h.newCommandRegistry()
@@ -180,6 +186,10 @@ func (h *ServerHandler) handleAuthKeyCommand(_ context.Context, _ lcontext.LCont
return
}
- sshserver.AuthKeys().Add(h.user.Name, pubKey)
+ if h.authKeyStore == nil {
+ h.sendln(h.serverMessages, "AUTHKEY ERR internal key store unavailable")
+ return
+ }
+ h.authKeyStore.Add(h.user.Name, pubKey)
h.sendln(h.serverMessages, "AUTHKEY OK")
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 943defa..72094ef 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -37,6 +37,8 @@ type Server struct {
cont *continuous
// Authentication strategies keyed by SSH username.
authStrategies map[string]authStrategy
+ // In-memory auth key cache for fast reconnect.
+ authKeyStore *server.AuthKeyStore
}
type authStrategy func(*user.User, string, string) bool
@@ -48,7 +50,6 @@ func New(cfg config.RuntimeConfig) *Server {
}
dlog.Server.Info("Starting server", version.String())
- server.ConfigureAuthKeyStore(cfg.Server.AuthKeyTTLSeconds, cfg.Server.AuthKeyMaxPerUser)
s := Server{
cfg: cfg,
@@ -64,11 +65,19 @@ func New(cfg config.RuntimeConfig) *Server {
tailLimiter: make(chan struct{}, cfg.Server.MaxConcurrentTails),
sched: newScheduler(cfg),
cont: newContinuous(cfg),
+ authKeyStore: server.NewAuthKeyStore(
+ time.Duration(cfg.Server.AuthKeyTTLSeconds)*time.Second,
+ cfg.Server.AuthKeyMaxPerUser,
+ ),
}
s.authStrategies = s.newAuthStrategies()
s.sshServerConfig.PasswordCallback = s.Callback
- s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback
+ s.sshServerConfig.PublicKeyCallback = server.NewPublicKeyCallback(
+ cfg.Server.AuthKeyEnabled,
+ cfg.Common.CacheDir,
+ s.authKeyStore,
+ )
private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
if err != nil {
@@ -222,7 +231,13 @@ func (s *Server) handleShellRequest(ctx context.Context, sshConn gossh.Conn,
case config.HealthUser:
handler = handlers.NewHealthHandler(user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.cfg.Server)
+ handler = handlers.NewServerHandler(
+ user,
+ s.catLimiter,
+ s.tailLimiter,
+ s.cfg.Server,
+ s.authKeyStore,
+ )
}
terminate := func() {