package middleware import ( "bytes" "fmt" "net/http" "path/filepath" "strings" "tss-rocks-be/internal/config" "github.com/gin-gonic/gin" ) // ValidateUpload 验证上传的文件 func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc { return func(c *gin.Context) { // Get file from form file, err := c.FormFile("file") if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"}) c.Abort() return } // 获取文件类型和扩展名 contentType := file.Header.Get("Content-Type") ext := strings.ToLower(filepath.Ext(file.Filename)) // 如果 Content-Type 为空,尝试从文件扩展名判断 if contentType == "" { switch ext { case ".jpg", ".jpeg": contentType = "image/jpeg" case ".png": contentType = "image/png" case ".gif": contentType = "image/gif" case ".webp": contentType = "image/webp" case ".mp4": contentType = "video/mp4" case ".webm": contentType = "video/webm" case ".mp3": contentType = "audio/mpeg" case ".ogg": contentType = "audio/ogg" case ".wav": contentType = "audio/wav" case ".pdf": contentType = "application/pdf" case ".doc": contentType = "application/msword" case ".docx": contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" } } // 根据 Content-Type 确定文件类型和限制 var maxSize int64 var allowedTypes []string var fileType string limits := cfg.Limits switch { case strings.HasPrefix(contentType, "image/"): maxSize = int64(limits.Image.MaxSize) * 1024 * 1024 allowedTypes = limits.Image.AllowedTypes fileType = "image" case strings.HasPrefix(contentType, "video/"): maxSize = int64(limits.Video.MaxSize) * 1024 * 1024 allowedTypes = limits.Video.AllowedTypes fileType = "video" case strings.HasPrefix(contentType, "audio/"): maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024 allowedTypes = limits.Audio.AllowedTypes fileType = "audio" case strings.HasPrefix(contentType, "application/"): maxSize = int64(limits.Document.MaxSize) * 1024 * 1024 allowedTypes = limits.Document.AllowedTypes fileType = "document" default: c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("Unsupported file type: %s", contentType), }) c.Abort() return } // 检查文件类型是否允许 typeAllowed := false for _, allowed := range allowedTypes { if contentType == allowed { typeAllowed = true break } } if !typeAllowed { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType), }) c.Abort() return } // 检查文件大小 if file.Size > maxSize { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType), }) c.Abort() return } c.Next() } } // GetValidatedFile 从上下文中获取验证过的文件内容 func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) { file, exists := c.Get("validated_file_" + filename) if !exists { return nil, "", false } contentType, exists := c.Get("validated_content_type_" + filename) if !exists { return nil, "", false } return file.(*bytes.Buffer), contentType.(string), true }