package middleware import ( "bytes" "encoding/json" "io" "mime/multipart" "net/http" "net/http/httptest" "strings" "testing" "tss-rocks-be/internal/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("file", filename) if err != nil { return nil, err } _, err = io.Copy(part, bytes.NewReader(content)) if err != nil { return nil, err } err = writer.Close() if err != nil { return nil, err } req := httptest.NewRequest("POST", "/upload", body) req.Header.Set("Content-Type", writer.FormDataContentType()) return req, nil } func TestValidateUpload(t *testing.T) { tests := []struct { name string config *types.UploadConfig filename string content []byte setupRequest func(*testing.T) *http.Request expectedStatus int expectedError string }{ { name: "Valid image upload", config: &types.UploadConfig{ MaxSize: 5, // 5MB AllowedExtensions: []string{".jpg", ".jpeg", ".png"}, AllowedTypes: []string{"image/jpeg", "image/png"}, }, filename: "test.jpg", content: []byte{ 0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, }, expectedStatus: http.StatusOK, }, { name: "Invalid file extension", config: &types.UploadConfig{ MaxSize: 5, AllowedExtensions: []string{".jpg", ".jpeg", ".png"}, AllowedTypes: []string{"image/jpeg", "image/png"}, }, filename: "test.txt", content: []byte("test content"), expectedStatus: http.StatusBadRequest, expectedError: "File extension .txt is not allowed", }, { name: "File too large", config: &types.UploadConfig{ MaxSize: 1, // 1MB AllowedExtensions: []string{".jpg"}, AllowedTypes: []string{"image/jpeg"}, }, filename: "large.jpg", content: make([]byte, 2<<20), // 2MB expectedStatus: http.StatusBadRequest, expectedError: "File large.jpg exceeds maximum size of 1 MB", }, { name: "Invalid content type", config: &types.UploadConfig{ MaxSize: 5, AllowedExtensions: []string{".jpg"}, AllowedTypes: []string{"image/jpeg"}, }, filename: "fake.jpg", content: []byte("not a real image"), expectedStatus: http.StatusBadRequest, expectedError: "File type text/plain; charset=utf-8 is not allowed", }, { name: "Missing file", config: &types.UploadConfig{ MaxSize: 5, AllowedExtensions: []string{".jpg"}, AllowedTypes: []string{"image/jpeg"}, }, setupRequest: func(t *testing.T) *http.Request { req := httptest.NewRequest("POST", "/upload", strings.NewReader("")) req.Header.Set("Content-Type", "multipart/form-data") return req }, expectedStatus: http.StatusBadRequest, expectedError: "Failed to parse form", }, { name: "Invalid content type header", config: &types.UploadConfig{ MaxSize: 5, AllowedExtensions: []string{".jpg"}, AllowedTypes: []string{"image/jpeg"}, }, setupRequest: func(t *testing.T) *http.Request { return httptest.NewRequest("POST", "/upload", nil) }, expectedStatus: http.StatusBadRequest, expectedError: "Content-Type must be multipart/form-data", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) var req *http.Request var err error if tt.setupRequest != nil { req = tt.setupRequest(t) } else { req, err = createMultipartRequest(t, tt.filename, tt.content, "") if err != nil { t.Fatalf("Failed to create request: %v", err) } } c.Request = req middleware := ValidateUpload(tt.config) middleware(c) assert.Equal(t, tt.expectedStatus, w.Code) if tt.expectedError != "" { var response map[string]string err := json.NewDecoder(w.Body).Decode(&response) assert.NoError(t, err) assert.Contains(t, response["error"], tt.expectedError) } }) } } func TestGetValidatedFile(t *testing.T) { tests := []struct { name string setupContext func(*gin.Context) filename string expectedFound bool expectedError string }{ { name: "Get existing file", setupContext: func(c *gin.Context) { // 创建测试文件内容 content := []byte("test content") buf := bytes.NewBuffer(content) // 设置验证过的文件和内容类型 c.Set("validated_file_test.txt", buf) c.Set("validated_content_type_test.txt", "text/plain") }, filename: "test.txt", expectedFound: true, }, { name: "File not found", setupContext: func(c *gin.Context) { // 不设置任何文件 }, filename: "nonexistent.txt", expectedFound: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) if tt.setupContext != nil { tt.setupContext(c) } buffer, contentType, found := GetValidatedFile(c, tt.filename) assert.Equal(t, tt.expectedFound, found) if tt.expectedFound { assert.NotNil(t, buffer) assert.NotEmpty(t, contentType) } else { assert.Nil(t, buffer) assert.Empty(t, contentType) } }) } } func TestContains(t *testing.T) { tests := []struct { name string slice []string str string expected bool }{ { name: "String found in slice", slice: []string{"a", "b", "c"}, str: "b", expected: true, }, { name: "String not found in slice", slice: []string{"a", "b", "c"}, str: "d", expected: false, }, { name: "Empty slice", slice: []string{}, str: "a", expected: false, }, { name: "Empty string", slice: []string{"a", "b", "c"}, str: "", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := contains(tt.slice, tt.str) assert.Equal(t, tt.expected, result) }) } }