summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go47
1 files changed, 21 insertions, 26 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index b3d4bff..0cb5e27 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -24,9 +24,9 @@ type Server struct {
stats stats
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
- // To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
+ // To control the max amount of concurrent cats.
catLimiter chan struct{}
- // To control the max amount of concurrent tails
+ // To control the max amount of concurrent tails.
tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
@@ -61,7 +61,6 @@ func New() *Server {
// Start the server.
func (s *Server) Start(ctx context.Context) int {
dlog.Server.Info("Starting server")
-
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
dlog.Server.Info("Binding server", bindAt)
@@ -76,14 +75,12 @@ func (s *Server) Start(ctx context.Context) int {
go s.listenerLoop(ctx, listener)
<-ctx.Done()
-
// For future use.
return 0
}
func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
dlog.Server.Debug("Starting listener loop")
-
for {
conn, err := listener.Accept() // Blocking
if err != nil {
@@ -101,7 +98,6 @@ func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
conn.Close()
continue
}
-
go s.handleConnection(ctx, conn)
}
}
@@ -116,22 +112,23 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
}
s.stats.incrementConnections()
-
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn,
+ newChannel gossh.NewChannel) {
+
user, err := user.New(sshConn.User(), sshConn.RemoteAddr().String())
if err != nil {
dlog.Server.Error(user, err)
newChannel.Reject(gossh.Prohibited, err.Error())
return
}
- dlog.Server.Info(user, "Invoking channel handler")
+ dlog.Server.Info(user, "Invoking channel handler")
if newChannel.ChannelType() != "session" {
err := errors.New("Don'w allow other channel types than session")
dlog.Server.Error(user, err)
@@ -151,9 +148,10 @@ func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChann
}
}
-func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
- dlog.Server.Info(user, "Invoking request handler")
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn,
+ in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+ dlog.Server.Info(user, "Invoking request handler")
for req := range in {
var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
@@ -167,7 +165,6 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
default:
handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
}
-
terminate := func() {
handler.Shutdown()
sshConn.Close()
@@ -178,13 +175,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
io.Copy(channel, handler)
terminate()
}()
-
go func() {
// Broken pipe, cancel
io.Copy(handler, channel)
terminate()
}()
-
go func() {
select {
case <-ctx.Done():
@@ -192,7 +187,6 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
}
terminate()
}()
-
go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
dlog.Server.Error(user, err)
@@ -204,20 +198,19 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
// Only serving shell type
req.Reply(true, nil)
-
default:
req.Reply(false, nil)
-
return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
req.Type, payload.Value)
}
}
-
return nil
}
// Callback for SSH authentication.
-func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
+func (s *Server) Callback(c gossh.ConnMetadata,
+ authPayload []byte) (*gossh.Permissions, error) {
+
user, err := user.New(c.User(), c.RemoteAddr().String())
if err != nil {
return nil, err
@@ -229,7 +222,6 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm
}
authInfo := string(authPayload)
-
splitted := strings.Split(c.RemoteAddr().String(), ":")
remoteIP := splitted[0]
@@ -259,23 +251,26 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm
return nil, fmt.Errorf("user %s not authorized", user)
}
-func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool {
- dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
+func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP,
+ allowedJobName string, allowFrom []string) bool {
+ dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
if jobName != allowedJobName {
- dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH",
+ "Job name does not match, skipping to next one...", allowedJobName)
return false
}
for _, myAddr := range allowFrom {
ips, err := net.LookupIP(myAddr)
if err != nil {
- dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP "+
+ "address for allowed hosts lookup, skipping to next one...", myAddr, err)
continue
}
-
for _, ip := range ips {
- dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String())
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses",
+ remoteIP, ip.String())
if remoteIP == ip.String() {
return true
}