mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-06 02:30:57 -08:00
Merge 1e82d35b00 into 59e24b98bd
This commit is contained in:
commit
d7cfa522ad
5 changed files with 117 additions and 113 deletions
194
dns_server.go
194
dns_server.go
|
|
@ -8,9 +8,9 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
)
|
||||
|
||||
|
|
@ -19,129 +19,106 @@ import (
|
|||
var dnsR *dnsRecords
|
||||
var dnsServer *dns.Server
|
||||
var dnsAddr string
|
||||
var dnsSuffix string
|
||||
|
||||
type dnsRecords struct {
|
||||
sync.RWMutex
|
||||
l *logrus.Logger
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Lite
|
||||
l *logrus.Logger
|
||||
dnsMap map[dns.Question][]dns.RR
|
||||
}
|
||||
|
||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||
func newDnsRecords(l *logrus.Logger) *dnsRecords {
|
||||
return &dnsRecords{
|
||||
l: l,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
l: l,
|
||||
dnsMap: make(map[dns.Question][]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
switch q {
|
||||
case dns.TypeA:
|
||||
if r, ok := d.dnsMap4[data]; ok {
|
||||
return r
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if r, ok := d.dnsMap6[data]; ok {
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
q := hostinfo.GetCert()
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := q.Certificate.MarshalJSON()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
||||
host = strings.ToLower(host)
|
||||
func (d *dnsRecords) AddA(name string, addresses []netip.Addr) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
haveV4 := false
|
||||
haveV6 := false
|
||||
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeA}
|
||||
d.dnsMap[q] = nil
|
||||
|
||||
for _, addr := range addresses {
|
||||
if addr.Is4() && !haveV4 {
|
||||
d.dnsMap4[host] = addr
|
||||
haveV4 = true
|
||||
} else if addr.Is6() && !haveV6 {
|
||||
d.dnsMap6[host] = addr
|
||||
haveV6 = true
|
||||
}
|
||||
if haveV4 && haveV6 {
|
||||
break
|
||||
if addr.Is4() {
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
|
||||
if err == nil {
|
||||
d.dnsMap[q] = append(d.dnsMap[q], rr)
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
a, _, _ := net.SplitHostPort(addr)
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
func (d *dnsRecords) AddAAAA(name string, addresses []netip.Addr) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeAAAA}
|
||||
d.dnsMap[q] = nil
|
||||
|
||||
if b.IsLoopback() {
|
||||
return true
|
||||
for _, addr := range addresses {
|
||||
if addr.Is6() {
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
|
||||
if err == nil {
|
||||
d.dnsMap[q] = append(d.dnsMap[q], rr)
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//if we found it in this table, it's good
|
||||
return d.myVpnAddrsTable.Contains(b)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
func (d *dnsRecords) AddPTR(name string, addresses []netip.Addr) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
|
||||
for _, addr := range addresses {
|
||||
arpa, err := dns.ReverseAddr(addr.String())
|
||||
if err == nil {
|
||||
q := dns.Question{Name: arpa, Qclass: dns.ClassINET, Qtype: dns.TypePTR}
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", arpa, qType, name))
|
||||
if err == nil {
|
||||
d.dnsMap[q] = []dns.RR{rr}
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) AddTXT(name string, crt cert.Certificate) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeTXT}
|
||||
d.dnsMap[q] = nil
|
||||
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s \"Name: %v\" \"Networks: %v\" \"Groups: %v\" \"UnsafeNetworks: %v\"", name, qType, crt.Name(), crt.Networks(), crt.Groups(), crt.UnsafeNetworks()))
|
||||
if err == nil {
|
||||
d.dnsMap[q] = []dns.RR{rr}
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Add(crt cert.Certificate, addresses []netip.Addr) {
|
||||
host := dns.Fqdn(strings.ToLower(crt.Name() + dnsSuffix))
|
||||
d.AddA(host, addresses)
|
||||
d.AddAAAA(host, addresses)
|
||||
d.AddPTR(host, addresses)
|
||||
d.AddTXT(host, crt)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg) {
|
||||
for _, q := range m.Question {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
// We only answer these queries from nebula nodes or localhost
|
||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||
return
|
||||
}
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR, dns.TypeTXT:
|
||||
d.RLock()
|
||||
if rr, ok := d.dnsMap[q]; ok {
|
||||
m.Answer = append(m.Answer, rr...)
|
||||
}
|
||||
d.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -157,14 +134,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
d.parseQuery(m, w)
|
||||
d.parseQuery(m)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l)
|
||||
dnsSuffix = getDnsSuffix(c)
|
||||
|
||||
// Add self to dns records
|
||||
dnsR.Add(cs.GetDefaultCertificate(), cs.myVpnAddrs)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
|
@ -187,6 +168,11 @@ func getDnsServerAddr(c *config.C) string {
|
|||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||
}
|
||||
|
||||
func getDnsSuffix(c *config.C) string {
|
||||
suffix := strings.TrimSpace(c.GetString("lighthouse.dns.suffix", ""))
|
||||
return suffix
|
||||
}
|
||||
|
||||
func startDns(l *logrus.Logger, c *config.C) {
|
||||
dnsAddr = getDnsServerAddr(c)
|
||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||
|
|
|
|||
|
|
@ -12,27 +12,44 @@ import (
|
|||
|
||||
func TestParsequery(t *testing.T) {
|
||||
l := logrus.New()
|
||||
hostMap := &HostMap{}
|
||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
||||
ds := newDnsRecords(l)
|
||||
addrs := []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
netip.MustParseAddr("1.2.3.5"),
|
||||
netip.MustParseAddr("fd01::24"),
|
||||
netip.MustParseAddr("fd01::25"),
|
||||
}
|
||||
ds.Add("test.com.com", addrs)
|
||||
dnsSuffix = ".com"
|
||||
crt := &dummyCert{
|
||||
name: "test.com",
|
||||
}
|
||||
ds.Add(crt, addrs)
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeA)
|
||||
ds.parseQuery(m, nil)
|
||||
m.SetQuestion("test.com.com.", dns.TypeA)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||
assert.Equal(t, "1.2.3.5", m.Answer[1].(*dns.A).A.String())
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||
ds.parseQuery(m, nil)
|
||||
m.SetQuestion("test.com.com.", dns.TypeAAAA)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||
assert.Equal(t, "fd01::25", m.Answer[1].(*dns.AAAA).AAAA.String())
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("4.3.2.1.in-addr.arpa.", dns.TypePTR)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("4.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.d.f.ip6.arpa.", dns.TypePTR)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
|
||||
func Test_getDnsServerAddr(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ lighthouse:
|
|||
# The DNS host defines the IP to bind the dns listener to. This also allows binding to the nebula node IP.
|
||||
#host: 0.0.0.0
|
||||
#port: 53
|
||||
#suffix: ".nebula"
|
||||
# interval is the number of seconds between updates from this node to a lighthouse.
|
||||
# during updates, a node sends information about its current IP addresses to each node.
|
||||
interval: 60
|
||||
|
|
|
|||
|
|
@ -606,7 +606,7 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
|
|||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
if f.serveDns {
|
||||
remoteCert := hostinfo.ConnectionState.peerCert
|
||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||
dnsR.Add(remoteCert.Certificate, hostinfo.vpnAddrs)
|
||||
}
|
||||
for _, addr := range hostinfo.vpnAddrs {
|
||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||
|
|
|
|||
2
main.go
2
main.go
|
|
@ -287,7 +287,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
var dnsStart func()
|
||||
if lightHouse.amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||
dnsStart = dnsMain(l, pki.getCertState(), c)
|
||||
}
|
||||
|
||||
return &Control{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue