summaryrefslogtreecommitdiff
path: root/internal/clients/connectors/serverconnection.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/connectors/serverconnection.go')
-rw-r--r--internal/clients/connectors/serverconnection.go59
1 files changed, 58 insertions, 1 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 3c29ac0..ca1fc43 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -2,9 +2,11 @@ package connectors
import (
"context"
+ "encoding/base64"
"fmt"
"io"
"net"
+ "os"
"strconv"
"strings"
"time"
@@ -29,6 +31,7 @@ type ServerConnection struct {
config *ssh.ClientConfig
handler handlers.Handler
commands []string
+ authKeyPath string
hostKeyCallback client.HostKeyCallback
throttlingDone bool
}
@@ -38,7 +41,7 @@ var _ Connector = (*ServerConnection)(nil)
// NewServerConnection returns a new DTail SSH server connection.
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
- handler handlers.Handler, commands []string) *ServerConnection {
+ handler handlers.Handler, commands []string, authKeyPath string) *ServerConnection {
dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond
@@ -51,6 +54,7 @@ func NewServerConnection(server string, userName string,
server: server,
handler: handler,
commands: commands,
+ authKeyPath: resolveAuthKeyPath(authKeyPath),
config: &ssh.ClientConfig{
User: userName,
Auth: authMethods,
@@ -224,6 +228,7 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
dlog.Client.Debug(err)
}
}
+ c.sendAuthKeyRegistrationCommand()
if !c.throttlingDone {
dlog.Client.Debug(c.server, "Unthrottling connection (2)",
@@ -236,3 +241,55 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
c.handler.Shutdown()
return nil
}
+
+func resolveAuthKeyPath(authKeyPath string) string {
+ if strings.TrimSpace(authKeyPath) != "" {
+ return authKeyPath
+ }
+ return os.Getenv("HOME") + "/.ssh/id_rsa"
+}
+
+func (c *ServerConnection) sendAuthKeyRegistrationCommand() {
+ authKeyPubPath := c.authKeyPath + ".pub"
+ authKeyPubBytes, err := os.ReadFile(authKeyPubPath)
+ if err != nil {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, unable to read public key", authKeyPubPath, err)
+ return
+ }
+
+ authKeyBase64, err := extractAuthKeyBase64(authKeyPubBytes)
+ if err != nil {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, invalid public key file", authKeyPubPath, err)
+ return
+ }
+
+ if err := c.handler.SendMessage("AUTHKEY " + authKeyBase64); err != nil {
+ dlog.Client.Debug(c.server, "Unable to send AUTHKEY registration command", err)
+ return
+ }
+ dlog.Client.Debug(c.server, "Sent AUTHKEY registration command", authKeyPubPath)
+}
+
+func extractAuthKeyBase64(authKeyPubBytes []byte) (string, error) {
+ authKeyPubContent := string(authKeyPubBytes)
+ for _, line := range strings.Split(authKeyPubContent, "\n") {
+ trimmedLine := strings.TrimSpace(line)
+ if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") {
+ continue
+ }
+
+ fields := strings.Fields(trimmedLine)
+ if len(fields) < 2 {
+ return "", fmt.Errorf("expected authorized key format '<type> <base64-key> [comment]'")
+ }
+
+ authKeyBase64 := strings.TrimSpace(fields[1])
+ if _, err := base64.StdEncoding.DecodeString(authKeyBase64); err != nil {
+ return "", fmt.Errorf("invalid base64 public key: %w", err)
+ }
+
+ return authKeyBase64, nil
+ }
+
+ return "", fmt.Errorf("no public key found")
+}