159 lines
3.8 KiB
Go
159 lines
3.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"tss-rocks-be/internal/types"
|
|
)
|
|
|
|
const (
|
|
defaultMaxMemory = 32 << 20 // 32 MB
|
|
maxHeaderBytes = 512 // 用于MIME类型检测的最大字节数
|
|
)
|
|
|
|
// ValidateUpload 创建文件上传验证中间件
|
|
func ValidateUpload(cfg *types.UploadConfig) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// 检查是否是multipart/form-data请求
|
|
if !strings.HasPrefix(c.GetHeader("Content-Type"), "multipart/form-data") {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": "Content-Type must be multipart/form-data",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 解析multipart表单
|
|
if err := c.Request.ParseMultipartForm(defaultMaxMemory); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": fmt.Sprintf("Failed to parse form: %v", err),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
form := c.Request.MultipartForm
|
|
if form == nil || form.File == nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": "No file uploaded",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 遍历所有上传的文件
|
|
for _, files := range form.File {
|
|
for _, file := range files {
|
|
// 检查文件大小
|
|
if file.Size > int64(cfg.MaxSize)<<20 { // 转换为字节
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": fmt.Sprintf("File %s exceeds maximum size of %d MB", file.Filename, cfg.MaxSize),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 检查文件扩展名
|
|
ext := strings.ToLower(filepath.Ext(file.Filename))
|
|
if !contains(cfg.AllowedExtensions, ext) {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": fmt.Sprintf("File extension %s is not allowed", ext),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 打开文件
|
|
src, err := file.Open()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": fmt.Sprintf("Failed to open file: %v", err),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
defer src.Close()
|
|
|
|
// 读取文件头部用于MIME类型检测
|
|
header := make([]byte, maxHeaderBytes)
|
|
n, err := src.Read(header)
|
|
if err != nil && err != io.EOF {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": fmt.Sprintf("Failed to read file: %v", err),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
header = header[:n]
|
|
|
|
// 检测MIME类型
|
|
contentType := http.DetectContentType(header)
|
|
if !contains(cfg.AllowedTypes, contentType) {
|
|
c.JSON(http.StatusBadRequest, gin.H{
|
|
"error": fmt.Sprintf("File type %s is not allowed", contentType),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 将文件指针重置到开始位置
|
|
_, err = src.Seek(0, 0)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": fmt.Sprintf("Failed to read file: %v", err),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 将文件内容读入缓冲区
|
|
buf := &bytes.Buffer{}
|
|
_, err = io.Copy(buf, src)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
"error": fmt.Sprintf("Failed to read file: %v", err),
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// 将验证过的文件内容和类型保存到上下文中
|
|
c.Set("validated_file_"+file.Filename, buf)
|
|
c.Set("validated_content_type_"+file.Filename, contentType)
|
|
}
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// contains 检查切片中是否包含指定的字符串
|
|
func contains(slice []string, str string) bool {
|
|
for _, s := range slice {
|
|
if s == str {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetValidatedFile 从上下文中获取验证过的文件内容
|
|
func GetValidatedFile(c *gin.Context, filename string) (*bytes.Buffer, string, bool) {
|
|
file, exists := c.Get("validated_file_" + filename)
|
|
if !exists {
|
|
return nil, "", false
|
|
}
|
|
|
|
contentType, exists := c.Get("validated_content_type_" + filename)
|
|
if !exists {
|
|
return nil, "", false
|
|
}
|
|
|
|
return file.(*bytes.Buffer), contentType.(string), true
|
|
}
|