1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-04 15:39:45 +00:00

feat: add OIDC refresh_token support (#325)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Kyle Mendell
2025-03-23 15:14:26 -05:00
committed by GitHub
parent 7888d70656
commit b8dcda8049
14 changed files with 339 additions and 55 deletions

View File

@@ -145,60 +145,121 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (string, string, string, int, error) {
if grantType == "authorization_code" {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", "", 0, err
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
if err != nil {
return "", "", "", 0, err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", "", 0, err
}
// Generate a refresh token
refreshToken, err := s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
if err != nil {
return "", "", "", 0, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
s.db.Delete(&authorizationCodeMetaData)
return idToken, accessToken, refreshToken, 3600, nil
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
if grantType == "refresh_token" {
if refreshToken == "" {
return "", "", "", 0, &common.OidcMissingRefreshTokenError{}
}
// Get the client to check if it's public
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
if err := s.db.Preload("User").Where("token = ? AND expires_at > ?", refreshToken, datatype.DateTime(time.Now())).First(&storedRefreshToken).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
return "", "", "", 0, err
}
// Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID {
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
// Generate a new access token
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
if err != nil {
return "", "", "", 0, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err := s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
if err != nil {
return "", "", "", 0, err
}
// Delete the used refresh token
s.db.Delete(&storedRefreshToken)
return "", accessToken, newRefreshToken, 3600, nil
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
if err != nil {
return "", "", err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
s.db.Delete(&authorizationCodeMetaData)
return idToken, accessToken, nil
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
}
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
@@ -567,3 +628,24 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
}
refreshToken := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
Token: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
}
if err := s.db.Create(&refreshToken).Error; err != nil {
return "", err
}
return randomString, nil
}