[feature/backend] implement /auth/logout handling + overall enhancement
This commit is contained in:
parent
d8d8e4b0d7
commit
e5fc8691bf
12 changed files with 420 additions and 64 deletions
|
@ -2,6 +2,7 @@ package handler
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -28,8 +29,8 @@ type AuthResponse struct {
|
|||
|
||||
func (h *Handler) Register(c *gin.Context) {
|
||||
// 检查是否启用注册功能
|
||||
if !h.config.Auth.Registration.Enabled {
|
||||
message := h.config.Auth.Registration.Message
|
||||
if !h.cfg.Auth.Registration.Enabled {
|
||||
message := h.cfg.Auth.Registration.Message
|
||||
if message == "" {
|
||||
message = "Registration is currently disabled"
|
||||
}
|
||||
|
@ -181,3 +182,72 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
|
||||
c.JSON(http.StatusOK, AuthResponse{Token: tokenString})
|
||||
}
|
||||
|
||||
func (h *Handler) Logout(c *gin.Context) {
|
||||
// 获取当前用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "UNAUTHORIZED",
|
||||
"message": "User not authenticated",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取 token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "UNAUTHORIZED",
|
||||
"message": "Authorization header is required",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "UNAUTHORIZED",
|
||||
"message": "Authorization header format must be Bearer {token}",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 token 以获取过期时间
|
||||
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(h.cfg.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "INVALID_TOKEN",
|
||||
"message": "Invalid token",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// 将 token 添加到黑名单
|
||||
h.service.GetTokenBlacklist().AddToBlacklist(parts[1], claims)
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
log.Info().
|
||||
Interface("user_id", userID).
|
||||
Msg("User logged out")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "Successfully logged out",
|
||||
})
|
||||
}
|
||||
|
|
|
@ -150,7 +150,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置注册功能状态
|
||||
s.handler.config.Auth.Registration.Enabled = tc.registration
|
||||
s.handler.cfg.Auth.Registration.Enabled = tc.registration
|
||||
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/categorycontent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
"tss-rocks-be/internal/types"
|
||||
|
||||
|
@ -59,12 +60,24 @@ type CategoryHandlerTestSuite struct {
|
|||
func (s *CategoryHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -29,12 +30,24 @@ type ContributorHandlerTestSuite struct {
|
|||
func (s *ContributorHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -29,12 +30,24 @@ type DailyHandlerTestSuite struct {
|
|||
func (s *DailyHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,14 +14,12 @@ import (
|
|||
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
config *config.Config
|
||||
service service.Service
|
||||
}
|
||||
|
||||
func NewHandler(cfg *config.Config, service service.Service) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
config: cfg,
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
@ -35,10 +33,11 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
|||
{
|
||||
auth.POST("/register", h.Register)
|
||||
auth.POST("/login", h.Login)
|
||||
auth.POST("/logout", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()), h.Logout)
|
||||
}
|
||||
|
||||
// User routes
|
||||
users := api.Group("/users", middleware.AuthMiddleware(h.config.JWT.Secret))
|
||||
users := api.Group("/users", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
|
||||
{
|
||||
users.GET("", h.ListUsers)
|
||||
users.POST("", h.CreateUser)
|
||||
|
@ -86,7 +85,7 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
|||
}
|
||||
|
||||
// Media routes
|
||||
media := api.Group("/media")
|
||||
media := api.Group("/media", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
|
||||
{
|
||||
media.GET("", h.ListMedia)
|
||||
media.POST("", h.UploadMedia)
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -29,12 +30,24 @@ type PostHandlerTestSuite struct {
|
|||
func (s *PostHandlerTestSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.service = mock.NewMockService(s.ctrl)
|
||||
cfg := &config.Config{}
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
s.handler = NewHandler(cfg, s.service)
|
||||
|
||||
// Setup Gin router
|
||||
gin.SetMode(gin.TestMode)
|
||||
s.router = gin.New()
|
||||
|
||||
// Setup mock for GetTokenBlacklist
|
||||
tokenBlacklist := &service.TokenBlacklist{}
|
||||
s.service.EXPECT().
|
||||
GetTokenBlacklist().
|
||||
Return(tokenBlacklist).
|
||||
AnyTimes()
|
||||
|
||||
s.handler.RegisterRoutes(s.router)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue