package service import ( "bytes" "context" "fmt" "io" "mime/multipart" "net/textproto" "reflect" "testing" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" "tss-rocks-be/ent" "tss-rocks-be/internal/storage" "tss-rocks-be/internal/storage/mock" "tss-rocks-be/internal/testutil" "bou.ke/monkey" ) type MediaServiceTestSuite struct { suite.Suite ctx context.Context client *ent.Client storage *mock.MockStorage ctrl *gomock.Controller svc MediaService } func (s *MediaServiceTestSuite) SetupTest() { s.ctx = context.Background() s.client = testutil.NewTestClient() require.NotNil(s.T(), s.client) s.ctrl = gomock.NewController(s.T()) s.storage = mock.NewMockStorage(s.ctrl) s.svc = NewMediaService(s.client, s.storage) // 清理数据库 _, err := s.client.Media.Delete().Exec(s.ctx) require.NoError(s.T(), err) } func (s *MediaServiceTestSuite) TearDownTest() { s.ctrl.Finish() s.client.Close() } func TestMediaServiceSuite(t *testing.T) { suite.Run(t, new(MediaServiceTestSuite)) } type mockFileHeader struct { filename string contentType string size int64 content []byte } func (h *mockFileHeader) Open() (multipart.File, error) { return newMockMultipartFile(h.content), nil } func (h *mockFileHeader) Filename() string { return h.filename } func (h *mockFileHeader) Size() int64 { return h.size } func (h *mockFileHeader) Header() textproto.MIMEHeader { header := make(textproto.MIMEHeader) header.Set("Content-Type", h.contentType) return header } func (s *MediaServiceTestSuite) createTestFile(filename, contentType string, content []byte) *multipart.FileHeader { header := &multipart.FileHeader{ Filename: filename, Header: make(textproto.MIMEHeader), Size: int64(len(content)), } header.Header.Set("Content-Type", contentType) monkey.PatchInstanceMethod(reflect.TypeOf(header), "Open", func(_ *multipart.FileHeader) (multipart.File, error) { return newMockMultipartFile(content), nil }) return header } func (s *MediaServiceTestSuite) TestUpload() { testCases := []struct { name string filename string contentType string content []byte setupMock func() wantErr bool errMsg string }{ { name: "Upload text file", filename: "test.txt", contentType: "text/plain", content: []byte("test content"), setupMock: func() { s.storage.EXPECT(). Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()). DoAndReturn(func(ctx context.Context, name, contentType string, reader io.Reader) (*storage.FileInfo, error) { content, err := io.ReadAll(reader) s.Require().NoError(err) s.Equal([]byte("test content"), content) return &storage.FileInfo{ ID: "test-id", Name: "test.txt", ContentType: "text/plain", Size: int64(len(content)), }, nil }) }, wantErr: false, }, { name: "Invalid filename", filename: "../test.txt", contentType: "text/plain", content: []byte("test content"), setupMock: func() {}, wantErr: true, errMsg: "invalid filename", }, { name: "Storage error", filename: "test.txt", contentType: "text/plain", content: []byte("test content"), setupMock: func() { s.storage.EXPECT(). Save(gomock.Any(), "test.txt", "text/plain", gomock.Any()). Return(nil, fmt.Errorf("storage error")) }, wantErr: true, errMsg: "storage error", }, } for _, tc := range testCases { s.Run(tc.name, func() { // Setup mock tc.setupMock() // Create test file fileHeader := s.createTestFile(tc.filename, tc.contentType, tc.content) // Add debug output s.T().Logf("Testing file: %s, content-type: %s, size: %d", fileHeader.Filename, fileHeader.Header.Get("Content-Type"), fileHeader.Size) // Test upload media, err := s.svc.Upload(s.ctx, fileHeader, 1) // Add debug output if err != nil { s.T().Logf("Upload error: %v", err) } if tc.wantErr { s.Require().Error(err) s.Contains(err.Error(), tc.errMsg) return } s.Require().NoError(err) s.NotNil(media) s.Equal(tc.filename, media.OriginalName) s.Equal(tc.contentType, media.MimeType) s.Equal(int64(len(tc.content)), media.Size) s.Equal("1", media.CreatedBy) }) } } func (s *MediaServiceTestSuite) TestGet() { // Create test media media, err := s.client.Media.Create(). SetStorageID("test-id"). SetOriginalName("test.txt"). SetMimeType("text/plain"). SetSize(12). SetURL("/api/media/test-id"). SetCreatedBy("1"). Save(s.ctx) s.Require().NoError(err) // Test get existing media result, err := s.svc.Get(s.ctx, media.ID) s.Require().NoError(err) s.Equal(media.ID, result.ID) s.Equal(media.OriginalName, result.OriginalName) // Test get non-existing media _, err = s.svc.Get(s.ctx, -1) s.Require().Error(err) s.Contains(err.Error(), "media not found") } func (s *MediaServiceTestSuite) TestDelete() { // Create test media media, err := s.client.Media.Create(). SetStorageID("test-id"). SetOriginalName("test.txt"). SetMimeType("text/plain"). SetSize(12). SetURL("/api/media/test-id"). SetCreatedBy("1"). Save(s.ctx) s.Require().NoError(err) // Test delete by unauthorized user err = s.svc.Delete(s.ctx, media.ID, 2) s.Require().Error(err) s.Contains(err.Error(), "unauthorized") // Test delete by owner s.storage.EXPECT(). Delete(gomock.Any(), "test-id"). Return(nil) err = s.svc.Delete(s.ctx, media.ID, 1) s.Require().NoError(err) // Verify media is deleted _, err = s.svc.Get(s.ctx, media.ID) s.Require().Error(err) s.Contains(err.Error(), "not found") } func (s *MediaServiceTestSuite) TestList() { // Create test media for i := 0; i < 5; i++ { _, err := s.client.Media.Create(). SetStorageID(fmt.Sprintf("test-id-%d", i)). SetOriginalName(fmt.Sprintf("test-%d.txt", i)). SetMimeType("text/plain"). SetSize(12). SetURL(fmt.Sprintf("/api/media/test-id-%d", i)). SetCreatedBy("1"). Save(s.ctx) s.Require().NoError(err) } // Test list with limit and offset media, err := s.svc.List(s.ctx, 3, 1) s.Require().NoError(err) s.Len(media, 3) } func (s *MediaServiceTestSuite) TestGetFile() { // Create test media media, err := s.client.Media.Create(). SetStorageID("test-id"). SetOriginalName("test.txt"). SetMimeType("text/plain"). SetSize(12). SetURL("/api/media/test-id"). SetCreatedBy("1"). Save(s.ctx) s.Require().NoError(err) // Mock storage.Get mockReader := io.NopCloser(bytes.NewReader([]byte("test content"))) mockFileInfo := &storage.FileInfo{ ID: "test-id", Name: "test.txt", ContentType: "text/plain", Size: 12, } s.storage.EXPECT(). Get(gomock.Any(), "test-id"). Return(mockReader, mockFileInfo, nil) // Test get file reader, info, err := s.svc.GetFile(s.ctx, media.ID) s.Require().NoError(err) s.NotNil(reader) s.Equal(mockFileInfo, info) // Test get non-existing file _, _, err = s.svc.GetFile(s.ctx, -1) s.Require().Error(err) s.Contains(err.Error(), "not found") } func (s *MediaServiceTestSuite) TestIsValidFilename() { testCases := []struct { name string filename string want bool }{ { name: "Valid filename", filename: "test.txt", want: true, }, { name: "Invalid filename with ../", filename: "../test.txt", want: false, }, { name: "Invalid filename with ./", filename: "./test.txt", want: false, }, { name: "Invalid filename with backslash", filename: "test\\file.txt", want: false, }, } for _, tc := range testCases { s.Run(tc.name, func() { got := isValidFilename(tc.filename) s.Equal(tc.want, got) }) } }