mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 14:12:43 -07:00
Add a way to set the network type on windows + tests
This commit is contained in:
parent
c82db210ef
commit
b1fda24567
6 changed files with 654 additions and 10 deletions
23
.github/workflows/smoke-extra.yml
vendored
23
.github/workflows/smoke-extra.yml
vendored
|
|
@ -81,3 +81,26 @@ jobs:
|
|||
run: make smoke-vagrant/linux-386
|
||||
|
||||
timeout-minutes: 30
|
||||
|
||||
smoke-windows:
|
||||
if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
|
||||
name: Run windows smoke test
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.25'
|
||||
check-latest: true
|
||||
|
||||
- name: build
|
||||
run: make bin-windows
|
||||
|
||||
- name: run smoke-windows
|
||||
shell: pwsh
|
||||
working-directory: ./.github/workflows/smoke
|
||||
run: ./smoke-windows.ps1
|
||||
|
||||
timeout-minutes: 10
|
||||
|
|
|
|||
126
.github/workflows/smoke/smoke-windows.ps1
vendored
Normal file
126
.github/workflows/smoke/smoke-windows.ps1
vendored
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
#!/usr/bin/env pwsh
|
||||
$ErrorActionPreference = 'Stop'
|
||||
|
||||
$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.."
|
||||
$Nebula = Join-Path $RepoRoot 'nebula.exe'
|
||||
$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe'
|
||||
|
||||
if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" }
|
||||
if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" }
|
||||
|
||||
$WorkDir = Join-Path $env:TEMP "nebula-smoke-windows"
|
||||
if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir }
|
||||
New-Item -ItemType Directory -Path $WorkDir | Out-Null
|
||||
|
||||
$DevName = "nebula-smoke"
|
||||
$CaCrt = Join-Path $WorkDir 'ca.crt'
|
||||
$CaKey = Join-Path $WorkDir 'ca.key'
|
||||
$HostCrt = Join-Path $WorkDir 'host.crt'
|
||||
$HostKey = Join-Path $WorkDir 'host.key'
|
||||
|
||||
& $NebulaCert ca -name "smoke-ca" -out-crt $CaCrt -out-key $CaKey
|
||||
if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" }
|
||||
|
||||
& $NebulaCert sign -name "smoke" -networks "192.168.241.1/24" -ca-crt $CaCrt -ca-key $CaKey -out-crt $HostCrt -out-key $HostKey
|
||||
if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign failed (exit $LASTEXITCODE)" }
|
||||
|
||||
function Write-Config {
|
||||
param([string]$Category)
|
||||
$cfg = Join-Path $WorkDir 'config.yml'
|
||||
@"
|
||||
pki:
|
||||
ca: $CaCrt
|
||||
cert: $HostCrt
|
||||
key: $HostKey
|
||||
static_host_map: {}
|
||||
lighthouse:
|
||||
am_lighthouse: true
|
||||
interval: 60
|
||||
hosts: []
|
||||
listen:
|
||||
host: 0.0.0.0
|
||||
port: 4242
|
||||
tun:
|
||||
disabled: false
|
||||
dev: $DevName
|
||||
drop_local_broadcast: false
|
||||
drop_multicast: false
|
||||
tx_queue: 500
|
||||
mtu: 1300
|
||||
network_category: $Category
|
||||
logging:
|
||||
level: info
|
||||
format: text
|
||||
firewall:
|
||||
outbound_action: drop
|
||||
inbound_action: drop
|
||||
conntrack:
|
||||
tcp_timeout: 12m
|
||||
udp_timeout: 3m
|
||||
default_timeout: 10m
|
||||
outbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
inbound:
|
||||
- port: any
|
||||
proto: any
|
||||
host: any
|
||||
"@ | Out-File -FilePath $cfg -Encoding utf8
|
||||
return $cfg
|
||||
}
|
||||
|
||||
function Test-Category {
|
||||
param(
|
||||
[Parameter(Mandatory)] [string]$ConfigValue,
|
||||
[Parameter(Mandatory)] [string]$ExpectedCategory
|
||||
)
|
||||
Write-Host ""
|
||||
Write-Host "=== smoke: network_category=$ConfigValue (expecting $ExpectedCategory) ==="
|
||||
|
||||
$cfg = Write-Config -Category $ConfigValue
|
||||
$stdoutLog = Join-Path $WorkDir "nebula-$ConfigValue.out.log"
|
||||
$stderrLog = Join-Path $WorkDir "nebula-$ConfigValue.err.log"
|
||||
|
||||
$proc = Start-Process -FilePath $Nebula -ArgumentList @('-config', $cfg) `
|
||||
-PassThru -NoNewWindow `
|
||||
-RedirectStandardOutput $stdoutLog `
|
||||
-RedirectStandardError $stderrLog
|
||||
|
||||
try {
|
||||
$deadline = (Get-Date).AddSeconds(30)
|
||||
$observed = $null
|
||||
while ((Get-Date) -lt $deadline) {
|
||||
if ($proc.HasExited) {
|
||||
Get-Content $stdoutLog -ErrorAction SilentlyContinue | Out-Host
|
||||
Get-Content $stderrLog -ErrorAction SilentlyContinue | Out-Host
|
||||
throw "nebula exited prematurely (code $($proc.ExitCode))"
|
||||
}
|
||||
$netProfile = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue
|
||||
if ($netProfile) {
|
||||
$observed = "$($netProfile.NetworkCategory)"
|
||||
if ($observed -ieq $ExpectedCategory) {
|
||||
Write-Host "OK: $DevName NetworkCategory=$observed"
|
||||
return
|
||||
}
|
||||
}
|
||||
Start-Sleep -Milliseconds 500
|
||||
}
|
||||
|
||||
Get-Content $stdoutLog -ErrorAction SilentlyContinue | Out-Host
|
||||
Get-Content $stderrLog -ErrorAction SilentlyContinue | Out-Host
|
||||
throw "expected NetworkCategory=$ExpectedCategory, observed='$observed' within 30s"
|
||||
}
|
||||
finally {
|
||||
if (-not $proc.HasExited) {
|
||||
Stop-Process -Id $proc.Id -Force -ErrorAction SilentlyContinue
|
||||
$proc.WaitForExit(5000) | Out-Null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Test-Category -ConfigValue 'private' -ExpectedCategory 'Private'
|
||||
Test-Category -ConfigValue 'public' -ExpectedCategory 'Public'
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "All smoke checks passed."
|
||||
|
|
@ -286,6 +286,16 @@ tun:
|
|||
# metric: 100
|
||||
# install: true
|
||||
|
||||
# On Windows only, sets the network category of the nebula interface. Without this, Windows often
|
||||
# leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more
|
||||
# restrictive than you usually want for an overlay between trusted peers. Valid values:
|
||||
# private - treat the nebula network as a private/trusted network (default)
|
||||
# public - treat it as a public/untrusted network
|
||||
# domain - treat it as a domain-authenticated network
|
||||
# unset - leave whatever Windows decided alone
|
||||
# Not reloadable.
|
||||
#network_category: private
|
||||
|
||||
# On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of
|
||||
# in nebula configuration files. Default false, not reloadable.
|
||||
#use_system_route_table: false
|
||||
|
|
|
|||
358
overlay/network_category_windows.go
Normal file
358
overlay/network_category_windows.go
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h.
|
||||
type networkCategory int32
|
||||
|
||||
const (
|
||||
networkCategoryPublic networkCategory = 0
|
||||
networkCategoryPrivate networkCategory = 1
|
||||
networkCategoryDomainAuthenticated networkCategory = 2
|
||||
)
|
||||
|
||||
func (c networkCategory) String() string {
|
||||
switch c {
|
||||
case networkCategoryPublic:
|
||||
return "public"
|
||||
case networkCategoryPrivate:
|
||||
return "private"
|
||||
case networkCategoryDomainAuthenticated:
|
||||
return "domain"
|
||||
}
|
||||
return fmt.Sprintf("unknown(%d)", c)
|
||||
}
|
||||
|
||||
// parseNetworkCategory accepts the user-supplied tun.network_category. A
|
||||
// second return of false means "leave the category alone".
|
||||
func parseNetworkCategory(s string) (networkCategory, bool, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "unset":
|
||||
return 0, false, nil
|
||||
case "public":
|
||||
return networkCategoryPublic, true, nil
|
||||
case "private":
|
||||
return networkCategoryPrivate, true, nil
|
||||
case "domain", "domainauthenticated":
|
||||
return networkCategoryDomainAuthenticated, true, nil
|
||||
}
|
||||
return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s)
|
||||
}
|
||||
|
||||
// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B}
|
||||
var clsidNetworkListManager = windows.GUID{
|
||||
Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B,
|
||||
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
|
||||
}
|
||||
|
||||
// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B}
|
||||
var iidINetworkListManager = windows.GUID{
|
||||
Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B,
|
||||
Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B},
|
||||
}
|
||||
|
||||
// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves.
|
||||
var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance")
|
||||
|
||||
const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER |
|
||||
windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER
|
||||
|
||||
const (
|
||||
hrSFALSE = 0x00000001
|
||||
hrRPCEChangedMode = 0x80010106
|
||||
)
|
||||
|
||||
type hresult uint32
|
||||
|
||||
func (h hresult) failed() bool { return int32(h) < 0 }
|
||||
func (h hresult) String() string {
|
||||
return fmt.Sprintf("HRESULT 0x%08x", uint32(h))
|
||||
}
|
||||
|
||||
var errAdapterNotFound = errors.New("adapter not present in network connections enumeration")
|
||||
|
||||
// Vtable layouts. Slot order must match the declaration order in netlistmgr.h.
|
||||
// All NLM interfaces here derive from IDispatch, which derives from IUnknown.
|
||||
|
||||
type iUnknownVtbl struct {
|
||||
QueryInterface uintptr
|
||||
AddRef uintptr
|
||||
Release uintptr
|
||||
}
|
||||
|
||||
type iDispatchVtbl struct {
|
||||
iUnknownVtbl
|
||||
GetTypeInfoCount uintptr
|
||||
GetTypeInfo uintptr
|
||||
GetIDsOfNames uintptr
|
||||
Invoke uintptr
|
||||
}
|
||||
|
||||
type iNetworkListManagerVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetNetworks uintptr
|
||||
GetNetwork uintptr
|
||||
GetNetworkConnections uintptr
|
||||
GetNetworkConnection uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
}
|
||||
|
||||
type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl }
|
||||
|
||||
func (n *iNetworkListManager) Release() {
|
||||
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
|
||||
}
|
||||
|
||||
func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) {
|
||||
var enum *iEnumNetworkConnections
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr)
|
||||
}
|
||||
return enum, nil
|
||||
}
|
||||
|
||||
type iEnumNetworkConnectionsVtbl struct {
|
||||
iDispatchVtbl
|
||||
NewEnum uintptr
|
||||
Next uintptr
|
||||
Skip uintptr
|
||||
Reset uintptr
|
||||
Clone uintptr
|
||||
}
|
||||
|
||||
type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl }
|
||||
|
||||
func (e *iEnumNetworkConnections) Release() {
|
||||
syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e)))
|
||||
}
|
||||
|
||||
// Next returns the next connection, or (nil, nil) at the end of the enumeration.
|
||||
func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) {
|
||||
var conn *iNetworkConnection
|
||||
var fetched uint32
|
||||
r1, _, _ := syscall.SyscallN(e.Vtbl.Next,
|
||||
uintptr(unsafe.Pointer(e)), 1,
|
||||
uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr)
|
||||
}
|
||||
if fetched == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type iNetworkConnectionVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetNetwork uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
GetConnectionId uintptr
|
||||
GetAdapterId uintptr
|
||||
GetDomainType uintptr
|
||||
}
|
||||
|
||||
type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl }
|
||||
|
||||
func (c *iNetworkConnection) Release() {
|
||||
syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c)))
|
||||
}
|
||||
|
||||
func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) {
|
||||
var g windows.GUID
|
||||
r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId,
|
||||
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) {
|
||||
var net *iNetwork
|
||||
r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork,
|
||||
uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr)
|
||||
}
|
||||
return net, nil
|
||||
}
|
||||
|
||||
type iNetworkVtbl struct {
|
||||
iDispatchVtbl
|
||||
GetName uintptr
|
||||
SetName uintptr
|
||||
GetDescription uintptr
|
||||
SetDescription uintptr
|
||||
GetNetworkId uintptr
|
||||
GetDomainType uintptr
|
||||
GetNetworkConnections uintptr
|
||||
GetTimeCreatedAndConnected uintptr
|
||||
IsConnectedToInternet uintptr
|
||||
IsConnected uintptr
|
||||
GetConnectivity uintptr
|
||||
GetCategory uintptr
|
||||
SetCategory uintptr
|
||||
}
|
||||
|
||||
type iNetwork struct{ Vtbl *iNetworkVtbl }
|
||||
|
||||
func (n *iNetwork) Release() {
|
||||
syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n)))
|
||||
}
|
||||
|
||||
func (n *iNetwork) GetCategory() (networkCategory, error) {
|
||||
var c networkCategory
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return 0, fmt.Errorf("INetwork.GetCategory: %s", hr)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (n *iNetwork) SetCategory(c networkCategory) error {
|
||||
r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory,
|
||||
uintptr(unsafe.Pointer(n)), uintptr(int32(c)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return fmt.Errorf("INetwork.SetCategory: %s", hr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// coInit initializes COM for the current OS thread. The returned function must
|
||||
// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is
|
||||
// already initialized in a different mode on this thread, which is still fine
|
||||
// for our calls but we must not Uninitialize in that case.
|
||||
func coInit() (func(), error) {
|
||||
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
|
||||
if err == nil {
|
||||
return windows.CoUninitialize, nil
|
||||
}
|
||||
if e, ok := err.(syscall.Errno); ok {
|
||||
switch uint32(e) {
|
||||
case hrSFALSE:
|
||||
return windows.CoUninitialize, nil
|
||||
case hrRPCEChangedMode:
|
||||
return func() {}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("CoInitializeEx: %w", err)
|
||||
}
|
||||
|
||||
func createNetworkListManager() (*iNetworkListManager, error) {
|
||||
var nlm *iNetworkListManager
|
||||
r1, _, _ := procCoCreateInstance.Call(
|
||||
uintptr(unsafe.Pointer(&clsidNetworkListManager)),
|
||||
0,
|
||||
uintptr(clsCtxAll),
|
||||
uintptr(unsafe.Pointer(&iidINetworkListManager)),
|
||||
uintptr(unsafe.Pointer(&nlm)),
|
||||
)
|
||||
if hr := hresult(r1); hr.failed() {
|
||||
return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr)
|
||||
}
|
||||
return nlm, nil
|
||||
}
|
||||
|
||||
// setNetworkCategory locates the network connection bound to adapterGUID and
|
||||
// sets the category of its parent network. Returns errAdapterNotFound if the
|
||||
// adapter is not yet visible in the NLM enumeration.
|
||||
func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error {
|
||||
deinit, err := coInit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer deinit()
|
||||
|
||||
nlm, err := createNetworkListManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nlm.Release()
|
||||
|
||||
enum, err := nlm.GetNetworkConnections()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer enum.Release()
|
||||
|
||||
for {
|
||||
conn, err := enum.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errAdapterNotFound
|
||||
}
|
||||
|
||||
guid, err := conn.GetAdapterId()
|
||||
if err != nil || guid != adapterGUID {
|
||||
conn.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
net, err := conn.GetNetwork()
|
||||
conn.Release()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = net.SetCategory(cat)
|
||||
net.Release()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// applyNetworkCategory polls until the wintun adapter shows up in the NLM
|
||||
// enumeration, then sets the category. Intended to run in its own goroutine.
|
||||
func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) {
|
||||
// COM Init/Uninit must be paired on the same OS thread.
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
const (
|
||||
attempts = 30
|
||||
interval = 500 * time.Millisecond
|
||||
)
|
||||
for i := 0; i < attempts; i++ {
|
||||
err := setNetworkCategory(adapterGUID, cat)
|
||||
if err == nil {
|
||||
l.Info("Set Windows network category", "category", cat.String())
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errAdapterNotFound) {
|
||||
l.Warn("Failed to set Windows network category", "error", err, "category", cat.String())
|
||||
return
|
||||
}
|
||||
time.Sleep(interval)
|
||||
}
|
||||
l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set",
|
||||
"category", cat.String(),
|
||||
"waited", time.Duration(attempts)*interval,
|
||||
)
|
||||
}
|
||||
109
overlay/network_category_windows_test.go
Normal file
109
overlay/network_category_windows_test.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
//go:build !e2e_testing
|
||||
// +build !e2e_testing
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_parseNetworkCategory(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
wantCat networkCategory
|
||||
wantApply bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"", 0, false, false},
|
||||
{"unset", 0, false, false},
|
||||
{" UNSET ", 0, false, false},
|
||||
{"private", networkCategoryPrivate, true, false},
|
||||
{"Private", networkCategoryPrivate, true, false},
|
||||
{" PRIVATE ", networkCategoryPrivate, true, false},
|
||||
{"public", networkCategoryPublic, true, false},
|
||||
{"PUBLIC", networkCategoryPublic, true, false},
|
||||
{"domain", networkCategoryDomainAuthenticated, true, false},
|
||||
{"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false},
|
||||
{"garbage", 0, false, true},
|
||||
{"privates", 0, false, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
cat, apply, err := parseNetworkCategory(tc.in)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
|
||||
continue
|
||||
}
|
||||
if cat != tc.wantCat || apply != tc.wantApply {
|
||||
t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory
|
||||
// without mutating the host's network state. It validates the CLSID/IID
|
||||
// constants and every vtable index by enumerating connections, fetching the
|
||||
// adapter id and parent network, reading the current category, and writing it
|
||||
// back unchanged.
|
||||
//
|
||||
// Requires Windows but does not require admin or the wintun driver. Skips if
|
||||
// no network connections are available (unlikely outside of an isolated
|
||||
// container).
|
||||
func Test_NLM_round_trip(t *testing.T) {
|
||||
deinit, err := coInit()
|
||||
if err != nil {
|
||||
t.Fatalf("coInit: %v", err)
|
||||
}
|
||||
defer deinit()
|
||||
|
||||
nlm, err := createNetworkListManager()
|
||||
if err != nil {
|
||||
t.Fatalf("createNetworkListManager: %v", err)
|
||||
}
|
||||
defer nlm.Release()
|
||||
|
||||
enum, err := nlm.GetNetworkConnections()
|
||||
if err != nil {
|
||||
t.Fatalf("GetNetworkConnections: %v", err)
|
||||
}
|
||||
defer enum.Release()
|
||||
|
||||
saw := 0
|
||||
for {
|
||||
conn, err := enum.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("EnumNetworkConnections.Next: %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
break
|
||||
}
|
||||
saw++
|
||||
|
||||
if _, err := conn.GetAdapterId(); err != nil {
|
||||
conn.Release()
|
||||
t.Fatalf("INetworkConnection.GetAdapterId: %v", err)
|
||||
}
|
||||
|
||||
net, err := conn.GetNetwork()
|
||||
conn.Release()
|
||||
if err != nil {
|
||||
t.Fatalf("INetworkConnection.GetNetwork: %v", err)
|
||||
}
|
||||
|
||||
cat, err := net.GetCategory()
|
||||
if err != nil {
|
||||
net.Release()
|
||||
t.Fatalf("INetwork.GetCategory: %v", err)
|
||||
}
|
||||
// Set to the current value so the host's NLM state is unchanged but
|
||||
// SetCategory's vtable slot is still validated end-to-end.
|
||||
if err := net.SetCategory(cat); err != nil {
|
||||
net.Release()
|
||||
t.Fatalf("INetwork.SetCategory(%v): %v", cat, err)
|
||||
}
|
||||
net.Release()
|
||||
}
|
||||
|
||||
if saw == 0 {
|
||||
t.Skip("no NLM network connections available; skipping round-trip")
|
||||
}
|
||||
}
|
||||
|
|
@ -28,12 +28,15 @@ import (
|
|||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||
|
||||
type winTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *slog.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
guid windows.GUID
|
||||
networkCategory networkCategory
|
||||
setCategory bool
|
||||
l *slog.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
|
|
@ -54,11 +57,19 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w
|
|||
return nil, fmt.Errorf("generate GUID failed: %w", err)
|
||||
}
|
||||
|
||||
cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
guid: *guid,
|
||||
networkCategory: cat,
|
||||
setCategory: setCat,
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
|
|
@ -142,6 +153,13 @@ func (t *winTun) Activate() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if t.setCategory {
|
||||
// The wintun adapter takes a moment to register with the Network List
|
||||
// Manager, so we apply the category in the background and retry until
|
||||
// it shows up.
|
||||
go applyNetworkCategory(t.l, t.guid, t.networkCategory)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue