diff --git a/sshd/server.go b/sshd/server.go index a8b60ba7..e22fc260 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -6,12 +6,17 @@ import ( "fmt" "net" "sync" + "time" "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) +// defaultHandshakeTimeout defines the timeout duration +// that should be used for SSH-Handshake +var defaultHandshakeTimeout = 10 * time.Second + type SSHServer struct { config *ssh.ServerConfig l *logrus.Entry @@ -176,7 +181,12 @@ func (s *SSHServer) run() { return } - conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + conn, chans, reqs, err := s.handshakeWithTimeout(c, defaultHandshakeTimeout) + if err != nil { + s.l.WithField("remoteAddress", c.RemoteAddr()).WithError(err).Warn("failed to handshake") + continue + } + fp := "" if conn != nil { fp = conn.Permissions.Extensions["fp"] @@ -216,6 +226,30 @@ func (s *SSHServer) run() { } } +func (s *SSHServer) handshakeWithTimeout(c net.Conn, timeout time.Duration) (*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + type connResult struct { + conn *ssh.ServerConn + chans <-chan ssh.NewChannel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan connResult, 1) + + go func() { + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + resultChan <- connResult{conn, chans, reqs, err} + }() + + select { + case result := <-resultChan: + return result.conn, result.chans, result.reqs, result.err + case <-time.After(timeout): + s.l.WithField("remoteAddress", c.RemoteAddr()).Warn("handshake timeout") + c.Close() + return nil, nil, nil, errors.New("handshake timeout") + } +} + func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil {