mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-02-15 15:45:05 +00:00
feat: add database storage backend (#1091)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
committed by
GitHub
parent
12125713a2
commit
29a1d3b778
@@ -12,7 +12,6 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
@@ -679,19 +678,21 @@ func (s *OidcService) introspectRefreshToken(ctx context.Context, clientID strin
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
|
||||
return s.getClientInternal(ctx, clientID, s.db)
|
||||
return s.getClientInternal(ctx, clientID, s.db, false)
|
||||
}
|
||||
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
q := tx.
|
||||
WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
Preload("AllowedUserGroups").
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
Preload("AllowedUserGroups")
|
||||
if forUpdate {
|
||||
q = q.Clauses(clause.Locking{Strength: "UPDATE"})
|
||||
}
|
||||
q = q.First(&client, "id = ?", clientID)
|
||||
if q.Error != nil {
|
||||
return model.OidcClient{}, q.Error
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
@@ -724,11 +725,6 @@ func (s *OidcService) ListClients(ctx context.Context, name string, listRequestO
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client := model.OidcClient{
|
||||
Base: model.Base{
|
||||
ID: input.ID,
|
||||
@@ -737,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
}
|
||||
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
||||
|
||||
err := tx.
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Create(&client).
|
||||
Error
|
||||
@@ -748,62 +744,65 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// All storage operations must be executed outside of a transaction
|
||||
if input.LogoURL != nil {
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if input.DarkLogoURL != nil {
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() { tx.Rollback() }()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).
|
||||
err := tx.WithContext(ctx).
|
||||
Preload("CreatedBy").
|
||||
First(&client, "id = ?", clientID).Error; err != nil {
|
||||
First(&client, "id = ?", clientID).Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
updateOIDCClientModelFromDto(&client, &input)
|
||||
|
||||
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
|
||||
err = tx.WithContext(ctx).Save(&client).Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
// All storage operations must be executed outside of a transaction
|
||||
if input.LogoURL != nil {
|
||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL, true)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.LogoURL, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if input.DarkLogoURL != nil {
|
||||
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.DarkLogoURL, false)
|
||||
err = s.downloadAndSaveLogoFromURL(ctx, client.ID, *input.DarkLogoURL, false)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, fmt.Errorf("failed to download dark logo: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -836,12 +835,24 @@ func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
|
||||
err := s.db.
|
||||
WithContext(ctx).
|
||||
Where("id = ?", clientID).
|
||||
Clauses(clause.Returning{}).
|
||||
Delete(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete images if present
|
||||
// Note that storage operations must be done outside of a transaction
|
||||
if client.ImageType != nil && *client.ImageType != "" {
|
||||
old := path.Join("oidc-client-images", client.ID+"."+*client.ImageType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
if client.DarkImageType != nil && *client.DarkImageType != "" {
|
||||
old := path.Join("oidc-client-images", client.ID+"-dark."+*client.DarkImageType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -941,57 +952,12 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
if err := s.fileStorage.Save(ctx, imagePath, reader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
|
||||
err = s.updateClientLogoType(ctx, tx, clientID, fileType, light)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
var client model.OidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
err = s.fileStorage.Save(ctx, imagePath, reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
return errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
err = s.updateClientLogoType(ctx, clientID, fileType, light)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -999,7 +965,31 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||
return s.deleteClientLogoInternal(ctx, clientID, "", func(client *model.OidcClient) (string, error) {
|
||||
if client.ImageType == nil {
|
||||
return "", errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.ImageType
|
||||
client.ImageType = nil
|
||||
return oldImageType, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string) error {
|
||||
return s.deleteClientLogoInternal(ctx, clientID, "-dark", func(client *model.OidcClient) (string, error) {
|
||||
if client.DarkImageType == nil {
|
||||
return "", errors.New("image not found")
|
||||
}
|
||||
|
||||
oldImageType := *client.DarkImageType
|
||||
client.DarkImageType = nil
|
||||
return oldImageType, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OidcService) deleteClientLogoInternal(ctx context.Context, clientID string, imagePathSuffix string, setClientImage func(*model.OidcClient) (string, error)) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -1014,13 +1004,11 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
||||
return err
|
||||
}
|
||||
|
||||
if client.DarkImageType == nil {
|
||||
return errors.New("image not found")
|
||||
oldImageType, err := setClientImage(&client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldImageType := *client.DarkImageType
|
||||
client.DarkImageType = nil
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
@@ -1029,12 +1017,14 @@ func (s *OidcService) DeleteClientDarkLogo(ctx context.Context, clientID string)
|
||||
return err
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", client.ID+"-dark."+oldImageType)
|
||||
if err := s.fileStorage.Delete(ctx, imagePath); err != nil {
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
// All storage operations must be performed outside of a database transaction
|
||||
imagePath := path.Join("oidc-client-images", client.ID+imagePathSuffix+"."+oldImageType)
|
||||
err = s.fileStorage.Delete(ctx, imagePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1048,7 +1038,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err = s.getClientInternal(ctx, id, tx)
|
||||
client, err = s.getClientInternal(ctx, id, tx, true)
|
||||
if err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
@@ -1831,7 +1821,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.getClientInternal(ctx, clientID, tx)
|
||||
client, err := s.getClientInternal(ctx, clientID, tx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1976,7 +1966,25 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||
}
|
||||
|
||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string, light bool) error {
|
||||
var errLogoTooLarge = errors.New("logo is too large")
|
||||
|
||||
func httpClientWithCheckRedirect(source *http.Client, checkRedirect func(req *http.Request, via []*http.Request) error) *http.Client {
|
||||
if source == nil {
|
||||
source = http.DefaultClient
|
||||
}
|
||||
|
||||
// Create a new client that clones the transport
|
||||
client := &http.Client{
|
||||
Transport: source.Transport,
|
||||
}
|
||||
|
||||
// Assign the CheckRedirect function
|
||||
client.CheckRedirect = checkRedirect
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, clientID string, raw string, light bool) error {
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1985,18 +1993,29 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r := net.Resolver{}
|
||||
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||
if err != nil || len(ips) == 0 {
|
||||
return fmt.Errorf("cannot resolve hostname")
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
ok, err := utils.IsURLPrivate(ctx, u)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return errors.New("private IP addresses are not allowed")
|
||||
}
|
||||
|
||||
// Prevents SSRF by allowing only public IPs
|
||||
for _, addr := range ips {
|
||||
if utils.IsPrivateIP(addr.IP) {
|
||||
return fmt.Errorf("private IP addresses are not allowed")
|
||||
// We need to check this on redirects too
|
||||
client := httpClientWithCheckRedirect(s.httpClient, func(r *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := utils.IsURLPrivate(r.Context(), r.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return errors.New("private IP addresses are not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
||||
if err != nil {
|
||||
@@ -2005,7 +2024,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
||||
req.Header.Set("Accept", "image/*")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2017,7 +2036,7 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
|
||||
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
||||
if resp.ContentLength > maxLogoSize {
|
||||
return fmt.Errorf("logo is too large")
|
||||
return errLogoTooLarge
|
||||
}
|
||||
|
||||
// Prefer extension in path if supported
|
||||
@@ -2037,48 +2056,70 @@ func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *
|
||||
}
|
||||
|
||||
imagePath := path.Join("oidc-client-images", clientID+darkSuffix+"."+ext)
|
||||
if err := s.fileStorage.Save(ctx, imagePath, io.LimitReader(resp.Body, maxLogoSize+1)); err != nil {
|
||||
err = s.fileStorage.Save(ctx, imagePath, utils.NewLimitReader(resp.Body, maxLogoSize+1))
|
||||
if errors.Is(err, utils.ErrSizeExceeded) {
|
||||
return errLogoTooLarge
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.updateClientLogoType(ctx, tx, clientID, ext, light); err != nil {
|
||||
err = s.updateClientLogoType(ctx, clientID, ext, light)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string, light bool) error {
|
||||
func (s *OidcService) updateClientLogoType(ctx context.Context, clientID string, ext string, light bool) error {
|
||||
var darkSuffix string
|
||||
if !light {
|
||||
darkSuffix = "-dark"
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// We need to acquire an update lock for the row to be locked, since we'll update it later
|
||||
var client model.OidcClient
|
||||
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
First(&client, "id = ?", clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to look up client: %w", err)
|
||||
}
|
||||
|
||||
var currentType *string
|
||||
if light {
|
||||
currentType = client.ImageType
|
||||
client.ImageType = &ext
|
||||
} else {
|
||||
currentType = client.DarkImageType
|
||||
client.DarkImageType = &ext
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Save(&client).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save updated client: %w", err)
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// Storage operations must be executed outside of a transaction
|
||||
if currentType != nil && *currentType != ext {
|
||||
old := path.Join("oidc-client-images", client.ID+darkSuffix+"."+*currentType)
|
||||
_ = s.fileStorage.Delete(ctx, old)
|
||||
}
|
||||
|
||||
var column string
|
||||
if light {
|
||||
column = "image_type"
|
||||
} else {
|
||||
column = "dark_image_type"
|
||||
}
|
||||
|
||||
return tx.WithContext(ctx).
|
||||
Model(&model.OidcClient{}).
|
||||
Where("id = ?", clientID).
|
||||
Update(column, ext).
|
||||
Error
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user