summaryrefslogtreecommitdiff
path: root/ssh
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-09 20:30:15 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-01-09 20:30:15 +0000
commit3755a9911ecb05886577095f2b8cc8b9e4066a3a (patch)
tree86e24bc466986cb5c9c6d167a918e6064defeafc /ssh
Release of DTail v1.0.0v1.0.0
Diffstat (limited to 'ssh')
-rw-r--r--ssh/client/authmethods.go45
-rw-r--r--ssh/client/hostkeycallback.go285
-rw-r--r--ssh/server/hostkey.go37
-rw-r--r--ssh/server/publickeycallback.go61
-rw-r--r--ssh/ssh.go112
5 files changed, 540 insertions, 0 deletions
diff --git a/ssh/client/authmethods.go b/ssh/client/authmethods.go
new file mode 100644
index 0000000..84b7ce3
--- /dev/null
+++ b/ssh/client/authmethods.go
@@ -0,0 +1,45 @@
+package client
+
+import (
+ "dtail/config"
+ "dtail/logger"
+ "dtail/ssh"
+ "os"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// InitSSHAuthMethods initialises all known SSH auth methods on othe client side.
+func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) {
+ var sshAuthMethods []gossh.AuthMethod
+
+ if config.Common.ExperimentalFeaturesEnable {
+ sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
+ logger.Info("Added experimental method to list of auth methods")
+ }
+
+ keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
+ if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added path to list of auth methods", keyPath)
+ }
+
+ keyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
+ if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added path to list of auth methods", keyPath)
+ }
+
+ if authMethod, err := ssh.Agent(); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added SSH Agent to list of auth methods")
+ }
+
+ knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
+ hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh)
+ if err != nil {
+ logger.FatalExit(knownHostsPath, err)
+ }
+
+ return sshAuthMethods, hostKeyCallback
+}
diff --git a/ssh/client/hostkeycallback.go b/ssh/client/hostkeycallback.go
new file mode 100644
index 0000000..7279f5e
--- /dev/null
+++ b/ssh/client/hostkeycallback.go
@@ -0,0 +1,285 @@
+package client
+
+import (
+ "bufio"
+ "dtail/logger"
+ "dtail/prompt"
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
+)
+
+type response int
+
+const (
+ trustHost response = iota
+ dontTrustHost response = iota
+)
+
+// Represents an unknown host.
+type unknownHost struct {
+ server string
+ remote net.Addr
+ key ssh.PublicKey
+ hostLine string
+ ipLine string
+ responseCh chan response
+}
+
+// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all
+// unknown hosts in a single batch to the known_hosts file.
+type HostKeyCallback struct {
+ knownHostsPath string
+ unknownCh chan unknownHost
+ throttleCh chan struct{}
+ trustAllHostsCh chan struct{}
+ untrustedHosts map[string]bool
+ mutex sync.Mutex
+}
+
+// NewHostKeyCallback returns a new wrapper.
+func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) {
+ // Ensure file exists
+ os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+
+ h := HostKeyCallback{
+ knownHostsPath: knownHostsPath,
+ unknownCh: make(chan unknownHost),
+ trustAllHostsCh: make(chan struct{}),
+ throttleCh: throttleCh,
+ untrustedHosts: make(map[string]bool),
+ }
+
+ if trustAllHosts {
+ close(h.trustAllHostsCh)
+ }
+
+ return &h, nil
+}
+
+// Wrap the host key callback.
+func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback {
+ return func(server string, remote net.Addr, key ssh.PublicKey) error {
+ // Parse known_hosts file
+ knownHostsCb, err := knownhosts.New(h.knownHostsPath)
+ if err != nil {
+ // Problem parsing it
+ return err
+ }
+
+ // Check for valid entry in known_hosts file
+ err = knownHostsCb(server, remote, key)
+ if err == nil {
+ // OK
+ return nil
+ }
+
+ // Make sure that interactive user callback does not interfere with
+ // SSH connection throttler.
+ <-h.throttleCh
+ defer func() { h.throttleCh <- struct{}{} }()
+
+ unknown := unknownHost{
+ server: server,
+ remote: remote,
+ key: key,
+ hostLine: knownhosts.Line([]string{server}, key),
+ ipLine: knownhosts.Line([]string{remote.String()}, key),
+ responseCh: make(chan response),
+ }
+
+ logger.Warn("Encountered unknown host", unknown)
+ // Notify user that there is an unknown host
+ h.unknownCh <- unknown
+
+ // Wait for user input.
+ switch <-unknown.responseCh {
+ case trustHost:
+ // End user acknowledged host key
+ return nil
+ case dontTrustHost:
+ }
+
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.untrustedHosts[server] = true
+
+ return err
+ }
+}
+
+// PromptAddHosts prompts a question to the user whether unknown hosts should
+// be added to the known hosts or not.
+func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) {
+ var hosts []unknownHost
+
+ for {
+ // Check whether there is a unknown host
+ select {
+ case unknown := <-h.unknownCh:
+ hosts = append(hosts, unknown)
+ // Ask every 50 unknown hosts
+ if len(hosts) >= 50 {
+ h.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-time.After(2 * time.Second):
+ // Or ask when after 2 seconds no new unknown hosts were added.
+ if len(hosts) > 0 {
+ h.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-stop:
+ logger.Debug("Stopping goroutine prompting new hosts...")
+ return
+ }
+ }
+}
+
+func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) {
+ var servers []string
+
+ for _, host := range hosts {
+ servers = append(servers, host.server)
+ }
+
+ select {
+ case <-h.trustAllHostsCh:
+ logger.Warn("Trusting host keys of servers", servers)
+ h.trustHosts(hosts)
+ return
+ default:
+ }
+
+ question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s",
+ len(servers),
+ strings.Join(servers, ","),
+ "Do you want to trust these hosts?",
+ )
+
+ p := prompt.New(question)
+
+ a := prompt.Answer{
+ Long: "yes",
+ Short: "y",
+ Callback: func() {
+ h.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "all",
+ Short: "a",
+ Callback: func() {
+ close(h.trustAllHostsCh)
+ h.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "no",
+ Short: "n",
+ Callback: func() {
+ h.dontTrustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "details",
+ Short: "d",
+ AskAgain: true,
+ Callback: func() {
+ for _, unknown := range hosts {
+ fmt.Println(unknown.hostLine)
+ fmt.Println(unknown.ipLine)
+ }
+ },
+ }
+ p.Add(a)
+
+ p.Ask()
+}
+
+func (h *HostKeyCallback) trustHosts(hosts []unknownHost) {
+ tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath)
+
+ newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ if err != nil {
+ panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
+ }
+ defer newFd.Close()
+
+ // Newly trusted hosts in normalized form
+ addresses := make(map[string]struct{})
+
+ // First write to new known hosts file, and keep track of addresses
+ for _, unknown := range hosts {
+ unknown.responseCh <- trustHost
+
+ // Add once as [HOSTNAME]:PORT
+ addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
+ // And once as [IP]:PORT
+ addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
+
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine))
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine))
+ }
+
+ // Read old known hosts file, to see which are old and new entries
+ os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ oldFd, err := os.Open(h.knownHostsPath)
+ if err != nil {
+ panic(err)
+ }
+ defer oldFd.Close()
+
+ scanner := bufio.NewScanner(oldFd)
+
+ // Now, append all still valid old entries to the new host file
+ for scanner.Scan() {
+ line := scanner.Text()
+ address := strings.SplitN(line, " ", 2)[0]
+
+ if _, ok := addresses[address]; !ok {
+ newFd.WriteString(fmt.Sprintf("%s\n", line))
+ }
+ }
+
+ // Now, replace old known hosts file
+ if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil {
+ panic(err)
+ }
+}
+
+func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) {
+ for _, unknown := range hosts {
+ unknown.responseCh <- dontTrustHost
+ }
+}
+
+// Untrusted returns true if the host is not trusted. False otherwise.
+func (h *HostKeyCallback) Untrusted(server string) bool {
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ _, ok := h.untrustedHosts[server]
+
+ return ok
+}
diff --git a/ssh/server/hostkey.go b/ssh/server/hostkey.go
new file mode 100644
index 0000000..ff1eb82
--- /dev/null
+++ b/ssh/server/hostkey.go
@@ -0,0 +1,37 @@
+package server
+
+import (
+ "dtail/config"
+ "dtail/logger"
+ "dtail/ssh"
+ "io/ioutil"
+ "os"
+)
+
+// PrivateHostKey retrieves the private server RSA host key.
+func PrivateHostKey() []byte {
+ hostKeyFile := config.Server.HostKeyFile
+ _, err := os.Stat(hostKeyFile)
+
+ if os.IsNotExist(err) {
+ logger.Info("Generating private server RSA host key")
+ privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits)
+
+ if err != nil {
+ logger.FatalExit("Failed to generate private server RSA host key", err)
+ }
+
+ pem := ssh.EncodePrivateKeyToPEM(privateKey)
+ if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil {
+ logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err)
+ }
+ return pem
+ }
+
+ logger.Info("Reading private server RSA host key from file", hostKeyFile)
+ pem, err := ioutil.ReadFile(hostKeyFile)
+ if err != nil {
+ logger.FatalExit("Failed to load private server RSA host key", err)
+ }
+ return pem
+}
diff --git a/ssh/server/publickeycallback.go b/ssh/server/publickeycallback.go
new file mode 100644
index 0000000..867f639
--- /dev/null
+++ b/ssh/server/publickeycallback.go
@@ -0,0 +1,61 @@
+package server
+
+import (
+ "dtail/config"
+ "dtail/logger"
+ "dtail/server/user"
+ "fmt"
+ "io/ioutil"
+ "os"
+ osUser "os/user"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not.
+func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ user := user.New(c.User(), c.RemoteAddr().String())
+ logger.Info(user, "Incoming authorization")
+
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error())
+ }
+
+ authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name)
+ if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) {
+ user, err := osUser.Lookup(user.Name)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error())
+ }
+ // Fallback to ~
+ authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys"
+ }
+
+ logger.Info(user, "Reading", authorizedKeysFile)
+ authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error())
+ }
+
+ authorizedKeysMap := map[string]bool{}
+ for len(authorizedKeysBytes) > 0 {
+ pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error())
+ }
+ authorizedKeysMap[string(pubKey.Marshal())] = true
+ authorizedKeysBytes = rest
+ }
+
+ if authorizedKeysMap[string(pubKey.Marshal())] {
+ logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user)
+ return &gossh.Permissions{
+ Extensions: map[string]string{
+ "pubkey-fp": gossh.FingerprintSHA256(pubKey),
+ },
+ }, nil
+ }
+
+ return nil, fmt.Errorf("Unknown public key|%s", user)
+}
diff --git a/ssh/ssh.go b/ssh/ssh.go
new file mode 100644
index 0000000..6cd28a2
--- /dev/null
+++ b/ssh/ssh.go
@@ -0,0 +1,112 @@
+package ssh
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "dtail/logger"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "syscall"
+
+ gossh "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/agent"
+ "golang.org/x/crypto/ssh/terminal"
+)
+
+// GeneratePrivateRSAKey is used by the server to generate its key.
+func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, size)
+ if err != nil {
+ return nil, err
+ }
+
+ err = privateKey.Validate()
+ if err != nil {
+ return nil, err
+ }
+
+ return privateKey, nil
+}
+
+// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format.
+func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
+ derFormat := x509.MarshalPKCS1PrivateKey(privateKey)
+
+ block := pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Headers: nil,
+ Bytes: derFormat,
+ }
+
+ return pem.EncodeToMemory(&block)
+}
+
+// Agent used for SSH auth.
+func Agent() (gossh.AuthMethod, error) {
+ sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
+ if err != nil {
+ return nil, err
+ }
+ agentClient := agent.NewClient(sshAgent)
+ keys, err := agentClient.List()
+ if err != nil {
+ return nil, err
+ }
+ for i, key := range keys {
+ logger.Debug("Public key", i, key)
+ }
+ return gossh.PublicKeysCallback(agentClient.Signers), nil
+}
+
+// EnterKeyPhrase is required to read phrase protected private keys.
+func EnterKeyPhrase(keyFile string) []byte {
+ fmt.Printf("Enter phrase for key %s: ", keyFile)
+ phrase, err := terminal.ReadPassword(int(syscall.Stdin))
+ if err != nil {
+ panic(err)
+ }
+ fmt.Printf("%s\n", string(phrase))
+ return phrase
+}
+
+// KeyFile returns the key as a SSH auth method.
+func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+ buffer, err := ioutil.ReadFile(keyFile)
+ if err != nil {
+ return nil, err
+ }
+
+ key, err := gossh.ParsePrivateKey(buffer)
+ if err != nil {
+ return nil, err
+ }
+
+ // Key phrase support disabled as password will be printed to stdout!
+ /*
+ if err == nil {
+ return gossh.PublicKeys(key), nil
+ }
+
+ keyPhrase := EnterKeyPhrase(keyFile)
+ key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase)
+ if err != nil {
+ return nil, err
+ }
+ */
+
+ return gossh.PublicKeys(key), nil
+}
+
+// PrivateKey returns the private key as a SSH auth method.
+func PrivateKey(keyFile string) (gossh.AuthMethod, error) {
+ signer, err := KeyFile(keyFile)
+ if err != nil {
+ logger.Debug(keyFile, err)
+ return nil, err
+ }
+ return gossh.AuthMethod(signer), nil
+}