diff --git a/backend/config/config.yaml.example b/backend/config/config.yaml.example index 2b45cd9..220ace8 100644 --- a/backend/config/config.yaml.example +++ b/backend/config/config.yaml.example @@ -15,6 +15,11 @@ jwt: secret: your-jwt-secret-here # 在生产环境中应该使用环境变量 expiration: 24h +auth: + registration: + enabled: false # 是否允许注册 + message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息 + storage: driver: local local: diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3b4864c..faa6bdd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -11,6 +11,7 @@ type Config struct { Database DatabaseConfig `yaml:"database"` Server ServerConfig `yaml:"server"` JWT JWTConfig `yaml:"jwt"` + Auth AuthConfig `yaml:"auth"` Storage StorageConfig `yaml:"storage"` Logging LoggingConfig `yaml:"logging"` RateLimit types.RateLimitConfig `yaml:"rate_limit"` @@ -32,6 +33,13 @@ type JWTConfig struct { Expiration string `yaml:"expiration"` } +type AuthConfig struct { + Registration struct { + Enabled bool `yaml:"enabled"` + Message string `yaml:"message"` + } `yaml:"registration"` +} + type LoggingConfig struct { Level string `yaml:"level"` Format string `yaml:"format"` diff --git a/backend/internal/handler/auth.go b/backend/internal/handler/auth.go index 9503c6b..4e0bed8 100644 --- a/backend/internal/handler/auth.go +++ b/backend/internal/handler/auth.go @@ -27,16 +27,41 @@ type AuthResponse struct { } func (h *Handler) Register(c *gin.Context) { + // 检查是否启用注册功能 + if !h.config.Auth.Registration.Enabled { + message := h.config.Auth.Registration.Message + if message == "" { + message = "Registration is currently disabled" + } + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "code": "REGISTRATION_DISABLED", + "message": message, + }, + }) + return + } + var req RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "code": "INVALID_REQUEST", + "message": err.Error(), + }, + }) return } user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role) if err != nil { log.Error().Err(err).Msg("Failed to create user") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "code": "CREATE_USER_FAILED", + "message": "Failed to create user", + }, + }) return } @@ -44,7 +69,12 @@ func (h *Handler) Register(c *gin.Context) { roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) if err != nil { log.Error().Err(err).Msg("Failed to get user roles") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "code": "GET_ROLES_FAILED", + "message": "Failed to get user roles", + }, + }) return } @@ -64,7 +94,12 @@ func (h *Handler) Register(c *gin.Context) { tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) if err != nil { log.Error().Err(err).Msg("Failed to generate token") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "code": "GENERATE_TOKEN_FAILED", + "message": "Failed to generate token", + }, + }) return } @@ -74,14 +109,22 @@ func (h *Handler) Register(c *gin.Context) { func (h *Handler) Login(c *gin.Context) { var req LoginRequest if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "code": "INVALID_REQUEST", + "message": err.Error(), + }, + }) return } user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid username or password", + "error": gin.H{ + "code": "INVALID_CREDENTIALS", + "message": "Invalid username or password", + }, }) return } @@ -90,7 +133,10 @@ func (h *Handler) Login(c *gin.Context) { err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{ - "error": "Invalid username or password", + "error": gin.H{ + "code": "INVALID_CREDENTIALS", + "message": "Invalid username or password", + }, }) return } @@ -100,7 +146,10 @@ func (h *Handler) Login(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to get user roles") c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to get user roles", + "error": gin.H{ + "code": "GET_ROLES_FAILED", + "message": "Failed to get user roles", + }, }) return } @@ -122,7 +171,10 @@ func (h *Handler) Login(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to generate token") c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to generate token", + "error": gin.H{ + "code": "GENERATE_TOKEN_FAILED", + "message": "Failed to generate token", + }, }) return } diff --git a/backend/internal/handler/auth_handler_test.go b/backend/internal/handler/auth_handler_test.go index dfa4458..ed23a86 100644 --- a/backend/internal/handler/auth_handler_test.go +++ b/backend/internal/handler/auth_handler_test.go @@ -33,6 +33,15 @@ func (s *AuthHandlerTestSuite) SetupTest() { JWT: config.JWTConfig{ Secret: "test-secret", }, + Auth: config.AuthConfig{ + Registration: struct { + Enabled bool `yaml:"enabled"` + Message string `yaml:"message"` + }{ + Enabled: true, + Message: "Registration is disabled", + }, + }, }, s.service) s.router = gin.New() } @@ -45,6 +54,13 @@ func TestAuthHandlerSuite(t *testing.T) { suite.Run(t, new(AuthHandlerTestSuite)) } +type ErrorResponse struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + func (s *AuthHandlerTestSuite) TestRegister() { testCases := []struct { name string @@ -52,6 +68,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { setupMock func() expectedStatus int expectedError string + registration bool }{ { name: "成功注册", @@ -74,6 +91,20 @@ func (s *AuthHandlerTestSuite) TestRegister() { Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) }, expectedStatus: http.StatusCreated, + registration: true, + }, + { + name: "注册功能已禁用", + request: RegisterRequest{ + Username: "testuser", + Email: "test@example.com", + Password: "password123", + Role: "contributor", + }, + setupMock: func() {}, + expectedStatus: http.StatusForbidden, + expectedError: "Registration is disabled", + registration: false, }, { name: "无效的邮箱格式", @@ -86,6 +117,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { setupMock: func() {}, expectedStatus: http.StatusBadRequest, expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", + registration: true, }, { name: "密码太短", @@ -98,6 +130,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { setupMock: func() {}, expectedStatus: http.StatusBadRequest, expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag", + registration: true, }, { name: "无效的角色", @@ -110,11 +143,15 @@ func (s *AuthHandlerTestSuite) TestRegister() { setupMock: func() {}, expectedStatus: http.StatusBadRequest, expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag", + registration: true, }, } for _, tc := range testCases { s.Run(tc.name, func() { + // 设置注册功能状态 + s.handler.config.Auth.Registration.Enabled = tc.registration + // 设置 mock tc.setupMock() @@ -132,10 +169,10 @@ func (s *AuthHandlerTestSuite) TestRegister() { // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { - var response map[string]string + var response ErrorResponse err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) - s.Contains(response["error"], tc.expectedError) + s.Contains(response.Error.Message, tc.expectedError) } else { var response AuthResponse err := json.Unmarshal(w.Body.Bytes(), &response) @@ -147,6 +184,8 @@ func (s *AuthHandlerTestSuite) TestRegister() { } func (s *AuthHandlerTestSuite) TestLogin() { + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) + testCases := []struct { name string request LoginRequest @@ -161,35 +200,28 @@ func (s *AuthHandlerTestSuite) TestLogin() { Password: "password123", }, setupMock: func() { - // 使用 bcrypt 生成正确的密码哈希 - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) - user := &ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - } s.service.EXPECT(). GetUserByUsername(gomock.Any(), "testuser"). - Return(user, nil) + Return(&ent.User{ + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), + }, nil) s.service.EXPECT(). - GetUserRoles(gomock.Any(), user.ID). - Return([]*ent.Role{{Name: "admin"}}, nil) + GetUserRoles(gomock.Any(), 1). + Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) }, expectedStatus: http.StatusOK, }, { name: "无效的用户名", request: LoginRequest{ - Username: "invalid", + Username: "te", Password: "password123", }, - setupMock: func() { - s.service.EXPECT(). - GetUserByUsername(gomock.Any(), "invalid"). - Return(nil, fmt.Errorf("user not found")) - }, - expectedStatus: http.StatusUnauthorized, - expectedError: "Invalid username or password", + setupMock: func() {}, + expectedStatus: http.StatusBadRequest, + expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag", }, { name: "用户不存在", @@ -209,19 +241,16 @@ func (s *AuthHandlerTestSuite) TestLogin() { name: "密码错误", request: LoginRequest{ Username: "testuser", - Password: "wrong-password", + Password: "wrongpassword", }, setupMock: func() { - // 使用 bcrypt 生成正确的密码哈希 - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) - user := &ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - } s.service.EXPECT(). GetUserByUsername(gomock.Any(), "testuser"). - Return(user, nil) + Return(&ent.User{ + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), + }, nil) }, expectedStatus: http.StatusUnauthorized, expectedError: "Invalid username or password", @@ -233,18 +262,15 @@ func (s *AuthHandlerTestSuite) TestLogin() { Password: "password123", }, setupMock: func() { - // 使用 bcrypt 生成正确的密码哈希 - hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) - user := &ent.User{ - ID: 1, - Username: "testuser", - PasswordHash: string(hashedPassword), - } s.service.EXPECT(). GetUserByUsername(gomock.Any(), "testuser"). - Return(user, nil) + Return(&ent.User{ + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), + }, nil) s.service.EXPECT(). - GetUserRoles(gomock.Any(), user.ID). + GetUserRoles(gomock.Any(), 1). Return(nil, fmt.Errorf("failed to get roles")) }, expectedStatus: http.StatusInternalServerError, @@ -271,10 +297,10 @@ func (s *AuthHandlerTestSuite) TestLogin() { // 验证响应 s.Equal(tc.expectedStatus, w.Code) if tc.expectedError != "" { - var response map[string]string + var response ErrorResponse err := json.Unmarshal(w.Body.Bytes(), &response) s.NoError(err) - s.Contains(response["error"], tc.expectedError) + s.Contains(response.Error.Message, tc.expectedError) } else { var response AuthResponse err := json.Unmarshal(w.Body.Bytes(), &response) diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 4c821aa..3a4f4e9 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -13,12 +13,14 @@ import ( type Handler struct { cfg *config.Config + config *config.Config service service.Service } func NewHandler(cfg *config.Config, service service.Service) *Handler { return &Handler{ cfg: cfg, + config: cfg, service: service, } }