mirror of
https://github.com/slackhq/nebula.git
synced 2026-03-09 08:11:29 -07:00
ssh: add a handshake timeout for SSH handshake to prevent hangs when client blocks (prevent possible DoS in SSH management interface)
This commit is contained in:
parent
5f6bfc56df
commit
1f3b31f692
1 changed files with 167 additions and 0 deletions
167
sshd/server_test.go
Normal file
167
sshd/server_test.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package sshd
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// newTCPConnPair returns two connected net.Conn via TCP loopback.
|
||||
// net.Pipe() is intentionally avoided: SSH handshakes require both sides
|
||||
// to write simultaneously (kexInit, auth), which deadlocks on an
|
||||
// unbuffered synchronous pipe. TCP loopback provides kernel buffering.
|
||||
func newTCPConnPair(t *testing.T) (serverConn net.Conn, clientConn net.Conn) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan dialResult, 1)
|
||||
go func() {
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
ch <- dialResult{c, err}
|
||||
}()
|
||||
|
||||
s, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
ln.Close()
|
||||
|
||||
r := <-ch
|
||||
require.NoError(t, r.err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
r.conn.Close()
|
||||
})
|
||||
|
||||
return s, r.conn
|
||||
}
|
||||
|
||||
// newTestSSHServer creates a minimal SSHServer with an ephemeral ed25519 host key.
|
||||
// If addAuthorizedKey is true, the returned signer's public key is registered for "testuser".
|
||||
func newTestSSHServer(t *testing.T, addAuthorizedKey bool) (*SSHServer, ssh.Signer) {
|
||||
t.Helper()
|
||||
|
||||
l := test.NewLogger()
|
||||
server, err := NewSSHServer(l.WithField("subsystem", "sshd"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate ephemeral host key.
|
||||
_, hostPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
hostKeyBlock, err := ssh.MarshalPrivateKey(hostPriv, "")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.SetHostKey(pem.EncodeToMemory(hostKeyBlock)))
|
||||
|
||||
// Generate ephemeral client key.
|
||||
_, clientPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := ssh.NewSignerFromKey(clientPriv)
|
||||
require.NoError(t, err)
|
||||
|
||||
if addAuthorizedKey {
|
||||
authorizedKey := string(ssh.MarshalAuthorizedKey(clientSigner.PublicKey()))
|
||||
require.NoError(t, server.AddAuthorizedKey("testuser", authorizedKey))
|
||||
}
|
||||
|
||||
return server, clientSigner
|
||||
}
|
||||
|
||||
func newClientConfig(user string, signer ssh.Signer) *ssh.ClientConfig {
|
||||
return &ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // test only
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandshakeWithTimeout_Success verifies that a valid SSH handshake completes
|
||||
// and returns a non-nil ServerConn with the correct user extensions.
|
||||
func TestHandshakeWithTimeout_Success(t *testing.T) {
|
||||
server, clientSigner := newTestSSHServer(t, true)
|
||||
serverConn, clientConn := newTCPConnPair(t)
|
||||
|
||||
type clientResult struct {
|
||||
conn ssh.Conn
|
||||
err error
|
||||
}
|
||||
clientDone := make(chan clientResult, 1)
|
||||
go func() {
|
||||
c, chans, reqs, err := ssh.NewClientConn(clientConn, "", newClientConfig("testuser", clientSigner))
|
||||
if err == nil {
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go func() {
|
||||
for range chans {
|
||||
}
|
||||
}()
|
||||
}
|
||||
clientDone <- clientResult{c, err}
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
if r := <-clientDone; r.conn != nil {
|
||||
r.conn.Close()
|
||||
}
|
||||
})
|
||||
|
||||
conn, chans, reqs, err := server.handshakeWithTimeout(serverConn, 5*time.Second)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
assert.NotNil(t, chans)
|
||||
assert.NotNil(t, reqs)
|
||||
assert.Equal(t, "testuser", conn.Permissions.Extensions["user"])
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// TestHandshakeWithTimeout_HandshakeError verifies that an authentication failure
|
||||
// (unknown user) is returned as an error without triggering the timeout path.
|
||||
func TestHandshakeWithTimeout_HandshakeError(t *testing.T) {
|
||||
// Server has no authorized keys → any client will be rejected.
|
||||
server, clientSigner := newTestSSHServer(t, false)
|
||||
serverConn, clientConn := newTCPConnPair(t)
|
||||
|
||||
clientDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, _, _, err := ssh.NewClientConn(clientConn, "", newClientConfig("testuser", clientSigner))
|
||||
clientDone <- err
|
||||
}()
|
||||
t.Cleanup(func() { <-clientDone })
|
||||
|
||||
conn, chans, reqs, err := server.handshakeWithTimeout(serverConn, 5*time.Second)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.NotEqual(t, "handshake timeout", err.Error())
|
||||
assert.Nil(t, conn)
|
||||
assert.Nil(t, chans)
|
||||
assert.Nil(t, reqs)
|
||||
}
|
||||
|
||||
// TestHandshakeWithTimeout_Timeout verifies that when no client sends SSH traffic,
|
||||
// the function returns a "handshake timeout" error and closes the connection.
|
||||
func TestHandshakeWithTimeout_Timeout(t *testing.T) {
|
||||
server, _ := newTestSSHServer(t, true)
|
||||
serverConn, _ := newTCPConnPair(t)
|
||||
// The client side is intentionally idle — no SSH traffic is sent.
|
||||
|
||||
conn, chans, reqs, err := server.handshakeWithTimeout(serverConn, 1*time.Millisecond)
|
||||
|
||||
require.EqualError(t, err, "handshake timeout")
|
||||
assert.Nil(t, conn)
|
||||
assert.Nil(t, chans)
|
||||
assert.Nil(t, reqs)
|
||||
|
||||
// Confirm that handshakeWithTimeout closed the connection.
|
||||
_, writeErr := serverConn.Write([]byte("probe"))
|
||||
assert.Error(t, writeErr, "serverConn should be closed after timeout")
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue