package middleware import ( "encoding/json" "fmt" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" ) // AuthMiddleware creates a middleware for JWT authentication func AuthMiddleware(jwtSecret string) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) c.Abort() return } parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) c.Abort() return } token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, jwt.ErrSignatureInvalid } return []byte(jwtSecret), nil }) if err != nil { log.Error().Err(err).Msg("Failed to parse token") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { // 添加调试信息 log.Debug().Interface("claims", claims).Msg("Token claims") // 获取用户ID sub, exists := claims["sub"] if !exists { log.Error().Msg("Token does not contain sub claim") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"}) c.Abort() return } // 打印类型信息 log.Debug(). Str("type", fmt.Sprintf("%T", sub)). Interface("value", sub). Msg("User ID from token") var userID int switch v := sub.(type) { case string: var err error userID, err = strconv.Atoi(v) if err != nil { log.Error().Err(err).Str("sub", v).Msg("Failed to convert string user ID to int") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"}) c.Abort() return } case float64: userID = int(v) case json.Number: var err error userID, err = strconv.Atoi(v.String()) if err != nil { log.Error().Err(err).Str("sub", v.String()).Msg("Failed to convert json.Number user ID to int") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"}) c.Abort() return } default: log.Error(). Str("type", fmt.Sprintf("%T", sub)). Interface("value", sub). Msg("Unexpected user ID type") c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID type"}) c.Abort() return } // 将 userID 转换为 int64 以确保类型一致性 c.Set("user_id", int64(userID)) if roles, ok := claims["roles"].([]interface{}); ok { c.Set("user_roles", roles) } c.Next() } else { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.Abort() return } } } // RoleMiddleware creates a middleware for role-based authorization func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("user_role") if !exists { c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"}) c.Abort() return } roleStr, ok := userRole.(string) if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"}) c.Abort() return } for _, role := range roles { if role == roleStr { c.Next() return } } c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.Abort() } }