diff --git a/backend/internal/service/geolite_service.go b/backend/internal/service/geolite_service.go index 0d5baa58..34abd23f 100644 --- a/backend/internal/service/geolite_service.go +++ b/backend/internal/service/geolite_service.go @@ -2,6 +2,7 @@ package service import ( "archive/tar" + "bytes" "compress/gzip" "context" "errors" @@ -21,6 +22,7 @@ import ( "github.com/pocket-id/pocket-id/backend/internal/common" ) +const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size type GeoLiteService struct { httpClient *http.Client @@ -151,7 +153,22 @@ func (s *GeoLiteService) isDatabaseUpToDate() bool { // extractDatabase extracts the database file from the tar.gz archive directly to the target location. func (s *GeoLiteService) extractDatabase(reader io.Reader) error { - gzr, err := gzip.NewReader(reader) + // Check for gzip magic number + buf := make([]byte, 2) + _, err := io.ReadFull(reader, buf) + if err != nil { + return fmt.Errorf("failed to read magic number: %w", err) + } + + // Check if the file starts with the gzip magic number + isGzip := buf[0] == 0x1f && buf[1] == 0x8b + + if !isGzip { + // If not gzip, assume it's a regular database file + return s.writeDatabaseFile(io.MultiReader(bytes.NewReader(buf), reader)) + } + + gzr, err := gzip.NewReader(io.MultiReader(bytes.NewReader(buf), reader)) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) } @@ -160,7 +177,6 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error { tarReader := tar.NewReader(gzr) var totalSize int64 - const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size // Iterate over the files in the tar archive for { @@ -222,3 +238,47 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error { return errors.New("GeoLite2-City.mmdb not found in archive") } + +func (s *GeoLiteService) writeDatabaseFile(reader io.Reader) error { + baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath) + tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp") + if err != nil { + return fmt.Errorf("failed to create temporary database file: %w", err) + } + defer tmpFile.Close() + + // Limit the amount we read to maxTotalSize. + // We read one extra byte to detect if the source is larger than the limit. + limitReader := io.LimitReader(reader, maxTotalSize+1) + + // Write the file contents directly to the temporary file + written, err := io.Copy(tmpFile, limitReader) + if err != nil { + os.Remove(tmpFile.Name()) + return fmt.Errorf("failed to write database file: %w", err) + } + + if written > maxTotalSize { + os.Remove(tmpFile.Name()) + return errors.New("total database size exceeds maximum allowed limit") + } + + // Validate the downloaded database file + if db, err := maxminddb.Open(tmpFile.Name()); err == nil { + db.Close() + } else { + os.Remove(tmpFile.Name()) + return fmt.Errorf("failed to open downloaded database file: %w", err) + } + + // Ensure atomic replacement of the old database file + s.mutex.Lock() + err = os.Rename(tmpFile.Name(), common.EnvConfig.GeoLiteDBPath) + s.mutex.Unlock() + + if err != nil { + os.Remove(tmpFile.Name()) + return fmt.Errorf("failed to replace database file: %w", err) + } + return nil +}