diff --git a/dns_server.go b/dns_server.go index ce146f52..1a6701be 100644 --- a/dns_server.go +++ b/dns_server.go @@ -34,9 +34,7 @@ func newDnsRecords(l *logrus.Logger) *dnsRecords { } } -func (d *dnsRecords) AddA(name string, addresses []netip.Addr) { - d.Lock() - defer d.Unlock() +func (d *dnsRecords) addA(name string, addresses []netip.Addr) { q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeA} d.dnsMap[q] = nil @@ -50,11 +48,9 @@ func (d *dnsRecords) AddA(name string, addresses []netip.Addr) { } } } -} +} -func (d *dnsRecords) AddAAAA(name string, addresses []netip.Addr) { - d.Lock() - defer d.Unlock() +func (d *dnsRecords) addAAAA(name string, addresses []netip.Addr) { q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeAAAA} d.dnsMap[q] = nil @@ -68,12 +64,9 @@ func (d *dnsRecords) AddAAAA(name string, addresses []netip.Addr) { } } } -} - -func (d *dnsRecords) AddPTR(name string, addresses []netip.Addr) { - d.Lock() - defer d.Unlock() +} +func (d *dnsRecords) addPTR(name string, addresses []netip.Addr) { for _, addr := range addresses { arpa, err := dns.ReverseAddr(addr.String()) if err == nil { @@ -86,11 +79,9 @@ func (d *dnsRecords) AddPTR(name string, addresses []netip.Addr) { } } } -} +} -func (d *dnsRecords) AddTXT(name string, crt cert.Certificate) { - d.Lock() - defer d.Unlock() +func (d *dnsRecords) addTXT(name string, crt cert.Certificate) { q := dns.Question{Name: name, Qclass: dns.ClassINET, Qtype: dns.TypeTXT} d.dnsMap[q] = nil @@ -100,14 +91,18 @@ func (d *dnsRecords) AddTXT(name string, crt cert.Certificate) { d.dnsMap[q] = []dns.RR{rr} d.l.Debugf("DNS record added %s", rr.String()) } -} +} func (d *dnsRecords) Add(crt cert.Certificate, addresses []netip.Addr) { host := dns.Fqdn(strings.ToLower(crt.Name() + dnsSuffix)) - d.AddA(host, addresses) - d.AddAAAA(host, addresses) - d.AddPTR(host, addresses) - d.AddTXT(host, crt) + + d.Lock() + defer d.Unlock() + + d.addA(host, addresses) + d.addAAAA(host, addresses) + d.addPTR(host, addresses) + d.addTXT(host, crt) } func (d *dnsRecords) parseQuery(m *dns.Msg) {