package handler import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "mime/multipart" "net/http" "net/http/httptest" "strings" "testing" "tss-rocks-be/ent" "tss-rocks-be/internal/config" "tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/storage" "net/textproto" "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" ) type MediaHandlerTestSuite struct { suite.Suite ctrl *gomock.Controller service *mock.MockService handler *Handler router *gin.Engine } func (s *MediaHandlerTestSuite) SetupTest() { s.ctrl = gomock.NewController(s.T()) s.service = mock.NewMockService(s.ctrl) s.handler = NewHandler(&config.Config{}, s.service) s.router = gin.New() } func (s *MediaHandlerTestSuite) TearDownTest() { s.ctrl.Finish() } func TestMediaHandlerSuite(t *testing.T) { suite.Run(t, new(MediaHandlerTestSuite)) } func (s *MediaHandlerTestSuite) TestListMedia() { testCases := []struct { name string query string setupMock func() expectedStatus int expectedError string }{ { name: "成功列出媒体", query: "?limit=10&offset=0", setupMock: func() { s.service.EXPECT(). ListMedia(gomock.Any(), 10, 0). Return([]*ent.Media{{ID: 1}}, nil) }, expectedStatus: http.StatusOK, }, { name: "使用默认限制和偏移", query: "", setupMock: func() { s.service.EXPECT(). ListMedia(gomock.Any(), 10, 0). Return([]*ent.Media{{ID: 1}}, nil) }, expectedStatus: http.StatusOK, }, { name: "列出媒体失败", query: "", setupMock: func() { s.service.EXPECT(). ListMedia(gomock.Any(), 10, 0). Return(nil, errors.New("failed to list media")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to list media", }, } for _, tc := range testCases { s.Run(tc.name, func() { // 设置 mock tc.setupMock() // 创建请求 req, _ := http.NewRequest(http.MethodGet, "/media"+tc.query, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // 执行请求 s.handler.ListMedia(c) // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.Equal(tc.expectedError, response["error"]) } else { var response []*ent.Media err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.NotEmpty(response) } }) } } func (s *MediaHandlerTestSuite) TestUploadMedia() { testCases := []struct { name string setupRequest func() (*http.Request, error) setupMock func() expectedStatus int expectedError string }{ { name: "成功上传媒体", setupRequest: func() (*http.Request, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) // 创建文件部分 fileHeader := make(textproto.MIMEHeader) fileHeader.Set("Content-Type", "image/jpeg") fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`) part, err := writer.CreatePart(fileHeader) if err != nil { return nil, err } testContent := "test content" _, err = io.Copy(part, strings.NewReader(testContent)) if err != nil { return nil, err } writer.Close() req := httptest.NewRequest(http.MethodPost, "/media", body) req.Header.Set("Content-Type", writer.FormDataContentType()) return req, nil }, setupMock: func() { expectedFile := &multipart.FileHeader{ Filename: "test.jpg", Size: int64(len("test content")), Header: textproto.MIMEHeader{ "Content-Type": []string{"image/jpeg"}, }, } s.service.EXPECT(). Upload(gomock.Any(), gomock.Any(), 1). DoAndReturn(func(_ context.Context, f *multipart.FileHeader, uid int) (*ent.Media, error) { s.Equal(expectedFile.Filename, f.Filename) s.Equal(expectedFile.Size, f.Size) s.Equal(expectedFile.Header.Get("Content-Type"), f.Header.Get("Content-Type")) return &ent.Media{ID: 1}, nil }) }, expectedStatus: http.StatusCreated, }, { name: "未授权", setupRequest: func() (*http.Request, error) { req := httptest.NewRequest(http.MethodPost, "/media", nil) return req, nil }, setupMock: func() {}, expectedStatus: http.StatusUnauthorized, expectedError: "Unauthorized", }, { name: "上传失败", setupRequest: func() (*http.Request, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) // 创建文件部分 fileHeader := make(textproto.MIMEHeader) fileHeader.Set("Content-Type", "image/jpeg") fileHeader.Set("Content-Disposition", `form-data; name="file"; filename="test.jpg"`) part, err := writer.CreatePart(fileHeader) if err != nil { return nil, err } testContent := "test content" _, err = io.Copy(part, strings.NewReader(testContent)) if err != nil { return nil, err } writer.Close() req := httptest.NewRequest(http.MethodPost, "/media", body) req.Header.Set("Content-Type", writer.FormDataContentType()) return req, nil }, setupMock: func() { s.service.EXPECT(). Upload(gomock.Any(), gomock.Any(), 1). Return(nil, errors.New("failed to upload")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to upload media", }, } for _, tc := range testCases { s.Run(tc.name, func() { // 设置 mock tc.setupMock() // 创建请求 req, err := tc.setupRequest() s.NoError(err) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // 设置用户ID(除了未授权的测试用例) if tc.expectedError != "Unauthorized" { c.Set("user_id", 1) } // 执行请求 s.handler.UploadMedia(c) // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.Equal(tc.expectedError, response["error"]) } else { var response *ent.Media err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.NotNil(response) } }) } } func (s *MediaHandlerTestSuite) TestGetMedia() { testCases := []struct { name string mediaID string setupMock func() expectedStatus int expectedError string }{ { name: "成功获取媒体", mediaID: "1", setupMock: func() { media := &ent.Media{ ID: 1, MimeType: "image/jpeg", OriginalName: "test.jpg", } s.service.EXPECT(). GetMedia(gomock.Any(), 1). Return(media, nil) s.service.EXPECT(). GetFile(gomock.Any(), 1). Return(io.NopCloser(strings.NewReader("test content")), &storage.FileInfo{ Size: 11, Name: "test.jpg", ContentType: "image/jpeg", }, nil) }, expectedStatus: http.StatusOK, }, { name: "无效的媒体ID", mediaID: "invalid", setupMock: func() {}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid media ID", }, { name: "获取媒体元数据失败", mediaID: "1", setupMock: func() { s.service.EXPECT(). GetMedia(gomock.Any(), 1). Return(nil, errors.New("failed to get media")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to get media", }, { name: "获取媒体文件失败", mediaID: "1", setupMock: func() { media := &ent.Media{ ID: 1, MimeType: "image/jpeg", OriginalName: "test.jpg", } s.service.EXPECT(). GetMedia(gomock.Any(), 1). Return(media, nil) s.service.EXPECT(). GetFile(gomock.Any(), 1). Return(nil, nil, errors.New("failed to get file")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to get media file", }, } for _, tc := range testCases { s.Run(tc.name, func() { // 设置 mock tc.setupMock() // 创建请求 req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/media/%s", tc.mediaID), nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // Extract ID from URL path parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") if len(parts) >= 2 { c.Params = []gin.Param{{Key: "id", Value: parts[1]}} } // 执行请求 s.handler.GetMedia(c) // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.Equal(tc.expectedError, response["error"]) } else { s.Equal("image/jpeg", w.Header().Get("Content-Type")) s.Equal("11", w.Header().Get("Content-Length")) s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition")) s.Equal("test content", w.Body.String()) } }) } } func (s *MediaHandlerTestSuite) TestGetMediaFile() { testCases := []struct { name string setupRequest func() (*http.Request, error) setupMock func() expectedStatus int expectedBody []byte }{ { name: "成功获取媒体文件", setupRequest: func() (*http.Request, error) { return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil }, setupMock: func() { fileContent := "test file content" s.service.EXPECT(). GetFile(gomock.Any(), 1). Return(io.NopCloser(strings.NewReader(fileContent)), &storage.FileInfo{ Name: "test.jpg", Size: int64(len(fileContent)), ContentType: "image/jpeg", }, nil) }, expectedStatus: http.StatusOK, expectedBody: []byte("test file content"), }, { name: "无效的媒体ID", setupRequest: func() (*http.Request, error) { return httptest.NewRequest(http.MethodGet, "/media/invalid/file", nil), nil }, setupMock: func() {}, expectedStatus: http.StatusBadRequest, }, { name: "获取媒体文件失败", setupRequest: func() (*http.Request, error) { return httptest.NewRequest(http.MethodGet, "/media/1/file", nil), nil }, setupMock: func() { s.service.EXPECT(). GetFile(gomock.Any(), 1). Return(nil, nil, errors.New("failed to get file")) }, expectedStatus: http.StatusInternalServerError, }, } for _, tc := range testCases { s.Run(tc.name, func() { // Setup req, err := tc.setupRequest() s.Require().NoError(err) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // Extract ID from URL path parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") if len(parts) >= 2 { c.Params = []gin.Param{{Key: "id", Value: parts[1]}} } // Setup mock tc.setupMock() // Test s.handler.GetMediaFile(c) // Verify s.Equal(tc.expectedStatus, w.Code) if tc.expectedBody != nil { s.Equal(tc.expectedBody, w.Body.Bytes()) s.Equal("image/jpeg", w.Header().Get("Content-Type")) s.Equal(fmt.Sprintf("%d", len(tc.expectedBody)), w.Header().Get("Content-Length")) s.Equal("inline; filename=test.jpg", w.Header().Get("Content-Disposition")) } }) } } func (s *MediaHandlerTestSuite) TestDeleteMedia() { testCases := []struct { name string mediaID string setupMock func() expectedStatus int expectedError string }{ { name: "成功删除媒体", mediaID: "1", setupMock: func() { s.service.EXPECT(). DeleteMedia(gomock.Any(), 1, 1). Return(nil) }, expectedStatus: http.StatusNoContent, }, { name: "未授权", mediaID: "1", setupMock: func() {}, expectedStatus: http.StatusUnauthorized, expectedError: "Unauthorized", }, { name: "无效的媒体ID", mediaID: "invalid", setupMock: func() {}, expectedStatus: http.StatusBadRequest, expectedError: "Invalid media ID", }, { name: "删除媒体失败", mediaID: "1", setupMock: func() { s.service.EXPECT(). DeleteMedia(gomock.Any(), 1, 1). Return(errors.New("failed to delete")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to delete media", }, } for _, tc := range testCases { s.Run(tc.name, func() { // 设置 mock tc.setupMock() // 创建请求 req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/media/%s", tc.mediaID), nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req // Extract ID from URL path parts := strings.Split(strings.Trim(req.URL.Path, "/"), "/") if len(parts) >= 2 { c.Params = []gin.Param{{Key: "id", Value: parts[1]}} } // 设置用户ID(除了未授权的测试用例) if tc.expectedError != "Unauthorized" { c.Set("user_id", 1) } // 执行请求 s.handler.DeleteMedia(c) // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { var response map[string]string err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) s.Equal(tc.expectedError, response["error"]) } }) } }