This commit is contained in:
Andreas Bell Martinsen 2025-12-05 10:21:41 -06:00 committed by GitHub
commit d7cfa522ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 117 additions and 113 deletions

View file

@ -8,9 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
) )
@ -19,129 +19,106 @@ import (
var dnsR *dnsRecords var dnsR *dnsRecords
var dnsServer *dns.Server var dnsServer *dns.Server
var dnsAddr string var dnsAddr string
var dnsSuffix string
type dnsRecords struct { type dnsRecords struct {
sync.RWMutex sync.RWMutex
l *logrus.Logger l *logrus.Logger
dnsMap4 map[string]netip.Addr dnsMap map[dns.Question][]dns.RR
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
} }
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { func newDnsRecords(l *logrus.Logger) *dnsRecords {
return &dnsRecords{ return &dnsRecords{
l: l, l: l,
dnsMap4: make(map[string]netip.Addr), dnsMap: make(map[dns.Question][]dns.RR),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
} }
} }
func (d *dnsRecords) Query(q uint16, data string) netip.Addr { func (d *dnsRecords) AddA(name string, addresses []netip.Addr) {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return netip.Addr{}
}
func (d *dnsRecords) QueryCert(data string) string {
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
}
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()
if q == nil {
return ""
}
b, err := q.Certificate.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
haveV4 := false q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeA}
haveV6 := false d.dnsMap[q] = nil
for _, addr := range addresses { for _, addr := range addresses {
if addr.Is4() && !haveV4 { if addr.Is4() {
d.dnsMap4[host] = addr qType := dns.TypeToString[q.Qtype]
haveV4 = true rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
} else if addr.Is6() && !haveV6 { if err == nil {
d.dnsMap6[host] = addr d.dnsMap[q] = append(d.dnsMap[q], rr)
haveV6 = true d.l.Debugf("DNS record added %s", rr.String())
} }
if haveV4 && haveV6 {
break
} }
} }
} }
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { func (d *dnsRecords) AddAAAA(name string, addresses []netip.Addr) {
a, _, _ := net.SplitHostPort(addr) d.Lock()
b, err := netip.ParseAddr(a) defer d.Unlock()
if err != nil { q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeAAAA}
return false d.dnsMap[q] = nil
}
if b.IsLoopback() { for _, addr := range addresses {
return true if addr.Is6() {
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
if err == nil {
d.dnsMap[q] = append(d.dnsMap[q], rr)
d.l.Debugf("DNS record added %s", rr.String())
}
}
} }
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
} }
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { func (d *dnsRecords) AddPTR(name string, addresses []netip.Addr) {
d.Lock()
defer d.Unlock()
for _, addr := range addresses {
arpa, err := dns.ReverseAddr(addr.String())
if err == nil {
q := dns.Question{Name: arpa, Qclass: dns.ClassINET, Qtype: dns.TypePTR}
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", arpa, qType, name))
if err == nil {
d.dnsMap[q] = []dns.RR{rr}
d.l.Debugf("DNS record added %s", rr.String())
}
}
}
}
func (d *dnsRecords) AddTXT(name string, crt cert.Certificate) {
d.Lock()
defer d.Unlock()
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeTXT}
d.dnsMap[q] = nil
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s \"Name: %v\" \"Networks: %v\" \"Groups: %v\" \"UnsafeNetworks: %v\"", name, qType, crt.Name(), crt.Networks(), crt.Groups(), crt.UnsafeNetworks()))
if err == nil {
d.dnsMap[q] = []dns.RR{rr}
d.l.Debugf("DNS record added %s", rr.String())
}
}
func (d *dnsRecords) Add(crt cert.Certificate, addresses []netip.Addr) {
host := dns.Fqdn(strings.ToLower(crt.Name() + dnsSuffix))
d.AddA(host, addresses)
d.AddAAAA(host, addresses)
d.AddPTR(host, addresses)
d.AddTXT(host, crt)
}
func (d *dnsRecords) parseQuery(m *dns.Msg) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA, dns.TypePTR, dns.TypeTXT:
qType := dns.TypeToString[q.Qtype] d.RLock()
d.l.Debugf("Query for %s %s", qType, q.Name) if rr, ok := d.dnsMap[q]; ok {
ip := d.Query(q.Qtype, q.Name) m.Answer = append(m.Answer, rr...)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
} }
d.RUnlock()
} }
} }
@ -157,14 +134,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
switch r.Opcode { switch r.Opcode {
case dns.OpcodeQuery: case dns.OpcodeQuery:
d.parseQuery(m, w) d.parseQuery(m)
} }
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { func dnsMain(l *logrus.Logger, cs *CertState, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap) dnsR = newDnsRecords(l)
dnsSuffix = getDnsSuffix(c)
// Add self to dns records
dnsR.Add(cs.GetDefaultCertificate(), cs.myVpnAddrs)
// attach request handler func // attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest) dns.HandleFunc(".", dnsR.handleDnsRequest)
@ -187,6 +168,11 @@ func getDnsServerAddr(c *config.C) string {
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
} }
func getDnsSuffix(c *config.C) string {
suffix := strings.TrimSpace(c.GetString("lighthouse.dns.suffix", ""))
return suffix
}
func startDns(l *logrus.Logger, c *config.C) { func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c) dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}

View file

@ -12,27 +12,44 @@ import (
func TestParsequery(t *testing.T) { func TestParsequery(t *testing.T) {
l := logrus.New() l := logrus.New()
hostMap := &HostMap{} ds := newDnsRecords(l)
ds := newDnsRecords(l, &CertState{}, hostMap)
addrs := []netip.Addr{ addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"), netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"), netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"), netip.MustParseAddr("fd01::25"),
} }
ds.Add("test.com.com", addrs) dnsSuffix = ".com"
crt := &dummyCert{
name: "test.com",
}
ds.Add(crt, addrs)
m := &dns.Msg{} m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA) m.SetQuestion("test.com.com.", dns.TypeA)
ds.parseQuery(m, nil) ds.parseQuery(m)
assert.NotNil(t, m.Answer) assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
assert.Equal(t, "1.2.3.5", m.Answer[1].(*dns.A).A.String())
m = &dns.Msg{} m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA) m.SetQuestion("test.com.com.", dns.TypeAAAA)
ds.parseQuery(m, nil) ds.parseQuery(m)
assert.NotNil(t, m.Answer) assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
assert.Equal(t, "fd01::25", m.Answer[1].(*dns.AAAA).AAAA.String())
m = &dns.Msg{}
m.SetQuestion("4.3.2.1.in-addr.arpa.", dns.TypePTR)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
m = &dns.Msg{}
m.SetQuestion("4.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.d.f.ip6.arpa.", dns.TypePTR)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
} }
func Test_getDnsServerAddr(t *testing.T) { func Test_getDnsServerAddr(t *testing.T) {

View file

@ -52,6 +52,7 @@ lighthouse:
# The DNS host defines the IP to bind the dns listener to. This also allows binding to the nebula node IP. # The DNS host defines the IP to bind the dns listener to. This also allows binding to the nebula node IP.
#host: 0.0.0.0 #host: 0.0.0.0
#port: 53 #port: 53
#suffix: ".nebula"
# interval is the number of seconds between updates from this node to a lighthouse. # interval is the number of seconds between updates from this node to a lighthouse.
# during updates, a node sends information about its current IP addresses to each node. # during updates, a node sends information about its current IP addresses to each node.
interval: 60 interval: 60

View file

@ -606,7 +606,7 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns { if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) dnsR.Add(remoteCert.Certificate, hostinfo.vpnAddrs)
} }
for _, addr := range hostinfo.vpnAddrs { for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f) hm.unlockedInnerAddHostInfo(addr, hostinfo, f)

View file

@ -287,7 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func() var dnsStart func()
if lightHouse.amLighthouse && serveDns { if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) dnsStart = dnsMain(l, pki.getCertState(), c)
} }
return &Control{ return &Control{