This commit is contained in:
Andreas Bell Martinsen 2025-12-05 10:21:41 -06:00 committed by GitHub
commit d7cfa522ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 117 additions and 113 deletions

View file

@ -8,9 +8,9 @@ import (
"strings"
"sync"
"github.com/gaissmai/bart"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
)
@ -19,129 +19,106 @@ import (
var dnsR *dnsRecords
var dnsServer *dns.Server
var dnsAddr string
var dnsSuffix string
type dnsRecords struct {
sync.RWMutex
l *logrus.Logger
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
myVpnAddrsTable *bart.Lite
l *logrus.Logger
dnsMap map[dns.Question][]dns.RR
}
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
func newDnsRecords(l *logrus.Logger) *dnsRecords {
return &dnsRecords{
l: l,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
myVpnAddrsTable: cs.myVpnAddrsTable,
l: l,
dnsMap: make(map[dns.Question][]dns.RR),
}
}
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
switch q {
case dns.TypeA:
if r, ok := d.dnsMap4[data]; ok {
return r
}
case dns.TypeAAAA:
if r, ok := d.dnsMap6[data]; ok {
return r
}
}
return netip.Addr{}
}
func (d *dnsRecords) QueryCert(data string) string {
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
}
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()
if q == nil {
return ""
}
b, err := q.Certificate.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
host = strings.ToLower(host)
func (d *dnsRecords) AddA(name string, addresses []netip.Addr) {
d.Lock()
defer d.Unlock()
haveV4 := false
haveV6 := false
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeA}
d.dnsMap[q] = nil
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
if addr.Is4() {
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
if err == nil {
d.dnsMap[q] = append(d.dnsMap[q], rr)
d.l.Debugf("DNS record added %s", rr.String())
}
}
}
}
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
func (d *dnsRecords) AddAAAA(name string, addresses []netip.Addr) {
d.Lock()
defer d.Unlock()
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeAAAA}
d.dnsMap[q] = nil
if b.IsLoopback() {
return true
for _, addr := range addresses {
if addr.Is6() {
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
if err == nil {
d.dnsMap[q] = append(d.dnsMap[q], rr)
d.l.Debugf("DNS record added %s", rr.String())
}
}
}
//if we found it in this table, it's good
return d.myVpnAddrsTable.Contains(b)
}
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
func (d *dnsRecords) AddPTR(name string, addresses []netip.Addr) {
d.Lock()
defer d.Unlock()
for _, addr := range addresses {
arpa, err := dns.ReverseAddr(addr.String())
if err == nil {
q := dns.Question{Name: arpa, Qclass: dns.ClassINET, Qtype: dns.TypePTR}
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", arpa, qType, name))
if err == nil {
d.dnsMap[q] = []dns.RR{rr}
d.l.Debugf("DNS record added %s", rr.String())
}
}
}
}
func (d *dnsRecords) AddTXT(name string, crt cert.Certificate) {
d.Lock()
defer d.Unlock()
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeTXT}
d.dnsMap[q] = nil
qType := dns.TypeToString[q.Qtype]
rr, err := dns.NewRR(fmt.Sprintf("%s %s \"Name: %v\" \"Networks: %v\" \"Groups: %v\" \"UnsafeNetworks: %v\"", name, qType, crt.Name(), crt.Networks(), crt.Groups(), crt.UnsafeNetworks()))
if err == nil {
d.dnsMap[q] = []dns.RR{rr}
d.l.Debugf("DNS record added %s", rr.String())
}
}
func (d *dnsRecords) Add(crt cert.Certificate, addresses []netip.Addr) {
host := dns.Fqdn(strings.ToLower(crt.Name() + dnsSuffix))
d.AddA(host, addresses)
d.AddAAAA(host, addresses)
d.AddPTR(host, addresses)
d.AddTXT(host, crt)
}
func (d *dnsRecords) parseQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
case dns.TypeA, dns.TypeAAAA, dns.TypePTR, dns.TypeTXT:
d.RLock()
if rr, ok := d.dnsMap[q]; ok {
m.Answer = append(m.Answer, rr...)
}
d.RUnlock()
}
}
@ -157,14 +134,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
switch r.Opcode {
case dns.OpcodeQuery:
d.parseQuery(m, w)
d.parseQuery(m)
}
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
func dnsMain(l *logrus.Logger, cs *CertState, c *config.C) func() {
dnsR = newDnsRecords(l)
dnsSuffix = getDnsSuffix(c)
// Add self to dns records
dnsR.Add(cs.GetDefaultCertificate(), cs.myVpnAddrs)
// attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest)
@ -187,6 +168,11 @@ func getDnsServerAddr(c *config.C) string {
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}
func getDnsSuffix(c *config.C) string {
suffix := strings.TrimSpace(c.GetString("lighthouse.dns.suffix", ""))
return suffix
}
func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}

View file

@ -12,27 +12,44 @@ import (
func TestParsequery(t *testing.T) {
l := logrus.New()
hostMap := &HostMap{}
ds := newDnsRecords(l, &CertState{}, hostMap)
ds := newDnsRecords(l)
addrs := []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
netip.MustParseAddr("1.2.3.5"),
netip.MustParseAddr("fd01::24"),
netip.MustParseAddr("fd01::25"),
}
ds.Add("test.com.com", addrs)
dnsSuffix = ".com"
crt := &dummyCert{
name: "test.com",
}
ds.Add(crt, addrs)
m := &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeA)
ds.parseQuery(m, nil)
m.SetQuestion("test.com.com.", dns.TypeA)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
assert.Equal(t, "1.2.3.5", m.Answer[1].(*dns.A).A.String())
m = &dns.Msg{}
m.SetQuestion("test.com.com", dns.TypeAAAA)
ds.parseQuery(m, nil)
m.SetQuestion("test.com.com.", dns.TypeAAAA)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
assert.Equal(t, "fd01::25", m.Answer[1].(*dns.AAAA).AAAA.String())
m = &dns.Msg{}
m.SetQuestion("4.3.2.1.in-addr.arpa.", dns.TypePTR)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
m = &dns.Msg{}
m.SetQuestion("4.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.d.f.ip6.arpa.", dns.TypePTR)
ds.parseQuery(m)
assert.NotNil(t, m.Answer)
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
}
func Test_getDnsServerAddr(t *testing.T) {

View file

@ -52,6 +52,7 @@ lighthouse:
# The DNS host defines the IP to bind the dns listener to. This also allows binding to the nebula node IP.
#host: 0.0.0.0
#port: 53
#suffix: ".nebula"
# interval is the number of seconds between updates from this node to a lighthouse.
# during updates, a node sends information about its current IP addresses to each node.
interval: 60

View file

@ -606,7 +606,7 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
if f.serveDns {
remoteCert := hostinfo.ConnectionState.peerCert
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
dnsR.Add(remoteCert.Certificate, hostinfo.vpnAddrs)
}
for _, addr := range hostinfo.vpnAddrs {
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)

View file

@ -287,7 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
dnsStart = dnsMain(l, pki.getCertState(), c)
}
return &Control{