This commit is contained in:
Nate Brown 2026-05-08 03:41:04 +00:00 committed by GitHub
commit 5ec0f46685
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 718 additions and 57 deletions

View file

@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else {
// out-key is meaningless under PKCS#11 because the private key never
// leaves the HSM; reject it so we never silently accept or claim a
// stdout slot for it.
outKeySet := false
cf.set.Visit(func(f *flag.Flag) {
if f.Name == "out-key" {
outKeySet = true
}
})
if outKeySet {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
}
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
return err
@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-crt", *cf.outCertPath,
"out-qr", *cf.outQRPath,
); err != nil {
return err
}
var passphrase []byte
if !isP11 && *cf.encryption {
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
if len(passphrase) == 0 {
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if err == ErrNoTerminal {
@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
Curve: curve,
}
if !isP11 {
if !isP11 && !isStdio(*cf.outKeyPath) {
if _, err := os.Stat(*cf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
}
}
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
if !isStdio(*cf.outCertPath) {
if _, err := os.Stat(*cf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
}
}
var c cert.Certificate
@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
}
err = os.WriteFile(*cf.outKeyPath, b, 0600)
err = writeOutput(*cf.outKeyPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while marshalling certificate: %s", err)
}
err = os.WriteFile(*cf.outCertPath, b, 0600)
err = writeOutput(*cf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*cf.outQRPath, b, 0600)
err = writeOutput(*cf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@ -332,6 +356,7 @@ func caSummary() string {
func caHelp(out io.Writer) {
cf := newCaFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}

View file

@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -argon-iterations uint\n"+
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
" -argon-memory uint\n"+
@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
err: nil,
}
pwPromptOb := "Enter passphrase: "
pwPromptEB := "Enter passphrase: "
// required args
assertHelpError(t, ca(
@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test encrypted key with passphrase environment variable
os.Remove(keyF.Name())
@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, pwPromptEB, eb.String())
// test when user fails to enter a password
os.Remove(keyF.Name())
@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
// create valid cert/key for overwrite tests
os.Remove(keyF.Name())
@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
os.Remove(keyF.Name())
}
func Test_ca_stdio(t *testing.T) {
nopw := &StubPasswordReader{}
keyF, err := os.CreateTemp("", "ca.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
crtF, err := os.CreateTemp("", "ca.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
// out-crt on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
assert.Empty(t, eb.String())
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.True(t, c.IsCA())
assert.Equal(t, "test-ca", c.Name())
// out-key on stdout, out-crt on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// dual stdout is rejected up front
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// an output conflict combined with -encrypt must error BEFORE prompting
// for a passphrase; pr would record any read attempt
tracker := &trackingPasswordReader{}
ob.Reset()
eb.Reset()
require.EqualError(t,
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
}
type trackingPasswordReader struct {
calls int
}
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
pr.calls++
return []byte(""), nil
}

View file

@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
return err
}
} else if *cf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
return err
@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
}
}
var claims ioClaims
if err := reserveOutputs(&claims,
"out-key", *cf.outKeyPath,
"out-pub", *cf.outPubPath,
); err != nil {
return err
}
if isP11 {
p11Client, err := pkclient.FromUrl(*cf.p11url)
if err != nil {
@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while getting public key: %w", err)
}
} else {
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
}
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-pub: %s", err)
}
@ -102,6 +112,7 @@ func keygenSummary() string {
func keygenHelp(out io.Writer) {
cf := newKeygenFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
cf.set.SetOutput(out)
cf.set.PrintDefaults()
}

View file

@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -curve string\n"+
" \tECDH Curve (25519, P256) (default \"25519\")\n"+
" -out-key string\n"+
@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
require.NoError(t, err)
assert.Len(t, lPub, 32)
}
func Test_keygen_stdio(t *testing.T) {
keyF, err := os.CreateTemp("", "test.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
pubF, err := os.CreateTemp("", "test.pub")
require.NoError(t, err)
os.Remove(pubF.Name())
defer os.Remove(pubF.Name())
// out-pub on stdout, out-key on disk
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
assert.Empty(t, eb.String())
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lPub, 32)
// out-key on stdout, out-pub on disk
os.Remove(keyF.Name())
ob.Reset()
eb.Reset()
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
assert.Empty(t, eb.String())
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Len(t, lKey, 32)
// both on stdout is a conflict caught up front
ob.Reset()
eb.Reset()
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
}

View file

@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
}
password, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
// Terminal echo is off while reading, so the user's Enter key does not
// produce a visible newline. Emit one on stderr to match the prompt.
fmt.Fprintln(os.Stderr)
return password, err
}

View file

@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return err
}
rawCert, err := os.ReadFile(*pf.path)
var claims ioClaims
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
return err
}
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
return err
}
rawCert, err := readInput("path", *pf.path, &claims)
if err != nil {
return fmt.Errorf("unable to read cert; %s", err)
}
// When the QR is going to stdout, suppress the human-readable text/json
// output so the binary stream is not contaminated.
qrToStdout := isStdio(*pf.outQRPath)
var c cert.Certificate
var qrBytes []byte
part := 0
@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while unmarshaling cert: %s", err)
}
if *pf.json {
jsonCerts = append(jsonCerts, c)
} else {
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
if !qrToStdout {
if *pf.json {
jsonCerts = append(jsonCerts, c)
} else {
_, _ = out.Write([]byte(c.String()))
_, _ = out.Write([]byte("\n"))
}
}
if *pf.outQRPath != "" {
@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
part++
}
if *pf.json {
if *pf.json && !qrToStdout {
b, _ := json.Marshal(jsonCerts)
_, _ = out.Write(b)
_, _ = out.Write([]byte("\n"))
@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*pf.outQRPath, b, 0600)
err = writeOutput(*pf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@ -107,6 +121,7 @@ func printSummary() string {
func printHelp(out io.Writer) {
pf := newPrintFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
out.Write([]byte(stdioHelpText))
pf.set.SetOutput(out)
pf.set.PrintDefaults()
}

View file

@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -json\n"+
" \tOptional: outputs certificates in json format\n"+
" -out-qr string\n"+
@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
ob.String(),
)
assert.Empty(t, eb.String())
// read cert from stdin
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
require.NoError(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
assert.Empty(t, eb.String())
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
stdout := ob.Bytes()
require.NotEmpty(t, stdout)
// PNG magic, no PEM/JSON noise prepended
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
assert.NotContains(t, string(stdout), "NebulaCertificate")
assert.NotContains(t, string(stdout), `"details"`)
// json + out-qr - still suppresses json
ob.Reset()
eb.Reset()
withStdin(t, bytes.NewReader(p))
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
require.NoError(t, err)
assert.Empty(t, eb.String())
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
assert.NotContains(t, ob.String(), `"details"`)
}
// NewTestCaCert will generate a CA cert

View file

@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set both -in-pub and -out-key")
}
if isP11 && *sf.outKeyPath != "" {
return newHelpErrorf("cannot set -out-key with -pkcs11")
}
var v4Networks []netip.Prefix
var v6Networks []netip.Prefix
@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
var claims ioClaims
if err := reserveInputs(&claims,
"ca-key", *sf.caKeyPath,
"ca-crt", *sf.caCertPath,
"in-pub", *sf.inPubPath,
); err != nil {
return err
}
if err := reserveOutputs(&claims,
"out-key", *sf.outKeyPath,
"out-crt", *sf.outCertPath,
"out-qr", *sf.outQRPath,
); err != nil {
return err
}
var curve cert.Curve
var caKey []byte
if !isP11 {
var rawCAKey []byte
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-key: %s", err)
}
@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if len(passphrase) == 0 {
// ask for a passphrase until we get one
for i := 0; i < 5; i++ {
out.Write([]byte("Enter passphrase: "))
errOut.Write([]byte("Enter passphrase: "))
passphrase, err = pr.ReadPassword()
if errors.Is(err, ErrNoTerminal) {
@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
}
rawCACert, err := os.ReadFile(*sf.caCertPath)
rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca-crt: %s", err)
}
@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
if *sf.inPubPath != "" {
var pubCurve cert.Curve
rawPub, err := os.ReadFile(*sf.inPubPath)
rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
if err != nil {
return fmt.Errorf("error while reading in-pub: %s", err)
}
@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
pub, rawPriv = newKeypair(curve)
}
if *sf.outKeyPath == "" {
*sf.outKeyPath = *sf.name + ".key"
}
if *sf.outCertPath == "" {
*sf.outCertPath = *sf.name + ".crt"
}
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
if !isStdio(*sf.outCertPath) {
if _, err := os.Stat(*sf.outCertPath); err == nil {
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
}
}
var crts []cert.Certificate
@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
if !isP11 && *sf.inPubPath == "" {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
if !isStdio(*sf.outKeyPath) {
if _, err := os.Stat(*sf.outKeyPath); err == nil {
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
}
}
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-key: %s", err)
}
@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
b = append(b, sb...)
}
err = os.WriteFile(*sf.outCertPath, b, 0600)
err = writeOutput(*sf.outCertPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-crt: %s", err)
}
@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("error while generating qr code: %s", err)
}
err = os.WriteFile(*sf.outQRPath, b, 0600)
err = writeOutput(*sf.outQRPath, b, 0600, out)
if err != nil {
return fmt.Errorf("error while writing out-qr: %s", err)
}
@ -440,6 +461,7 @@ func signSummary() string {
func signHelp(out io.Writer) {
sf := newSignFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
out.Write([]byte(stdioHelpText))
sf.set.SetOutput(out)
sf.set.PrintDefaults()
}

View file

@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca-crt string\n"+
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
" -ca-key string\n"+
@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
// test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the proper password in the environment
os.Remove(crtF.Name())
os.Remove(keyF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
ob.Reset()
eb.Reset()
require.NoError(t, signCert(args, ob, eb, testpw))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
os.Setenv("NEBULA_CA_PASSPHRASE", "")
@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
// test with the wrong password in environment
ob.Reset()
@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
// test an error condition
ob.Reset()
@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String())
assert.Empty(t, ob.String())
assert.Equal(t, "Enter passphrase: ", eb.String())
}
func Test_signCert_stdio(t *testing.T) {
nopw := &StubPasswordReader{
password: []byte(""),
err: nil,
}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
rawCACrt, _ := ca.MarshalPEM()
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
require.NoError(t, err)
defer os.Remove(caCrtF.Name())
caCrtF.Write(rawCACrt)
caKeyF, err := os.CreateTemp("", "sign-cert.key")
require.NoError(t, err)
defer os.Remove(caKeyF.Name())
caKeyF.Write(rawCAKey)
keyF, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF.Name())
defer os.Remove(keyF.Name())
// ca-key on stdin, cert to stdout
withStdin(t, bytes.NewReader(rawCAKey))
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "stdin-test", lCrt.Name())
assert.True(t, lCrt.CheckSignature(caPub))
// two flags reading from stdin should error before any read attempt;
// otherwise an interactive shell would hang on io.ReadAll
stdinIn := bytes.NewReader(rawCAKey)
withStdin(t, stdinIn)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
// two flags writing to stdout should error before any output is written
// AND before stdin is consumed
stdinR := bytes.NewReader(rawCAKey)
withStdin(t, stdinR)
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
require.EqualError(t, signCert(args, ob, eb, nopw),
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
assert.Empty(t, ob.String())
// stdin should be untouched because the conflict was caught up front
assert.Equal(t, len(rawCAKey), stdinR.Len())
// out-key on stdout, cert on disk
keyF2, err := os.CreateTemp("", "sign.key")
require.NoError(t, err)
os.Remove(keyF2.Name())
defer os.Remove(keyF2.Name())
crtF, err := os.CreateTemp("", "sign.crt")
require.NoError(t, err)
os.Remove(crtF.Name())
defer os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, cert.Curve_CURVE25519, curve)
// in-pub on stdin (caller already has a keypair, only the cert is generated)
inPub, _ := x25519Keypair()
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
withStdin(t, bytes.NewReader(rawInPub))
os.Remove(crtF.Name())
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, eb.String())
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
require.NoError(t, err)
assert.Equal(t, "in-pub-test", stdinCrt.Name())
assert.Equal(t, inPub, stdinCrt.PublicKey())
}

117
cmd/nebula-cert/stdio.go Normal file
View file

@ -0,0 +1,117 @@
package main
import (
"fmt"
"io"
"os"
)
// stdioPath is the special path value that selects stdin (for inputs) or
// stdout (for outputs) instead of a file on disk.
const stdioPath = "-"
// stdioHelpText is rendered just under the Usage line of each subcommand
// help so the - convention is documented once instead of on every flag.
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
// stdinReader is the source used when an input flag is set to "-".
// It is a package level var so tests can swap in a deterministic reader.
// Tests that mutate stdinReader cannot run with t.Parallel().
var stdinReader io.Reader = os.Stdin
// ioClaims tracks which flags have claimed stdin and stdout during a single
// command invocation so we can refuse a second flag asking for the same
// stream.
type ioClaims struct {
in string
out string
}
func (c *ioClaims) claimIn(flagName string) error {
if c.in != "" && c.in != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
}
c.in = flagName
return nil
}
func (c *ioClaims) claimOut(flagName string) error {
if c.out != "" && c.out != flagName {
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
}
c.out = flagName
return nil
}
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
// for any path equal to stdioPath. It must be called before any input is
// read so a conflict can be reported immediately instead of blocking on
// io.ReadAll while waiting for input that will never arrive.
func reserveInputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
}
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
// for any path equal to stdioPath. It must be called before any output is
// written so a conflict cannot leave one stream half written before the
// second flag fails.
func reserveOutputs(claims *ioClaims, pairs ...string) error {
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
}
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
if len(pairs)%2 != 0 {
panic(who + " requires alternating name, path pairs")
}
for i := 0; i < len(pairs); i += 2 {
name, path := pairs[i], pairs[i+1]
if path != stdioPath {
continue
}
if err := claim(claims, name); err != nil {
return err
}
}
return nil
}
// readInput returns the bytes referenced by path, reading from stdin when
// path is stdioPath.
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.ReadAll(stdinReader)
}
return os.ReadFile(path)
}
// openInput returns a reader for path. When path is stdioPath the returned
// reader wraps stdin and Close is a no-op.
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
if path == stdioPath {
if err := claims.claimIn(flagName); err != nil {
return nil, err
}
return io.NopCloser(stdinReader), nil
}
return os.Open(path)
}
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
// is only used for file output. The caller must have already claimed stdout
// via reserveOutputs before invoking with stdioPath.
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
if path == stdioPath {
_, err := stdout.Write(data)
return err
}
return os.WriteFile(path, data, perm)
}
// isStdio reports whether path is the stdio sentinel and so should skip
// existence checks like "refuse to overwrite".
func isStdio(path string) bool {
return path == stdioPath
}

View file

@ -0,0 +1,167 @@
package main
import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// withStdin temporarily replaces stdinReader for the duration of t.
func withStdin(t *testing.T, r io.Reader) {
t.Helper()
prev := stdinReader
stdinReader = r
t.Cleanup(func() { stdinReader = prev })
}
func Test_readInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
got, err := readInput("path", "-", &claims)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), got)
assert.Equal(t, "path", claims.in)
}
func Test_readInput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
var claims ioClaims
got, err := readInput("path", p, &claims)
require.NoError(t, err)
assert.Equal(t, []byte("file"), got)
assert.Empty(t, claims.in)
}
func Test_readInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hello"))
var claims ioClaims
_, err := readInput("ca-key", "-", &claims)
require.NoError(t, err)
_, err = readInput("ca-crt", "-", &claims)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_openInput_stdin(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
defer r.Close()
b, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, []byte("hi"), b)
}
func Test_openInput_doubleStdinErrors(t *testing.T) {
withStdin(t, bytes.NewBufferString("hi"))
var claims ioClaims
r, err := openInput("ca", "-", &claims)
require.NoError(t, err)
r.Close()
_, err = openInput("crt", "-", &claims)
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
}
func Test_writeOutput_stdout(t *testing.T) {
out := &bytes.Buffer{}
err := writeOutput("-", []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Equal(t, "payload", out.String())
}
func Test_writeOutput_file(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
out := &bytes.Buffer{}
err := writeOutput(p, []byte("payload"), 0600, out)
require.NoError(t, err)
assert.Empty(t, out.String())
got, err := os.ReadFile(p)
require.NoError(t, err)
assert.Equal(t, []byte("payload"), got)
}
func Test_reserveOutputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveOutputs(&claims,
"out-key", "/tmp/key",
"out-crt", "-",
"out-qr", "",
))
assert.Equal(t, "out-crt", claims.out)
}
func Test_reserveOutputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveOutputs(&claims,
"out-key", "-",
"out-crt", "-",
)
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
}
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
}()
var claims ioClaims
_ = reserveOutputs(&claims, "out-key")
}
func Test_reserveInputs_noConflict(t *testing.T) {
var claims ioClaims
require.NoError(t, reserveInputs(&claims,
"ca-key", "/tmp/ca.key",
"ca-crt", "-",
"in-pub", "",
))
assert.Equal(t, "ca-crt", claims.in)
}
func Test_reserveInputs_conflict(t *testing.T) {
var claims ioClaims
err := reserveInputs(&claims,
"ca-key", "-",
"ca-crt", "-",
)
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
}
func Test_claimIn_idempotent(t *testing.T) {
// pre-claim then a lazy re-claim of the same flag should be a no-op
var claims ioClaims
require.NoError(t, claims.claimIn("ca-key"))
require.NoError(t, claims.claimIn("ca-key"))
assert.Equal(t, "ca-key", claims.in)
}
func Test_claimOut_idempotent(t *testing.T) {
var claims ioClaims
require.NoError(t, claims.claimOut("out-crt"))
require.NoError(t, claims.claimOut("out-crt"))
assert.Equal(t, "out-crt", claims.out)
}
func Test_isStdio(t *testing.T) {
assert.True(t, isStdio("-"))
assert.False(t, isStdio(""))
assert.False(t, isStdio("./-"))
assert.False(t, isStdio("foo"))
}

View file

@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err
}
caFile, err := os.Open(*vf.caPath)
var claims ioClaims
if err := reserveInputs(&claims,
"ca", *vf.caPath,
"crt", *vf.certPath,
); err != nil {
return err
}
caReader, err := openInput("ca", *vf.caPath, &claims)
if err != nil {
return fmt.Errorf("error while reading ca: %w", err)
}
defer caFile.Close()
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if err != nil && !errors.Is(err, cert.ErrExpired) {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
rawCert, err := os.ReadFile(*vf.certPath)
rawCert, err := readInput("crt", *vf.certPath, &claims)
if err != nil {
return fmt.Errorf("unable to read crt: %w", err)
}
@ -85,6 +93,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) {
vf := newVerifyFlags()
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
_, _ = out.Write([]byte(stdioHelpText))
vf.set.SetOutput(out)
vf.set.PrintDefaults()
}

View file

@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
assert.Equal(
t,
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
" -ca string\n"+
" \tRequired: path to a file containing one or more ca certificates\n"+
" -crt string\n"+
@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
assert.Empty(t, eb.String())
require.NoError(t, err)
}
func Test_verify_stdio(t *testing.T) {
ob := &bytes.Buffer{}
eb := &bytes.Buffer{}
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
caPEM, _ := ca.MarshalPEM()
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
crtPEM, _ := crt.MarshalPEM()
caFile, err := os.CreateTemp("", "verify-ca")
require.NoError(t, err)
defer os.Remove(caFile.Name())
caFile.Write(caPEM)
// crt on stdin, ca on disk
withStdin(t, bytes.NewReader(crtPEM))
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// ca on stdin, crt on disk
certFile, err := os.CreateTemp("", "verify-cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
certFile.Write(crtPEM)
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())
// both flags on stdin should error
withStdin(t, bytes.NewReader(caPEM))
ob.Reset()
eb.Reset()
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
`-ca and -crt both set to "-", only one input may read from stdin`)
}