mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 22:22:27 -07:00
340 lines
8.7 KiB
Go
340 lines
8.7 KiB
Go
package nebula
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/slackhq/nebula/config"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type stubDNSWriter struct{}
|
|
|
|
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
|
|
func (stubDNSWriter) RemoteAddr() net.Addr {
|
|
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
|
|
}
|
|
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
|
|
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
|
|
func (stubDNSWriter) Close() error { return nil }
|
|
func (stubDNSWriter) TsigStatus() error { return nil }
|
|
func (stubDNSWriter) TsigTimersOnly(bool) {}
|
|
func (stubDNSWriter) Hijack() {}
|
|
|
|
func TestParsequery(t *testing.T) {
|
|
l := slog.New(slog.DiscardHandler)
|
|
hostMap := &HostMap{}
|
|
ds := &dnsServer{
|
|
l: l,
|
|
dnsMap4: make(map[string]netip.Addr),
|
|
dnsMap6: make(map[string]netip.Addr),
|
|
hostMap: hostMap,
|
|
}
|
|
ds.enabled.Store(true)
|
|
addrs := []netip.Addr{
|
|
netip.MustParseAddr("1.2.3.4"),
|
|
netip.MustParseAddr("1.2.3.5"),
|
|
netip.MustParseAddr("fd01::24"),
|
|
netip.MustParseAddr("fd01::25"),
|
|
}
|
|
ds.Add("test.com.com", addrs)
|
|
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
|
|
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
|
|
|
|
m := &dns.Msg{}
|
|
m.SetQuestion("test.com.com", dns.TypeA)
|
|
ds.parseQuery(m, nil)
|
|
assert.NotNil(t, m.Answer)
|
|
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
|
|
|
m = &dns.Msg{}
|
|
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
|
ds.parseQuery(m, nil)
|
|
assert.NotNil(t, m.Answer)
|
|
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
|
|
|
// A known name with no record of the requested type should return NODATA
|
|
// (NOERROR with empty answer), not NXDOMAIN.
|
|
m = &dns.Msg{}
|
|
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
|
|
ds.parseQuery(m, nil)
|
|
assert.Empty(t, m.Answer)
|
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
|
|
|
m = &dns.Msg{}
|
|
m.SetQuestion("v6only.com.com", dns.TypeA)
|
|
ds.parseQuery(m, nil)
|
|
assert.Empty(t, m.Answer)
|
|
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
|
|
|
|
// An unknown name should still return NXDOMAIN.
|
|
m = &dns.Msg{}
|
|
m.SetQuestion("unknown.com.com", dns.TypeA)
|
|
ds.parseQuery(m, nil)
|
|
assert.Empty(t, m.Answer)
|
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
|
|
|
// short lookups should not fail
|
|
m = &dns.Msg{}
|
|
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
|
ds.parseQuery(m, stubDNSWriter{})
|
|
assert.Empty(t, m.Answer)
|
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
|
|
|
m = &dns.Msg{}
|
|
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
|
|
ds.parseQuery(m, stubDNSWriter{})
|
|
assert.Empty(t, m.Answer)
|
|
assert.Equal(t, dns.RcodeNameError, m.Rcode)
|
|
}
|
|
|
|
func Test_getDnsServerAddr(t *testing.T) {
|
|
c := config.NewC(nil)
|
|
|
|
c.Settings["lighthouse"] = map[string]any{
|
|
"dns": map[string]any{
|
|
"host": "0.0.0.0",
|
|
"port": "1",
|
|
},
|
|
}
|
|
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
|
|
|
|
c.Settings["lighthouse"] = map[string]any{
|
|
"dns": map[string]any{
|
|
"host": "::",
|
|
"port": "1",
|
|
},
|
|
}
|
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
|
|
|
c.Settings["lighthouse"] = map[string]any{
|
|
"dns": map[string]any{
|
|
"host": "[::]",
|
|
"port": "1",
|
|
},
|
|
}
|
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
|
|
|
// Make sure whitespace doesn't mess us up
|
|
c.Settings["lighthouse"] = map[string]any{
|
|
"dns": map[string]any{
|
|
"host": "[::] ",
|
|
"port": "1",
|
|
},
|
|
}
|
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
|
}
|
|
|
|
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
|
t.Helper()
|
|
sl := slog.New(slog.DiscardHandler)
|
|
ds := &dnsServer{
|
|
l: sl,
|
|
ctx: context.Background(),
|
|
dnsMap4: make(map[string]netip.Addr),
|
|
dnsMap6: make(map[string]netip.Addr),
|
|
hostMap: &HostMap{},
|
|
}
|
|
ds.mux = dns.NewServeMux()
|
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
|
return ds, config.NewC(nil)
|
|
}
|
|
|
|
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
|
c.Settings["lighthouse"] = map[string]any{
|
|
"am_lighthouse": amLighthouse,
|
|
"serve_dns": serveDns,
|
|
"dns": map[string]any{
|
|
"host": host,
|
|
"port": port,
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestDnsServer_reload_initial_disabled(t *testing.T) {
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", "0", true, false)
|
|
|
|
require.NoError(t, ds.reload(c, true))
|
|
assert.False(t, ds.enabled.Load())
|
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
|
assert.Nil(t, ds.server)
|
|
}
|
|
|
|
func TestDnsServer_reload_initial_enabled(t *testing.T) {
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
|
|
|
require.NoError(t, ds.reload(c, true))
|
|
assert.True(t, ds.enabled.Load())
|
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
|
// initial never starts a runner; that's Control.Start's job
|
|
assert.Nil(t, ds.server)
|
|
}
|
|
|
|
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", "0", false, true)
|
|
|
|
require.NoError(t, ds.reload(c, true))
|
|
// Wants DNS but isn't a lighthouse: gated off, no runner.
|
|
assert.False(t, ds.enabled.Load())
|
|
}
|
|
|
|
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
|
|
|
require.NoError(t, ds.reload(c, true))
|
|
// No server running yet, no addr change. Reload should not spawn anything.
|
|
require.NoError(t, ds.reload(c, false))
|
|
assert.True(t, ds.enabled.Load())
|
|
assert.Nil(t, ds.server)
|
|
}
|
|
|
|
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
|
|
// Bind to a real (random) UDP port so we exercise the actual
|
|
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
|
|
port := freeUDPPort(t)
|
|
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
|
require.NoError(t, ds.reload(c, true))
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
ds.Start()
|
|
close(done)
|
|
}()
|
|
|
|
waitFor(t, func() bool {
|
|
ds.serverMu.Lock()
|
|
started := ds.started
|
|
ds.serverMu.Unlock()
|
|
if started == nil {
|
|
return false
|
|
}
|
|
select {
|
|
case <-started:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
})
|
|
|
|
ds.Stop()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Start did not return after Stop")
|
|
}
|
|
}
|
|
|
|
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
|
|
// Stop called immediately after Start should not deadlock even if bind
|
|
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
|
|
// path: by binding to an obviously bad port (privileged) we get a fast
|
|
// bind error before NotifyStartedFunc fires.
|
|
ds, c := newTestDnsServer(t)
|
|
// Use a port that should fail to bind (negative would be invalid, use a
|
|
// host that won't resolve to ensure listenUDP fails quickly).
|
|
setDnsConfig(c, "256.256.256.256", "53", true, true)
|
|
require.NoError(t, ds.reload(c, true))
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
ds.Start()
|
|
close(done)
|
|
}()
|
|
|
|
// Give Start a moment to attempt the bind and fail.
|
|
select {
|
|
case <-done:
|
|
// Bind failed and Start returned; Stop should be a no-op.
|
|
case <-time.After(time.Second):
|
|
t.Fatal("Start did not return after a bad bind")
|
|
}
|
|
|
|
stopped := make(chan struct{})
|
|
go func() {
|
|
ds.Stop()
|
|
close(stopped)
|
|
}()
|
|
select {
|
|
case <-stopped:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("Stop hung after a failed bind")
|
|
}
|
|
}
|
|
|
|
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
|
|
port := freeUDPPort(t)
|
|
ds, c := newTestDnsServer(t)
|
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
|
require.NoError(t, ds.reload(c, true))
|
|
|
|
startReturned := make(chan struct{})
|
|
go func() {
|
|
ds.Start()
|
|
close(startReturned)
|
|
}()
|
|
waitForBind(t, ds)
|
|
|
|
// Toggle serve_dns off; reload should shut the running server down.
|
|
setDnsConfig(c, "127.0.0.1", port, true, false)
|
|
require.NoError(t, ds.reload(c, false))
|
|
select {
|
|
case <-startReturned:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Start did not return after reload disabled DNS")
|
|
}
|
|
assert.False(t, ds.enabled.Load())
|
|
}
|
|
|
|
func freeUDPPort(t *testing.T) string {
|
|
t.Helper()
|
|
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
port := conn.LocalAddr().(*net.UDPAddr).Port
|
|
require.NoError(t, conn.Close())
|
|
return strconv.Itoa(port)
|
|
}
|
|
|
|
func waitForBind(t *testing.T, ds *dnsServer) {
|
|
t.Helper()
|
|
waitFor(t, func() bool {
|
|
ds.serverMu.Lock()
|
|
started := ds.started
|
|
ds.serverMu.Unlock()
|
|
if started == nil {
|
|
return false
|
|
}
|
|
select {
|
|
case <-started:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
})
|
|
}
|
|
|
|
func waitFor(t *testing.T, cond func() bool) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(5 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if cond() {
|
|
return
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
t.Fatal("timed out waiting for condition")
|
|
}
|