[feature/backend] registration control

This commit is contained in:
CDN 2025-02-21 04:38:47 +08:00
parent 86ab334bc9
commit a853374009
Signed by: CDN
GPG key ID: 0C656827F9F80080
5 changed files with 142 additions and 49 deletions

View file

@ -15,6 +15,11 @@ jwt:
secret: your-jwt-secret-here # 在生产环境中应该使用环境变量 secret: your-jwt-secret-here # 在生产环境中应该使用环境变量
expiration: 24h expiration: 24h
auth:
registration:
enabled: false # 是否允许注册
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
storage: storage:
driver: local driver: local
local: local:

View file

@ -11,6 +11,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Server ServerConfig `yaml:"server"` Server ServerConfig `yaml:"server"`
JWT JWTConfig `yaml:"jwt"` JWT JWTConfig `yaml:"jwt"`
Auth AuthConfig `yaml:"auth"`
Storage StorageConfig `yaml:"storage"` Storage StorageConfig `yaml:"storage"`
Logging LoggingConfig `yaml:"logging"` Logging LoggingConfig `yaml:"logging"`
RateLimit types.RateLimitConfig `yaml:"rate_limit"` RateLimit types.RateLimitConfig `yaml:"rate_limit"`
@ -32,6 +33,13 @@ type JWTConfig struct {
Expiration string `yaml:"expiration"` Expiration string `yaml:"expiration"`
} }
type AuthConfig struct {
Registration struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
} `yaml:"registration"`
}
type LoggingConfig struct { type LoggingConfig struct {
Level string `yaml:"level"` Level string `yaml:"level"`
Format string `yaml:"format"` Format string `yaml:"format"`

View file

@ -27,16 +27,41 @@ type AuthResponse struct {
} }
func (h *Handler) Register(c *gin.Context) { func (h *Handler) Register(c *gin.Context) {
// 检查是否启用注册功能
if !h.config.Auth.Registration.Enabled {
message := h.config.Auth.Registration.Message
if message == "" {
message = "Registration is currently disabled"
}
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"code": "REGISTRATION_DISABLED",
"message": message,
},
})
return
}
var req RegisterRequest var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"code": "INVALID_REQUEST",
"message": err.Error(),
},
})
return return
} }
user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role) user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to create user") log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"code": "CREATE_USER_FAILED",
"message": "Failed to create user",
},
})
return return
} }
@ -44,7 +69,12 @@ func (h *Handler) Register(c *gin.Context) {
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get user roles") log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"code": "GET_ROLES_FAILED",
"message": "Failed to get user roles",
},
})
return return
} }
@ -64,7 +94,12 @@ func (h *Handler) Register(c *gin.Context) {
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to generate token") log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"code": "GENERATE_TOKEN_FAILED",
"message": "Failed to generate token",
},
})
return return
} }
@ -74,14 +109,22 @@ func (h *Handler) Register(c *gin.Context) {
func (h *Handler) Login(c *gin.Context) { func (h *Handler) Login(c *gin.Context) {
var req LoginRequest var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"code": "INVALID_REQUEST",
"message": err.Error(),
},
})
return return
} }
user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username) user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password", "error": gin.H{
"code": "INVALID_CREDENTIALS",
"message": "Invalid username or password",
},
}) })
return return
} }
@ -90,7 +133,10 @@ func (h *Handler) Login(c *gin.Context) {
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)) err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password", "error": gin.H{
"code": "INVALID_CREDENTIALS",
"message": "Invalid username or password",
},
}) })
return return
} }
@ -100,7 +146,10 @@ func (h *Handler) Login(c *gin.Context) {
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get user roles") log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user roles", "error": gin.H{
"code": "GET_ROLES_FAILED",
"message": "Failed to get user roles",
},
}) })
return return
} }
@ -122,7 +171,10 @@ func (h *Handler) Login(c *gin.Context) {
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to generate token") log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to generate token", "error": gin.H{
"code": "GENERATE_TOKEN_FAILED",
"message": "Failed to generate token",
},
}) })
return return
} }

View file

@ -33,6 +33,15 @@ func (s *AuthHandlerTestSuite) SetupTest() {
JWT: config.JWTConfig{ JWT: config.JWTConfig{
Secret: "test-secret", Secret: "test-secret",
}, },
Auth: config.AuthConfig{
Registration: struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
}{
Enabled: true,
Message: "Registration is disabled",
},
},
}, s.service) }, s.service)
s.router = gin.New() s.router = gin.New()
} }
@ -45,6 +54,13 @@ func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite)) suite.Run(t, new(AuthHandlerTestSuite))
} }
type ErrorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func (s *AuthHandlerTestSuite) TestRegister() { func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct { testCases := []struct {
name string name string
@ -52,6 +68,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock func() setupMock func()
expectedStatus int expectedStatus int
expectedError string expectedError string
registration bool
}{ }{
{ {
name: "成功注册", name: "成功注册",
@ -74,6 +91,20 @@ func (s *AuthHandlerTestSuite) TestRegister() {
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
}, },
expectedStatus: http.StatusCreated, expectedStatus: http.StatusCreated,
registration: true,
},
{
name: "注册功能已禁用",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusForbidden,
expectedError: "Registration is disabled",
registration: false,
}, },
{ {
name: "无效的邮箱格式", name: "无效的邮箱格式",
@ -86,6 +117,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
registration: true,
}, },
{ {
name: "密码太短", name: "密码太短",
@ -98,6 +130,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag", expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
registration: true,
}, },
{ {
name: "无效的角色", name: "无效的角色",
@ -110,11 +143,15 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag", expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
registration: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
// 设置注册功能状态
s.handler.config.Auth.Registration.Enabled = tc.registration
// 设置 mock // 设置 mock
tc.setupMock() tc.setupMock()
@ -132,10 +169,10 @@ func (s *AuthHandlerTestSuite) TestRegister() {
// 验证响应 // 验证响应
s.Equal(tc.expectedStatus, w.Code) s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" { if tc.expectedError != "" {
var response map[string]string var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err) s.NoError(err)
s.Contains(response["error"], tc.expectedError) s.Contains(response.Error.Message, tc.expectedError)
} else { } else {
var response AuthResponse var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
@ -147,6 +184,8 @@ func (s *AuthHandlerTestSuite) TestRegister() {
} }
func (s *AuthHandlerTestSuite) TestLogin() { func (s *AuthHandlerTestSuite) TestLogin() {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
testCases := []struct { testCases := []struct {
name string name string
request LoginRequest request LoginRequest
@ -161,35 +200,28 @@ func (s *AuthHandlerTestSuite) TestLogin() {
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT(). s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID). GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{Name: "admin"}}, nil) Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "无效的用户名", name: "无效的用户名",
request: LoginRequest{ request: LoginRequest{
Username: "invalid", Username: "te",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {},
s.service.EXPECT(). expectedStatus: http.StatusBadRequest,
GetUserByUsername(gomock.Any(), "invalid"). expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
}, },
{ {
name: "用户不存在", name: "用户不存在",
@ -209,19 +241,16 @@ func (s *AuthHandlerTestSuite) TestLogin() {
name: "密码错误", name: "密码错误",
request: LoginRequest{ request: LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrong-password", Password: "wrongpassword",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password", expectedError: "Invalid username or password",
@ -233,18 +262,15 @@ func (s *AuthHandlerTestSuite) TestLogin() {
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT(). s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID). GetUserRoles(gomock.Any(), 1).
Return(nil, fmt.Errorf("failed to get roles")) Return(nil, fmt.Errorf("failed to get roles"))
}, },
expectedStatus: http.StatusInternalServerError, expectedStatus: http.StatusInternalServerError,
@ -271,10 +297,10 @@ func (s *AuthHandlerTestSuite) TestLogin() {
// 验证响应 // 验证响应
s.Equal(tc.expectedStatus, w.Code) s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" { if tc.expectedError != "" {
var response map[string]string var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err) s.NoError(err)
s.Contains(response["error"], tc.expectedError) s.Contains(response.Error.Message, tc.expectedError)
} else { } else {
var response AuthResponse var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)

View file

@ -13,12 +13,14 @@ import (
type Handler struct { type Handler struct {
cfg *config.Config cfg *config.Config
config *config.Config
service service.Service service service.Service
} }
func NewHandler(cfg *config.Config, service service.Service) *Handler { func NewHandler(cfg *config.Config, service service.Service) *Handler {
return &Handler{ return &Handler{
cfg: cfg, cfg: cfg,
config: cfg,
service: service, service: service,
} }
} }