package service import ( "bytes" "context" "fmt" "io" "mime/multipart" "strings" "path/filepath" "tss-rocks-be/ent" "tss-rocks-be/internal/storage" "tss-rocks-be/pkg/imageutil" ) type MediaService interface { // Upload uploads a new file and creates a media record Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) // Get retrieves a media file by ID Get(ctx context.Context, id int) (*ent.Media, error) // Delete deletes a media file Delete(ctx context.Context, id int, userID int) error // List lists media files with pagination List(ctx context.Context, limit, offset int) ([]*ent.Media, error) // GetFile gets the file content and info GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) } type mediaService struct { client *ent.Client storage storage.Storage } func NewMediaService(client *ent.Client, storage storage.Storage) MediaService { return &mediaService{ client: client, storage: storage, } } // isValidFilename checks if a filename is valid func isValidFilename(filename string) bool { // Check for illegal characters if strings.Contains(filename, "../") || strings.Contains(filename, "./") || strings.Contains(filename, "\\") { return false } return true } func (s *mediaService) Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error) { // Validate filename if !isValidFilename(file.Filename) { return nil, fmt.Errorf("invalid filename: %s", file.Filename) } // Open the file src, err := file.Open() if err != nil { return nil, fmt.Errorf("failed to open file: %w", err) } defer src.Close() // Read file content for processing fileBytes, err := io.ReadAll(src) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } contentType := file.Header.Get("Content-Type") filename := file.Filename var processedBytes []byte // Process image if it's an image file if imageutil.IsImageFormat(contentType) { opts := imageutil.DefaultOptions() processedBytes, err = imageutil.ProcessImage(bytes.NewReader(fileBytes), opts) if err != nil { return nil, fmt.Errorf("failed to process image: %w", err) } // Update content type and filename for WebP contentType = "image/webp" filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + ".webp" } else { processedBytes = fileBytes } // Save the processed file fileInfo, err := s.storage.Save(ctx, filename, contentType, bytes.NewReader(processedBytes)) if err != nil { return nil, fmt.Errorf("failed to save file: %w", err) } // Create media record in database media, err := s.client.Media.Create(). SetStorageID(fileInfo.ID). SetOriginalName(filename). SetMimeType(contentType). SetSize(int64(len(processedBytes))). SetURL(fmt.Sprintf("/api/media/%s", fileInfo.ID)). SetCreatedBy(fmt.Sprint(userID)). Save(ctx) if err != nil { // Try to cleanup the stored file if database operation fails _ = s.storage.Delete(ctx, fileInfo.ID) return nil, fmt.Errorf("failed to create media record: %w", err) } return media, nil } func (s *mediaService) Get(ctx context.Context, id int) (*ent.Media, error) { media, err := s.client.Media.Get(ctx, id) if err != nil { if ent.IsNotFound(err) { return nil, fmt.Errorf("media not found: %d", id) } return nil, fmt.Errorf("failed to get media: %w", err) } return media, nil } func (s *mediaService) Delete(ctx context.Context, id int, userID int) error { media, err := s.Get(ctx, id) if err != nil { return err } // Check ownership if media.CreatedBy != fmt.Sprintf("%d", userID) { return fmt.Errorf("unauthorized to delete media") } // Delete from storage if err := s.storage.Delete(ctx, media.StorageID); err != nil { return fmt.Errorf("failed to delete file from storage: %w", err) } // Delete from database if err := s.client.Media.DeleteOne(media).Exec(ctx); err != nil { return fmt.Errorf("failed to delete media record: %w", err) } return nil } func (s *mediaService) List(ctx context.Context, limit, offset int) ([]*ent.Media, error) { media, err := s.client.Media.Query(). Order(ent.Desc("created_at")). Limit(limit). Offset(offset). All(ctx) if err != nil { return nil, fmt.Errorf("failed to list media: %w", err) } return media, nil } func (s *mediaService) GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error) { media, err := s.Get(ctx, id) if err != nil { return nil, nil, err } reader, info, err := s.storage.Get(ctx, media.StorageID) if err != nil { return nil, nil, fmt.Errorf("failed to get file from storage: %w", err) } return reader, info, nil }