package middleware import ( "bytes" "fmt" "io" "net/http" "path/filepath" "strings" "github.com/gin-gonic/gin" "tss-rocks-be/internal/types" ) const ( defaultMaxMemory = 32 << 20 // 32 MB maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数 ) // ValidateUpload 创建文件上传验证中间件 func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc { return func(c *gin.Context) { // 检查是否是multipart/form-data请求 if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") { c.JSON(http.StatusBadRequest, gin.H{ "error": "Content-Type must be multipart/form-data", }) c.Abort() return } // 解析multipart表单 if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("Failed to parse form: %v", err), }) c.Abort() return } form := c.Request.MultipartForm if form == nil || form.File == nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "No file uploaded", }) c.Abort() return } // 遍历所有上传的文件 for _, files := range form.File { for _, file := range files { // 检查文件大小 if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节 c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize), }) c.Abort() return } // 检查文件扩展名 ext := strings.ToLower(filepath.Ext(file.Filename)) if !contains(cfg.AllowedExtensions, ext) { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("File extension %s is not allowed", ext), }) c.Abort() return } // 打开文件 src, err := file.Open() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("Failed to open file: %v", err), }) c.Abort() return } defer src.Close() // 读取文件头部用于MIME类型检测 header := make([]byte, maxHeaderBytes) n, err := src.Read(header) if err != nil && err != io.EOF { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("Failed to read file: %v", err), }) c.Abort() return } header = header[:n] // 检测MIME类型 contentType := http.DetectContentType(header) if !contains(cfg.AllowedTypes, contentType) { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("File type %s is not allowed", contentType), }) c.Abort() return } // 将文件指针重置到开始位置 _, err = src.Seek(0, 0) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("Failed to read file: %v", err), }) c.Abort() return } // 将文件内容读入缓冲区 buf := &bytes.Buffer{} _, err = io.Copy(buf, src) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": fmt.Sprintf("Failed to read file: %v", err), }) c.Abort() return } // 将验证过的文件内容和类型保存到上下文中 c.Set("validated_file_"+file.Filename, buf) c.Set("validated_content_type_"+file.Filename, contentType) } } c.Next() } } // contains 检查切片中是否包含指定的字符串 func contains(slice []string, str string) bool { for _, s := range slice { if s == str { return true } } return false } // GetValidatedFile 从上下文中获取验证过的文件内容 func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) { file, exists := c.Get("validated_file_" + filename) if !exists { return nil, "", false } contentType, exists := c.Get("validated_content_type_" + filename) if !exists { return nil, "", false } return file.(*bytes.Buffer), contentType.(string), true }