diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index bc465001..5c03d8ba 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -23,6 +23,7 @@ type EnvConfigSchema struct { SqliteDBPath string `env:"SQLITE_DB_PATH"` PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` UploadPath string `env:"UPLOAD_PATH"` + KeysPath string `env:"KEYS_PATH"` Port string `env:"BACKEND_PORT"` Host string `env:"HOST"` MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"` @@ -37,6 +38,7 @@ var EnvConfig = &EnvConfigSchema{ SqliteDBPath: "data/pocket-id.db", PostgresConnectionString: "", UploadPath: "data/uploads", + KeysPath: "data/keys", AppURL: "http://localhost", Port: "8080", Host: "0.0.0.0", @@ -50,19 +52,21 @@ func init() { if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil { log.Fatal(err) } + // Validate the environment variables - if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres { + switch EnvConfig.DbProvider { + case DbProviderSqlite: + if EnvConfig.SqliteDBPath == "" { + log.Fatal("Missing SQLITE_DB_PATH environment variable") + } + case DbProviderPostgres: + if EnvConfig.PostgresConnectionString == "" { + log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable") + } + default: log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'") } - if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" { - log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable") - } - - if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" { - log.Fatal("Missing SQLITE_DB_PATH environment variable") - } - parsedAppUrl, err := url.Parse(EnvConfig.AppURL) if err != nil { log.Fatal("PUBLIC_APP_URL is not a valid URL") diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index 8a832e88..48f08835 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "encoding/pem" "errors" + "fmt" "log" "math/big" "os" @@ -22,13 +23,12 @@ import ( ) const ( - privateKeyPath = "data/keys/jwt_private_key.pem" - publicKeyPath = "data/keys/jwt_public_key.pem" + privateKeyFile = "jwt_private_key.pem" ) type JwtService struct { - PublicKey *rsa.PublicKey - PrivateKey *rsa.PrivateKey + privateKey *rsa.PrivateKey + keyId string appConfigService *AppConfigService } @@ -38,7 +38,7 @@ func NewJwtService(appConfigService *AppConfigService) *JwtService { } // Ensure keys are generated or loaded - if err := service.loadOrGenerateKeys(); err != nil { + if err := service.loadOrGenerateKey(common.EnvConfig.KeysPath); err != nil { log.Fatalf("Failed to initialize jwt service: %v", err) } @@ -59,30 +59,39 @@ type JWK struct { E string `json:"e"` } -// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist. -func (s *JwtService) loadOrGenerateKeys() error { +// loadOrGenerateKey loads RSA keys from the given paths or generates them if they do not exist. +func (s *JwtService) loadOrGenerateKey(keysPath string) error { + privateKeyPath := filepath.Join(keysPath, privateKeyFile) + if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) { - if err := s.generateKeys(); err != nil { - return err + if err := s.generateKey(keysPath); err != nil { + return fmt.Errorf("can't generate key: %w", err) } } privateKeyBytes, err := os.ReadFile(privateKeyPath) if err != nil { - return errors.New("can't read jwt private key: " + err.Error()) + return fmt.Errorf("can't read jwt private key: %w", err) } - s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes) if err != nil { - return errors.New("can't parse jwt private key: " + err.Error()) + return fmt.Errorf("can't parse jwt private key: %w", err) } - publicKeyBytes, err := os.ReadFile(publicKeyPath) + err = s.SetKey(privateKey) if err != nil { - return errors.New("can't read jwt public key: " + err.Error()) + return fmt.Errorf("failed to set private key: %w", err) } - s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes) + + return nil +} + +func (s *JwtService) SetKey(privateKey *rsa.PrivateKey) (err error) { + s.privateKey = privateKey + + s.keyId, err = s.generateKeyID() if err != nil { - return errors.New("can't parse jwt public key: " + err.Error()) + return fmt.Errorf("can't generate key ID: %w", err) } return nil @@ -100,20 +109,15 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { IsAdmin: user.IsAdmin, } - kid, err := s.generateKeyID(s.PublicKey) - if err != nil { - return "", errors.New("failed to generate key ID: " + err.Error()) - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) - token.Header["kid"] = kid + token.Header["kid"] = s.keyId - return token.SignedString(s.PrivateKey) + return token.SignedString(s.privateKey) } func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) { - return s.PublicKey, nil + return &s.privateKey.PublicKey, nil }) if err != nil || !token.Valid { return nil, errors.New("couldn't handle this token") @@ -146,15 +150,10 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID claims["nonce"] = nonce } - kid, err := s.generateKeyID(s.PublicKey) - if err != nil { - return "", errors.New("failed to generate key ID: " + err.Error()) - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header["kid"] = kid + token.Header["kid"] = s.keyId - return token.SignedString(s.PrivateKey) + return token.SignedString(s.privateKey) } func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { @@ -166,20 +165,15 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) Issuer: common.EnvConfig.AppURL, } - kid, err := s.generateKeyID(s.PublicKey) - if err != nil { - return "", errors.New("failed to generate key ID: " + err.Error()) - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim) - token.Header["kid"] = kid + token.Header["kid"] = s.keyId - return token.SignedString(s.PrivateKey) + return token.SignedString(s.privateKey) } func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - return s.PublicKey, nil + return &s.privateKey.PublicKey, nil }) if err != nil || !token.Valid { return nil, errors.New("couldn't handle this token") @@ -195,7 +189,7 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.Registered func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) { token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { - return s.PublicKey, nil + return &s.privateKey.PublicKey, nil }, jwt.WithIssuer(common.EnvConfig.AppURL)) if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { @@ -212,32 +206,27 @@ func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, e // GetJWK returns the JSON Web Key (JWK) for the public key. func (s *JwtService) GetJWK() (JWK, error) { - if s.PublicKey == nil { + if s.privateKey == nil { return JWK{}, errors.New("public key is not initialized") } - kid, err := s.generateKeyID(s.PublicKey) - if err != nil { - return JWK{}, err - } - jwk := JWK{ - Kid: kid, + Kid: s.keyId, Kty: "RSA", Use: "sig", Alg: "RS256", - N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()), - E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()), + N: base64.RawURLEncoding.EncodeToString(s.privateKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.privateKey.E)).Bytes()), } return jwk, nil } // GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key. -func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) { - pubASN1, err := x509.MarshalPKIXPublicKey(publicKey) +func (s *JwtService) generateKeyID() (string, error) { + pubASN1, err := x509.MarshalPKIXPublicKey(&s.privateKey.PublicKey) if err != nil { - return "", errors.New("failed to marshal public key: " + err.Error()) + return "", fmt.Errorf("failed to marshal public key: %w", err) } // Compute SHA-256 hash of the public key @@ -252,29 +241,22 @@ func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) { return base64.RawURLEncoding.EncodeToString(shortHash), nil } -// generateKeys generates a new RSA key pair and saves them to the specified paths. -func (s *JwtService) generateKeys() error { - if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil { - return errors.New("failed to create directories for keys: " + err.Error()) +// generateKey generates a new RSA key and saves it to the specified path. +func (s *JwtService) generateKey(keysPath string) error { + if err := os.MkdirAll(keysPath, 0700); err != nil { + return fmt.Errorf("failed to create directories for keys: %w", err) } privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return errors.New("failed to generate private key: " + err.Error()) + return fmt.Errorf("failed to generate private key: %w", err) } - s.PrivateKey = privateKey + privateKeyPath := filepath.Join(keysPath, privateKeyFile) if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil { return err } - publicKey := &privateKey.PublicKey - s.PublicKey = publicKey - - if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil { - return err - } - return nil } @@ -282,7 +264,7 @@ func (s *JwtService) generateKeys() error { func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error { keyFile, err := os.Create(path) if err != nil { - return errors.New("failed to create key file: " + err.Error()) + return fmt.Errorf("failed to create key file: %w", err) } defer keyFile.Close() @@ -292,7 +274,7 @@ func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) er }) if _, err := keyFile.Write(keyPEM); err != nil { - return errors.New("failed to write key file: " + err.Error()) + return fmt.Errorf("failed to write key file: %w", err) } return nil diff --git a/backend/internal/service/test_service.go b/backend/internal/service/test_service.go index 1c5d3d3c..b7dc9096 100644 --- a/backend/internal/service/test_service.go +++ b/backend/internal/service/test_service.go @@ -336,8 +336,7 @@ wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI= block, _ := pem.Decode([]byte(privateKeyString)) privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes) - s.jwtService.PrivateKey = privateKey - s.jwtService.PublicKey = &privateKey.PublicKey + s.jwtService.SetKey(privateKey) } // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key