diff options
Diffstat (limited to 'ssh')
| -rw-r--r-- | ssh/client/authmethods.go | 45 | ||||
| -rw-r--r-- | ssh/client/hostkeycallback.go | 285 | ||||
| -rw-r--r-- | ssh/server/hostkey.go | 37 | ||||
| -rw-r--r-- | ssh/server/publickeycallback.go | 61 | ||||
| -rw-r--r-- | ssh/ssh.go | 112 |
5 files changed, 540 insertions, 0 deletions
diff --git a/ssh/client/authmethods.go b/ssh/client/authmethods.go new file mode 100644 index 0000000..84b7ce3 --- /dev/null +++ b/ssh/client/authmethods.go @@ -0,0 +1,45 @@ +package client + +import ( + "dtail/config" + "dtail/logger" + "dtail/ssh" + "os" + + gossh "golang.org/x/crypto/ssh" +) + +// InitSSHAuthMethods initialises all known SSH auth methods on othe client side. +func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) { + var sshAuthMethods []gossh.AuthMethod + + if config.Common.ExperimentalFeaturesEnable { + sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) + logger.Info("Added experimental method to list of auth methods") + } + + keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + keyPath = os.Getenv("HOME") + "/.ssh/id_dsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + if authMethod, err := ssh.Agent(); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added SSH Agent to list of auth methods") + } + + knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" + hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh) + if err != nil { + logger.FatalExit(knownHostsPath, err) + } + + return sshAuthMethods, hostKeyCallback +} diff --git a/ssh/client/hostkeycallback.go b/ssh/client/hostkeycallback.go new file mode 100644 index 0000000..7279f5e --- /dev/null +++ b/ssh/client/hostkeycallback.go @@ -0,0 +1,285 @@ +package client + +import ( + "bufio" + "dtail/logger" + "dtail/prompt" + "fmt" + "net" + "os" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type response int + +const ( + trustHost response = iota + dontTrustHost response = iota +) + +// Represents an unknown host. +type unknownHost struct { + server string + remote net.Addr + key ssh.PublicKey + hostLine string + ipLine string + responseCh chan response +} + +// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type HostKeyCallback struct { + knownHostsPath string + unknownCh chan unknownHost + throttleCh chan struct{} + trustAllHostsCh chan struct{} + untrustedHosts map[string]bool + mutex sync.Mutex +} + +// NewHostKeyCallback returns a new wrapper. +func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) { + // Ensure file exists + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + + h := HostKeyCallback{ + knownHostsPath: knownHostsPath, + unknownCh: make(chan unknownHost), + trustAllHostsCh: make(chan struct{}), + throttleCh: throttleCh, + untrustedHosts: make(map[string]bool), + } + + if trustAllHosts { + close(h.trustAllHostsCh) + } + + return &h, nil +} + +// Wrap the host key callback. +func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + // Parse known_hosts file + knownHostsCb, err := knownhosts.New(h.knownHostsPath) + if err != nil { + // Problem parsing it + return err + } + + // Check for valid entry in known_hosts file + err = knownHostsCb(server, remote, key) + if err == nil { + // OK + return nil + } + + // Make sure that interactive user callback does not interfere with + // SSH connection throttler. + <-h.throttleCh + defer func() { h.throttleCh <- struct{}{} }() + + unknown := unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response), + } + + logger.Warn("Encountered unknown host", unknown) + // Notify user that there is an unknown host + h.unknownCh <- unknown + + // Wait for user input. + switch <-unknown.responseCh { + case trustHost: + // End user acknowledged host key + return nil + case dontTrustHost: + } + + h.mutex.Lock() + defer h.mutex.Unlock() + h.untrustedHosts[server] = true + + return err + } +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { + var hosts []unknownHost + + for { + // Check whether there is a unknown host + select { + case unknown := <-h.unknownCh: + hosts = append(hosts, unknown) + // Ask every 50 unknown hosts + if len(hosts) >= 50 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-time.After(2 * time.Second): + // Or ask when after 2 seconds no new unknown hosts were added. + if len(hosts) > 0 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-stop: + logger.Debug("Stopping goroutine prompting new hosts...") + return + } + } +} + +func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) { + var servers []string + + for _, host := range hosts { + servers = append(servers, host.server) + } + + select { + case <-h.trustAllHostsCh: + logger.Warn("Trusting host keys of servers", servers) + h.trustHosts(hosts) + return + default: + } + + question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", + len(servers), + strings.Join(servers, ","), + "Do you want to trust these hosts?", + ) + + p := prompt.New(question) + + a := prompt.Answer{ + Long: "yes", + Short: "y", + Callback: func() { + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "all", + Short: "a", + Callback: func() { + close(h.trustAllHostsCh) + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "no", + Short: "n", + Callback: func() { + h.dontTrustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "details", + Short: "d", + AskAgain: true, + Callback: func() { + for _, unknown := range hosts { + fmt.Println(unknown.hostLine) + fmt.Println(unknown.ipLine) + } + }, + } + p.Add(a) + + p.Ask() +} + +func (h *HostKeyCallback) trustHosts(hosts []unknownHost) { + tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath) + + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + } + defer newFd.Close() + + // Newly trusted hosts in normalized form + addresses := make(map[string]struct{}) + + // First write to new known hosts file, and keep track of addresses + for _, unknown := range hosts { + unknown.responseCh <- trustHost + + // Add once as [HOSTNAME]:PORT + addresses[knownhosts.Normalize(unknown.server)] = struct{}{} + // And once as [IP]:PORT + addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} + + newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) + newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) + } + + // Read old known hosts file, to see which are old and new entries + os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + oldFd, err := os.Open(h.knownHostsPath) + if err != nil { + panic(err) + } + defer oldFd.Close() + + scanner := bufio.NewScanner(oldFd) + + // Now, append all still valid old entries to the new host file + for scanner.Scan() { + line := scanner.Text() + address := strings.SplitN(line, " ", 2)[0] + + if _, ok := addresses[address]; !ok { + newFd.WriteString(fmt.Sprintf("%s\n", line)) + } + } + + // Now, replace old known hosts file + if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil { + panic(err) + } +} + +func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) { + for _, unknown := range hosts { + unknown.responseCh <- dontTrustHost + } +} + +// Untrusted returns true if the host is not trusted. False otherwise. +func (h *HostKeyCallback) Untrusted(server string) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + _, ok := h.untrustedHosts[server] + + return ok +} diff --git a/ssh/server/hostkey.go b/ssh/server/hostkey.go new file mode 100644 index 0000000..ff1eb82 --- /dev/null +++ b/ssh/server/hostkey.go @@ -0,0 +1,37 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "dtail/ssh" + "io/ioutil" + "os" +) + +// PrivateHostKey retrieves the private server RSA host key. +func PrivateHostKey() []byte { + hostKeyFile := config.Server.HostKeyFile + _, err := os.Stat(hostKeyFile) + + if os.IsNotExist(err) { + logger.Info("Generating private server RSA host key") + privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits) + + if err != nil { + logger.FatalExit("Failed to generate private server RSA host key", err) + } + + pem := ssh.EncodePrivateKeyToPEM(privateKey) + if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil { + logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err) + } + return pem + } + + logger.Info("Reading private server RSA host key from file", hostKeyFile) + pem, err := ioutil.ReadFile(hostKeyFile) + if err != nil { + logger.FatalExit("Failed to load private server RSA host key", err) + } + return pem +} diff --git a/ssh/server/publickeycallback.go b/ssh/server/publickeycallback.go new file mode 100644 index 0000000..867f639 --- /dev/null +++ b/ssh/server/publickeycallback.go @@ -0,0 +1,61 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "dtail/server/user" + "fmt" + "io/ioutil" + "os" + osUser "os/user" + + gossh "golang.org/x/crypto/ssh" +) + +// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not. +func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) { + user := user.New(c.User(), c.RemoteAddr().String()) + logger.Info(user, "Incoming authorization") + + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error()) + } + + authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name) + if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) { + user, err := osUser.Lookup(user.Name) + if err != nil { + return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error()) + } + // Fallback to ~ + authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys" + } + + logger.Info(user, "Reading", authorizedKeysFile) + authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile) + if err != nil { + return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error()) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error()) + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + if authorizedKeysMap[string(pubKey.Marshal())] { + logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user) + return &gossh.Permissions{ + Extensions: map[string]string{ + "pubkey-fp": gossh.FingerprintSHA256(pubKey), + }, + }, nil + } + + return nil, fmt.Errorf("Unknown public key|%s", user) +} diff --git a/ssh/ssh.go b/ssh/ssh.go new file mode 100644 index 0000000..6cd28a2 --- /dev/null +++ b/ssh/ssh.go @@ -0,0 +1,112 @@ +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "dtail/logger" + "encoding/pem" + "fmt" + "io/ioutil" + "net" + "os" + "syscall" + + gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/terminal" +) + +// GeneratePrivateRSAKey is used by the server to generate its key. +func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, size) + if err != nil { + return nil, err + } + + err = privateKey.Validate() + if err != nil { + return nil, err + } + + return privateKey, nil +} + +// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format. +func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { + derFormat := x509.MarshalPKCS1PrivateKey(privateKey) + + block := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: derFormat, + } + + return pem.EncodeToMemory(&block) +} + +// Agent used for SSH auth. +func Agent() (gossh.AuthMethod, error) { + sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err != nil { + return nil, err + } + agentClient := agent.NewClient(sshAgent) + keys, err := agentClient.List() + if err != nil { + return nil, err + } + for i, key := range keys { + logger.Debug("Public key", i, key) + } + return gossh.PublicKeysCallback(agentClient.Signers), nil +} + +// EnterKeyPhrase is required to read phrase protected private keys. +func EnterKeyPhrase(keyFile string) []byte { + fmt.Printf("Enter phrase for key %s: ", keyFile) + phrase, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + panic(err) + } + fmt.Printf("%s\n", string(phrase)) + return phrase +} + +// KeyFile returns the key as a SSH auth method. +func KeyFile(keyFile string) (gossh.AuthMethod, error) { + buffer, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, err + } + + key, err := gossh.ParsePrivateKey(buffer) + if err != nil { + return nil, err + } + + // Key phrase support disabled as password will be printed to stdout! + /* + if err == nil { + return gossh.PublicKeys(key), nil + } + + keyPhrase := EnterKeyPhrase(keyFile) + key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase) + if err != nil { + return nil, err + } + */ + + return gossh.PublicKeys(key), nil +} + +// PrivateKey returns the private key as a SSH auth method. +func PrivateKey(keyFile string) (gossh.AuthMethod, error) { + signer, err := KeyFile(keyFile) + if err != nil { + logger.Debug(keyFile, err) + return nil, err + } + return gossh.AuthMethod(signer), nil +} |
