tss-rocks/backend/internal/middleware/upload_test.go
CDN 05ddc1f783
Some checks failed
Build Backend / Build Docker Image (push) Successful in 3m33s
Test Backend / test (push) Failing after 31s
[feature] migrate to monorepo
2025-02-21 00:49:20 +08:00

262 lines
6.2 KiB
Go

package middleware
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"tss-rocks-be/internal/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func createMultipartRequest(t *testing.T, filename string, content []byte, contentType string) (*http.Request, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
_, err = io.Copy(part, bytes.NewReader(content))
if err != nil {
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
req := httptest.NewRequest("POST", "/upload", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
return req, nil
}
func TestValidateUpload(t *testing.T) {
tests := []struct {
name string
config *types.UploadConfig
filename string
content []byte
setupRequest func(*testing.T) *http.Request
expectedStatus int
expectedError string
}{
{
name: "Valid image upload",
config: &types.UploadConfig{
MaxSize: 5, // 5MB
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.jpg",
content: []byte{
0xFF, 0xD8, 0xFF, 0xE0, // JPEG magic numbers
0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00,
},
expectedStatus: http.StatusOK,
},
{
name: "Invalid file extension",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg", ".jpeg", ".png"},
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "test.txt",
content: []byte("test content"),
expectedStatus: http.StatusBadRequest,
expectedError: "File extension .txt is not allowed",
},
{
name: "File too large",
config: &types.UploadConfig{
MaxSize: 1, // 1MB
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: make([]byte, 2<<20), // 2MB
expectedStatus: http.StatusBadRequest,
expectedError: "File large.jpg exceeds maximum size of 1 MB",
},
{
name: "Invalid content type",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
filename: "fake.jpg",
content: []byte("not a real image"),
expectedStatus: http.StatusBadRequest,
expectedError: "File type text/plain; charset=utf-8 is not allowed",
},
{
name: "Missing file",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
req := httptest.NewRequest("POST", "/upload", strings.NewReader(""))
req.Header.Set("Content-Type", "multipart/form-data")
return req
},
expectedStatus: http.StatusBadRequest,
expectedError: "Failed to parse form",
},
{
name: "Invalid content type header",
config: &types.UploadConfig{
MaxSize: 5,
AllowedExtensions: []string{".jpg"},
AllowedTypes: []string{"image/jpeg"},
},
setupRequest: func(t *testing.T) *http.Request {
return httptest.NewRequest("POST", "/upload", nil)
},
expectedStatus: http.StatusBadRequest,
expectedError: "Content-Type must be multipart/form-data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
var req *http.Request
var err error
if tt.setupRequest != nil {
req = tt.setupRequest(t)
} else {
req, err = createMultipartRequest(t, tt.filename, tt.content, "")
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
}
c.Request = req
middleware := ValidateUpload(tt.config)
middleware(c)
assert.Equal(t, tt.expectedStatus, w.Code)
if tt.expectedError != "" {
var response map[string]string
err := json.NewDecoder(w.Body).Decode(&response)
assert.NoError(t, err)
assert.Contains(t, response["error"], tt.expectedError)
}
})
}
}
func TestGetValidatedFile(t *testing.T) {
tests := []struct {
name string
setupContext func(*gin.Context)
filename string
expectedFound bool
expectedError string
}{
{
name: "Get existing file",
setupContext: func(c *gin.Context) {
// 创建测试文件内容
content := []byte("test content")
buf := bytes.NewBuffer(content)
// 设置验证过的文件和内容类型
c.Set("validated_file_test.txt", buf)
c.Set("validated_content_type_test.txt", "text/plain")
},
filename: "test.txt",
expectedFound: true,
},
{
name: "File not found",
setupContext: func(c *gin.Context) {
// 不设置任何文件
},
filename: "nonexistent.txt",
expectedFound: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
if tt.setupContext != nil {
tt.setupContext(c)
}
buffer, contentType, found := GetValidatedFile(c, tt.filename)
assert.Equal(t, tt.expectedFound, found)
if tt.expectedFound {
assert.NotNil(t, buffer)
assert.NotEmpty(t, contentType)
} else {
assert.Nil(t, buffer)
assert.Empty(t, contentType)
}
})
}
}
func TestContains(t *testing.T) {
tests := []struct {
name string
slice []string
str string
expected bool
}{
{
name: "String found in slice",
slice: []string{"a", "b", "c"},
str: "b",
expected: true,
},
{
name: "String not found in slice",
slice: []string{"a", "b", "c"},
str: "d",
expected: false,
},
{
name: "Empty slice",
slice: []string{},
str: "a",
expected: false,
},
{
name: "Empty string",
slice: []string{"a", "b", "c"},
str: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := contains(tt.slice, tt.str)
assert.Equal(t, tt.expected, result)
})
}
}