mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-04 15:39:45 +00:00
feat: add OIDC refresh_token support (#325)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -145,60 +145,121 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
|
||||
return isAllowedToAuthorize
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
|
||||
if grantType != "authorization_code" {
|
||||
return "", "", &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Verify the client secret if the client is not public
|
||||
if !client.IsPublic {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", &common.OidcMissingClientCredentialsError{}
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (string, string, string, int, error) {
|
||||
if grantType == "authorization_code" {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
// Verify the client secret if the client is not public
|
||||
if !client.IsPublic {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
if err != nil {
|
||||
return "", "", &common.OidcClientSecretInvalidError{}
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
if err != nil {
|
||||
return "", "", &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
||||
if client.IsPublic || client.PkceEnabled {
|
||||
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||
return "", "", &common.OidcInvalidCodeVerifierError{}
|
||||
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
|
||||
if client.IsPublic || client.PkceEnabled {
|
||||
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
|
||||
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
|
||||
}
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Generate a refresh token
|
||||
refreshToken, err := s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
|
||||
return idToken, accessToken, refreshToken, 3600, nil
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return "", "", &common.OidcInvalidAuthorizationCodeError{}
|
||||
if grantType == "refresh_token" {
|
||||
if refreshToken == "" {
|
||||
return "", "", "", 0, &common.OidcMissingRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Get the client to check if it's public
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Verify the client secret if the client is not public
|
||||
if !client.IsPublic {
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify refresh token
|
||||
var storedRefreshToken model.OidcRefreshToken
|
||||
if err := s.db.Preload("User").Where("token = ? AND expires_at > ?", refreshToken, datatype.DateTime(time.Now())).First(&storedRefreshToken).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Verify that the refresh token belongs to the provided client
|
||||
if storedRefreshToken.ClientID != clientID {
|
||||
return "", "", "", 0, &common.OidcInvalidRefreshTokenError{}
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Generate a new refresh token and invalidate the old one
|
||||
newRefreshToken, err := s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
}
|
||||
|
||||
// Delete the used refresh token
|
||||
s.db.Delete(&storedRefreshToken)
|
||||
|
||||
return "", accessToken, newRefreshToken, 3600, nil
|
||||
}
|
||||
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
|
||||
return idToken, accessToken, nil
|
||||
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||
@@ -567,3 +628,24 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
|
||||
|
||||
return "", &common.OidcInvalidCallbackURLError{}
|
||||
}
|
||||
|
||||
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(40)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
refreshToken := model.OidcRefreshToken{
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
|
||||
Token: randomString,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&refreshToken).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return randomString, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user