From b1fda24567ce12e7b568900cbfd855ea0a766bd8 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 23:25:28 -0500 Subject: [PATCH] Add a way to set the network type on windows + tests --- .github/workflows/smoke-extra.yml | 23 ++ .github/workflows/smoke/smoke-windows.ps1 | 126 ++++++++ examples/config.yml | 10 + overlay/network_category_windows.go | 358 ++++++++++++++++++++++ overlay/network_category_windows_test.go | 109 +++++++ overlay/tun_windows.go | 38 ++- 6 files changed, 654 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/smoke/smoke-windows.ps1 create mode 100644 overlay/network_category_windows.go create mode 100644 overlay/network_category_windows_test.go diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index cca7678b..e33a5fa9 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -81,3 +81,26 @@ jobs: run: make smoke-vagrant/linux-386 timeout-minutes: 30 + + smoke-windows: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: Run windows smoke test + runs-on: windows-latest + steps: + + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + - name: build + run: make bin-windows + + - name: run smoke-windows + shell: pwsh + working-directory: ./.github/workflows/smoke + run: ./smoke-windows.ps1 + + timeout-minutes: 10 diff --git a/.github/workflows/smoke/smoke-windows.ps1 b/.github/workflows/smoke/smoke-windows.ps1 new file mode 100644 index 00000000..b1216c78 --- /dev/null +++ b/.github/workflows/smoke/smoke-windows.ps1 @@ -0,0 +1,126 @@ +#!/usr/bin/env pwsh +$ErrorActionPreference = 'Stop' + +$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.." +$Nebula = Join-Path $RepoRoot 'nebula.exe' +$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe' + +if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" } +if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" } + +$WorkDir = Join-Path $env:TEMP "nebula-smoke-windows" +if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir } +New-Item -ItemType Directory -Path $WorkDir | Out-Null + +$DevName = "nebula-smoke" +$CaCrt = Join-Path $WorkDir 'ca.crt' +$CaKey = Join-Path $WorkDir 'ca.key' +$HostCrt = Join-Path $WorkDir 'host.crt' +$HostKey = Join-Path $WorkDir 'host.key' + +& $NebulaCert ca -name "smoke-ca" -out-crt $CaCrt -out-key $CaKey +if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" } + +& $NebulaCert sign -name "smoke" -networks "192.168.241.1/24" -ca-crt $CaCrt -ca-key $CaKey -out-crt $HostCrt -out-key $HostKey +if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign failed (exit $LASTEXITCODE)" } + +function Write-Config { + param([string]$Category) + $cfg = Join-Path $WorkDir 'config.yml' + @" +pki: + ca: $CaCrt + cert: $HostCrt + key: $HostKey +static_host_map: {} +lighthouse: + am_lighthouse: true + interval: 60 + hosts: [] +listen: + host: 0.0.0.0 + port: 4242 +tun: + disabled: false + dev: $DevName + drop_local_broadcast: false + drop_multicast: false + tx_queue: 500 + mtu: 1300 + network_category: $Category +logging: + level: info + format: text +firewall: + outbound_action: drop + inbound_action: drop + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +"@ | Out-File -FilePath $cfg -Encoding utf8 + return $cfg +} + +function Test-Category { + param( + [Parameter(Mandatory)] [string]$ConfigValue, + [Parameter(Mandatory)] [string]$ExpectedCategory + ) + Write-Host "" + Write-Host "=== smoke: network_category=$ConfigValue (expecting $ExpectedCategory) ===" + + $cfg = Write-Config -Category $ConfigValue + $stdoutLog = Join-Path $WorkDir "nebula-$ConfigValue.out.log" + $stderrLog = Join-Path $WorkDir "nebula-$ConfigValue.err.log" + + $proc = Start-Process -FilePath $Nebula -ArgumentList @('-config', $cfg) ` + -PassThru -NoNewWindow ` + -RedirectStandardOutput $stdoutLog ` + -RedirectStandardError $stderrLog + + try { + $deadline = (Get-Date).AddSeconds(30) + $observed = $null + while ((Get-Date) -lt $deadline) { + if ($proc.HasExited) { + Get-Content $stdoutLog -ErrorAction SilentlyContinue | Out-Host + Get-Content $stderrLog -ErrorAction SilentlyContinue | Out-Host + throw "nebula exited prematurely (code $($proc.ExitCode))" + } + $netProfile = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue + if ($netProfile) { + $observed = "$($netProfile.NetworkCategory)" + if ($observed -ieq $ExpectedCategory) { + Write-Host "OK: $DevName NetworkCategory=$observed" + return + } + } + Start-Sleep -Milliseconds 500 + } + + Get-Content $stdoutLog -ErrorAction SilentlyContinue | Out-Host + Get-Content $stderrLog -ErrorAction SilentlyContinue | Out-Host + throw "expected NetworkCategory=$ExpectedCategory, observed='$observed' within 30s" + } + finally { + if (-not $proc.HasExited) { + Stop-Process -Id $proc.Id -Force -ErrorAction SilentlyContinue + $proc.WaitForExit(5000) | Out-Null + } + } +} + +Test-Category -ConfigValue 'private' -ExpectedCategory 'Private' +Test-Category -ConfigValue 'public' -ExpectedCategory 'Public' + +Write-Host "" +Write-Host "All smoke checks passed." diff --git a/examples/config.yml b/examples/config.yml index ac4810e6..364506c5 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -286,6 +286,16 @@ tun: # metric: 100 # install: true + # On Windows only, sets the network category of the nebula interface. Without this, Windows often + # leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more + # restrictive than you usually want for an overlay between trusted peers. Valid values: + # private - treat the nebula network as a private/trusted network (default) + # public - treat it as a public/untrusted network + # domain - treat it as a domain-authenticated network + # unset - leave whatever Windows decided alone + # Not reloadable. + #network_category: private + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false diff --git a/overlay/network_category_windows.go b/overlay/network_category_windows.go new file mode 100644 index 00000000..cbf87f00 --- /dev/null +++ b/overlay/network_category_windows.go @@ -0,0 +1,358 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "errors" + "fmt" + "log/slog" + "runtime" + "strings" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h. +type networkCategory int32 + +const ( + networkCategoryPublic networkCategory = 0 + networkCategoryPrivate networkCategory = 1 + networkCategoryDomainAuthenticated networkCategory = 2 +) + +func (c networkCategory) String() string { + switch c { + case networkCategoryPublic: + return "public" + case networkCategoryPrivate: + return "private" + case networkCategoryDomainAuthenticated: + return "domain" + } + return fmt.Sprintf("unknown(%d)", c) +} + +// parseNetworkCategory accepts the user-supplied tun.network_category. A +// second return of false means "leave the category alone". +func parseNetworkCategory(s string) (networkCategory, bool, error) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "", "unset": + return 0, false, nil + case "public": + return networkCategoryPublic, true, nil + case "private": + return networkCategoryPrivate, true, nil + case "domain", "domainauthenticated": + return networkCategoryDomainAuthenticated, true, nil + } + return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s) +} + +// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B} +var clsidNetworkListManager = windows.GUID{ + Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B} +var iidINetworkListManager = windows.GUID{ + Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves. +var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance") + +const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER | + windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER + +const ( + hrSFALSE = 0x00000001 + hrRPCEChangedMode = 0x80010106 +) + +type hresult uint32 + +func (h hresult) failed() bool { return int32(h) < 0 } +func (h hresult) String() string { + return fmt.Sprintf("HRESULT 0x%08x", uint32(h)) +} + +var errAdapterNotFound = errors.New("adapter not present in network connections enumeration") + +// Vtable layouts. Slot order must match the declaration order in netlistmgr.h. +// All NLM interfaces here derive from IDispatch, which derives from IUnknown. + +type iUnknownVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr +} + +type iDispatchVtbl struct { + iUnknownVtbl + GetTypeInfoCount uintptr + GetTypeInfo uintptr + GetIDsOfNames uintptr + Invoke uintptr +} + +type iNetworkListManagerVtbl struct { + iDispatchVtbl + GetNetworks uintptr + GetNetwork uintptr + GetNetworkConnections uintptr + GetNetworkConnection uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr +} + +type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl } + +func (n *iNetworkListManager) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) { + var enum *iEnumNetworkConnections + r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr) + } + return enum, nil +} + +type iEnumNetworkConnectionsVtbl struct { + iDispatchVtbl + NewEnum uintptr + Next uintptr + Skip uintptr + Reset uintptr + Clone uintptr +} + +type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl } + +func (e *iEnumNetworkConnections) Release() { + syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e))) +} + +// Next returns the next connection, or (nil, nil) at the end of the enumeration. +func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) { + var conn *iNetworkConnection + var fetched uint32 + r1, _, _ := syscall.SyscallN(e.Vtbl.Next, + uintptr(unsafe.Pointer(e)), 1, + uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr) + } + if fetched == 0 { + return nil, nil + } + return conn, nil +} + +type iNetworkConnectionVtbl struct { + iDispatchVtbl + GetNetwork uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetConnectionId uintptr + GetAdapterId uintptr + GetDomainType uintptr +} + +type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl } + +func (c *iNetworkConnection) Release() { + syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c))) +} + +func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) { + var g windows.GUID + r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)), + ) + if hr := hresult(r1); hr.failed() { + return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr) + } + return g, nil +} + +func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) { + var net *iNetwork + r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr) + } + return net, nil +} + +type iNetworkVtbl struct { + iDispatchVtbl + GetName uintptr + SetName uintptr + GetDescription uintptr + SetDescription uintptr + GetNetworkId uintptr + GetDomainType uintptr + GetNetworkConnections uintptr + GetTimeCreatedAndConnected uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetCategory uintptr + SetCategory uintptr +} + +type iNetwork struct{ Vtbl *iNetworkVtbl } + +func (n *iNetwork) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetwork) GetCategory() (networkCategory, error) { + var c networkCategory + r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)), + ) + if hr := hresult(r1); hr.failed() { + return 0, fmt.Errorf("INetwork.GetCategory: %s", hr) + } + return c, nil +} + +func (n *iNetwork) SetCategory(c networkCategory) error { + r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory, + uintptr(unsafe.Pointer(n)), uintptr(int32(c)), + ) + if hr := hresult(r1); hr.failed() { + return fmt.Errorf("INetwork.SetCategory: %s", hr) + } + return nil +} + +// coInit initializes COM for the current OS thread. The returned function must +// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is +// already initialized in a different mode on this thread, which is still fine +// for our calls but we must not Uninitialize in that case. +func coInit() (func(), error) { + err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED) + if err == nil { + return windows.CoUninitialize, nil + } + if e, ok := err.(syscall.Errno); ok { + switch uint32(e) { + case hrSFALSE: + return windows.CoUninitialize, nil + case hrRPCEChangedMode: + return func() {}, nil + } + } + return nil, fmt.Errorf("CoInitializeEx: %w", err) +} + +func createNetworkListManager() (*iNetworkListManager, error) { + var nlm *iNetworkListManager + r1, _, _ := procCoCreateInstance.Call( + uintptr(unsafe.Pointer(&clsidNetworkListManager)), + 0, + uintptr(clsCtxAll), + uintptr(unsafe.Pointer(&iidINetworkListManager)), + uintptr(unsafe.Pointer(&nlm)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr) + } + return nlm, nil +} + +// setNetworkCategory locates the network connection bound to adapterGUID and +// sets the category of its parent network. Returns errAdapterNotFound if the +// adapter is not yet visible in the NLM enumeration. +func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error { + deinit, err := coInit() + if err != nil { + return err + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + return err + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + return err + } + defer enum.Release() + + for { + conn, err := enum.Next() + if err != nil { + return err + } + if conn == nil { + return errAdapterNotFound + } + + guid, err := conn.GetAdapterId() + if err != nil || guid != adapterGUID { + conn.Release() + continue + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + return err + } + err = net.SetCategory(cat) + net.Release() + return err + } +} + +// applyNetworkCategory polls until the wintun adapter shows up in the NLM +// enumeration, then sets the category. Intended to run in its own goroutine. +func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) { + // COM Init/Uninit must be paired on the same OS thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + const ( + attempts = 30 + interval = 500 * time.Millisecond + ) + for i := 0; i < attempts; i++ { + err := setNetworkCategory(adapterGUID, cat) + if err == nil { + l.Info("Set Windows network category", "category", cat.String()) + return + } + if !errors.Is(err, errAdapterNotFound) { + l.Warn("Failed to set Windows network category", "error", err, "category", cat.String()) + return + } + time.Sleep(interval) + } + l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set", + "category", cat.String(), + "waited", time.Duration(attempts)*interval, + ) +} diff --git a/overlay/network_category_windows_test.go b/overlay/network_category_windows_test.go new file mode 100644 index 00000000..c679f8c4 --- /dev/null +++ b/overlay/network_category_windows_test.go @@ -0,0 +1,109 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "testing" +) + +func Test_parseNetworkCategory(t *testing.T) { + cases := []struct { + in string + wantCat networkCategory + wantApply bool + wantErr bool + }{ + {"", 0, false, false}, + {"unset", 0, false, false}, + {" UNSET ", 0, false, false}, + {"private", networkCategoryPrivate, true, false}, + {"Private", networkCategoryPrivate, true, false}, + {" PRIVATE ", networkCategoryPrivate, true, false}, + {"public", networkCategoryPublic, true, false}, + {"PUBLIC", networkCategoryPublic, true, false}, + {"domain", networkCategoryDomainAuthenticated, true, false}, + {"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false}, + {"garbage", 0, false, true}, + {"privates", 0, false, true}, + } + for _, tc := range cases { + cat, apply, err := parseNetworkCategory(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr) + continue + } + if cat != tc.wantCat || apply != tc.wantApply { + t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply) + } + } +} + +// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory +// without mutating the host's network state. It validates the CLSID/IID +// constants and every vtable index by enumerating connections, fetching the +// adapter id and parent network, reading the current category, and writing it +// back unchanged. +// +// Requires Windows but does not require admin or the wintun driver. Skips if +// no network connections are available (unlikely outside of an isolated +// container). +func Test_NLM_round_trip(t *testing.T) { + deinit, err := coInit() + if err != nil { + t.Fatalf("coInit: %v", err) + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + t.Fatalf("createNetworkListManager: %v", err) + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + t.Fatalf("GetNetworkConnections: %v", err) + } + defer enum.Release() + + saw := 0 + for { + conn, err := enum.Next() + if err != nil { + t.Fatalf("EnumNetworkConnections.Next: %v", err) + } + if conn == nil { + break + } + saw++ + + if _, err := conn.GetAdapterId(); err != nil { + conn.Release() + t.Fatalf("INetworkConnection.GetAdapterId: %v", err) + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + t.Fatalf("INetworkConnection.GetNetwork: %v", err) + } + + cat, err := net.GetCategory() + if err != nil { + net.Release() + t.Fatalf("INetwork.GetCategory: %v", err) + } + // Set to the current value so the host's NLM state is unchanged but + // SetCategory's vtable slot is still validated end-to-end. + if err := net.SetCategory(cat); err != nil { + net.Release() + t.Fatalf("INetwork.SetCategory(%v): %v", cat, err) + } + net.Release() + } + + if saw == 0 { + t.Skip("no NLM network connections available; skipping round-trip") + } +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 14c8d499..703ed48e 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -28,12 +28,15 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *slog.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + guid windows.GUID + networkCategory networkCategory + setCategory bool + l *slog.Logger tun *wintun.NativeTun } @@ -54,11 +57,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w return nil, fmt.Errorf("generate GUID failed: %w", err) } + cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private")) + if err != nil { + return nil, err + } + t := &winTun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + guid: *guid, + networkCategory: cat, + setCategory: setCat, + l: l, } err = t.reload(c, true) @@ -142,6 +153,13 @@ func (t *winTun) Activate() error { return err } + if t.setCategory { + // The wintun adapter takes a moment to register with the Network List + // Manager, so we apply the category in the background and retry until + // it shows up. + go applyNetworkCategory(t.l, t.guid, t.networkCategory) + } + return nil }