mirror of
https://github.com/slackhq/nebula.git
synced 2025-12-06 02:30:57 -08:00
Add dns suffix an reverse dns lookup
This commit is contained in:
parent
061e733007
commit
47ba7a5108
4 changed files with 85 additions and 100 deletions
155
dns_server.go
155
dns_server.go
|
|
@ -8,7 +8,6 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
|
|
@ -19,80 +18,77 @@ import (
|
|||
var dnsR *dnsRecords
|
||||
var dnsServer *dns.Server
|
||||
var dnsAddr string
|
||||
var dnsSuffix string
|
||||
|
||||
type dnsRecords struct {
|
||||
sync.RWMutex
|
||||
l *logrus.Logger
|
||||
dnsMap4 map[string]netip.Addr
|
||||
dnsMap6 map[string]netip.Addr
|
||||
hostMap *HostMap
|
||||
myVpnAddrsTable *bart.Lite
|
||||
l *logrus.Logger
|
||||
dnsMap map[dns.Question]dns.RR
|
||||
}
|
||||
|
||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
||||
func newDnsRecords(l *logrus.Logger) *dnsRecords {
|
||||
return &dnsRecords{
|
||||
l: l,
|
||||
dnsMap4: make(map[string]netip.Addr),
|
||||
dnsMap6: make(map[string]netip.Addr),
|
||||
hostMap: hostMap,
|
||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||
l: l,
|
||||
dnsMap: make(map[dns.Question]dns.RR),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
||||
data = strings.ToLower(data)
|
||||
d.RLock()
|
||||
defer d.RUnlock()
|
||||
switch q {
|
||||
case dns.TypeA:
|
||||
if r, ok := d.dnsMap4[data]; ok {
|
||||
return r
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if r, ok := d.dnsMap6[data]; ok {
|
||||
return r
|
||||
func (d *dnsRecords) AddA(name string, addr netip.Addr) {
|
||||
if addr.Is4() {
|
||||
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeA}
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
|
||||
if err == nil {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.dnsMap[q] = rr
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) QueryCert(data string) string {
|
||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||
if err != nil {
|
||||
return ""
|
||||
func (d *dnsRecords) AddAAAA(name string, addr netip.Addr) {
|
||||
if addr.Is6() {
|
||||
q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeAAAA}
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", name, qType, addr.String()))
|
||||
if err == nil {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.dnsMap[q] = rr
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hostinfo := d.hostMap.QueryVpnAddr(ip)
|
||||
if hostinfo == nil {
|
||||
return ""
|
||||
func (d *dnsRecords) AddPTR(name string, addr netip.Addr) {
|
||||
arpa, err := dns.ReverseAddr(addr.String())
|
||||
if err == nil {
|
||||
q := dns.Question{Name: arpa, Qclass: dns.ClassINET, Qtype: dns.TypePTR}
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", arpa, qType, name))
|
||||
if err == nil {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
d.dnsMap[q] = rr
|
||||
d.l.Debugf("DNS record added %s", rr.String())
|
||||
}
|
||||
}
|
||||
|
||||
q := hostinfo.GetCert()
|
||||
if q == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := q.Certificate.MarshalJSON()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
||||
host = strings.ToLower(host)
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
func (d *dnsRecords) Add(name string, addresses []netip.Addr) {
|
||||
host := dns.Fqdn(strings.ToLower(name + dnsSuffix))
|
||||
haveV4 := false
|
||||
haveV6 := false
|
||||
for _, addr := range addresses {
|
||||
if addr.Is4() && !haveV4 {
|
||||
d.dnsMap4[host] = addr
|
||||
d.AddA(host, addr)
|
||||
d.AddPTR(host, addr)
|
||||
haveV4 = true
|
||||
} else if addr.Is6() && !haveV6 {
|
||||
d.dnsMap6[host] = addr
|
||||
d.AddAAAA(host, addr)
|
||||
d.AddPTR(host, addr)
|
||||
haveV6 = true
|
||||
}
|
||||
if haveV4 && haveV6 {
|
||||
|
|
@ -101,47 +97,15 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
||||
a, _, _ := net.SplitHostPort(addr)
|
||||
b, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if b.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
//if we found it in this table, it's good
|
||||
return d.myVpnAddrsTable.Contains(b)
|
||||
}
|
||||
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||
func (d *dnsRecords) parseQuery(m *dns.Msg) {
|
||||
for _, q := range m.Question {
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
qType := dns.TypeToString[q.Qtype]
|
||||
d.l.Debugf("Query for %s %s", qType, q.Name)
|
||||
ip := d.Query(q.Qtype, q.Name)
|
||||
if ip.IsValid() {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
// We only answer these queries from nebula nodes or localhost
|
||||
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
|
||||
return
|
||||
}
|
||||
d.l.Debugf("Query for TXT %s", q.Name)
|
||||
ip := d.QueryCert(q.Name)
|
||||
if ip != "" {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
|
||||
if err == nil {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypePTR:
|
||||
d.RLock()
|
||||
if rr, ok := d.dnsMap[q]; ok {
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
d.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -157,14 +121,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||
|
||||
switch r.Opcode {
|
||||
case dns.OpcodeQuery:
|
||||
d.parseQuery(m, w)
|
||||
d.parseQuery(m)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l, cs, hostMap)
|
||||
func dnsMain(l *logrus.Logger, cs *CertState, c *config.C) func() {
|
||||
dnsR = newDnsRecords(l)
|
||||
dnsSuffix = getDnsSuffix(c)
|
||||
|
||||
// Add self to dns records
|
||||
dnsR.Add(cs.GetDefaultCertificate().Name(), cs.myVpnAddrs)
|
||||
|
||||
// attach request handler func
|
||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
||||
|
|
@ -187,6 +155,11 @@ func getDnsServerAddr(c *config.C) string {
|
|||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||
}
|
||||
|
||||
func getDnsSuffix(c *config.C) string {
|
||||
suffix := strings.TrimSpace(c.GetString("lighthouse.dns.suffix", ""))
|
||||
return suffix
|
||||
}
|
||||
|
||||
func startDns(l *logrus.Logger, c *config.C) {
|
||||
dnsAddr = getDnsServerAddr(c)
|
||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
||||
|
|
|
|||
|
|
@ -12,27 +12,39 @@ import (
|
|||
|
||||
func TestParsequery(t *testing.T) {
|
||||
l := logrus.New()
|
||||
hostMap := &HostMap{}
|
||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
||||
ds := newDnsRecords(l)
|
||||
addrs := []netip.Addr{
|
||||
netip.MustParseAddr("1.2.3.4"),
|
||||
netip.MustParseAddr("1.2.3.5"),
|
||||
netip.MustParseAddr("fd01::24"),
|
||||
netip.MustParseAddr("fd01::25"),
|
||||
}
|
||||
ds.Add("test.com.com", addrs)
|
||||
dnsSuffix = ".com"
|
||||
ds.Add("test.com", addrs)
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeA)
|
||||
ds.parseQuery(m, nil)
|
||||
m.SetQuestion("test.com.com.", dns.TypeA)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("test.com.com", dns.TypeAAAA)
|
||||
ds.parseQuery(m, nil)
|
||||
m.SetQuestion("test.com.com.", dns.TypeAAAA)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("4.3.2.1.in-addr.arpa.", dns.TypePTR)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetQuestion("4.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.d.f.ip6.arpa.", dns.TypePTR)
|
||||
ds.parseQuery(m)
|
||||
assert.NotNil(t, m.Answer)
|
||||
assert.Equal(t, "test.com.com.", m.Answer[0].(*dns.PTR).Ptr)
|
||||
}
|
||||
|
||||
func Test_getDnsServerAddr(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -565,7 +565,7 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
|
|||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||
if f.serveDns {
|
||||
remoteCert := hostinfo.ConnectionState.peerCert
|
||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||
dnsR.Add(remoteCert.Certificate.Name(), hostinfo.vpnAddrs)
|
||||
}
|
||||
for _, addr := range hostinfo.vpnAddrs {
|
||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||
|
|
|
|||
2
main.go
2
main.go
|
|
@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
var dnsStart func()
|
||||
if lightHouse.amLighthouse && serveDns {
|
||||
l.Debugln("Starting dns server")
|
||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
||||
dnsStart = dnsMain(l, pki.getCertState(), c)
|
||||
}
|
||||
|
||||
return &Control{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue