[feature/backend] overall enhancement of image uploading
All checks were successful
Build Backend / Build Docker Image (push) Successful in 5m3s
All checks were successful
Build Backend / Build Docker Image (push) Successful in 5m3s
This commit is contained in:
parent
6e1be3d513
commit
3e6181e578
13 changed files with 740 additions and 314 deletions
|
@ -3,146 +3,120 @@ package middleware
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"tss-rocks-be/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
|
||||
)
|
||||
|
||||
// ValidateUpload 创建文件上传验证中间件
|
||||
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
|
||||
// ValidateUpload 验证上传的文件
|
||||
func ValidateUpload(cfg *config.UploadConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否是multipart/form-data请求
|
||||
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Content-Type must be multipart/form-data",
|
||||
})
|
||||
// Get file from form
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "No file uploaded"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析multipart表单
|
||||
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Failed to parse form: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// 获取文件类型和扩展名
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil || form.File == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "No file uploaded",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有上传的文件
|
||||
for _, files := range form.File {
|
||||
for _, file := range files {
|
||||
// 检查文件大小
|
||||
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件扩展名
|
||||
ext := strings.ToLower(filepath.Ext(file.Filename))
|
||||
if !contains(cfg.AllowedExtensions, ext) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File extension %s is not allowed", ext),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to open file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 读取文件头部用于MIME类型检测
|
||||
header := make([]byte, maxHeaderBytes)
|
||||
n, err := src.Read(header)
|
||||
if err != nil && err != io.EOF {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
header = header[:n]
|
||||
|
||||
// 检测MIME类型
|
||||
contentType := http.DetectContentType(header)
|
||||
if !contains(cfg.AllowedTypes, contentType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File type %s is not allowed", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件指针重置到开始位置
|
||||
_, err = src.Seek(0, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将文件内容读入缓冲区
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = io.Copy(buf, src)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("Failed to read file: %v", err),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将验证过的文件内容和类型保存到上下文中
|
||||
c.Set("validated_file_"+file.Filename, buf)
|
||||
c.Set("validated_content_type_"+file.Filename, contentType)
|
||||
// 如果 Content-Type 为空,尝试从文件扩展名判断
|
||||
if contentType == "" {
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
contentType = "image/jpeg"
|
||||
case ".png":
|
||||
contentType = "image/png"
|
||||
case ".gif":
|
||||
contentType = "image/gif"
|
||||
case ".webp":
|
||||
contentType = "image/webp"
|
||||
case ".mp4":
|
||||
contentType = "video/mp4"
|
||||
case ".webm":
|
||||
contentType = "video/webm"
|
||||
case ".mp3":
|
||||
contentType = "audio/mpeg"
|
||||
case ".ogg":
|
||||
contentType = "audio/ogg"
|
||||
case ".wav":
|
||||
contentType = "audio/wav"
|
||||
case ".pdf":
|
||||
contentType = "application/pdf"
|
||||
case ".doc":
|
||||
contentType = "application/msword"
|
||||
case ".docx":
|
||||
contentType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 Content-Type 确定文件类型和限制
|
||||
var maxSize int64
|
||||
var allowedTypes []string
|
||||
var fileType string
|
||||
|
||||
limits := cfg.Limits
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "image/"):
|
||||
maxSize = int64(limits.Image.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Image.AllowedTypes
|
||||
fileType = "image"
|
||||
case strings.HasPrefix(contentType, "video/"):
|
||||
maxSize = int64(limits.Video.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Video.AllowedTypes
|
||||
fileType = "video"
|
||||
case strings.HasPrefix(contentType, "audio/"):
|
||||
maxSize = int64(limits.Audio.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Audio.AllowedTypes
|
||||
fileType = "audio"
|
||||
case strings.HasPrefix(contentType, "application/"):
|
||||
maxSize = int64(limits.Document.MaxSize) * 1024 * 1024
|
||||
allowedTypes = limits.Document.AllowedTypes
|
||||
fileType = "document"
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Unsupported file type: %s", contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件类型是否允许
|
||||
typeAllowed := false
|
||||
for _, allowed := range allowedTypes {
|
||||
if contentType == allowed {
|
||||
typeAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !typeAllowed {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("Unsupported %s type: %s", fileType, contentType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件大小
|
||||
if file.Size > maxSize {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("File size exceeds the limit (%d MB) for %s files", limits.Image.MaxSize, fileType),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含指定的字符串
|
||||
func contains(slice []string, str string) bool {
|
||||
for _, s := range slice {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetValidatedFile 从上下文中获取验证过的文件内容
|
||||
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
|
||||
file, exists := c.Get("validated_file_" + filename)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue