tss-rocks/backend/internal/handler/auth_handler_test.go

312 lines
7.7 KiB
Go

package handler
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
)
type AuthHandlerTestSuite struct {
suite.Suite
ctrl *gomock.Controller
service *mock.MockService
handler *Handler
router *gin.Engine
}
func (s *AuthHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl)
s.handler = NewHandler(&config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
Auth: config.AuthConfig{
Registration: struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
}{
Enabled: true,
Message: "Registration is disabled",
},
},
}, s.service)
s.router = gin.New()
}
func (s *AuthHandlerTestSuite) TearDownTest() {
s.ctrl.Finish()
}
func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite))
}
type ErrorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct {
name string
request RegisterRequest
setupMock func()
expectedStatus int
expectedError string
registration bool
}{
{
name: "成功注册",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {
s.service.EXPECT().
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(&ent.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusCreated,
registration: true,
},
{
name: "注册功能已禁用",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusForbidden,
expectedError: "Registration is disabled",
registration: false,
},
{
name: "无效的邮箱格式",
request: RegisterRequest{
Username: "testuser",
Email: "invalid-email",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
registration: true,
},
{
name: "密码太短",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "short",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
registration: true,
},
{
name: "无效的角色",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "invalid-role",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
registration: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置注册功能状态
s.handler.cfg.Auth.Registration.Enabled = tc.registration
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/register", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Register(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}
func (s *AuthHandlerTestSuite) TestLogin() {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
testCases := []struct {
name string
request LoginRequest
setupMock func()
expectedStatus int
expectedError string
}{
{
name: "成功登录",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的用户名",
request: LoginRequest{
Username: "te",
Password: "password123",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
},
{
name: "用户不存在",
request: LoginRequest{
Username: "nonexistent",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "密码错误",
request: LoginRequest{
Username: "testuser",
Password: "wrongpassword",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "获取用户角色失败",
request: LoginRequest{
Username: "testuser",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "testuser").
Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), 1).
Return(nil, fmt.Errorf("failed to get roles"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles",
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
// 设置 mock
tc.setupMock()
// 创建请求
reqBody, _ := json.Marshal(tc.request)
req, _ := http.NewRequest(http.MethodPost, "/login", bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 执行请求
s.handler.Login(c)
// 验证响应
s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" {
var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.Contains(response.Error.Message, tc.expectedError)
} else {
var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err)
s.NotEmpty(response.Token)
}
})
}
}