nebula/dns_server_test.go
2026-04-27 09:41:47 -05:00

340 lines
8.7 KiB
Go

package nebula
import (
"context"
"log/slog"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/miekg/dns"
"github.com/slackhq/nebula/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type stubDNSWriter struct{}
func (stubDNSWriter) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (stubDNSWriter) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5353}
}
func (stubDNSWriter) Write([]byte) (int, error) { return 0, nil }
func (stubDNSWriter) WriteMsg(*dns.Msg) error { return nil }
func (stubDNSWriter) Close() error { return nil }
func (stubDNSWriter) TsigStatus() error { return nil }
func (stubDNSWriter) TsigTimersOnly(bool) {}
func (stubDNSWriter) Hijack() {}
func TestParsequery(t *testing.T) {
l := slog.New(slog.DiscardHandler)
hostMap := &HostMap{}
ds := &dnsServer{
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
}
ds.enabled.Store(true)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
ds.Add("v4only.com.com", []netip.Addr{netip.MustParseAddr("1.2.3.6")})
ds.Add("v6only.com.com", []netip.Addr{netip.MustParseAddr("fd01::26")})
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// A known name with no record of the requested type should return NODATA
// (NOERROR with empty answer), not NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("v4only.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
m = &dns.Msg{}
m.SetQuestion("v6only.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeSuccess, m.Rcode)
// An unknown name should still return NXDOMAIN.
m = &dns.Msg{}
m.SetQuestion("unknown.com.com", dns.TypeA)
ds.parseQuery(m, nil)
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
// short lookups should not fail
m = &dns.Msg{}
m.Question = []dns.Question{{Name: "", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
m = &dns.Msg{}
m.Question = []dns.Question{{Name: ".", Qtype: dns.TypeTXT, Qclass: dns.ClassINET}}
ds.parseQuery(m, stubDNSWriter{})
assert.Empty(t, m.Answer)
assert.Equal(t, dns.RcodeNameError, m.Rcode)
}
func Test_getDnsServerAddr(t *testing.T) {
c := config.NewC(nil)
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
"host": "0.0.0.0",
"port": "1",
},
}
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
"host": "::",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
"host": "[::]",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
// Make sure whitespace doesn't mess us up
c.Settings["lighthouse"] = map[string]any{
"dns": map[string]any{
"host": "[::] ",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
}
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
t.Helper()
sl := slog.New(slog.DiscardHandler)
ds := &dnsServer{
l: sl,
ctx: context.Background(),
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: &HostMap{},
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
return ds, config.NewC(nil)
}
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
c.Settings["lighthouse"] = map[string]any{
"am_lighthouse": amLighthouse,
"serve_dns": serveDns,
"dns": map[string]any{
"host": host,
"port": port,
},
}
}
func TestDnsServer_reload_initial_disabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, false)
require.NoError(t, ds.reload(c, true))
assert.False(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_enabled(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
assert.True(t, ds.enabled.Load())
assert.Equal(t, "127.0.0.1:0", ds.addr)
// initial never starts a runner; that's Control.Start's job
assert.Nil(t, ds.server)
}
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", false, true)
require.NoError(t, ds.reload(c, true))
// Wants DNS but isn't a lighthouse: gated off, no runner.
assert.False(t, ds.enabled.Load())
}
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", "0", true, true)
require.NoError(t, ds.reload(c, true))
// No server running yet, no addr change. Reload should not spawn anything.
require.NoError(t, ds.reload(c, false))
assert.True(t, ds.enabled.Load())
assert.Nil(t, ds.server)
}
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
// Bind to a real (random) UDP port so we exercise the actual
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
ds.Stop()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after Stop")
}
}
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
// Stop called immediately after Start should not deadlock even if bind
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
// path: by binding to an obviously bad port (privileged) we get a fast
// bind error before NotifyStartedFunc fires.
ds, c := newTestDnsServer(t)
// Use a port that should fail to bind (negative would be invalid, use a
// host that won't resolve to ensure listenUDP fails quickly).
setDnsConfig(c, "256.256.256.256", "53", true, true)
require.NoError(t, ds.reload(c, true))
done := make(chan struct{})
go func() {
ds.Start()
close(done)
}()
// Give Start a moment to attempt the bind and fail.
select {
case <-done:
// Bind failed and Start returned; Stop should be a no-op.
case <-time.After(time.Second):
t.Fatal("Start did not return after a bad bind")
}
stopped := make(chan struct{})
go func() {
ds.Stop()
close(stopped)
}()
select {
case <-stopped:
case <-time.After(time.Second):
t.Fatal("Stop hung after a failed bind")
}
}
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
port := freeUDPPort(t)
ds, c := newTestDnsServer(t)
setDnsConfig(c, "127.0.0.1", port, true, true)
require.NoError(t, ds.reload(c, true))
startReturned := make(chan struct{})
go func() {
ds.Start()
close(startReturned)
}()
waitForBind(t, ds)
// Toggle serve_dns off; reload should shut the running server down.
setDnsConfig(c, "127.0.0.1", port, true, false)
require.NoError(t, ds.reload(c, false))
select {
case <-startReturned:
case <-time.After(5 * time.Second):
t.Fatal("Start did not return after reload disabled DNS")
}
assert.False(t, ds.enabled.Load())
}
func freeUDPPort(t *testing.T) string {
t.Helper()
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
port := conn.LocalAddr().(*net.UDPAddr).Port
require.NoError(t, conn.Close())
return strconv.Itoa(port)
}
func waitForBind(t *testing.T, ds *dnsServer) {
t.Helper()
waitFor(t, func() bool {
ds.serverMu.Lock()
started := ds.started
ds.serverMu.Unlock()
if started == nil {
return false
}
select {
case <-started:
return true
default:
return false
}
})
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}