diff --git a/backend/internal/dto/audit_log_dto.go b/backend/internal/dto/audit_log_dto.go index 7e25df48..dceeea0c 100644 --- a/backend/internal/dto/audit_log_dto.go +++ b/backend/internal/dto/audit_log_dto.go @@ -1,7 +1,6 @@ package dto import ( - "github.com/pocket-id/pocket-id/backend/internal/model" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" ) @@ -9,14 +8,14 @@ type AuditLogDto struct { ID string `json:"id"` CreatedAt datatype.DateTime `json:"createdAt"` - Event model.AuditLogEvent `json:"event"` - IpAddress string `json:"ipAddress"` - Country string `json:"country"` - City string `json:"city"` - Device string `json:"device"` - UserID string `json:"userID"` - Username string `json:"username"` - Data model.AuditLogData `json:"data"` + Event string `json:"event"` + IpAddress string `json:"ipAddress"` + Country string `json:"country"` + City string `json:"city"` + Device string `json:"device"` + UserID string `json:"userID"` + Username string `json:"username"` + Data map[string]string `json:"data"` } type AuditLogFilterDto struct { diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go index 5cba3c73..1313c9b9 100644 --- a/backend/internal/dto/validations.go +++ b/backend/internal/dto/validations.go @@ -8,13 +8,13 @@ import ( "github.com/go-playground/validator/v10" ) +// [a-zA-Z0-9] : The username must start with an alphanumeric character +// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols +// [a-zA-Z0-9]$ : The username must end with an alphanumeric character +var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$") + var validateUsername validator.Func = func(fl validator.FieldLevel) bool { - // [a-zA-Z0-9] : The username must start with an alphanumeric character - // [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols - // [a-zA-Z0-9]$ : The username must end with an alphanumeric character - regex := "^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$" - matched, _ := regexp.MatchString(regex, fl.Field().String()) - return matched + return validateUsernameRegex.MatchString(fl.Field().String()) } func init() { diff --git a/backend/internal/model/audit_log.go b/backend/internal/model/audit_log.go index 154a8dfe..0c19943a 100644 --- a/backend/internal/model/audit_log.go +++ b/backend/internal/model/audit_log.go @@ -10,7 +10,7 @@ type AuditLog struct { Base Event AuditLogEvent `sortable:"true"` - IpAddress string `sortable:"true"` + IpAddress *string `sortable:"true"` Country string `sortable:"true"` City string `sortable:"true"` UserAgent string `sortable:"true"` diff --git a/backend/internal/service/audit_log_service.go b/backend/internal/service/audit_log_service.go index a6508677..39ca7ebf 100644 --- a/backend/internal/service/audit_log_service.go +++ b/backend/internal/service/audit_log_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "log/slog" userAgentParser "github.com/mileusna/useragent" "github.com/pocket-id/pocket-id/backend/internal/dto" @@ -25,15 +26,15 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe } // Create creates a new audit log entry in the database -func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog { +func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) (model.AuditLog, bool) { country, city, err := s.geoliteService.GetLocationByIP(ipAddress) if err != nil { - log.Printf("Failed to get IP location: %v", err) + // Log the error but don't interrupt the operation + slog.Warn("Failed to get IP location", "error", err) } auditLog := model.AuditLog{ Event: event, - IpAddress: ipAddress, Country: country, City: city, UserAgent: userAgent, @@ -41,22 +42,31 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, Data: data, } + if ipAddress != "" { + // Only set ipAddress if not empty, because on Postgres we use INET columns that don't allow non-null empty values + auditLog.IpAddress = &ipAddress + } + // Save the audit log in the database err = tx. WithContext(ctx). Create(&auditLog). Error if err != nil { - log.Printf("Failed to create audit log: %v", err) - return model.AuditLog{} + slog.Error("Failed to create audit log", "error", err) + return model.AuditLog{}, false } - return auditLog + return auditLog, true } // CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog { - createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx) + createdAuditLog, ok := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx) + if !ok { + // At this point the transaction has been canceled already, and error has been logged + return createdAuditLog + } // Count the number of times the user has logged in from the same device var count int64 @@ -67,7 +77,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres Count(&count). Error if err != nil { - log.Printf("Failed to count audit logs: %v\n", err) + log.Printf("Failed to count audit logs: %v", err) return createdAuditLog } diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index 97e08a26..c260040b 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -122,6 +122,10 @@ func (s *GeoLiteService) DisableUpdater() bool { // GetLocationByIP returns the country and city of the given IP address. func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) { + if ipAddress == "" { + return "", "", nil + } + // Check the IP address against known private IP ranges if ip := net.ParseIP(ipAddress); ip != nil { // Check IPv6 local ranges first @@ -147,6 +151,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } } + addr, err := netip.ParseAddr(ipAddress) + if err != nil { + return "", "", fmt.Errorf("failed to parse IP address: %w", err) + } + // Race condition between reading and writing the database. s.mutex.RLock() defer s.mutex.RUnlock() @@ -157,11 +166,6 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string } defer db.Close() - addr, err := netip.ParseAddr(ipAddress) - if err != nil { - return "", "", fmt.Errorf("failed to parse IP address: %w", err) - } - var record struct { City struct { Names map[string]string `maxminddb:"names"` diff --git a/backend/internal/service/geolite_service_test.go b/backend/internal/service/geolite_service_test.go index fad5b00b..638c7721 100644 --- a/backend/internal/service/geolite_service_test.go +++ b/backend/internal/service/geolite_service_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGeoLiteService_IPv6LocalRanges(t *testing.T) { @@ -80,15 +82,9 @@ func TestGeoLiteService_IPv6LocalRanges(t *testing.T) { t.Errorf("Expected error or internal network classification for external IP") } } else { - if err != nil { - t.Errorf("Expected no error for local IP, got: %v", err) - } - if country != tt.expectedCountry { - t.Errorf("Expected country %s, got %s", tt.expectedCountry, country) - } - if city != tt.expectedCity { - t.Errorf("Expected city %s, got %s", tt.expectedCity, city) - } + require.NoError(t, err) + assert.Equal(t, tt.expectedCountry, country) + assert.Equal(t, tt.expectedCity, city) } }) } @@ -148,9 +144,7 @@ func TestGeoLiteService_isLocalIPv6(t *testing.T) { } result := service.isLocalIPv6(ip) - if result != tt.expected { - t.Errorf("Expected %v, got %v for IP %s", tt.expected, result, tt.testIP) - } + assert.Equal(t, tt.expected, result) }) } } @@ -214,18 +208,13 @@ func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) { err := service.initializeIPv6LocalRanges() - if tt.expectError && err == nil { - t.Errorf("Expected error but got none") - } - if !tt.expectError && err != nil { - t.Errorf("Expected no error but got: %v", err) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) } - rangeCount := len(service.localIPv6Ranges) - - if rangeCount != tt.expectCount { - t.Errorf("Expected %d ranges, got %d", tt.expectCount, rangeCount) - } + assert.Len(t, service.localIPv6Ranges, tt.expectCount) }) } } diff --git a/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.down.sql b/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.down.sql new file mode 100644 index 00000000..6f6f3280 --- /dev/null +++ b/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE audit_logs ALTER COLUMN ip_address SET NOT NULL; + +DROP INDEX IF EXISTS idx_audit_logs_created_at; +DROP INDEX IF EXISTS idx_audit_logs_user_agent; diff --git a/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.up.sql b/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.up.sql new file mode 100644 index 00000000..904400e0 --- /dev/null +++ b/backend/resources/migrations/postgres/20250628000000_audit_logs_ip_null.up.sql @@ -0,0 +1,5 @@ +ALTER TABLE audit_logs ALTER COLUMN ip_address DROP NOT NULL; + +-- Add missing indexes +CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at); +CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent); diff --git a/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.down.sql b/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.down.sql new file mode 100644 index 00000000..53803e77 --- /dev/null +++ b/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.down.sql @@ -0,0 +1,30 @@ +-- Re-create the table with non-nullable ip_address +-- We then move the data and rename the table +CREATE TABLE audit_logs_new +( + id TEXT NOT NULL PRIMARY KEY, + created_at DATETIME, + event TEXT NOT NULL, + ip_address TEXT NOT NULL, + user_agent TEXT NOT NULL, + data BLOB NOT NULL, + user_id TEXT REFERENCES users, + country TEXT, + city TEXT +); + +INSERT INTO audit_logs_new +SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city +FROM audit_logs; + +DROP TABLE audit_logs; + +ALTER TABLE audit_logs_new RENAME TO audit_logs; + +-- Re-create indexes +CREATE INDEX idx_audit_logs_event ON audit_logs(event); +CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at); +CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id); +CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent); +CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName'))); +CREATE INDEX idx_audit_logs_country ON audit_logs(country); diff --git a/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.up.sql b/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.up.sql new file mode 100644 index 00000000..edbeca9a --- /dev/null +++ b/backend/resources/migrations/sqlite/20250628000000_audit_logs_ip_null.up.sql @@ -0,0 +1,30 @@ +-- Re-create the table with nullable ip_address +-- We then move the data and rename the table +CREATE TABLE audit_logs_new +( + id TEXT NOT NULL PRIMARY KEY, + created_at DATETIME, + event TEXT NOT NULL, + ip_address TEXT, + user_agent TEXT NOT NULL, + data BLOB NOT NULL, + user_id TEXT REFERENCES users, + country TEXT, + city TEXT +); + +INSERT INTO audit_logs_new +SELECT id, created_at, event, ip_address, user_agent, data, user_id, country, city +FROM audit_logs; + +DROP TABLE audit_logs; + +ALTER TABLE audit_logs_new RENAME TO audit_logs; + +-- Re-create indexes +CREATE INDEX idx_audit_logs_event ON audit_logs(event); +CREATE INDEX idx_audit_logs_created_at ON audit_logs(created_at); +CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id); +CREATE INDEX idx_audit_logs_user_agent ON audit_logs(user_agent); +CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName'))); +CREATE INDEX idx_audit_logs_country ON audit_logs(country);