diff options
Diffstat (limited to 'internal/clients/connectors/serverconnection.go')
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 59 |
1 files changed, 58 insertions, 1 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 3c29ac0..ca1fc43 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -2,9 +2,11 @@ package connectors import ( "context" + "encoding/base64" "fmt" "io" "net" + "os" "strconv" "strings" "time" @@ -29,6 +31,7 @@ type ServerConnection struct { config *ssh.ClientConfig handler handlers.Handler commands []string + authKeyPath string hostKeyCallback client.HostKeyCallback throttlingDone bool } @@ -38,7 +41,7 @@ var _ Connector = (*ServerConnection)(nil) // NewServerConnection returns a new DTail SSH server connection. func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, - handler handlers.Handler, commands []string) *ServerConnection { + handler handlers.Handler, commands []string, authKeyPath string) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond @@ -51,6 +54,7 @@ func NewServerConnection(server string, userName string, server: server, handler: handler, commands: commands, + authKeyPath: resolveAuthKeyPath(authKeyPath), config: &ssh.ClientConfig{ User: userName, Auth: authMethods, @@ -224,6 +228,7 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc dlog.Client.Debug(err) } } + c.sendAuthKeyRegistrationCommand() if !c.throttlingDone { dlog.Client.Debug(c.server, "Unthrottling connection (2)", @@ -236,3 +241,55 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc c.handler.Shutdown() return nil } + +func resolveAuthKeyPath(authKeyPath string) string { + if strings.TrimSpace(authKeyPath) != "" { + return authKeyPath + } + return os.Getenv("HOME") + "/.ssh/id_rsa" +} + +func (c *ServerConnection) sendAuthKeyRegistrationCommand() { + authKeyPubPath := c.authKeyPath + ".pub" + authKeyPubBytes, err := os.ReadFile(authKeyPubPath) + if err != nil { + dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, unable to read public key", authKeyPubPath, err) + return + } + + authKeyBase64, err := extractAuthKeyBase64(authKeyPubBytes) + if err != nil { + dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, invalid public key file", authKeyPubPath, err) + return + } + + if err := c.handler.SendMessage("AUTHKEY " + authKeyBase64); err != nil { + dlog.Client.Debug(c.server, "Unable to send AUTHKEY registration command", err) + return + } + dlog.Client.Debug(c.server, "Sent AUTHKEY registration command", authKeyPubPath) +} + +func extractAuthKeyBase64(authKeyPubBytes []byte) (string, error) { + authKeyPubContent := string(authKeyPubBytes) + for _, line := range strings.Split(authKeyPubContent, "\n") { + trimmedLine := strings.TrimSpace(line) + if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") { + continue + } + + fields := strings.Fields(trimmedLine) + if len(fields) < 2 { + return "", fmt.Errorf("expected authorized key format '<type> <base64-key> [comment]'") + } + + authKeyBase64 := strings.TrimSpace(fields[1]) + if _, err := base64.StdEncoding.DecodeString(authKeyBase64); err != nil { + return "", fmt.Errorf("invalid base64 public key: %w", err) + } + + return authKeyBase64, nil + } + + return "", fmt.Errorf("no public key found") +} |
