diff --git a/dns_server.go b/dns_server.go index 73576546..1f839c58 100644 --- a/dns_server.go +++ b/dns_server.go @@ -184,11 +184,41 @@ func getDnsServerAddr(c *config.C) string { if dnsHost == "[::]" { dnsHost = "::" } + return dnsHost +} + +func getDnsServerAddrPort(c *config.C) string { + dnsHost := getDnsServerAddr(c) return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } +func shouldServeDns(c *config.C) (bool, error) { + if !c.GetBool("lighthouse.serve_dns", false) { + return false, nil + } + + dnsHostStr := getDnsServerAddr(c) + if dnsHostStr == "" { //setting an ip address is required + return false, fmt.Errorf("no DNS server IP address set") + } + + if c.GetBool("lighthouse.am_lighthouse", false) { + return true, nil + } + + dnsHost, err := netip.ParseAddr(dnsHostStr) + if err != nil { + return false, fmt.Errorf("failed to parse lighthouse.dns.host(%s) %v", dnsHostStr, err) + } + if !dnsHost.IsLoopback() { + return false, fmt.Errorf("lighthouse.dns.host(%s) must be loopback on non-lighthouses", dnsHostStr) + } + + return true, nil +} + func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) + dnsAddr = getDnsServerAddrPort(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") err := dnsServer.ListenAndServe() @@ -199,7 +229,7 @@ func startDns(l *logrus.Logger, c *config.C) { } func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { + if dnsAddr == getDnsServerAddrPort(c) { l.Debug("No DNS server config change detected") return } diff --git a/dns_server_test.go b/dns_server_test.go index 356e5890..6b60f98c 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -35,7 +35,7 @@ func TestParsequery(t *testing.T) { assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } -func Test_getDnsServerAddr(t *testing.T) { +func Test_getDnsServerAddrPort(t *testing.T) { c := config.NewC(nil) c.Settings["lighthouse"] = map[string]any{ @@ -44,7 +44,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) + assert.Equal(t, "0.0.0.0:1", getDnsServerAddrPort(c)) c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ @@ -52,7 +52,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ @@ -60,7 +60,7 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) // Make sure whitespace doesn't mess us up c.Settings["lighthouse"] = map[string]any{ @@ -69,5 +69,64 @@ func Test_getDnsServerAddr(t *testing.T) { "port": "1", }, } - assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + assert.Equal(t, "[::]:1", getDnsServerAddrPort(c)) +} + +func Test_shouldServeDns(t *testing.T) { + c := config.NewC(nil) + notLoopback := map[interface{}]interface{}{"host": "0.0.0.0", "port": "1"} + yesLoopbackv4 := map[interface{}]interface{}{"host": "127.0.0.2", "port": "1"} + yesLoopbackv6 := map[interface{}]interface{}{"host": "::1", "port": "1"} + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "serve_dns": false, + } + serveDns, err := shouldServeDns(c) + assert.NoError(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": true, + "serve_dns": true, + } + serveDns, err = shouldServeDns(c) + assert.Error(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": true, + "serve_dns": true, + "dns": notLoopback, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) + + //non-lighthouses must do DNS on loopback + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": notLoopback, + } + serveDns, err = shouldServeDns(c) + assert.Error(t, err) + assert.False(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": yesLoopbackv4, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "am_lighthouse": false, + "serve_dns": true, + "dns": yesLoopbackv6, + } + serveDns, err = shouldServeDns(c) + assert.NoError(t, err) + assert.True(t, serveDns) } diff --git a/main.go b/main.go index 17aaa548..4e2a87da 100644 --- a/main.go +++ b/main.go @@ -219,13 +219,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + serveDns, dnsErr := shouldServeDns(c) + if dnsErr != nil { + l.Warnf("failed to configure DNS server: %v", dnsErr) } ifConfig := &InterfaceConfig{ @@ -286,7 +282,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() - if lightHouse.amLighthouse && serveDns { + if serveDns { l.Debugln("Starting dns server") dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) }