[feature/backend] registration control
This commit is contained in:
parent
86ab334bc9
commit
a853374009
5 changed files with 142 additions and 49 deletions
|
@ -15,6 +15,11 @@ jwt:
|
|||
secret: your-jwt-secret-here # 在生产环境中应该使用环境变量
|
||||
expiration: 24h
|
||||
|
||||
auth:
|
||||
registration:
|
||||
enabled: false # 是否允许注册
|
||||
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
|
||||
|
||||
storage:
|
||||
driver: local
|
||||
local:
|
||||
|
|
|
@ -11,6 +11,7 @@ type Config struct {
|
|||
Database DatabaseConfig `yaml:"database"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
JWT JWTConfig `yaml:"jwt"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
RateLimit types.RateLimitConfig `yaml:"rate_limit"`
|
||||
|
@ -32,6 +33,13 @@ type JWTConfig struct {
|
|||
Expiration string `yaml:"expiration"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Registration struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Message string `yaml:"message"`
|
||||
} `yaml:"registration"`
|
||||
}
|
||||
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Format string `yaml:"format"`
|
||||
|
|
|
@ -27,16 +27,41 @@ type AuthResponse struct {
|
|||
}
|
||||
|
||||
func (h *Handler) Register(c *gin.Context) {
|
||||
// 检查是否启用注册功能
|
||||
if !h.config.Auth.Registration.Enabled {
|
||||
message := h.config.Auth.Registration.Message
|
||||
if message == "" {
|
||||
message = "Registration is currently disabled"
|
||||
}
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "REGISTRATION_DISABLED",
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "INVALID_REQUEST",
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "CREATE_USER_FAILED",
|
||||
"message": "Failed to create user",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -44,7 +69,12 @@ func (h *Handler) Register(c *gin.Context) {
|
|||
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user roles")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "GET_ROLES_FAILED",
|
||||
"message": "Failed to get user roles",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -64,7 +94,12 @@ func (h *Handler) Register(c *gin.Context) {
|
|||
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate token")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "GENERATE_TOKEN_FAILED",
|
||||
"message": "Failed to generate token",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -74,14 +109,22 @@ func (h *Handler) Register(c *gin.Context) {
|
|||
func (h *Handler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"code": "INVALID_REQUEST",
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid username or password",
|
||||
"error": gin.H{
|
||||
"code": "INVALID_CREDENTIALS",
|
||||
"message": "Invalid username or password",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -90,7 +133,10 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid username or password",
|
||||
"error": gin.H{
|
||||
"code": "INVALID_CREDENTIALS",
|
||||
"message": "Invalid username or password",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -100,7 +146,10 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user roles")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to get user roles",
|
||||
"error": gin.H{
|
||||
"code": "GET_ROLES_FAILED",
|
||||
"message": "Failed to get user roles",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -122,7 +171,10 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate token")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to generate token",
|
||||
"error": gin.H{
|
||||
"code": "GENERATE_TOKEN_FAILED",
|
||||
"message": "Failed to generate token",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -33,6 +33,15 @@ func (s *AuthHandlerTestSuite) SetupTest() {
|
|||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
Auth: config.AuthConfig{
|
||||
Registration: struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Message string `yaml:"message"`
|
||||
}{
|
||||
Enabled: true,
|
||||
Message: "Registration is disabled",
|
||||
},
|
||||
},
|
||||
}, s.service)
|
||||
s.router = gin.New()
|
||||
}
|
||||
|
@ -45,6 +54,13 @@ func TestAuthHandlerSuite(t *testing.T) {
|
|||
suite.Run(t, new(AuthHandlerTestSuite))
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestRegister() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
@ -52,6 +68,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
setupMock func()
|
||||
expectedStatus int
|
||||
expectedError string
|
||||
registration bool
|
||||
}{
|
||||
{
|
||||
name: "成功注册",
|
||||
|
@ -74,6 +91,20 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "注册功能已禁用",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedError: "Registration is disabled",
|
||||
registration: false,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
|
@ -86,6 +117,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "密码太短",
|
||||
|
@ -98,6 +130,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
|
||||
registration: true,
|
||||
},
|
||||
{
|
||||
name: "无效的角色",
|
||||
|
@ -110,11 +143,15 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
|
||||
registration: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
// 设置注册功能状态
|
||||
s.handler.config.Auth.Registration.Enabled = tc.registration
|
||||
|
||||
// 设置 mock
|
||||
tc.setupMock()
|
||||
|
||||
|
@ -132,10 +169,10 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response["error"], tc.expectedError)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
@ -147,6 +184,8 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
}
|
||||
|
||||
func (s *AuthHandlerTestSuite) TestLogin() {
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
request LoginRequest
|
||||
|
@ -161,35 +200,28 @@ func (s *AuthHandlerTestSuite) TestLogin() {
|
|||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return([]*ent.Role{{Name: "admin"}}, nil)
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的用户名",
|
||||
request: LoginRequest{
|
||||
Username: "invalid",
|
||||
Username: "te",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "invalid").
|
||||
Return(nil, fmt.Errorf("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
|
@ -209,19 +241,16 @@ func (s *AuthHandlerTestSuite) TestLogin() {
|
|||
name: "密码错误",
|
||||
request: LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrong-password",
|
||||
Password: "wrongpassword",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
|
@ -233,18 +262,15 @@ func (s *AuthHandlerTestSuite) TestLogin() {
|
|||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return(nil, fmt.Errorf("failed to get roles"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
|
@ -271,10 +297,10 @@ func (s *AuthHandlerTestSuite) TestLogin() {
|
|||
// 验证响应
|
||||
s.Equal(tc.expectedStatus, w.Code)
|
||||
if tc.expectedError != "" {
|
||||
var response map[string]string
|
||||
var response ErrorResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
s.NoError(err)
|
||||
s.Contains(response["error"], tc.expectedError)
|
||||
s.Contains(response.Error.Message, tc.expectedError)
|
||||
} else {
|
||||
var response AuthResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
|
|
|
@ -13,12 +13,14 @@ import (
|
|||
|
||||
type Handler struct {
|
||||
cfg *config.Config
|
||||
config *config.Config
|
||||
service service.Service
|
||||
}
|
||||
|
||||
func NewHandler(cfg *config.Config, service service.Service) *Handler {
|
||||
return &Handler{
|
||||
cfg: cfg,
|
||||
config: cfg,
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue