mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 22:22:27 -07:00
Allow for - to stand in for stdin/out
This commit is contained in:
parent
1ada3d4dd9
commit
7c31a8eb35
13 changed files with 718 additions and 57 deletions
|
|
@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// out-key is meaningless under PKCS#11 because the private key never
|
||||
// leaves the HSM; reject it so we never silently accept or claim a
|
||||
// stdout slot for it.
|
||||
outKeySet := false
|
||||
cf.set.Visit(func(f *flag.Flag) {
|
||||
if f.Name == "out-key" {
|
||||
outKeySet = true
|
||||
}
|
||||
})
|
||||
if outKeySet {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
}
|
||||
if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
|
||||
return err
|
||||
|
|
@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
}
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *cf.outKeyPath,
|
||||
"out-crt", *cf.outCertPath,
|
||||
"out-qr", *cf.outQRPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var passphrase []byte
|
||||
if !isP11 && *cf.encryption {
|
||||
passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
|
||||
if len(passphrase) == 0 {
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
errOut.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if err == ErrNoTerminal {
|
||||
|
|
@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
Curve: curve,
|
||||
}
|
||||
|
||||
if !isP11 {
|
||||
if !isP11 && !isStdio(*cf.outKeyPath) {
|
||||
if _, err := os.Stat(*cf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(*cf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
if !isStdio(*cf.outCertPath) {
|
||||
if _, err := os.Stat(*cf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
|
||||
}
|
||||
}
|
||||
|
||||
var c cert.Certificate
|
||||
|
|
@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outKeyPath, b, 0600)
|
||||
err = writeOutput(*cf.outKeyPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
|
|
@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
return fmt.Errorf("error while marshalling certificate: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outCertPath, b, 0600)
|
||||
err = writeOutput(*cf.outCertPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
|
|
@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
|
|||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*cf.outQRPath, b, 0600)
|
||||
err = writeOutput(*cf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
|
|
@ -332,6 +356,7 @@ func caSummary() string {
|
|||
func caHelp(out io.Writer) {
|
||||
cf := newCaFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
cf.set.SetOutput(out)
|
||||
cf.set.PrintDefaults()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) {
|
|||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" ca <flags>: create a self signed certificate authority\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -argon-iterations uint\n"+
|
||||
" \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+
|
||||
" -argon-memory uint\n"+
|
||||
|
|
@ -84,7 +85,7 @@ func Test_ca(t *testing.T) {
|
|||
err: nil,
|
||||
}
|
||||
|
||||
pwPromptOb := "Enter passphrase: "
|
||||
pwPromptEB := "Enter passphrase: "
|
||||
|
||||
// required args
|
||||
assertHelpError(t, ca(
|
||||
|
|
@ -168,8 +169,8 @@ func Test_ca(t *testing.T) {
|
|||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.NoError(t, ca(args, ob, eb, testpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, pwPromptEB, eb.String())
|
||||
|
||||
// test encrypted key with passphrase environment variable
|
||||
os.Remove(keyF.Name())
|
||||
|
|
@ -207,8 +208,8 @@ func Test_ca(t *testing.T) {
|
|||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.Error(t, ca(args, ob, eb, errpw))
|
||||
assert.Equal(t, pwPromptOb, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, pwPromptEB, eb.String())
|
||||
|
||||
// test when user fails to enter a password
|
||||
os.Remove(keyF.Name())
|
||||
|
|
@ -217,8 +218,8 @@ func Test_ca(t *testing.T) {
|
|||
eb.Reset()
|
||||
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
|
||||
require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
|
||||
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up
|
||||
|
||||
// create valid cert/key for overwrite tests
|
||||
os.Remove(keyF.Name())
|
||||
|
|
@ -247,3 +248,67 @@ func Test_ca(t *testing.T) {
|
|||
os.Remove(keyF.Name())
|
||||
|
||||
}
|
||||
|
||||
func Test_ca_stdio(t *testing.T) {
|
||||
nopw := &StubPasswordReader{}
|
||||
|
||||
keyF, err := os.CreateTemp("", "ca.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
crtF, err := os.CreateTemp("", "ca.crt")
|
||||
require.NoError(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
defer os.Remove(crtF.Name())
|
||||
|
||||
// out-crt on stdout, out-key on disk
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, c.IsCA())
|
||||
assert.Equal(t, "test-ca", c.Name())
|
||||
|
||||
// out-key on stdout, out-crt on disk
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
_, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
|
||||
// dual stdout is rejected up front
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t,
|
||||
ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
|
||||
// an output conflict combined with -encrypt must error BEFORE prompting
|
||||
// for a passphrase; pr would record any read attempt
|
||||
tracker := &trackingPasswordReader{}
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t,
|
||||
ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Zero(t, tracker.calls, "passphrase prompt should not have been called")
|
||||
}
|
||||
|
||||
type trackingPasswordReader struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) {
|
||||
pr.calls++
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
|||
if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if *cf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
|
||||
return err
|
||||
|
|
@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *cf.outKeyPath,
|
||||
"out-pub", *cf.outPubPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isP11 {
|
||||
p11Client, err := pkclient.FromUrl(*cf.p11url)
|
||||
if err != nil {
|
||||
|
|
@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
|
|||
return fmt.Errorf("error while getting public key: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
}
|
||||
err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
|
||||
err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-pub: %s", err)
|
||||
}
|
||||
|
|
@ -102,6 +112,7 @@ func keygenSummary() string {
|
|||
func keygenHelp(out io.Writer) {
|
||||
cf := newKeygenFlags()
|
||||
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
|
||||
_, _ = out.Write([]byte(stdioHelpText))
|
||||
cf.set.SetOutput(out)
|
||||
cf.set.PrintDefaults()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) {
|
|||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -curve string\n"+
|
||||
" \tECDH Curve (25519, P256) (default \"25519\")\n"+
|
||||
" -out-key string\n"+
|
||||
|
|
@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.Len(t, lPub, 32)
|
||||
}
|
||||
|
||||
func Test_keygen_stdio(t *testing.T) {
|
||||
keyF, err := os.CreateTemp("", "test.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
pubF, err := os.CreateTemp("", "test.pub")
|
||||
require.NoError(t, err)
|
||||
os.Remove(pubF.Name())
|
||||
defer os.Remove(pubF.Name())
|
||||
|
||||
// out-pub on stdout, out-key on disk
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb))
|
||||
assert.Empty(t, eb.String())
|
||||
lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Len(t, lPub, 32)
|
||||
|
||||
// out-key on stdout, out-pub on disk
|
||||
os.Remove(keyF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb))
|
||||
assert.Empty(t, eb.String())
|
||||
lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
assert.Len(t, lKey, 32)
|
||||
|
||||
// both on stdout is a conflict caught up front
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb),
|
||||
`-out-key and -out-pub both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) {
|
|||
}
|
||||
|
||||
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
// Terminal echo is off while reading, so the user's Enter key does not
|
||||
// produce a visible newline. Emit one on stderr to match the prompt.
|
||||
fmt.Fprintln(os.Stderr)
|
||||
|
||||
return password, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
rawCert, err := os.ReadFile(*pf.path)
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims, "path", *pf.path); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawCert, err := readInput("path", *pf.path, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read cert; %s", err)
|
||||
}
|
||||
|
||||
// When the QR is going to stdout, suppress the human-readable text/json
|
||||
// output so the binary stream is not contaminated.
|
||||
qrToStdout := isStdio(*pf.outQRPath)
|
||||
|
||||
var c cert.Certificate
|
||||
var qrBytes []byte
|
||||
part := 0
|
||||
|
|
@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||
return fmt.Errorf("error while unmarshaling cert: %s", err)
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
jsonCerts = append(jsonCerts, c)
|
||||
} else {
|
||||
_, _ = out.Write([]byte(c.String()))
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
if !qrToStdout {
|
||||
if *pf.json {
|
||||
jsonCerts = append(jsonCerts, c)
|
||||
} else {
|
||||
_, _ = out.Write([]byte(c.String()))
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
if *pf.outQRPath != "" {
|
||||
|
|
@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||
part++
|
||||
}
|
||||
|
||||
if *pf.json {
|
||||
if *pf.json && !qrToStdout {
|
||||
b, _ := json.Marshal(jsonCerts)
|
||||
_, _ = out.Write(b)
|
||||
_, _ = out.Write([]byte("\n"))
|
||||
|
|
@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
|
|||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*pf.outQRPath, b, 0600)
|
||||
err = writeOutput(*pf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
|
|
@ -107,6 +121,7 @@ func printSummary() string {
|
|||
func printHelp(out io.Writer) {
|
||||
pf := newPrintFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
pf.set.SetOutput(out)
|
||||
pf.set.PrintDefaults()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) {
|
|||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" print <flags>: prints details about a certificate\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -json\n"+
|
||||
" \tOptional: outputs certificates in json format\n"+
|
||||
" -out-qr string\n"+
|
||||
|
|
@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) {
|
|||
ob.String(),
|
||||
)
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// read cert from stdin
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-json", "-path", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
|
||||
`,
|
||||
ob.String(),
|
||||
)
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// -out-qr - sends only the PNG to stdout, suppressing the cert dump
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, eb.String())
|
||||
stdout := ob.Bytes()
|
||||
require.NotEmpty(t, stdout)
|
||||
// PNG magic, no PEM/JSON noise prepended
|
||||
assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8])
|
||||
assert.NotContains(t, string(stdout), "NebulaCertificate")
|
||||
assert.NotContains(t, string(stdout), `"details"`)
|
||||
|
||||
// json + out-qr - still suppresses json
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
withStdin(t, bytes.NewReader(p))
|
||||
err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4])
|
||||
assert.NotContains(t, ob.String(), `"details"`)
|
||||
}
|
||||
|
||||
// NewTestCaCert will generate a CA cert
|
||||
|
|
|
|||
|
|
@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set both -in-pub and -out-key")
|
||||
}
|
||||
if isP11 && *sf.outKeyPath != "" {
|
||||
return newHelpErrorf("cannot set -out-key with -pkcs11")
|
||||
}
|
||||
|
||||
var v4Networks []netip.Prefix
|
||||
var v6Networks []netip.Prefix
|
||||
|
|
@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
|
||||
}
|
||||
|
||||
if *sf.outKeyPath == "" {
|
||||
*sf.outKeyPath = *sf.name + ".key"
|
||||
}
|
||||
if *sf.outCertPath == "" {
|
||||
*sf.outCertPath = *sf.name + ".crt"
|
||||
}
|
||||
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims,
|
||||
"ca-key", *sf.caKeyPath,
|
||||
"ca-crt", *sf.caCertPath,
|
||||
"in-pub", *sf.inPubPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := reserveOutputs(&claims,
|
||||
"out-key", *sf.outKeyPath,
|
||||
"out-crt", *sf.outCertPath,
|
||||
"out-qr", *sf.outQRPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var curve cert.Curve
|
||||
var caKey []byte
|
||||
|
||||
if !isP11 {
|
||||
var rawCAKey []byte
|
||||
rawCAKey, err := os.ReadFile(*sf.caKeyPath)
|
||||
|
||||
rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca-key: %s", err)
|
||||
}
|
||||
|
|
@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
if len(passphrase) == 0 {
|
||||
// ask for a passphrase until we get one
|
||||
for i := 0; i < 5; i++ {
|
||||
out.Write([]byte("Enter passphrase: "))
|
||||
errOut.Write([]byte("Enter passphrase: "))
|
||||
passphrase, err = pr.ReadPassword()
|
||||
|
||||
if errors.Is(err, ErrNoTerminal) {
|
||||
|
|
@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
}
|
||||
}
|
||||
|
||||
rawCACert, err := os.ReadFile(*sf.caCertPath)
|
||||
rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca-crt: %s", err)
|
||||
}
|
||||
|
|
@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
|
||||
if *sf.inPubPath != "" {
|
||||
var pubCurve cert.Curve
|
||||
rawPub, err := os.ReadFile(*sf.inPubPath)
|
||||
rawPub, err := readInput("in-pub", *sf.inPubPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading in-pub: %s", err)
|
||||
}
|
||||
|
|
@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
pub, rawPriv = newKeypair(curve)
|
||||
}
|
||||
|
||||
if *sf.outKeyPath == "" {
|
||||
*sf.outKeyPath = *sf.name + ".key"
|
||||
}
|
||||
|
||||
if *sf.outCertPath == "" {
|
||||
*sf.outCertPath = *sf.name + ".crt"
|
||||
}
|
||||
|
||||
if _, err := os.Stat(*sf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||
if !isStdio(*sf.outCertPath) {
|
||||
if _, err := os.Stat(*sf.outCertPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
|
||||
}
|
||||
}
|
||||
|
||||
var crts []cert.Certificate
|
||||
|
|
@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
}
|
||||
|
||||
if !isP11 && *sf.inPubPath == "" {
|
||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||
if !isStdio(*sf.outKeyPath) {
|
||||
if _, err := os.Stat(*sf.outKeyPath); err == nil {
|
||||
return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
|
||||
}
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
|
||||
err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-key: %s", err)
|
||||
}
|
||||
|
|
@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
b = append(b, sb...)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outCertPath, b, 0600)
|
||||
err = writeOutput(*sf.outCertPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-crt: %s", err)
|
||||
}
|
||||
|
|
@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
|
|||
return fmt.Errorf("error while generating qr code: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(*sf.outQRPath, b, 0600)
|
||||
err = writeOutput(*sf.outQRPath, b, 0600, out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while writing out-qr: %s", err)
|
||||
}
|
||||
|
|
@ -440,6 +461,7 @@ func signSummary() string {
|
|||
func signHelp(out io.Writer) {
|
||||
sf := newSignFlags()
|
||||
out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n"))
|
||||
out.Write([]byte(stdioHelpText))
|
||||
sf.set.SetOutput(out)
|
||||
sf.set.PrintDefaults()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) {
|
|||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" sign <flags>: create and sign a certificate\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -ca-crt string\n"+
|
||||
" \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+
|
||||
" -ca-key string\n"+
|
||||
|
|
@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) {
|
|||
// test with the proper password
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
|
||||
// test with the proper password in the environment
|
||||
os.Remove(crtF.Name())
|
||||
os.Remove(keyF.Name())
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, signCert(args, ob, eb, testpw))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
os.Setenv("NEBULA_CA_PASSPHRASE", "")
|
||||
|
||||
|
|
@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) {
|
|||
testpw.password = []byte("invalid password")
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, testpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
|
||||
// test with the wrong password in environment
|
||||
ob.Reset()
|
||||
|
|
@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) {
|
|||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, nopw))
|
||||
// normally the user hitting enter on the prompt would add newlines between these
|
||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String())
|
||||
|
||||
// test an error condition
|
||||
ob.Reset()
|
||||
|
|
@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) {
|
|||
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
|
||||
require.Error(t, signCert(args, ob, eb, errpw))
|
||||
assert.Equal(t, "Enter passphrase: ", ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Equal(t, "Enter passphrase: ", eb.String())
|
||||
}
|
||||
|
||||
func Test_signCert_stdio(t *testing.T) {
|
||||
nopw := &StubPasswordReader{
|
||||
password: []byte(""),
|
||||
err: nil,
|
||||
}
|
||||
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)
|
||||
|
||||
ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
|
||||
rawCACrt, _ := ca.MarshalPEM()
|
||||
|
||||
caCrtF, err := os.CreateTemp("", "sign-cert.crt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caCrtF.Name())
|
||||
caCrtF.Write(rawCACrt)
|
||||
|
||||
caKeyF, err := os.CreateTemp("", "sign-cert.key")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caKeyF.Name())
|
||||
caKeyF.Write(rawCAKey)
|
||||
|
||||
keyF, err := os.CreateTemp("", "sign.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF.Name())
|
||||
defer os.Remove(keyF.Name())
|
||||
|
||||
// ca-key on stdin, cert to stdout
|
||||
withStdin(t, bytes.NewReader(rawCAKey))
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "stdin-test", lCrt.Name())
|
||||
assert.True(t, lCrt.CheckSignature(caPub))
|
||||
|
||||
// two flags reading from stdin should error before any read attempt;
|
||||
// otherwise an interactive shell would hang on io.ReadAll
|
||||
stdinIn := bytes.NewReader(rawCAKey)
|
||||
withStdin(t, stdinIn)
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw),
|
||||
`-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front")
|
||||
|
||||
// two flags writing to stdout should error before any output is written
|
||||
// AND before stdin is consumed
|
||||
stdinR := bytes.NewReader(rawCAKey)
|
||||
withStdin(t, stdinR)
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"}
|
||||
require.EqualError(t, signCert(args, ob, eb, nopw),
|
||||
`-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
assert.Empty(t, ob.String())
|
||||
// stdin should be untouched because the conflict was caught up front
|
||||
assert.Equal(t, len(rawCAKey), stdinR.Len())
|
||||
|
||||
// out-key on stdout, cert on disk
|
||||
keyF2, err := os.CreateTemp("", "sign.key")
|
||||
require.NoError(t, err)
|
||||
os.Remove(keyF2.Name())
|
||||
defer os.Remove(keyF2.Name())
|
||||
crtF, err := os.CreateTemp("", "sign.crt")
|
||||
require.NoError(t, err)
|
||||
os.Remove(crtF.Name())
|
||||
defer os.Remove(crtF.Name())
|
||||
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
_, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cert.Curve_CURVE25519, curve)
|
||||
|
||||
// in-pub on stdin (caller already has a keypair, only the cert is generated)
|
||||
inPub, _ := x25519Keypair()
|
||||
rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)
|
||||
|
||||
withStdin(t, bytes.NewReader(rawInPub))
|
||||
os.Remove(crtF.Name())
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"}
|
||||
require.NoError(t, signCert(args, ob, eb, nopw))
|
||||
assert.Empty(t, eb.String())
|
||||
stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "in-pub-test", stdinCrt.Name())
|
||||
assert.Equal(t, inPub, stdinCrt.PublicKey())
|
||||
}
|
||||
|
|
|
|||
117
cmd/nebula-cert/stdio.go
Normal file
117
cmd/nebula-cert/stdio.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// stdioPath is the special path value that selects stdin (for inputs) or
|
||||
// stdout (for outputs) instead of a file on disk.
|
||||
const stdioPath = "-"
|
||||
|
||||
// stdioHelpText is rendered just under the Usage line of each subcommand
|
||||
// help so the - convention is documented once instead of on every flag.
|
||||
const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"
|
||||
|
||||
// stdinReader is the source used when an input flag is set to "-".
|
||||
// It is a package level var so tests can swap in a deterministic reader.
|
||||
// Tests that mutate stdinReader cannot run with t.Parallel().
|
||||
var stdinReader io.Reader = os.Stdin
|
||||
|
||||
// ioClaims tracks which flags have claimed stdin and stdout during a single
|
||||
// command invocation so we can refuse a second flag asking for the same
|
||||
// stream.
|
||||
type ioClaims struct {
|
||||
in string
|
||||
out string
|
||||
}
|
||||
|
||||
func (c *ioClaims) claimIn(flagName string) error {
|
||||
if c.in != "" && c.in != flagName {
|
||||
return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath)
|
||||
}
|
||||
c.in = flagName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ioClaims) claimOut(flagName string) error {
|
||||
if c.out != "" && c.out != flagName {
|
||||
return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath)
|
||||
}
|
||||
c.out = flagName
|
||||
return nil
|
||||
}
|
||||
|
||||
// reserveInputs walks alternating (flagName, path) pairs and claims stdin
|
||||
// for any path equal to stdioPath. It must be called before any input is
|
||||
// read so a conflict can be reported immediately instead of blocking on
|
||||
// io.ReadAll while waiting for input that will never arrive.
|
||||
func reserveInputs(claims *ioClaims, pairs ...string) error {
|
||||
return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs)
|
||||
}
|
||||
|
||||
// reserveOutputs walks alternating (flagName, path) pairs and claims stdout
|
||||
// for any path equal to stdioPath. It must be called before any output is
|
||||
// written so a conflict cannot leave one stream half written before the
|
||||
// second flag fails.
|
||||
func reserveOutputs(claims *ioClaims, pairs ...string) error {
|
||||
return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs)
|
||||
}
|
||||
|
||||
func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error {
|
||||
if len(pairs)%2 != 0 {
|
||||
panic(who + " requires alternating name, path pairs")
|
||||
}
|
||||
for i := 0; i < len(pairs); i += 2 {
|
||||
name, path := pairs[i], pairs[i+1]
|
||||
if path != stdioPath {
|
||||
continue
|
||||
}
|
||||
if err := claim(claims, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readInput returns the bytes referenced by path, reading from stdin when
|
||||
// path is stdioPath.
|
||||
func readInput(flagName, path string, claims *ioClaims) ([]byte, error) {
|
||||
if path == stdioPath {
|
||||
if err := claims.claimIn(flagName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.ReadAll(stdinReader)
|
||||
}
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// openInput returns a reader for path. When path is stdioPath the returned
|
||||
// reader wraps stdin and Close is a no-op.
|
||||
func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) {
|
||||
if path == stdioPath {
|
||||
if err := claims.claimIn(flagName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return io.NopCloser(stdinReader), nil
|
||||
}
|
||||
return os.Open(path)
|
||||
}
|
||||
|
||||
// writeOutput writes data to path, or to stdout when path is stdioPath. perm
|
||||
// is only used for file output. The caller must have already claimed stdout
|
||||
// via reserveOutputs before invoking with stdioPath.
|
||||
func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error {
|
||||
if path == stdioPath {
|
||||
_, err := stdout.Write(data)
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
// isStdio reports whether path is the stdio sentinel and so should skip
|
||||
// existence checks like "refuse to overwrite".
|
||||
func isStdio(path string) bool {
|
||||
return path == stdioPath
|
||||
}
|
||||
167
cmd/nebula-cert/stdio_test.go
Normal file
167
cmd/nebula-cert/stdio_test.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// withStdin temporarily replaces stdinReader for the duration of t.
|
||||
func withStdin(t *testing.T, r io.Reader) {
|
||||
t.Helper()
|
||||
prev := stdinReader
|
||||
stdinReader = r
|
||||
t.Cleanup(func() { stdinReader = prev })
|
||||
}
|
||||
|
||||
func Test_readInput_stdin(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hello"))
|
||||
var claims ioClaims
|
||||
|
||||
got, err := readInput("path", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello"), got)
|
||||
assert.Equal(t, "path", claims.in)
|
||||
}
|
||||
|
||||
func Test_readInput_file(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "f")
|
||||
require.NoError(t, os.WriteFile(p, []byte("file"), 0600))
|
||||
var claims ioClaims
|
||||
|
||||
got, err := readInput("path", p, &claims)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("file"), got)
|
||||
assert.Equal(t, "", claims.in)
|
||||
}
|
||||
|
||||
func Test_readInput_doubleStdinErrors(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hello"))
|
||||
var claims ioClaims
|
||||
|
||||
_, err := readInput("ca-key", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = readInput("ca-crt", "-", &claims)
|
||||
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_openInput_stdin(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hi"))
|
||||
var claims ioClaims
|
||||
|
||||
r, err := openInput("ca", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
defer r.Close()
|
||||
b, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hi"), b)
|
||||
}
|
||||
|
||||
func Test_openInput_doubleStdinErrors(t *testing.T) {
|
||||
withStdin(t, bytes.NewBufferString("hi"))
|
||||
var claims ioClaims
|
||||
|
||||
r, err := openInput("ca", "-", &claims)
|
||||
require.NoError(t, err)
|
||||
r.Close()
|
||||
|
||||
_, err = openInput("crt", "-", &claims)
|
||||
require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_writeOutput_stdout(t *testing.T) {
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := writeOutput("-", []byte("payload"), 0600, out)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "payload", out.String())
|
||||
}
|
||||
|
||||
func Test_writeOutput_file(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
p := filepath.Join(dir, "f")
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := writeOutput(p, []byte("payload"), 0600, out)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, out.String())
|
||||
got, err := os.ReadFile(p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("payload"), got)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_noConflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, reserveOutputs(&claims,
|
||||
"out-key", "/tmp/key",
|
||||
"out-crt", "-",
|
||||
"out-qr", "",
|
||||
))
|
||||
assert.Equal(t, "out-crt", claims.out)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_conflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
err := reserveOutputs(&claims,
|
||||
"out-key", "-",
|
||||
"out-crt", "-",
|
||||
)
|
||||
require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`)
|
||||
}
|
||||
|
||||
func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
require.NotNil(t, r)
|
||||
}()
|
||||
var claims ioClaims
|
||||
_ = reserveOutputs(&claims, "out-key")
|
||||
}
|
||||
|
||||
func Test_reserveInputs_noConflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, reserveInputs(&claims,
|
||||
"ca-key", "/tmp/ca.key",
|
||||
"ca-crt", "-",
|
||||
"in-pub", "",
|
||||
))
|
||||
assert.Equal(t, "ca-crt", claims.in)
|
||||
}
|
||||
|
||||
func Test_reserveInputs_conflict(t *testing.T) {
|
||||
var claims ioClaims
|
||||
err := reserveInputs(&claims,
|
||||
"ca-key", "-",
|
||||
"ca-crt", "-",
|
||||
)
|
||||
require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
||||
func Test_claimIn_idempotent(t *testing.T) {
|
||||
// pre-claim then a lazy re-claim of the same flag should be a no-op
|
||||
var claims ioClaims
|
||||
require.NoError(t, claims.claimIn("ca-key"))
|
||||
require.NoError(t, claims.claimIn("ca-key"))
|
||||
assert.Equal(t, "ca-key", claims.in)
|
||||
}
|
||||
|
||||
func Test_claimOut_idempotent(t *testing.T) {
|
||||
var claims ioClaims
|
||||
require.NoError(t, claims.claimOut("out-crt"))
|
||||
require.NoError(t, claims.claimOut("out-crt"))
|
||||
assert.Equal(t, "out-crt", claims.out)
|
||||
}
|
||||
|
||||
func Test_isStdio(t *testing.T) {
|
||||
assert.True(t, isStdio("-"))
|
||||
assert.False(t, isStdio(""))
|
||||
assert.False(t, isStdio("./-"))
|
||||
assert.False(t, isStdio("foo"))
|
||||
}
|
||||
|
|
@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
|
|||
return err
|
||||
}
|
||||
|
||||
caFile, err := os.Open(*vf.caPath)
|
||||
var claims ioClaims
|
||||
if err := reserveInputs(&claims,
|
||||
"ca", *vf.caPath,
|
||||
"crt", *vf.certPath,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
caReader, err := openInput("ca", *vf.caPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca: %w", err)
|
||||
}
|
||||
defer caFile.Close()
|
||||
defer caReader.Close()
|
||||
|
||||
caPool, err := cert.NewCAPoolFromPEMReader(caFile)
|
||||
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
|
||||
if err != nil && !errors.Is(err, cert.ErrExpired) {
|
||||
return fmt.Errorf("error while adding ca cert to pool: %w", err)
|
||||
}
|
||||
|
||||
rawCert, err := os.ReadFile(*vf.certPath)
|
||||
rawCert, err := readInput("crt", *vf.certPath, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read crt: %w", err)
|
||||
}
|
||||
|
|
@ -85,6 +93,7 @@ func verifySummary() string {
|
|||
func verifyHelp(out io.Writer) {
|
||||
vf := newVerifyFlags()
|
||||
_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
|
||||
_, _ = out.Write([]byte(stdioHelpText))
|
||||
vf.set.SetOutput(out)
|
||||
vf.set.PrintDefaults()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) {
|
|||
assert.Equal(
|
||||
t,
|
||||
"Usage of "+os.Args[0]+" verify <flags>: verifies a certificate isn't expired and was signed by a trusted authority.\n"+
|
||||
" Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+
|
||||
" -ca string\n"+
|
||||
" \tRequired: path to a file containing one or more ca certificates\n"+
|
||||
" -crt string\n"+
|
||||
|
|
@ -122,3 +123,46 @@ func Test_verify(t *testing.T) {
|
|||
assert.Empty(t, eb.String())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_verify_stdio(t *testing.T) {
|
||||
ob := &bytes.Buffer{}
|
||||
eb := &bytes.Buffer{}
|
||||
|
||||
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
|
||||
ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
|
||||
caPEM, _ := ca.MarshalPEM()
|
||||
|
||||
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
crtPEM, _ := crt.MarshalPEM()
|
||||
|
||||
caFile, err := os.CreateTemp("", "verify-ca")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(caFile.Name())
|
||||
caFile.Write(caPEM)
|
||||
|
||||
// crt on stdin, ca on disk
|
||||
withStdin(t, bytes.NewReader(crtPEM))
|
||||
require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// ca on stdin, crt on disk
|
||||
certFile, err := os.CreateTemp("", "verify-cert")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(certFile.Name())
|
||||
certFile.Write(crtPEM)
|
||||
|
||||
withStdin(t, bytes.NewReader(caPEM))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb))
|
||||
assert.Empty(t, ob.String())
|
||||
assert.Empty(t, eb.String())
|
||||
|
||||
// both flags on stdin should error
|
||||
withStdin(t, bytes.NewReader(caPEM))
|
||||
ob.Reset()
|
||||
eb.Reset()
|
||||
require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb),
|
||||
`-ca and -crt both set to "-", only one input may read from stdin`)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue