diff --git a/firewall.go b/firewall.go index dddc1a0e..41a9e6e4 100644 --- a/firewall.go +++ b/firewall.go @@ -23,6 +23,8 @@ import ( "github.com/slackhq/nebula/firewall" ) +var ErrCannotSNAT = errors.New("cannot snat this packet") + type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } @@ -463,7 +465,7 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr return c.snat.SrcVpnIp } -func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { +func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error { if c.snat.Valid() { //old flow fp.RemoteAddr = f.snatAddr @@ -475,6 +477,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo fp.RemoteAddr = f.snatAddr //find a new port to use, if needed + numPortsChecked := 0 for { existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now if existingFlow == nil { @@ -482,18 +485,18 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo } //increment and retry. There's probably better strategies out there fp.RemotePort++ + numPortsChecked++ if fp.RemotePort < 0x7ff { fp.RemotePort += 0x7ff // keep it ephemeral for now - } //todo if we're totally out of ports this loops forever. Probably not good. + } //without this, if we're totally out of ports, this would loop forever. + if numPortsChecked >= 0x7ff { + return ErrCannotSNAT + } } + f.dupeConn(*fp, c) //track the snatted flow with the same expiration as the unsnatted version c.snat.SnatPort = fp.RemotePort } else { - f.l.WithFields(logrus.Fields{ - "fp": *fp, - "conn": *c, - "hostinfo": hostinfo, - }).Error("this packet cannot be snatted") - return + return ErrCannotSNAT } newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) @@ -515,7 +518,10 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo //src port is at offset 0 binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) recalcTCPv4Checksum(data, c.snat.Src, newIP) + default: + return ErrCannotSNAT } + return nil } // Drop returns an error if the packet should be dropped, explaining why. It @@ -547,7 +553,7 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn if h.vpnAddrs[0] != fp.RemoteAddr { if !specialSnatMode { f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP //todo! + return ErrInvalidRemoteIP } //else we're in special snat mode, and we need to apply more checks below } //else? all good, fall through } else { @@ -594,8 +600,10 @@ snat: if incoming { if specialSnatMode { //todo do not snat if you are not a router for the destination -- for now, just if you're not a router - f.applySnat(pkt, &fp, c, h) - f.dupeConn(fp, c) //track the snatted flow with the same expiration as the unsnatted version + err := f.applySnat(pkt, &fp, c, h) + if err != nil { + return err + } } } //outgoing snat is handled before this function is called (for now!)