package middleware import ( "context" "fmt" "net/http" "tss-rocks-be/ent" "tss-rocks-be/ent/user" "tss-rocks-be/internal/auth" "github.com/gin-gonic/gin" ) // RequirePermission creates a middleware that checks if the user has the required permission func RequirePermission(client *ent.Client, resource, action string) gin.HandlerFunc { return func(c *gin.Context) { // Get user from context userID, exists := c.Get(auth.UserIDKey) if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized", }) return } // Get user with roles user, err := client.User.Query(). Where(user.ID(userID.(int))). WithRoles(func(q *ent.RoleQuery) { q.WithPermissions() }). Only(context.Background()) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "User not found", }) return } // Check if user has the required permission through any of their roles hasPermission := false for _, r := range user.Edges.Roles { for _, p := range r.Edges.Permissions { if p.Resource == resource && p.Action == action { hasPermission = true break } } if hasPermission { break } } if !hasPermission { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": fmt.Sprintf("Missing required permission: %s:%s", resource, action), }) return } c.Next() } } // RequireRole creates a middleware that checks if the user has the required role func RequireRole(client *ent.Client, roleName string) gin.HandlerFunc { return func(c *gin.Context) { // Get user from context userID, exists := c.Get(auth.UserIDKey) if !exists { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Unauthorized", }) return } // Get user with roles user, err := client.User.Query(). Where(user.ID(userID.(int))). WithRoles(). Only(context.Background()) if err != nil { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "User not found", }) return } // Check if user has the required role hasRole := false for _, r := range user.Edges.Roles { if r.Name == roleName { hasRole = true break } } if !hasRole { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ "error": fmt.Sprintf("Required role: %s", roleName), }) return } c.Next() } }