173 lines
4.3 KiB
Go
173 lines
4.3 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"tss-rocks-be/ent"
|
|
"tss-rocks-be/ent/role"
|
|
"tss-rocks-be/ent/user"
|
|
"tss-rocks-be/internal/types"
|
|
)
|
|
|
|
// GetUser gets a user by ID
|
|
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
|
|
user, err := s.client.User.Get(ctx, id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
// VerifyPassword verifies the user's current password
|
|
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
|
|
user, err := s.GetUser(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !s.ValidatePassword(ctx, user, password) {
|
|
return fmt.Errorf("invalid password")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateUser updates user information
|
|
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
|
|
// Start building the update
|
|
update := s.client.User.UpdateOneID(userID)
|
|
|
|
// Update username if provided
|
|
if input.Username != "" {
|
|
// Check if username is already taken
|
|
exists, err := s.client.User.Query().
|
|
Where(user.And(
|
|
user.UsernameEQ(input.Username),
|
|
user.IDNEQ(userID),
|
|
)).
|
|
Exist(ctx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to check username availability")
|
|
return nil, fmt.Errorf("failed to check username availability: %w", err)
|
|
}
|
|
if exists {
|
|
return nil, fmt.Errorf("username already taken")
|
|
}
|
|
update.SetUsername(input.Username)
|
|
}
|
|
|
|
// Update email if provided
|
|
if input.Email != "" {
|
|
update.SetEmail(input.Email)
|
|
}
|
|
|
|
// Update display name if provided
|
|
if input.DisplayName != "" {
|
|
update.SetDisplayName(input.DisplayName)
|
|
}
|
|
|
|
// Update password if provided
|
|
if input.Password != "" {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to hash password")
|
|
return nil, fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
update.SetPasswordHash(string(hashedPassword))
|
|
}
|
|
|
|
// Update status if provided
|
|
if input.Status != "" {
|
|
update.SetStatus(user.Status(input.Status))
|
|
}
|
|
|
|
// Execute the update
|
|
user, err := update.Save(ctx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to update user")
|
|
return nil, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
// Update role if provided
|
|
if input.Role != "" {
|
|
// Clear existing roles
|
|
_, err = user.Update().ClearRoles().Save(ctx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to clear user roles")
|
|
return nil, fmt.Errorf("failed to update user roles: %w", err)
|
|
}
|
|
|
|
// Add new role
|
|
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to find role")
|
|
return nil, fmt.Errorf("failed to find role: %w", err)
|
|
}
|
|
|
|
_, err = user.Update().AddRoles(role).Save(ctx)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to add user role")
|
|
return nil, fmt.Errorf("failed to update user roles: %w", err)
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// DeleteUser deletes a user by ID
|
|
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
|
|
err := s.client.User.DeleteOneID(userID).Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListUsers lists users with filters and pagination
|
|
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
|
|
query := s.client.User.Query()
|
|
|
|
// Apply filters
|
|
if params.Role != "" {
|
|
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
|
|
}
|
|
if params.Status != "" {
|
|
query.Where(user.StatusEQ(user.Status(params.Status)))
|
|
}
|
|
if params.Email != "" {
|
|
query.Where(user.EmailContains(params.Email))
|
|
}
|
|
|
|
// Apply pagination
|
|
if params.PerPage > 0 {
|
|
query.Limit(params.PerPage)
|
|
if params.Page > 0 {
|
|
query.Offset((params.Page - 1) * params.PerPage)
|
|
}
|
|
}
|
|
|
|
// Apply sorting
|
|
if params.Sort != "" {
|
|
switch params.Sort {
|
|
case "email_asc":
|
|
query.Order(ent.Asc(user.FieldEmail))
|
|
case "email_desc":
|
|
query.Order(ent.Desc(user.FieldEmail))
|
|
case "created_at_asc":
|
|
query.Order(ent.Asc(user.FieldCreatedAt))
|
|
case "created_at_desc":
|
|
query.Order(ent.Desc(user.FieldCreatedAt))
|
|
}
|
|
}
|
|
|
|
// Execute query
|
|
users, err := query.All(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list users: %w", err)
|
|
}
|
|
|
|
return users, nil
|
|
}
|