summaryrefslogtreecommitdiff
path: root/server/server.go
blob: 4637458690df53efb31246b0f24e2f8e4888e2d4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
package server

import (
	"dtail/config"
	"dtail/logger"
	"dtail/server/handlers"
	"dtail/server/user"
	"dtail/ssh/server"
	"dtail/version"
	"errors"
	"fmt"
	"io"
	"net"
	"sync"

	gossh "golang.org/x/crypto/ssh"
)

// Server is the main server data structure.
type Server struct {
	// Various server statistics counters.
	stats stats
	// SSH server configuration.
	sshServerConfig *gossh.ServerConfig
	// To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
	catLimiterCh chan struct{}
	// To control the max amount of concurrent tails
	tailLimiterCh chan struct{}
	// Ask to shutdown the server
	stop chan struct{}
}

// New returns a new server.
func New() *Server {
	logger.Info("Creating server", version.String())

	s := Server{
		sshServerConfig: &gossh.ServerConfig{},
		catLimiterCh:    make(chan struct{}, config.Server.MaxConcurrentCats),
		tailLimiterCh:   make(chan struct{}, config.Server.MaxConcurrentTails),
		stop:            make(chan struct{}),
	}

	s.sshServerConfig.PasswordCallback = s.controlUserCallback
	s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback

	private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
	if err != nil {
		logger.FatalExit(err)
	}
	s.sshServerConfig.AddHostKey(private)

	return &s
}

// Start the server.
func (s *Server) Start(wg *sync.WaitGroup) int {
	defer wg.Done()
	logger.Info("Starting server")

	bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
	logger.Info("Binding server", bindAt)
	listener, err := net.Listen("tcp", bindAt)
	if err != nil {
		logger.FatalExit("Failed to open listening TCP socket", err)
	}

	go s.stats.periodicLogServerStats(s.stop)

	for {
		conn, err := listener.Accept() // Blocking
		if err != nil {
			logger.Error("Failed to accept incoming connection", err)
			continue
		}

		if err := s.stats.serverLimitExceeded(); err != nil {
			logger.Error(err)
			conn.Close()
			continue
		}

		go s.handleConnection(conn)
	}
}

func (s *Server) handleConnection(conn net.Conn) {
	logger.Info("Handling connection")

	sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
	if err != nil {
		logger.Error("Something just happened", err)
		return
	}

	s.stats.incrementConnections()

	go gossh.DiscardRequests(reqs)
	for newChannel := range chans {
		go s.handleChannel(sshConn, newChannel)
	}
}

func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
	user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
	logger.Info(user, "Invoking channel handler")

	if newChannel.ChannelType() != "session" {
		err := errors.New("Don'w allow other channel types than session")
		logger.Error(user, err)
		newChannel.Reject(gossh.Prohibited, err.Error())
		return
	}

	channel, requests, err := newChannel.Accept()
	if err != nil {
		logger.Error(user, "Could not accept channel", err)
		return
	}

	if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
		logger.Error(user, err)
		sshConn.Close()
	}
}

func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
	logger.Info(user, "Invoking request handler")

	for req := range in {
		var payload = struct{ Value string }{}
		gossh.Unmarshal(req.Payload, &payload)

		switch req.Type {
		case "shell":
			var handler handlers.Handler
			switch user.Name {
			case config.ControlUser:
				handler = handlers.NewControlHandler(user)
			default:
				handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
			}

			// Bi-directionally connect SSH stream to SSH handler
			brokenPipe1 := make(chan struct{})
			go func() {
				defer close(brokenPipe1)
				io.Copy(channel, handler)
			}()

			brokenPipe2 := make(chan struct{})
			go func() {
				defer close(brokenPipe2)
				io.Copy(handler, channel)
			}()

			// Ensure to close all fd's and stop all goroutines once ssh connection terminated
			go func() {
				defer s.stats.decrementConnections()
				defer handler.Close()

				if err := sshConn.Wait(); err != nil && err != io.EOF {
					logger.Error(user, err)
				}
				logger.Info(user, "Good bye Mister!")
			}()

			// Close the underlying ssh socket when server shuts down
			go func() {
				select {
				case <-s.stop:
					logger.Debug(user, "Server initiating shutdown on handler")
				case <-handler.Wait():
					logger.Debug(user, "Handler initiating shutdown by its own")
				case <-brokenPipe1:
					logger.Debug(user, "Broken pipe1")
				case <-brokenPipe2:
					logger.Debug(user, "Broken pipe2")
				}
				sshConn.Close()
				logger.Info(user, "Closed SSH connection")
			}()

			// Only serving shell type
			req.Reply(true, nil)

		default:
			req.Reply(false, nil)

			return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
				req.Type, payload.Value)
		}
	}

	return nil
}

func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
	user := user.New(c.User(), c.RemoteAddr().String())

	if user.Name == config.ControlUser && string(authPayload) == config.ControlUser {
		logger.Debug(user, "Initiating master control program")
		return nil, nil
	}

	return nil, fmt.Errorf("Not authorized")
}

// Stop the server.
func (s *Server) Stop() {
	close(s.stop)
	s.stats.waitForConnections()
}