fix(agent/agentssh): ensure RSA key generation always produces valid keys (#16694)

Modify the RSA key generation algorithm to check that GCD(e, p-1) = 1 and
GCD(e, q-1) = 1 when selecting prime numbers, ensuring that e and φ(n) 
are coprime. This prevents ModInverse from returning nil, which would 
cause private key generation to fail and result in a panic when `Precompute` is called.

Change-Id: I0a453e1e1f8c638e40e7a4b87a6d0d7299e1cb5d
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2025-02-26 11:45:35 +01:00
committed by GitHub
parent 172e52317c
commit 38c0e8a086
3 changed files with 139 additions and 72 deletions
+87
View File
@@ -0,0 +1,87 @@
package agentrsa
import (
"crypto/rsa"
"math/big"
"math/rand"
)
// GenerateDeterministicKey generates an RSA private key deterministically based on the provided seed.
// This function uses a deterministic random source to generate the primes p and q, ensuring that the
// same seed will always produce the same private key. The generated key is 2048 bits in size.
//
// Reference: https://pkg.go.dev/crypto/rsa#GenerateKey
func GenerateDeterministicKey(seed int64) *rsa.PrivateKey {
// Since the standard lib purposefully does not generate
// deterministic rsa keys, we need to do it ourselves.
// Create deterministic random source
// nolint: gosec
deterministicRand := rand.New(rand.NewSource(seed))
// Use fixed values for p and q based on the seed
p := big.NewInt(0)
q := big.NewInt(0)
e := big.NewInt(65537) // Standard RSA public exponent
for {
// Generate deterministic primes using the seeded random
// Each prime should be ~1024 bits to get a 2048-bit key
for {
p.SetBit(p, 1024, 1) // Ensure it's large enough
for i := range 1024 {
if deterministicRand.Int63()%2 == 1 {
p.SetBit(p, i, 1)
} else {
p.SetBit(p, i, 0)
}
}
p1 := new(big.Int).Sub(p, big.NewInt(1))
if p.ProbablyPrime(20) && new(big.Int).GCD(nil, nil, e, p1).Cmp(big.NewInt(1)) == 0 {
break
}
}
for {
q.SetBit(q, 1024, 1) // Ensure it's large enough
for i := range 1024 {
if deterministicRand.Int63()%2 == 1 {
q.SetBit(q, i, 1)
} else {
q.SetBit(q, i, 0)
}
}
q1 := new(big.Int).Sub(q, big.NewInt(1))
if q.ProbablyPrime(20) && p.Cmp(q) != 0 && new(big.Int).GCD(nil, nil, e, q1).Cmp(big.NewInt(1)) == 0 {
break
}
}
// Calculate phi = (p-1) * (q-1)
p1 := new(big.Int).Sub(p, big.NewInt(1))
q1 := new(big.Int).Sub(q, big.NewInt(1))
phi := new(big.Int).Mul(p1, q1)
// Calculate private exponent d
d := new(big.Int).ModInverse(e, phi)
if d != nil {
// Calculate n = p * q
n := new(big.Int).Mul(p, q)
// Create the private key
privateKey := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: n,
E: int(e.Int64()),
},
D: d,
Primes: []*big.Int{p, q},
}
// Compute precomputed values
privateKey.Precompute()
return privateKey
}
}
}
+50
View File
@@ -0,0 +1,50 @@
package agentrsa_test
import (
"crypto/rsa"
"math/rand/v2"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/agent/agentrsa"
)
func TestGenerateDeterministicKey(t *testing.T) {
t.Parallel()
key1 := agentrsa.GenerateDeterministicKey(1234)
key2 := agentrsa.GenerateDeterministicKey(1234)
assert.Equal(t, key1, key2)
assert.EqualExportedValues(t, key1, key2)
}
var result *rsa.PrivateKey
func BenchmarkGenerateDeterministicKey(b *testing.B) {
var r *rsa.PrivateKey
for range b.N {
// always record the result of DeterministicPrivateKey to prevent
// the compiler eliminating the function call.
r = agentrsa.GenerateDeterministicKey(rand.Int64())
}
// always store the result to a package level variable
// so the compiler cannot eliminate the Benchmark itself.
result = r
}
func FuzzGenerateDeterministicKey(f *testing.F) {
testcases := []int64{0, 1234, 1010101010}
for _, tc := range testcases {
f.Add(tc) // Use f.Add to provide a seed corpus
}
f.Fuzz(func(t *testing.T, seed int64) {
key1 := agentrsa.GenerateDeterministicKey(seed)
key2 := agentrsa.GenerateDeterministicKey(seed)
assert.Equal(t, key1, key2)
assert.EqualExportedValues(t, key1, key2)
})
}
+2 -72
View File
@@ -3,12 +3,9 @@ package agentssh
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/rsa"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/big"
"math/rand"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@@ -33,6 +30,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentexec" "github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentrsa"
"github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/pty"
@@ -1092,75 +1090,7 @@ func CoderSigner(seed int64) (gossh.Signer, error) {
// Clients should ignore the host key when connecting. // Clients should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH, // The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security. // so SSH authentication doesn't improve security.
coderHostKey := agentrsa.GenerateDeterministicKey(seed)
// Since the standard lib purposefully does not generate
// deterministic rsa keys, we need to do it ourselves.
coderHostKey := func() *rsa.PrivateKey {
// Create deterministic random source
// nolint: gosec
deterministicRand := rand.New(rand.NewSource(seed))
// Use fixed values for p and q based on the seed
p := big.NewInt(0)
q := big.NewInt(0)
e := big.NewInt(65537) // Standard RSA public exponent
// Generate deterministic primes using the seeded random
// Each prime should be ~1024 bits to get a 2048-bit key
for {
p.SetBit(p, 1024, 1) // Ensure it's large enough
for i := 0; i < 1024; i++ {
if deterministicRand.Int63()%2 == 1 {
p.SetBit(p, i, 1)
} else {
p.SetBit(p, i, 0)
}
}
if p.ProbablyPrime(20) {
break
}
}
for {
q.SetBit(q, 1024, 1) // Ensure it's large enough
for i := 0; i < 1024; i++ {
if deterministicRand.Int63()%2 == 1 {
q.SetBit(q, i, 1)
} else {
q.SetBit(q, i, 0)
}
}
if q.ProbablyPrime(20) && p.Cmp(q) != 0 {
break
}
}
// Calculate n = p * q
n := new(big.Int).Mul(p, q)
// Calculate phi = (p-1) * (q-1)
p1 := new(big.Int).Sub(p, big.NewInt(1))
q1 := new(big.Int).Sub(q, big.NewInt(1))
phi := new(big.Int).Mul(p1, q1)
// Calculate private exponent d
d := new(big.Int).ModInverse(e, phi)
// Create the private key
privateKey := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: n,
E: int(e.Int64()),
},
D: d,
Primes: []*big.Int{p, q},
}
// Compute precomputed values
privateKey.Precompute()
return privateKey
}()
coderSigner, err := gossh.NewSignerFromKey(coderHostKey) coderSigner, err := gossh.NewSignerFromKey(coderHostKey)
return coderSigner, err return coderSigner, err