[feature/backend] implement /users handler + switch to username + add display name + user management cli

This commit is contained in:
CDN 2025-02-21 04:30:07 +08:00
parent 1d712d4e6c
commit 86ab334bc9
Signed by: CDN
GPG key ID: 0C656827F9F80080
38 changed files with 1851 additions and 506 deletions

View file

@ -24,7 +24,6 @@ import (
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@ -54,59 +53,74 @@ func NewService(client *ent.Client, storage storage.Storage) Service {
}
// User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
// Hash the password
func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(email) {
return nil, fmt.Errorf("invalid email format")
}
// 验证密码长度
if len(password) < 8 {
return nil, fmt.Errorf("password must be at least 8 characters")
}
// 检查用户名是否已存在
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("error checking username: %v", err)
}
if exists {
return nil, fmt.Errorf("username '%s' already exists", username)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
return nil, fmt.Errorf("error hashing password: %v", err)
}
// Add the user role by default
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
// 创建用户
u, err := s.client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
AddRoles(userRole)
SetStatus("active").
Save(ctx)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
return nil, fmt.Errorf("error creating user: %v", err)
}
return user, nil
// 分配角色
err = s.AssignRole(ctx, u.ID, roleStr)
if err != nil {
return nil, fmt.Errorf("error assigning role: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with username '%s' not found", username)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query().
Where(user.EmailEQ(email)).
Only(ctx)
u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email)
return nil, fmt.Errorf("user with email '%s' not found", email)
}
return nil, fmt.Errorf("failed to get user: %w", err)
return nil, fmt.Errorf("error getting user: %v", err)
}
return user, nil
return u, nil
}
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {

View file

@ -103,52 +103,50 @@ func newMockMultipartFile(data []byte) *mockMultipartFile {
func (s *ServiceImplTestSuite) TestCreateUser() {
testCases := []struct {
name string
email string
password string
role string
wantError bool
name string
username string
email string
password string
role string
wantErr bool
}{
{
name: "Valid user creation",
email: "test@example.com",
password: "password123",
role: "admin",
wantError: false,
name: "有效的用户",
username: "testuser",
email: "test@example.com",
password: "password123",
role: "user",
wantErr: false,
},
{
name: "Empty email",
email: "",
password: "password123",
role: "user",
wantError: true,
name: "无效的邮箱",
username: "testuser2",
email: "invalid-email",
password: "password123",
role: "user",
wantErr: true,
},
{
name: "Empty password",
email: "test@example.com",
password: "",
role: "user",
wantError: true,
},
{
name: "Invalid role",
email: "test@example.com",
password: "password123",
role: "invalid_role",
wantError: true,
name: "空密码",
username: "testuser3",
email: "test3@example.com",
password: "",
role: "user",
wantErr: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role)
if tc.wantError {
assert.Error(s.T(), err)
assert.Nil(s.T(), user)
user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
if tc.wantErr {
s.Error(err)
s.Nil(user)
} else {
assert.NoError(s.T(), err)
assert.NotNil(s.T(), user)
assert.Equal(s.T(), tc.email, user.Email)
s.NoError(err)
s.NotNil(user)
s.Equal(tc.email, user.Email)
s.Equal(tc.username, user.Username)
}
})
}
@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() {
password := "password123"
role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role)
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
password := "password123"
role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role)
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("AssignRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
require.NoError(s.T(), err)
err = s.svc.AssignRole(s.ctx, user.ID, "user")
@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("RemoveRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
require.NoError(s.T(), err)
err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("HasPermission", func() {
s.Run("Admin can create users", func() {
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Editor cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor")
user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Editor can create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor")
user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User can read posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User cannot create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Invalid permission format", func() {
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
require.NoError(s.T(), err)
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
func (s *ServiceImplTestSuite) TestCategory() {
// Create a test user with admin role for testing
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin")
adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
require.NoError(s.T(), err)
require.NotNil(s.T(), adminUser)
@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() {
ctx := context.Background()
// 创建测试用户,默认会有 "user" 角色
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user")
user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
s.Require().NoError(err)
// 测试新用户有默认的 "user" 角色
@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() {
func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Upload Media", func() {
// Create a user first
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "")
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Delete Media - Unauthorized", func() {
// Create a user
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "")
user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
require.NoError(s.T(), err)
// Mock file content
@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
require.NoError(s.T(), err)
// Try to delete with different user
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "")
anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
require.NoError(s.T(), err)
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)

View file

@ -9,14 +9,22 @@ import (
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/types"
)
// Service interface defines all business logic operations
type Service interface {
// User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error)
GetUser(ctx context.Context, id int) (*ent.User, error)
GetUserByUsername(ctx context.Context, username string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
VerifyPassword(ctx context.Context, userID int, password string) error
UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error)
DeleteUser(ctx context.Context, userID int) error
ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error)
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
// Category operations
CreateCategory(ctx context.Context) (*ent.Category, error)
@ -51,9 +59,8 @@ type Service interface {
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations
InitializeRBAC(ctx context.Context) error
AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
}

View file

@ -0,0 +1,154 @@
package service
import (
"context"
"fmt"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/types"
)
// GetUser gets a user by ID
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
user, err := s.client.User.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// VerifyPassword verifies the user's current password
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
user, err := s.GetUser(ctx, userID)
if err != nil {
return err
}
if !s.ValidatePassword(ctx, user, password) {
return fmt.Errorf("invalid password")
}
return nil
}
// UpdateUser updates user information
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
// Start building the update
update := s.client.User.UpdateOneID(userID)
// Update email if provided
if input.Email != "" {
update.SetEmail(input.Email)
}
// Update display name if provided
if input.DisplayName != "" {
update.SetDisplayName(input.DisplayName)
}
// Update password if provided
if input.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
log.Error().Err(err).Msg("Failed to hash password")
return nil, fmt.Errorf("failed to hash password: %w", err)
}
update.SetPasswordHash(string(hashedPassword))
}
// Update status if provided
if input.Status != "" {
update.SetStatus(user.Status(input.Status))
}
// Execute the update
user, err := update.Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
return nil, fmt.Errorf("failed to update user: %w", err)
}
// Update role if provided
if input.Role != "" {
// Clear existing roles
_, err = user.Update().ClearRoles().Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to clear user roles")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
// Add new role
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to find role")
return nil, fmt.Errorf("failed to find role: %w", err)
}
_, err = user.Update().AddRoles(role).Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to add user role")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
}
return user, nil
}
// DeleteUser deletes a user by ID
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
err := s.client.User.DeleteOneID(userID).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// ListUsers lists users with filters and pagination
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
query := s.client.User.Query()
// Apply filters
if params.Role != "" {
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
}
if params.Status != "" {
query.Where(user.StatusEQ(user.Status(params.Status)))
}
if params.Email != "" {
query.Where(user.EmailContains(params.Email))
}
// Apply pagination
if params.PerPage > 0 {
query.Limit(params.PerPage)
if params.Page > 0 {
query.Offset((params.Page - 1) * params.PerPage)
}
}
// Apply sorting
if params.Sort != "" {
switch params.Sort {
case "email_asc":
query.Order(ent.Asc(user.FieldEmail))
case "email_desc":
query.Order(ent.Desc(user.FieldEmail))
case "created_at_asc":
query.Order(ent.Asc(user.FieldCreatedAt))
case "created_at_desc":
query.Order(ent.Desc(user.FieldCreatedAt))
}
}
// Execute query
users, err := query.All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
return users, nil
}