1
0
mirror of https://github.com/pocket-id/pocket-id.git synced 2026-02-15 14:35:06 +00:00

refactor: fix code smells

This commit is contained in:
Elias Schneider
2025-03-27 17:46:10 +01:00
parent c9e0073b63
commit 5c198c280c
15 changed files with 74 additions and 48 deletions

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
@@ -16,10 +17,9 @@ import (
"mime/quotedprintable"
"net/textproto"
"os"
"strings"
ttemplate "text/template"
"time"
"github.com/google/uuid"
"strings"
)
type EmailService struct {
@@ -107,7 +107,7 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
domain = hostname
}
}
c.AddHeader("Message-ID", "<" + uuid.New().String() + "@" + domain + ">")
c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
c.Body(body)
@@ -131,7 +131,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true",
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
}

View File

@@ -3,6 +3,7 @@ package service
import (
"archive/tar"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
@@ -124,8 +125,15 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
// Download the database tar.gz file nolint:gosec
resp, err := http.Get(downloadUrl)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download database: %w", err)
}
@@ -164,6 +172,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tarReader := tar.NewReader(gzr)
var totalSize int64
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
// Iterate over the files in the tar archive
for {
header, err := tarReader.Next()
@@ -176,6 +187,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
// Check if the file is the GeoLite2-City.mmdb file
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
totalSize += header.Size
if totalSize > maxTotalSize {
return errors.New("total decompressed size exceeds maximum allowed limit")
}
// extract to a temporary file to avoid having a corrupted db in case of write failure.
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
@@ -185,7 +201,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tempName := tmpFile.Name()
// Write the file contents directly to the target location
if _, err := io.Copy(tmpFile, tarReader); err != nil {
if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
// if fails to write, then cleanup and throw an error
tmpFile.Close()
os.Remove(tempName)

View File

@@ -38,7 +38,7 @@ func TestJwtService_Init(t *testing.T) {
// Verify the key has been saved to disk as JWK
jwkPath := filepath.Join(tempDir, PrivateKeyFile)
_, err = os.Stat(jwkPath)
assert.NoError(t, err, "JWK file should exist")
require.NoError(t, err, "JWK file should exist")
// Verify the generated key is valid
keyData, err := os.ReadFile(jwkPath)
@@ -229,7 +229,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Check the claims
assert.Equal(t, user.ID, claims.Subject, "Token subject should match user ID")
assert.Equal(t, false, claims.IsAdmin, "IsAdmin should be false")
assert.False(t, claims.IsAdmin, "IsAdmin should be false")
assert.Contains(t, claims.Audience, "https://test.example.com", "Audience should contain the app URL")
// Check token expiration time is approximately 60 minutes from now
@@ -263,7 +263,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
require.NoError(t, err, "Failed to verify generated token")
// Check the IsAdmin claim is true
assert.Equal(t, true, claims.IsAdmin, "IsAdmin should be true for admin users")
assert.True(t, claims.IsAdmin, "IsAdmin should be true for admin users")
assert.Equal(t, adminUser.ID, claims.Subject, "Token subject should match admin ID")
})
@@ -404,7 +404,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Verify should fail due to issuer mismatch
_, err = service.VerifyIdToken(tokenString)
assert.Error(t, err, "Verification should fail with incorrect issuer")
require.Error(t, err, "Verification should fail with incorrect issuer")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}
@@ -492,7 +492,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed))
assert.Error(t, err, "Verification should fail with expired token")
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
@@ -520,7 +520,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString)
assert.Error(t, err, "Verification should fail with invalid signature")
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "couldn't handle this token", "Error message should indicate token verification failure")
})
}

View File

@@ -2,6 +2,7 @@ package service
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
@@ -11,6 +12,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -36,7 +38,7 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify}))
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -65,6 +67,7 @@ func (s *LdapService) SyncAll() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -150,6 +153,9 @@ func (s *LdapService) SyncGroups() error {
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
@@ -180,6 +186,7 @@ func (s *LdapService) SyncGroups() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -296,8 +303,15 @@ func (s *LdapService) SaveProfilePicture(userId string, pictureString string) er
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
// If the photo is a URL, download it
response, err := http.Get(pictureString)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}

View File

@@ -209,6 +209,9 @@ func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSec
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
if err != nil {
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData)