summaryrefslogtreecommitdiff
path: root/ssh
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-20 18:41:05 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-01-21 14:35:23 +0000
commitc128865c4c7411c29a59fca9a3a2f95537686d7b (patch)
tree193bccc70d942c8b70cc93fae2670263701e43aa /ssh
parent3755a9911ecb05886577095f2b8cc8b9e4066a3a (diff)
Move commands to cmd/ and move internal dependencies to internal/
Diffstat (limited to 'ssh')
-rw-r--r--ssh/client/authmethods.go45
-rw-r--r--ssh/client/hostkeycallback.go285
-rw-r--r--ssh/server/hostkey.go37
-rw-r--r--ssh/server/publickeycallback.go61
-rw-r--r--ssh/ssh.go112
5 files changed, 0 insertions, 540 deletions
diff --git a/ssh/client/authmethods.go b/ssh/client/authmethods.go
deleted file mode 100644
index 84b7ce3..0000000
--- a/ssh/client/authmethods.go
+++ /dev/null
@@ -1,45 +0,0 @@
-package client
-
-import (
- "dtail/config"
- "dtail/logger"
- "dtail/ssh"
- "os"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-// InitSSHAuthMethods initialises all known SSH auth methods on othe client side.
-func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) {
- var sshAuthMethods []gossh.AuthMethod
-
- if config.Common.ExperimentalFeaturesEnable {
- sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
- logger.Info("Added experimental method to list of auth methods")
- }
-
- keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
- if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Info("Added path to list of auth methods", keyPath)
- }
-
- keyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
- if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Info("Added path to list of auth methods", keyPath)
- }
-
- if authMethod, err := ssh.Agent(); err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Info("Added SSH Agent to list of auth methods")
- }
-
- knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
- hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh)
- if err != nil {
- logger.FatalExit(knownHostsPath, err)
- }
-
- return sshAuthMethods, hostKeyCallback
-}
diff --git a/ssh/client/hostkeycallback.go b/ssh/client/hostkeycallback.go
deleted file mode 100644
index 7279f5e..0000000
--- a/ssh/client/hostkeycallback.go
+++ /dev/null
@@ -1,285 +0,0 @@
-package client
-
-import (
- "bufio"
- "dtail/logger"
- "dtail/prompt"
- "fmt"
- "net"
- "os"
- "strings"
- "sync"
- "time"
-
- "golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/knownhosts"
-)
-
-type response int
-
-const (
- trustHost response = iota
- dontTrustHost response = iota
-)
-
-// Represents an unknown host.
-type unknownHost struct {
- server string
- remote net.Addr
- key ssh.PublicKey
- hostLine string
- ipLine string
- responseCh chan response
-}
-
-// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all
-// unknown hosts in a single batch to the known_hosts file.
-type HostKeyCallback struct {
- knownHostsPath string
- unknownCh chan unknownHost
- throttleCh chan struct{}
- trustAllHostsCh chan struct{}
- untrustedHosts map[string]bool
- mutex sync.Mutex
-}
-
-// NewHostKeyCallback returns a new wrapper.
-func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) {
- // Ensure file exists
- os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
-
- h := HostKeyCallback{
- knownHostsPath: knownHostsPath,
- unknownCh: make(chan unknownHost),
- trustAllHostsCh: make(chan struct{}),
- throttleCh: throttleCh,
- untrustedHosts: make(map[string]bool),
- }
-
- if trustAllHosts {
- close(h.trustAllHostsCh)
- }
-
- return &h, nil
-}
-
-// Wrap the host key callback.
-func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback {
- return func(server string, remote net.Addr, key ssh.PublicKey) error {
- // Parse known_hosts file
- knownHostsCb, err := knownhosts.New(h.knownHostsPath)
- if err != nil {
- // Problem parsing it
- return err
- }
-
- // Check for valid entry in known_hosts file
- err = knownHostsCb(server, remote, key)
- if err == nil {
- // OK
- return nil
- }
-
- // Make sure that interactive user callback does not interfere with
- // SSH connection throttler.
- <-h.throttleCh
- defer func() { h.throttleCh <- struct{}{} }()
-
- unknown := unknownHost{
- server: server,
- remote: remote,
- key: key,
- hostLine: knownhosts.Line([]string{server}, key),
- ipLine: knownhosts.Line([]string{remote.String()}, key),
- responseCh: make(chan response),
- }
-
- logger.Warn("Encountered unknown host", unknown)
- // Notify user that there is an unknown host
- h.unknownCh <- unknown
-
- // Wait for user input.
- switch <-unknown.responseCh {
- case trustHost:
- // End user acknowledged host key
- return nil
- case dontTrustHost:
- }
-
- h.mutex.Lock()
- defer h.mutex.Unlock()
- h.untrustedHosts[server] = true
-
- return err
- }
-}
-
-// PromptAddHosts prompts a question to the user whether unknown hosts should
-// be added to the known hosts or not.
-func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) {
- var hosts []unknownHost
-
- for {
- // Check whether there is a unknown host
- select {
- case unknown := <-h.unknownCh:
- hosts = append(hosts, unknown)
- // Ask every 50 unknown hosts
- if len(hosts) >= 50 {
- h.promptAddHosts(hosts)
- hosts = []unknownHost{}
- }
- case <-time.After(2 * time.Second):
- // Or ask when after 2 seconds no new unknown hosts were added.
- if len(hosts) > 0 {
- h.promptAddHosts(hosts)
- hosts = []unknownHost{}
- }
- case <-stop:
- logger.Debug("Stopping goroutine prompting new hosts...")
- return
- }
- }
-}
-
-func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) {
- var servers []string
-
- for _, host := range hosts {
- servers = append(servers, host.server)
- }
-
- select {
- case <-h.trustAllHostsCh:
- logger.Warn("Trusting host keys of servers", servers)
- h.trustHosts(hosts)
- return
- default:
- }
-
- question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s",
- len(servers),
- strings.Join(servers, ","),
- "Do you want to trust these hosts?",
- )
-
- p := prompt.New(question)
-
- a := prompt.Answer{
- Long: "yes",
- Short: "y",
- Callback: func() {
- h.trustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Added hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "all",
- Short: "a",
- Callback: func() {
- close(h.trustAllHostsCh)
- h.trustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Added hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "no",
- Short: "n",
- Callback: func() {
- h.dontTrustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "details",
- Short: "d",
- AskAgain: true,
- Callback: func() {
- for _, unknown := range hosts {
- fmt.Println(unknown.hostLine)
- fmt.Println(unknown.ipLine)
- }
- },
- }
- p.Add(a)
-
- p.Ask()
-}
-
-func (h *HostKeyCallback) trustHosts(hosts []unknownHost) {
- tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath)
-
- newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
- if err != nil {
- panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
- }
- defer newFd.Close()
-
- // Newly trusted hosts in normalized form
- addresses := make(map[string]struct{})
-
- // First write to new known hosts file, and keep track of addresses
- for _, unknown := range hosts {
- unknown.responseCh <- trustHost
-
- // Add once as [HOSTNAME]:PORT
- addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
- // And once as [IP]:PORT
- addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
-
- newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine))
- newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine))
- }
-
- // Read old known hosts file, to see which are old and new entries
- os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
- oldFd, err := os.Open(h.knownHostsPath)
- if err != nil {
- panic(err)
- }
- defer oldFd.Close()
-
- scanner := bufio.NewScanner(oldFd)
-
- // Now, append all still valid old entries to the new host file
- for scanner.Scan() {
- line := scanner.Text()
- address := strings.SplitN(line, " ", 2)[0]
-
- if _, ok := addresses[address]; !ok {
- newFd.WriteString(fmt.Sprintf("%s\n", line))
- }
- }
-
- // Now, replace old known hosts file
- if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil {
- panic(err)
- }
-}
-
-func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) {
- for _, unknown := range hosts {
- unknown.responseCh <- dontTrustHost
- }
-}
-
-// Untrusted returns true if the host is not trusted. False otherwise.
-func (h *HostKeyCallback) Untrusted(server string) bool {
- h.mutex.Lock()
- defer h.mutex.Unlock()
- _, ok := h.untrustedHosts[server]
-
- return ok
-}
diff --git a/ssh/server/hostkey.go b/ssh/server/hostkey.go
deleted file mode 100644
index ff1eb82..0000000
--- a/ssh/server/hostkey.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package server
-
-import (
- "dtail/config"
- "dtail/logger"
- "dtail/ssh"
- "io/ioutil"
- "os"
-)
-
-// PrivateHostKey retrieves the private server RSA host key.
-func PrivateHostKey() []byte {
- hostKeyFile := config.Server.HostKeyFile
- _, err := os.Stat(hostKeyFile)
-
- if os.IsNotExist(err) {
- logger.Info("Generating private server RSA host key")
- privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits)
-
- if err != nil {
- logger.FatalExit("Failed to generate private server RSA host key", err)
- }
-
- pem := ssh.EncodePrivateKeyToPEM(privateKey)
- if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil {
- logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err)
- }
- return pem
- }
-
- logger.Info("Reading private server RSA host key from file", hostKeyFile)
- pem, err := ioutil.ReadFile(hostKeyFile)
- if err != nil {
- logger.FatalExit("Failed to load private server RSA host key", err)
- }
- return pem
-}
diff --git a/ssh/server/publickeycallback.go b/ssh/server/publickeycallback.go
deleted file mode 100644
index 867f639..0000000
--- a/ssh/server/publickeycallback.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package server
-
-import (
- "dtail/config"
- "dtail/logger"
- "dtail/server/user"
- "fmt"
- "io/ioutil"
- "os"
- osUser "os/user"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not.
-func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) {
- user := user.New(c.User(), c.RemoteAddr().String())
- logger.Info(user, "Incoming authorization")
-
- cwd, err := os.Getwd()
- if err != nil {
- return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error())
- }
-
- authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name)
- if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) {
- user, err := osUser.Lookup(user.Name)
- if err != nil {
- return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error())
- }
- // Fallback to ~
- authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys"
- }
-
- logger.Info(user, "Reading", authorizedKeysFile)
- authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile)
- if err != nil {
- return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error())
- }
-
- authorizedKeysMap := map[string]bool{}
- for len(authorizedKeysBytes) > 0 {
- pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
- if err != nil {
- return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error())
- }
- authorizedKeysMap[string(pubKey.Marshal())] = true
- authorizedKeysBytes = rest
- }
-
- if authorizedKeysMap[string(pubKey.Marshal())] {
- logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user)
- return &gossh.Permissions{
- Extensions: map[string]string{
- "pubkey-fp": gossh.FingerprintSHA256(pubKey),
- },
- }, nil
- }
-
- return nil, fmt.Errorf("Unknown public key|%s", user)
-}
diff --git a/ssh/ssh.go b/ssh/ssh.go
deleted file mode 100644
index 6cd28a2..0000000
--- a/ssh/ssh.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package ssh
-
-import (
- "crypto/rand"
- "crypto/rsa"
- "crypto/x509"
- "dtail/logger"
- "encoding/pem"
- "fmt"
- "io/ioutil"
- "net"
- "os"
- "syscall"
-
- gossh "golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/agent"
- "golang.org/x/crypto/ssh/terminal"
-)
-
-// GeneratePrivateRSAKey is used by the server to generate its key.
-func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
- privateKey, err := rsa.GenerateKey(rand.Reader, size)
- if err != nil {
- return nil, err
- }
-
- err = privateKey.Validate()
- if err != nil {
- return nil, err
- }
-
- return privateKey, nil
-}
-
-// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format.
-func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
- derFormat := x509.MarshalPKCS1PrivateKey(privateKey)
-
- block := pem.Block{
- Type: "RSA PRIVATE KEY",
- Headers: nil,
- Bytes: derFormat,
- }
-
- return pem.EncodeToMemory(&block)
-}
-
-// Agent used for SSH auth.
-func Agent() (gossh.AuthMethod, error) {
- sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
- if err != nil {
- return nil, err
- }
- agentClient := agent.NewClient(sshAgent)
- keys, err := agentClient.List()
- if err != nil {
- return nil, err
- }
- for i, key := range keys {
- logger.Debug("Public key", i, key)
- }
- return gossh.PublicKeysCallback(agentClient.Signers), nil
-}
-
-// EnterKeyPhrase is required to read phrase protected private keys.
-func EnterKeyPhrase(keyFile string) []byte {
- fmt.Printf("Enter phrase for key %s: ", keyFile)
- phrase, err := terminal.ReadPassword(int(syscall.Stdin))
- if err != nil {
- panic(err)
- }
- fmt.Printf("%s\n", string(phrase))
- return phrase
-}
-
-// KeyFile returns the key as a SSH auth method.
-func KeyFile(keyFile string) (gossh.AuthMethod, error) {
- buffer, err := ioutil.ReadFile(keyFile)
- if err != nil {
- return nil, err
- }
-
- key, err := gossh.ParsePrivateKey(buffer)
- if err != nil {
- return nil, err
- }
-
- // Key phrase support disabled as password will be printed to stdout!
- /*
- if err == nil {
- return gossh.PublicKeys(key), nil
- }
-
- keyPhrase := EnterKeyPhrase(keyFile)
- key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase)
- if err != nil {
- return nil, err
- }
- */
-
- return gossh.PublicKeys(key), nil
-}
-
-// PrivateKey returns the private key as a SSH auth method.
-func PrivateKey(keyFile string) (gossh.AuthMethod, error) {
- signer, err := KeyFile(keyFile)
- if err != nil {
- logger.Debug(keyFile, err)
- return nil, err
- }
- return gossh.AuthMethod(signer), nil
-}