Rough draft of what reworking this might look like.

This commit is contained in:
Henry Graham 2025-10-06 19:59:22 -05:00
parent b7726b8a70
commit 49eeee7f8c

148
info.go
View file

@ -2,17 +2,29 @@ package nebula
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"log"
"net"
"net/netip"
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
)
func handleHostmapList(l *logrus.Logger, hm *HostMap, w http.ResponseWriter, r *http.Request) {
// TODO Firm up how we return errors and accept messages + data
type Message struct {
Command string `json:"command"`
Data string `json:"data"`
}
type ErrorResponse struct {
Error string `json:"error"`
}
func handleHostmapList(l *logrus.Logger, hm *HostMap) ([]byte, error) {
type HostListItem struct {
VpnAddrs []netip.Addr `json:"vpnAddrs"`
//Remote netip.AddrPort `json:"remote"`
@ -20,7 +32,6 @@ func handleHostmapList(l *logrus.Logger, hm *HostMap, w http.ResponseWriter, r *
LastHandshakeTime time.Time `json:"lastHandshakeTime"`
Groups []string `json:"groups"`
}
out := map[string]HostListItem{}
hm.ForEachVpnAddr(func(hi *HostInfo) {
cert := hi.GetCert().Certificate
@ -32,94 +43,111 @@ func handleHostmapList(l *logrus.Logger, hm *HostMap, w http.ResponseWriter, r *
Groups: cert.Groups(),
}
})
w.Header().Set("Content-Type", "application/json")
js := json.NewEncoder(w)
err := js.Encode(out)
js, err := json.Marshal(out)
if err != nil {
http.Error(w, "json error: "+err.Error(), http.StatusInternalServerError)
return
return nil, fmt.Errorf("json error: %w", err)
}
return js, nil
}
func handleHostCertLookup(l *logrus.Logger, hm *HostMap, w http.ResponseWriter, r *http.Request) {
ipStr := r.PathValue("ipStr")
func handleHostCertLookup(l *logrus.Logger, hm *HostMap, msg *Message) ([]byte, error) {
ipStr := msg.Data //TODO how do we want to structure this? What if we expand to more ssh commands?
if ipStr == "" {
http.Error(w, "you must provide an IP address", http.StatusNotFound)
return
return nil, fmt.Errorf("you must provide an IP address")
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
//todo filter non-Nebula IPs?
http.Error(w, fmt.Sprintf("Invalid IP address: %s", ipStr), http.StatusBadRequest)
return
return nil, fmt.Errorf("invalid IP address: %s", ipStr)
}
hi := hm.QueryVpnAddr(addr)
if hi == nil {
http.Error(w, "IP address not found", http.StatusNotFound)
return
return nil, fmt.Errorf("ip address not found: %s", ipStr)
} else if hi.ConnectionState == nil {
http.Error(w, "Host not connected", http.StatusNotFound)
return
return nil, fmt.Errorf("host not connected: %s", ipStr)
}
out, err := hi.ConnectionState.peerCert.Certificate.MarshalJSON()
if err != nil {
l.WithError(err).Error("failed to marshal peer certificate")
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
return nil, fmt.Errorf("failed to marshal peer certificate: %w", err)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(out)
return out, nil
}
func setupInfoServer(l *logrus.Logger, hm *HostMap) *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("GET /hostmap", func(w http.ResponseWriter, r *http.Request) { handleHostmapList(l, hm, w, r) })
mux.HandleFunc("GET /host/{ipStr}", func(w http.ResponseWriter, r *http.Request) { handleHostCertLookup(l, hm, w, r) })
return mux
}
// startInfo stands up a REST API that serves information about what Nebula is doing to other services
// Right now, this is just hostmap info,
func startInfo(l *logrus.Logger, c *config.C, configTest bool, hm *HostMap) (func(), error) {
listen := c.GetString("info.listen", "")
if listen == "" {
return nil, nil
}
addrPort, err := netip.ParseAddrPort(listen)
if err != nil {
return nil, fmt.Errorf("failed to parse info.listen address: %w", err)
}
if err = shouldAllowBinding(addrPort.Addr()); err != nil {
l.WithError(err).Warn("Specified info.listen address is not private") // TODO phrasing, what if we add non-nebula-ip check?
}
listenAddr := c.GetString("info.listen", "")
var startFn func()
if configTest {
//TODO validate that lisstenAddr is an acceptable value as part of the config test
return startFn, nil
}
if err := os.RemoveAll(listenAddr); err != nil {
l.WithError(err).Fatal("failed to remove unix socket")
}
startFn = func() {
mux := setupInfoServer(l, hm)
l.WithField("bind", listen).Info("Info listener starting")
err := http.ListenAndServe(listen, mux)
if errors.Is(err, http.ErrServerClosed) {
return
}
listener, err := net.Listen("unix", listenAddr)
if err != nil {
l.Fatal(err)
log.Fatalf("Failed to listen on unix socket: %v", err)
}
defer listener.Close()
defer os.Remove(listenAddr)
l.WithField("bind", listenAddr).Info("Info listener starting")
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept connection: %v", err)
continue
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024) // Arbitrary
n, err := c.Read(buf)
if err != nil {
l.WithError(err).Error("Failed to read from connection")
return
}
var msg Message
if err := json.Unmarshal(buf[:n], &msg); err != nil {
l.WithError(err).Error("Failed to unmarshal JSON")
return
}
l.WithField("command", msg.Command).WithField("Data", msg.Data).Debug("Received Command")
err = handleCommand(l, c, hm, &msg)
if err != nil {
l.WithError(err).Error("Failed to handle command")
out, err := json.Marshal(ErrorResponse{Error: err.Error()})
if err != nil {
l.WithError(err).Error("Failed to marshal error response")
return
}
c.Write(out)
return
}
}(conn)
}
}
return startFn, nil
}
// https://github.com/slackhq/nebula/pull/1457#issuecomment-3275781278
// > Refusing to bind to (non-localhost || non-nebula-ip) feels right to me
// If in the future we want to check for a non-nebula-ip we can add that check in here
func shouldAllowBinding(listen netip.Addr) error {
if !listen.IsLoopback() {
return fmt.Errorf("info.listen is not a loopback address: %s", listen.String())
// maybe we can add more of the supported SSH commands here?
func handleCommand(l *logrus.Logger, c net.Conn, hm *HostMap, msg *Message) error {
switch msg.Command {
case "ping": // TODO remove test command
c.Write([]byte("pong\n"))
case "hostmap":
out, err := handleHostmapList(l, hm)
if err != nil {
return err
}
c.Write(out)
case "hostinfo":
out, err := handleHostCertLookup(l, hm, msg)
if err != nil {
return err
}
c.Write(out)
default:
c.Write([]byte("unknown command\n"))
}
return nil
}