[feature/backend] implement /users handler + switch to username + add display name + user management cli
This commit is contained in:
parent
1d712d4e6c
commit
86ab334bc9
38 changed files with 1851 additions and 506 deletions
|
@ -7,16 +7,18 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Role string `json:"role" binding:"required,oneof=admin editor contributor"`
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Username string `json:"username" binding:"required,min=3,max=32"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
|
@ -31,7 +33,7 @@ func (h *Handler) Register(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role)
|
||||
user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
|
@ -76,14 +78,20 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email)
|
||||
user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid username or password",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
|
||||
// 验证密码
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Invalid username or password",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -91,7 +99,9 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user roles")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to get user roles",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -111,7 +121,9 @@ func (h *Handler) Login(c *gin.Context) {
|
|||
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to generate token")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to generate token",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,11 @@ package handler
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/service/mock"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AuthHandlerTestSuite struct {
|
||||
|
@ -54,20 +56,21 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
{
|
||||
name: "成功注册",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
},
|
||||
setupMock: func() {
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
}
|
||||
s.service.EXPECT().
|
||||
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor").
|
||||
Return(user, nil)
|
||||
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
|
||||
Return(&ent.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
}, nil)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
GetUserRoles(gomock.Any(), 1).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusCreated,
|
||||
|
@ -75,6 +78,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
{
|
||||
name: "无效的邮箱格式",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "invalid-email",
|
||||
Password: "password123",
|
||||
Role: "contributor",
|
||||
|
@ -86,6 +90,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
{
|
||||
name: "密码太短",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "short",
|
||||
Role: "contributor",
|
||||
|
@ -97,6 +102,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
|
|||
{
|
||||
name: "无效的角色",
|
||||
request: RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Password: "password123",
|
||||
Role: "invalid-role",
|
||||
|
@ -151,91 +157,95 @@ func (s *AuthHandlerTestSuite) TestLogin() {
|
|||
{
|
||||
name: "成功登录",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "password123").
|
||||
Return(true)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
|
||||
Return([]*ent.Role{{Name: "admin"}}, nil)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
name: "无效的用户名",
|
||||
request: LoginRequest{
|
||||
Email: "invalid-email",
|
||||
Username: "invalid",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByUsername(gomock.Any(), "invalid").
|
||||
Return(nil, fmt.Errorf("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
request: LoginRequest{
|
||||
Email: "nonexistent@example.com",
|
||||
Username: "nonexistent",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "nonexistent@example.com").
|
||||
Return(nil, errors.New("user not found"))
|
||||
GetUserByUsername(gomock.Any(), "nonexistent").
|
||||
Return(nil, fmt.Errorf("user not found"))
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid credentials",
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "密码错误",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Username: "testuser",
|
||||
Password: "wrong-password",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "wrong-password").
|
||||
Return(false)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedError: "Invalid credentials",
|
||||
expectedError: "Invalid username or password",
|
||||
},
|
||||
{
|
||||
name: "获取用户角色失败",
|
||||
request: LoginRequest{
|
||||
Email: "test@example.com",
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
},
|
||||
setupMock: func() {
|
||||
// 使用 bcrypt 生成正确的密码哈希
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user := &ent.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
PasswordHash: string(hashedPassword),
|
||||
}
|
||||
s.service.EXPECT().
|
||||
GetUserByEmail(gomock.Any(), "test@example.com").
|
||||
GetUserByUsername(gomock.Any(), "testuser").
|
||||
Return(user, nil)
|
||||
s.service.EXPECT().
|
||||
ValidatePassword(gomock.Any(), user, "password123").
|
||||
Return(true)
|
||||
s.service.EXPECT().
|
||||
GetUserRoles(gomock.Any(), user.ID).
|
||||
Return(nil, errors.New("failed to get roles"))
|
||||
Return(nil, fmt.Errorf("failed to get roles"))
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedError: "Failed to get user roles",
|
||||
|
|
|
@ -34,6 +34,18 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
|
|||
auth.POST("/login", h.Login)
|
||||
}
|
||||
|
||||
// User routes
|
||||
users := api.Group("/users")
|
||||
{
|
||||
users.GET("", h.ListUsers)
|
||||
users.POST("", h.CreateUser)
|
||||
users.GET("/:id", h.GetUser)
|
||||
users.PUT("/:id", h.UpdateUser)
|
||||
users.DELETE("/:id", h.DeleteUser)
|
||||
users.GET("/me", h.GetCurrentUser)
|
||||
users.PUT("/me", h.UpdateCurrentUser)
|
||||
}
|
||||
|
||||
// Category routes
|
||||
categories := api.Group("/categories")
|
||||
{
|
||||
|
|
227
backend/internal/handler/user.go
Normal file
227
backend/internal/handler/user.go
Normal file
|
@ -0,0 +1,227 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
type UpdateCurrentUserRequest struct {
|
||||
Email string `json:"email,omitempty" binding:"omitempty,email"`
|
||||
CurrentPassword string `json:"current_password,omitempty"`
|
||||
NewPassword string `json:"new_password,omitempty" binding:"omitempty,min=8"`
|
||||
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Role string `json:"role" binding:"required,oneof=admin editor"`
|
||||
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
|
||||
}
|
||||
|
||||
type UpdateUserRequest struct {
|
||||
Email string `json:"email,omitempty" binding:"omitempty,email"`
|
||||
Password string `json:"password,omitempty" binding:"omitempty,min=8"`
|
||||
Role string `json:"role,omitempty" binding:"omitempty,oneof=admin editor"`
|
||||
Status string `json:"status,omitempty" binding:"omitempty,oneof=active inactive"`
|
||||
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
|
||||
}
|
||||
|
||||
// ListUsers returns a list of users
|
||||
func (h *Handler) ListUsers(c *gin.Context) {
|
||||
// Parse query parameters
|
||||
params := &types.ListUsersParams{
|
||||
Page: 1,
|
||||
PerPage: 10,
|
||||
}
|
||||
|
||||
if page := c.Query("page"); page != "" {
|
||||
if p, err := strconv.Atoi(page); err == nil && p > 0 {
|
||||
params.Page = p
|
||||
}
|
||||
}
|
||||
|
||||
if perPage := c.Query("per_page"); perPage != "" {
|
||||
if pp, err := strconv.Atoi(perPage); err == nil && pp > 0 {
|
||||
params.PerPage = pp
|
||||
}
|
||||
}
|
||||
|
||||
params.Sort = c.Query("sort")
|
||||
params.Role = c.Query("role")
|
||||
params.Status = c.Query("status")
|
||||
params.Email = c.Query("email")
|
||||
|
||||
// Get users
|
||||
users, err := h.service.ListUsers(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to list users")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": users,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUser creates a new user
|
||||
func (h *Handler) CreateUser(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Email, req.Password, req.Role)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUser returns user details
|
||||
func (h *Handler) GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.GetUser(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUser updates user information
|
||||
func (h *Handler) UpdateUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.service.UpdateUser(c.Request.Context(), id, &types.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Role: req.Role,
|
||||
Status: req.Status,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user
|
||||
func (h *Handler) DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteUser(c.Request.Context(), id); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetCurrentUser returns the current user's information
|
||||
func (h *Handler) GetCurrentUser(c *gin.Context) {
|
||||
// 从上下文中获取用户ID(由认证中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := h.service.GetUser(c.Request.Context(), userID.(int))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user information"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": user,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCurrentUser updates the current user's information
|
||||
func (h *Handler) UpdateCurrentUser(c *gin.Context) {
|
||||
// 从上下文中获取用户ID(由认证中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateCurrentUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果要更新密码,需要验证当前密码
|
||||
if req.NewPassword != "" {
|
||||
if req.CurrentPassword == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is required to update password"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证当前密码
|
||||
if err := h.service.VerifyPassword(c.Request.Context(), userID.(int), req.CurrentPassword); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current password"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
user, err := h.service.UpdateUser(c.Request.Context(), userID.(int), &types.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.NewPassword,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update user")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user information"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": user,
|
||||
})
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -78,26 +79,41 @@ func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
|
|||
|
||||
// 配置文件日志
|
||||
if config.EnableFile {
|
||||
// 确保日志目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
|
||||
// 验证文件路径
|
||||
if config.FilePath == "" {
|
||||
return nil, fmt.Errorf("file path cannot be empty")
|
||||
}
|
||||
|
||||
// 验证路径是否包含无效字符
|
||||
if strings.ContainsAny(config.FilePath, "\x00") {
|
||||
return nil, fmt.Errorf("file path contains invalid characters")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(config.FilePath)
|
||||
|
||||
// 检查目录是否存在或是否可以创建
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
// 配置日志轮转
|
||||
// 尝试打开或创建文件,验证路径是否有效且有写入权限
|
||||
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open or create log file: %w", err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// 配置文件日志
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: config.FilePath,
|
||||
MaxSize: config.Rotation.MaxSize, // MB
|
||||
MaxAge: config.Rotation.MaxAge, // days
|
||||
MaxBackups: config.Rotation.MaxBackups, // files
|
||||
MaxBackups: config.Rotation.MaxBackups, // 文件个数
|
||||
MaxAge: config.Rotation.MaxAge, // 天数
|
||||
Compress: config.Rotation.Compress, // 是否压缩
|
||||
LocalTime: config.Rotation.LocalTime, // 使用本地时间
|
||||
}
|
||||
|
||||
logger := zerolog.New(logWriter).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
logger := zerolog.New(logWriter).With().Timestamp().Logger()
|
||||
fileLogger = &logger
|
||||
}
|
||||
|
||||
|
|
|
@ -219,7 +219,7 @@ func TestAccessLogInvalidConfig(t *testing.T) {
|
|||
name: "Invalid file path",
|
||||
config: &types.AccessLogConfig{
|
||||
EnableFile: true,
|
||||
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
|
||||
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/permission"
|
||||
"tss-rocks-be/ent/role"
|
||||
)
|
||||
|
||||
|
@ -38,37 +39,69 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
|
|||
permissionMap := make(map[string]*ent.Permission)
|
||||
for resource, actions := range DefaultPermissions {
|
||||
for _, action := range actions {
|
||||
permission, err := client.Permission.Create().
|
||||
SetResource(resource).
|
||||
SetAction(action).
|
||||
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating permission: %w", err)
|
||||
}
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
permission, err := client.Permission.Query().
|
||||
Where(
|
||||
permission.ResourceEQ(resource),
|
||||
permission.ActionEQ(action),
|
||||
).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
permission, err = client.Permission.Create().
|
||||
SetResource(resource).
|
||||
SetAction(action).
|
||||
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating permission: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed querying permission: %w", err)
|
||||
}
|
||||
permissionMap[key] = permission
|
||||
}
|
||||
}
|
||||
|
||||
// Create roles with permissions
|
||||
for roleName, permissions := range DefaultRoles {
|
||||
roleCreate := client.Role.Create().
|
||||
SetName(roleName).
|
||||
SetDescription(fmt.Sprintf("Role for %s users", roleName))
|
||||
role, err := client.Role.Query().
|
||||
Where(role.NameEQ(roleName)).
|
||||
Only(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
roleCreate := client.Role.Create().
|
||||
SetName(roleName).
|
||||
SetDescription(fmt.Sprintf("Role for %s users", roleName))
|
||||
|
||||
// Add permissions to role
|
||||
for resource, actions := range permissions {
|
||||
for _, action := range actions {
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
if permission, exists := permissionMap[key]; exists {
|
||||
roleCreate.AddPermissions(permission)
|
||||
// Add permissions to role
|
||||
for resource, actions := range permissions {
|
||||
for _, action := range actions {
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
if permission, exists := permissionMap[key]; exists {
|
||||
roleCreate.AddPermissions(permission)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := roleCreate.Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed creating role %s: %w", roleName, err)
|
||||
if _, err := roleCreate.Save(ctx); err != nil {
|
||||
return fmt.Errorf("failed creating role %s: %w", roleName, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed querying role: %w", err)
|
||||
} else {
|
||||
// Update existing role's permissions
|
||||
for resource, actions := range permissions {
|
||||
for _, action := range actions {
|
||||
key := fmt.Sprintf("%s:%s", resource, action)
|
||||
if permission, exists := permissionMap[key]; exists {
|
||||
err = client.Role.UpdateOne(role).
|
||||
AddPermissions(permission).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed updating role %s permissions: %w", roleName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ func TestAssignRoleToUser(t *testing.T) {
|
|||
// Create a test user
|
||||
user, err := client.User.Create().
|
||||
SetEmail("test@example.com").
|
||||
SetUsername("testuser").
|
||||
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,31 +1,27 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"context"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/config"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// NewEntClient creates a new ent client
|
||||
func NewEntClient(cfg *config.Config) *ent.Client {
|
||||
// TODO: Implement database connection based on config
|
||||
// For now, we'll use SQLite for development
|
||||
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect to database")
|
||||
}
|
||||
// 使用配置文件中的数据库设置
|
||||
client, err := ent.Open(cfg.Database.Driver, cfg.Database.DSN)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect to database")
|
||||
}
|
||||
|
||||
// Create ent client
|
||||
client := ent.NewClient(ent.Driver(db))
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(context.Background()); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create schema resources")
|
||||
}
|
||||
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(context.Background()); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create schema resources")
|
||||
}
|
||||
|
||||
return client
|
||||
return client
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"tss-rocks-be/ent/role"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -54,59 +53,74 @@ func NewService(client *ent.Client, storage storage.Storage) Service {
|
|||
}
|
||||
|
||||
// User operations
|
||||
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
|
||||
// Hash the password
|
||||
func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
|
||||
// 验证邮箱格式
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
if !emailRegex.MatchString(email) {
|
||||
return nil, fmt.Errorf("invalid email format")
|
||||
}
|
||||
|
||||
// 验证密码长度
|
||||
if len(password) < 8 {
|
||||
return nil, fmt.Errorf("password must be at least 8 characters")
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error checking username: %v", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("username '%s' already exists", username)
|
||||
}
|
||||
|
||||
// 生成密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
return nil, fmt.Errorf("error hashing password: %v", err)
|
||||
}
|
||||
|
||||
// Add the user role by default
|
||||
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user role: %w", err)
|
||||
}
|
||||
|
||||
// If a specific role is requested and it's not "user", get that role too
|
||||
var additionalRole *ent.Role
|
||||
if roleStr != "" && roleStr != "user" {
|
||||
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create user with password and user role
|
||||
userCreate := s.client.User.Create().
|
||||
// 创建用户
|
||||
u, err := s.client.User.Create().
|
||||
SetUsername(username).
|
||||
SetEmail(email).
|
||||
SetPasswordHash(string(hashedPassword)).
|
||||
AddRoles(userRole)
|
||||
SetStatus("active").
|
||||
Save(ctx)
|
||||
|
||||
// Add the additional role if specified
|
||||
if additionalRole != nil {
|
||||
userCreate.AddRoles(additionalRole)
|
||||
}
|
||||
|
||||
// Save the user
|
||||
user, err := userCreate.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
return nil, fmt.Errorf("error creating user: %v", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
// 分配角色
|
||||
err = s.AssignRole(ctx, u.ID, roleStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error assigning role: %v", err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
|
||||
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
return nil, fmt.Errorf("error getting user: %v", err)
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
|
||||
user, err := s.client.User.Query().
|
||||
Where(user.EmailEQ(email)).
|
||||
Only(ctx)
|
||||
u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, fmt.Errorf("user not found: %s", email)
|
||||
return nil, fmt.Errorf("user with email '%s' not found", email)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
return nil, fmt.Errorf("error getting user: %v", err)
|
||||
}
|
||||
return user, nil
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
|
||||
|
|
|
@ -103,52 +103,50 @@ func newMockMultipartFile(data []byte) *mockMultipartFile {
|
|||
|
||||
func (s *ServiceImplTestSuite) TestCreateUser() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
role string
|
||||
wantError bool
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
role string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid user creation",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
role: "admin",
|
||||
wantError: false,
|
||||
name: "有效的用户",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
role: "user",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty email",
|
||||
email: "",
|
||||
password: "password123",
|
||||
role: "user",
|
||||
wantError: true,
|
||||
name: "无效的邮箱",
|
||||
username: "testuser2",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
role: "user",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Empty password",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
role: "user",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid role",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
role: "invalid_role",
|
||||
wantError: true,
|
||||
name: "空密码",
|
||||
username: "testuser3",
|
||||
email: "test3@example.com",
|
||||
password: "",
|
||||
role: "user",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role)
|
||||
if tc.wantError {
|
||||
assert.Error(s.T(), err)
|
||||
assert.Nil(s.T(), user)
|
||||
user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
|
||||
if tc.wantErr {
|
||||
s.Error(err)
|
||||
s.Nil(user)
|
||||
} else {
|
||||
assert.NoError(s.T(), err)
|
||||
assert.NotNil(s.T(), user)
|
||||
assert.Equal(s.T(), tc.email, user.Email)
|
||||
s.NoError(err)
|
||||
s.NotNil(user)
|
||||
s.Equal(tc.email, user.Email)
|
||||
s.Equal(tc.username, user.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() {
|
|||
password := "password123"
|
||||
role := "user"
|
||||
|
||||
user, err := s.svc.CreateUser(s.ctx, email, password, role)
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
|
||||
require.NoError(s.T(), err)
|
||||
require.NotNil(s.T(), user)
|
||||
|
||||
|
@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
|
|||
password := "password123"
|
||||
role := "user"
|
||||
|
||||
user, err := s.svc.CreateUser(s.ctx, email, password, role)
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
|
||||
require.NoError(s.T(), err)
|
||||
require.NotNil(s.T(), user)
|
||||
|
||||
|
@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
|
|||
|
||||
func (s *ServiceImplTestSuite) TestRBAC() {
|
||||
s.Run("AssignRole", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
err = s.svc.AssignRole(s.ctx, user.ID, "user")
|
||||
|
@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("RemoveRole", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
|
||||
|
@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
|
||||
s.Run("HasPermission", func() {
|
||||
s.Run("Admin can create users", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||||
|
@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("Editor cannot create users", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||||
|
@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("User cannot create users", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
|
||||
|
@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("Editor can create posts", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
|
||||
|
@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("User can read posts", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
|
||||
|
@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("User cannot create posts", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
|
||||
|
@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
})
|
||||
|
||||
s.Run("Invalid permission format", func() {
|
||||
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
|
||||
|
@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
|
|||
|
||||
func (s *ServiceImplTestSuite) TestCategory() {
|
||||
// Create a test user with admin role for testing
|
||||
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin")
|
||||
adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
|
||||
require.NoError(s.T(), err)
|
||||
require.NotNil(s.T(), adminUser)
|
||||
|
||||
|
@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() {
|
|||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户,默认会有 "user" 角色
|
||||
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user")
|
||||
user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 测试新用户有默认的 "user" 角色
|
||||
|
@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() {
|
|||
func (s *ServiceImplTestSuite) TestMedia() {
|
||||
s.Run("Upload Media", func() {
|
||||
// Create a user first
|
||||
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "")
|
||||
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
|
||||
require.NoError(s.T(), err)
|
||||
require.NotNil(s.T(), user)
|
||||
|
||||
|
@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
|
|||
|
||||
s.Run("Delete Media - Unauthorized", func() {
|
||||
// Create a user
|
||||
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "")
|
||||
user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Mock file content
|
||||
|
@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
|
|||
require.NoError(s.T(), err)
|
||||
|
||||
// Try to delete with different user
|
||||
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "")
|
||||
anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)
|
||||
|
|
|
@ -9,14 +9,22 @@ import (
|
|||
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/internal/storage"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// Service interface defines all business logic operations
|
||||
type Service interface {
|
||||
// User operations
|
||||
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
|
||||
CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error)
|
||||
GetUser(ctx context.Context, id int) (*ent.User, error)
|
||||
GetUserByUsername(ctx context.Context, username string) (*ent.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
|
||||
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
|
||||
VerifyPassword(ctx context.Context, userID int, password string) error
|
||||
UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error)
|
||||
DeleteUser(ctx context.Context, userID int) error
|
||||
ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error)
|
||||
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
|
||||
|
||||
// Category operations
|
||||
CreateCategory(ctx context.Context) (*ent.Category, error)
|
||||
|
@ -51,9 +59,8 @@ type Service interface {
|
|||
DeleteMedia(ctx context.Context, id int, userID int) error
|
||||
|
||||
// RBAC operations
|
||||
InitializeRBAC(ctx context.Context) error
|
||||
AssignRole(ctx context.Context, userID int, role string) error
|
||||
RemoveRole(ctx context.Context, userID int, role string) error
|
||||
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
|
||||
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
|
||||
InitializeRBAC(ctx context.Context) error
|
||||
}
|
||||
|
|
154
backend/internal/service/user.go
Normal file
154
backend/internal/service/user.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"tss-rocks-be/ent"
|
||||
"tss-rocks-be/ent/role"
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/internal/types"
|
||||
)
|
||||
|
||||
// GetUser gets a user by ID
|
||||
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
|
||||
user, err := s.client.User.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies the user's current password
|
||||
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
|
||||
user, err := s.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.ValidatePassword(ctx, user, password) {
|
||||
return fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUser updates user information
|
||||
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
|
||||
// Start building the update
|
||||
update := s.client.User.UpdateOneID(userID)
|
||||
|
||||
// Update email if provided
|
||||
if input.Email != "" {
|
||||
update.SetEmail(input.Email)
|
||||
}
|
||||
|
||||
// Update display name if provided
|
||||
if input.DisplayName != "" {
|
||||
update.SetDisplayName(input.DisplayName)
|
||||
}
|
||||
|
||||
// Update password if provided
|
||||
if input.Password != "" {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to hash password")
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
update.SetPasswordHash(string(hashedPassword))
|
||||
}
|
||||
|
||||
// Update status if provided
|
||||
if input.Status != "" {
|
||||
update.SetStatus(user.Status(input.Status))
|
||||
}
|
||||
|
||||
// Execute the update
|
||||
user, err := update.Save(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update user")
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
// Update role if provided
|
||||
if input.Role != "" {
|
||||
// Clear existing roles
|
||||
_, err = user.Update().ClearRoles().Save(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to clear user roles")
|
||||
return nil, fmt.Errorf("failed to update user roles: %w", err)
|
||||
}
|
||||
|
||||
// Add new role
|
||||
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to find role")
|
||||
return nil, fmt.Errorf("failed to find role: %w", err)
|
||||
}
|
||||
|
||||
_, err = user.Update().AddRoles(role).Save(ctx)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to add user role")
|
||||
return nil, fmt.Errorf("failed to update user roles: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user by ID
|
||||
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
|
||||
err := s.client.User.DeleteOneID(userID).Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUsers lists users with filters and pagination
|
||||
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
|
||||
query := s.client.User.Query()
|
||||
|
||||
// Apply filters
|
||||
if params.Role != "" {
|
||||
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
|
||||
}
|
||||
if params.Status != "" {
|
||||
query.Where(user.StatusEQ(user.Status(params.Status)))
|
||||
}
|
||||
if params.Email != "" {
|
||||
query.Where(user.EmailContains(params.Email))
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if params.PerPage > 0 {
|
||||
query.Limit(params.PerPage)
|
||||
if params.Page > 0 {
|
||||
query.Offset((params.Page - 1) * params.PerPage)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
if params.Sort != "" {
|
||||
switch params.Sort {
|
||||
case "email_asc":
|
||||
query.Order(ent.Asc(user.FieldEmail))
|
||||
case "email_desc":
|
||||
query.Order(ent.Desc(user.FieldEmail))
|
||||
case "created_at_asc":
|
||||
query.Order(ent.Asc(user.FieldCreatedAt))
|
||||
case "created_at_desc":
|
||||
query.Order(ent.Desc(user.FieldCreatedAt))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
users, err := query.All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
28
backend/internal/types/user.go
Normal file
28
backend/internal/types/user.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package types
|
||||
|
||||
// UpdateUserInput defines the input for updating a user
|
||||
type UpdateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Role string
|
||||
Status string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// ListUsersParams defines the parameters for listing users
|
||||
type ListUsersParams struct {
|
||||
Page int
|
||||
PerPage int
|
||||
Sort string
|
||||
Role string
|
||||
Status string
|
||||
Email string
|
||||
}
|
||||
|
||||
// CreateUserInput defines the input for creating a user
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
Role string
|
||||
DisplayName string
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue