From 8c963818bb90c84dac04018eec93790900d4b0ce Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 25 Mar 2025 07:36:53 -0700 Subject: [PATCH] fix: hash the refresh token in the DB (security) (#379) --- backend/internal/service/oidc_service.go | 238 ++++++++++++----------- backend/internal/service/test_service.go | 2 +- 2 files changed, 128 insertions(+), 112 deletions(-) diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 7579d3ae..1e4a7adb 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -145,121 +145,133 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode return isAllowedToAuthorize } -func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (string, string, string, int, error) { - if grantType == "authorization_code" { - var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return "", "", "", 0, err - } +func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) { + switch grantType { + case "authorization_code": + return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier) + case "refresh_token": + accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret) + return "", accessToken, newRefreshToken, exp, err + default: + return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{} + } +} - // Verify the client secret if the client is not public - if !client.IsPublic { - if clientID == "" || clientSecret == "" { - return "", "", "", 0, &common.OidcMissingClientCredentialsError{} - } - - err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) - if err != nil { - return "", "", "", 0, &common.OidcClientSecretInvalidError{} - } - } - - var authorizationCodeMetaData model.OidcAuthorizationCode - err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error - if err != nil { - return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} - } - - // If the client is public or PKCE is enabled, the code verifier must match the code challenge - if client.IsPublic || client.PkceEnabled { - if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { - return "", "", "", 0, &common.OidcInvalidCodeVerifierError{} - } - } - - if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { - return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} - } - - userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) - if err != nil { - return "", "", "", 0, err - } - - idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) - if err != nil { - return "", "", "", 0, err - } - - // Generate a refresh token - refreshToken, err := s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope) - if err != nil { - return "", "", "", 0, err - } - - accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) - - s.db.Delete(&authorizationCodeMetaData) - - return idToken, accessToken, refreshToken, 3600, nil +func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) { + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return "", "", "", 0, err } - if grantType == "refresh_token" { - if refreshToken == "" { - return "", "", "", 0, &common.OidcMissingRefreshTokenError{} + // Verify the client secret if the client is not public + if !client.IsPublic { + if clientID == "" || clientSecret == "" { + return "", "", "", 0, &common.OidcMissingClientCredentialsError{} } - // Get the client to check if it's public - var client model.OidcClient - if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { - return "", "", "", 0, err - } - - // Verify the client secret if the client is not public - if !client.IsPublic { - if clientID == "" || clientSecret == "" { - return "", "", "", 0, &common.OidcMissingClientCredentialsError{} - } - - err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) - if err != nil { - return "", "", "", 0, &common.OidcClientSecretInvalidError{} - } - } - - // Verify refresh token - var storedRefreshToken model.OidcRefreshToken - if err := s.db.Preload("User").Where("token = ? AND expires_at > ?", refreshToken, datatype.DateTime(time.Now())).First(&storedRefreshToken).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return "", "", "", 0, &common.OidcInvalidRefreshTokenError{} - } - return "", "", "", 0, err - } - - // Verify that the refresh token belongs to the provided client - if storedRefreshToken.ClientID != clientID { - return "", "", "", 0, &common.OidcInvalidRefreshTokenError{} - } - - // Generate a new access token - accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID) + err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) if err != nil { - return "", "", "", 0, err + return "", "", "", 0, &common.OidcClientSecretInvalidError{} } - - // Generate a new refresh token and invalidate the old one - newRefreshToken, err := s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope) - if err != nil { - return "", "", "", 0, err - } - - // Delete the used refresh token - s.db.Delete(&storedRefreshToken) - - return "", accessToken, newRefreshToken, 3600, nil } - return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{} + var authorizationCodeMetaData model.OidcAuthorizationCode + err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error + if err != nil { + return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} + } + + // If the client is public or PKCE is enabled, the code verifier must match the code challenge + if client.IsPublic || client.PkceEnabled { + if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { + return "", "", "", 0, &common.OidcInvalidCodeVerifierError{} + } + } + + if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { + return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{} + } + + userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) + if err != nil { + return "", "", "", 0, err + } + + idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) + if err != nil { + return "", "", "", 0, err + } + + // Generate a refresh token + refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope) + if err != nil { + return "", "", "", 0, err + } + + accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) + + s.db.Delete(&authorizationCodeMetaData) + + return idToken, accessToken, refreshToken, 3600, nil +} + +func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) { + if refreshToken == "" { + return "", "", 0, &common.OidcMissingRefreshTokenError{} + } + + // Get the client to check if it's public + var client model.OidcClient + if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { + return "", "", 0, err + } + + // Verify the client secret if the client is not public + if !client.IsPublic { + if clientID == "" || clientSecret == "" { + return "", "", 0, &common.OidcMissingClientCredentialsError{} + } + + err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) + if err != nil { + return "", "", 0, &common.OidcClientSecretInvalidError{} + } + } + + // Verify refresh token + var storedRefreshToken model.OidcRefreshToken + err = s.db.Preload("User"). + Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())). + First(&storedRefreshToken). + Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", "", 0, &common.OidcInvalidRefreshTokenError{} + } + return "", "", 0, err + } + + // Verify that the refresh token belongs to the provided client + if storedRefreshToken.ClientID != clientID { + return "", "", 0, &common.OidcInvalidRefreshTokenError{} + } + + // Generate a new access token + accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID) + if err != nil { + return "", "", 0, err + } + + // Generate a new refresh token and invalidate the old one + newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope) + if err != nil { + return "", "", 0, err + } + + // Delete the used refresh token + s.db.Delete(&storedRefreshToken) + + return accessToken, newRefreshToken, 3600, nil } func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) { @@ -630,22 +642,26 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca } func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) { - randomString, err := utils.GenerateRandomAlphanumericString(40) + refreshToken, err := utils.GenerateRandomAlphanumericString(40) if err != nil { return "", err } - refreshToken := model.OidcRefreshToken{ + // Compute the hash of the refresh token to store in the DB + // Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough + refreshTokenHash := utils.CreateSha256Hash(refreshToken) + + m := model.OidcRefreshToken{ ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days - Token: randomString, + Token: refreshTokenHash, ClientID: clientID, UserID: userID, Scope: scope, } - if err := s.db.Create(&refreshToken).Error; err != nil { + if err := s.db.Create(&m).Error; err != nil { return "", err } - return randomString, nil + return refreshToken, nil } diff --git a/backend/internal/service/test_service.go b/backend/internal/service/test_service.go index 9b86bed4..f7483ea1 100644 --- a/backend/internal/service/test_service.go +++ b/backend/internal/service/test_service.go @@ -153,7 +153,7 @@ func (s *TestService) SeedDatabase() error { } refreshToken := model.OidcRefreshToken{ - Token: "ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo", + Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"), ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)), Scope: "openid profile email", UserID: users[0].ID,