[feature/backend] implement /users handler + switch to username + add display name + user management cli

This commit is contained in:
CDN 2025-02-21 04:30:07 +08:00
parent 1d712d4e6c
commit 86ab334bc9
Signed by: CDN
GPG key ID: 0C656827F9F80080
38 changed files with 1851 additions and 506 deletions

View file

@ -60,12 +60,20 @@ User:
type: object
required:
- id
- email
- username
- role
- status
properties:
id:
type: integer
username:
type: string
minLength: 3
maxLength: 32
display_name:
type: string
maxLength: 64
description: 用户显示名称
email:
type: string
format: email
@ -74,11 +82,13 @@ User:
enum:
- admin
- editor
- contributor
status:
type: string
enum:
- active
- inactive
- banned
created_at:
type: string
format: date-time

View file

@ -70,6 +70,8 @@ components:
Daily:
$ref: './components/schemas.yaml#/Daily'
paths:
/auth/register:
$ref: './paths/auth.yaml#/register'
/auth/login:
$ref: './paths/auth.yaml#/login'
/auth/logout:

View file

@ -1,3 +1,63 @@
register:
post:
tags:
- auth
summary: 用户注册
operationId: register
security: [] # 注册接口不需要认证
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- username
- email
- password
- role
properties:
username:
type: string
minLength: 3
maxLength: 32
email:
type: string
format: email
password:
type: string
format: password
minLength: 8
role:
type: string
enum:
- admin
- editor
- contributor
responses:
'200':
description: 注册成功
content:
application/json:
schema:
type: object
required:
- token
- user
properties:
token:
type: string
user:
$ref: '../components/schemas.yaml#/User'
'400':
description: 用户名已存在
content:
application/json:
schema:
$ref: '../components/schemas.yaml#/Error'
'422':
$ref: '../components/responses.yaml#/ValidationError'
login:
post:
tags:
@ -12,12 +72,13 @@ login:
schema:
type: object
required:
- email
- username
- password
properties:
email:
username:
type: string
format: email
minLength: 3
maxLength: 32
password:
type: string
format: password

3
backend/.gitignore vendored
View file

@ -12,6 +12,9 @@ vendor/
# Build output
*.exe
# Config
config/config.yaml
# Database files
*.db
*.db-journal

View file

@ -5,7 +5,7 @@ RUN apk add --no-cache gcc musl-dev libwebp-dev
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN go build -o tss-rocks-be ./cmd/server
RUN go build -o tss-rocks-be
FROM alpine:latest
@ -13,9 +13,13 @@ RUN apk add --no-cache libwebp
RUN adduser -u 1000 -D tss-rocks
USER tss-rocks
WORKDIR /app
# 复制二进制文件和配置
COPY --from=builder /app/tss-rocks-be .
COPY --from=builder /app/config/config.yaml ./config/
EXPOSE 8080
ENV GIN_MODE=release
CMD ["./tss-rocks-be"]
# 启动服务器
CMD ["./tss-rocks-be", "server"]

89
backend/cmd/server.go Normal file
View file

@ -0,0 +1,89 @@
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func GetServerCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the server",
Long: `Start the server with the specified configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
// Load configuration
cfg, err := config.Load(configPath)
if err != nil {
return err
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
return nil
},
}
// Add flags
cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file")
return cmd
}

View file

@ -1,79 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func main() {
// Parse command line flags
configPath := flag.String("config", "config/config.yaml", "path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
}

View file

@ -1,127 +0,0 @@
package main
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
)
func TestConfigLoad(t *testing.T) {
// Create a temporary config file for testing
tmpConfig := `
database:
driver: sqlite3
dsn: ":memory:"
server:
port: 8080
host: localhost
storage:
type: local
local:
root_dir: "./testdata"
`
tmpFile, err := os.CreateTemp("", "config-*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString(tmpConfig)
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Test config loading
cfg, err := config.Load(tmpFile.Name())
require.NoError(t, err)
assert.Equal(t, "sqlite3", cfg.Database.Driver)
assert.Equal(t, ":memory:", cfg.Database.DSN)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, "localhost", cfg.Server.Host)
assert.Equal(t, "local", cfg.Storage.Type)
assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir)
}
func TestServerCreation(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 8080,
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
// Create ent client
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
// Test schema creation
err := client.Schema.Create(context.Background())
require.NoError(t, err)
// Test server creation
srv, err := server.New(cfg, client)
require.NoError(t, err)
require.NotNil(t, srv)
}
func TestServerStartAndShutdown(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 0, // Use random available port
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
err := client.Schema.Create(context.Background())
require.NoError(t, err)
srv, err := server.New(cfg, client)
require.NoError(t, err)
// Start server in goroutine
go func() {
err := srv.Start()
if err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Test graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
assert.NoError(t, err)
}

411
backend/cmd/user.go Normal file
View file

@ -0,0 +1,411 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"tss-rocks-be/ent/user"
"tss-rocks-be/ent/role"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
func GetUserCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `Commands for managing users, including create, delete, password, and role management.`,
}
configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file")
cmd.AddCommand(getCreateUserCmd(*configPath))
cmd.AddCommand(getPasswordCmd(*configPath))
cmd.AddCommand(getListUsersCmd(*configPath))
cmd.AddCommand(getDeleteUserCmd(*configPath))
cmd.AddCommand(getRoleCmd(*configPath))
return cmd
}
func getCreateUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "create",
Short: "Create a new user",
Long: `Create a new user with specified username, email and password.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
email, _ := cmd.Flags().GetString("email")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 检查用户名是否已存在
exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking username: %v", err)
}
if exists {
return fmt.Errorf("username '%s' already exists", username)
}
// 检查邮箱是否已存在
exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking email: %v", err)
}
if exists {
return fmt.Errorf("email '%s' already exists", email)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 创建用户
u, err := client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
SetStatus("active").
Save(cmd.Context())
if err != nil {
return fmt.Errorf("error creating user: %v", err)
}
fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username for the new user")
cmd.Flags().String("email", "", "email for the new user")
cmd.Flags().String("password", "", "password for the new user")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("email")
cmd.MarkFlagRequired("password")
return cmd
}
func getPasswordCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "password",
Short: "Change user password",
Long: `Change the password for an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 生成新的密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 更新密码
_, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context())
if err != nil {
return fmt.Errorf("error updating password: %v", err)
}
fmt.Printf("Successfully updated password for user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("password", "", "new password")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("password")
return cmd
}
func getListUsersCmd(configPath string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 获取所有用户
users, err := client.User.Query().All(cmd.Context())
if err != nil {
return fmt.Errorf("error querying users: %v", err)
}
// 创建表格写入器
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT")
// 输出用户信息
for _, u := range users {
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
u.Username,
u.Email,
u.Status,
u.CreatedAt.Format("2006-01-02 15:04:05"))
}
w.Flush()
return nil
},
}
}
func getDeleteUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Long: `Delete an existing user from the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
force, _ := cmd.Flags().GetBool("force")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查用户的角色
hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRoles && !force {
return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username)
}
// 删除用户
err = client.User.DeleteOne(u).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error deleting user: %v", err)
}
fmt.Printf("Successfully deleted user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user to delete")
cmd.Flags().Bool("force", false, "force delete even if user has roles")
// 设置必需的参数
cmd.MarkFlagRequired("username")
return cmd
}
func getRoleCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "role",
Short: "User role management commands",
Long: `Commands for managing user roles, including adding and removing roles.`,
}
cmd.AddCommand(getRoleAddCmd(configPath))
cmd.AddCommand(getRoleRemoveCmd(configPath))
return cmd
}
func getRoleAddCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "add",
Short: "Add a role to a user",
Long: `Add a role to an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否已存在
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRole {
return fmt.Errorf("user '%s' already has role '%s'", username, roleName)
}
// 获取要添加的角色
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 添加角色
err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error adding role: %v", err)
}
fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to add (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}
func getRoleRemoveCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "remove",
Short: "Remove a role from a user",
Long: `Remove a role from an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否存在
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 检查用户是否有该角色
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if !hasRole {
return fmt.Errorf("user '%s' does not have role '%s'", username, roleName)
}
// 移除角色
err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error removing role: %v", err)
}
fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}

View file

@ -1,91 +0,0 @@
database:
driver: sqlite3
dsn: "file:tss-rocks.db?_fk=1&cache=shared"
server:
port: 8080
host: localhost
jwt:
secret: "your-secret-key-here" # 在生产环境中应该使用环境变量
expiration: 24h # token 过期时间
logging:
level: debug # debug, info, warn, error
format: json # json, console
storage:
type: s3 # local or s3
local:
root_dir: "./storage/media"
s3:
region: "us-east-1"
bucket: "your-bucket-name"
access_key_id: "your-access-key-id"
secret_access_key: "your-secret-access-key"
endpoint: "" # Optional, for MinIO or other S3-compatible services
custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
upload:
max_size: 10 # 最大文件大小MB
allowed_types: # 允许的文件类型
- image/jpeg
- image/png
- image/gif
- image/webp
- image/svg+xml
- application/pdf
- application/msword
- application/vnd.openxmlformats-officedocument.wordprocessingml.document
- application/vnd.ms-excel
- application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
- application/zip
- application/x-rar-compressed
- text/plain
- text/csv
allowed_extensions: # 允许的文件扩展名(小写)
- .jpg
- .jpeg
- .png
- .gif
- .webp
- .svg
- .pdf
- .doc
- .docx
- .xls
- .xlsx
- .zip
- .rar
- .txt
- .csv
rate_limit:
# IP限流配置
ip_rate: 50 # 每秒请求数
ip_burst: 25 # 突发请求数
# 路由限流配置
route_rates:
"/api/v1/auth/login":
rate: 5 # 每秒5个请求
burst: 10 # 突发10个请求
"/api/v1/auth/register":
rate: 2 # 每秒2个请求
burst: 5 # 突发5个请求
"/api/v1/media/upload":
rate: 10 # 每秒10个请求
burst: 20 # 突发20个请求
access_log:
enable_console: true # 启用控制台输出
enable_file: true # 启用文件日志
file_path: "./logs/access.log" # 日志文件路径
format: "json" # 日志格式 (json 或 text)
level: "info" # 日志级别
rotation:
max_size: 100 # 每个日志文件的最大大小MB
max_age: 30 # 保留旧日志文件的最大天数
max_backups: 10 # 保留的旧日志文件的最大数量
compress: true # 是否压缩旧日志文件
local_time: true # 使用本地时间作为轮转时间

View file

@ -0,0 +1,36 @@
database:
driver: sqlite3
# SQLite DSN 说明:
# - file:tss.db => 相对路径,在当前目录创建 tss.db
# - cache=shared => 启用共享缓存,提高性能
# - _fk=1 => 启用外键约束
# - mode=rwc => 如果数据库不存在则创建read-write-create
dsn: "file:tss.db?cache=shared&_fk=1&mode=rwc"
server:
port: 8080
host: localhost
jwt:
secret: your-jwt-secret-here # 在生产环境中应该使用环境变量
expiration: 24h
storage:
driver: local
local:
root: storage
base_url: http://localhost:8080/storage
logging:
level: debug
format: console
rate_limit:
enabled: true
requests: 100
duration: 1m
access_log:
enabled: true
format: combined
output: stdout

View file

@ -373,7 +373,9 @@ var (
// UsersColumns holds the columns for the "users" table.
UsersColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true},
{Name: "email", Type: field.TypeString, Unique: true},
{Name: "username", Type: field.TypeString, Unique: true},
{Name: "display_name", Type: field.TypeString, Nullable: true, Size: 64},
{Name: "email", Type: field.TypeString},
{Name: "password_hash", Type: field.TypeString},
{Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"},
{Name: "created_at", Type: field.TypeTime},

View file

@ -9321,6 +9321,8 @@ type UserMutation struct {
op Op
typ string
id *int
username *string
display_name *string
email *string
password_hash *string
status *user.Status
@ -9439,6 +9441,91 @@ func (m *UserMutation) IDs(ctx context.Context) ([]int, error) {
}
}
// SetUsername sets the "username" field.
func (m *UserMutation) SetUsername(s string) {
m.username = &s
}
// Username returns the value of the "username" field in the mutation.
func (m *UserMutation) Username() (r string, exists bool) {
v := m.username
if v == nil {
return
}
return *v, true
}
// OldUsername returns the old "username" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldUsername(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsername is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsername requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsername: %w", err)
}
return oldValue.Username, nil
}
// ResetUsername resets all changes to the "username" field.
func (m *UserMutation) ResetUsername() {
m.username = nil
}
// SetDisplayName sets the "display_name" field.
func (m *UserMutation) SetDisplayName(s string) {
m.display_name = &s
}
// DisplayName returns the value of the "display_name" field in the mutation.
func (m *UserMutation) DisplayName() (r string, exists bool) {
v := m.display_name
if v == nil {
return
}
return *v, true
}
// OldDisplayName returns the old "display_name" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldDisplayName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldDisplayName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldDisplayName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldDisplayName: %w", err)
}
return oldValue.DisplayName, nil
}
// ClearDisplayName clears the value of the "display_name" field.
func (m *UserMutation) ClearDisplayName() {
m.display_name = nil
m.clearedFields[user.FieldDisplayName] = struct{}{}
}
// DisplayNameCleared returns if the "display_name" field was cleared in this mutation.
func (m *UserMutation) DisplayNameCleared() bool {
_, ok := m.clearedFields[user.FieldDisplayName]
return ok
}
// ResetDisplayName resets all changes to the "display_name" field.
func (m *UserMutation) ResetDisplayName() {
m.display_name = nil
delete(m.clearedFields, user.FieldDisplayName)
}
// SetEmail sets the "email" field.
func (m *UserMutation) SetEmail(s string) {
m.email = &s
@ -9815,7 +9902,13 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 5)
fields := make([]string, 0, 7)
if m.username != nil {
fields = append(fields, user.FieldUsername)
}
if m.display_name != nil {
fields = append(fields, user.FieldDisplayName)
}
if m.email != nil {
fields = append(fields, user.FieldEmail)
}
@ -9839,6 +9932,10 @@ func (m *UserMutation) Fields() []string {
// schema.
func (m *UserMutation) Field(name string) (ent.Value, bool) {
switch name {
case user.FieldUsername:
return m.Username()
case user.FieldDisplayName:
return m.DisplayName()
case user.FieldEmail:
return m.Email()
case user.FieldPasswordHash:
@ -9858,6 +9955,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
// database failed.
func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
case user.FieldUsername:
return m.OldUsername(ctx)
case user.FieldDisplayName:
return m.OldDisplayName(ctx)
case user.FieldEmail:
return m.OldEmail(ctx)
case user.FieldPasswordHash:
@ -9877,6 +9978,20 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
// type.
func (m *UserMutation) SetField(name string, value ent.Value) error {
switch name {
case user.FieldUsername:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsername(v)
return nil
case user.FieldDisplayName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetDisplayName(v)
return nil
case user.FieldEmail:
v, ok := value.(string)
if !ok {
@ -9941,7 +10056,11 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *UserMutation) ClearedFields() []string {
return nil
var fields []string
if m.FieldCleared(user.FieldDisplayName) {
fields = append(fields, user.FieldDisplayName)
}
return fields
}
// FieldCleared returns a boolean indicating if a field with the given name was
@ -9954,6 +10073,11 @@ func (m *UserMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *UserMutation) ClearField(name string) error {
switch name {
case user.FieldDisplayName:
m.ClearDisplayName()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@ -9961,6 +10085,12 @@ func (m *UserMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema.
func (m *UserMutation) ResetField(name string) error {
switch name {
case user.FieldUsername:
m.ResetUsername()
return nil
case user.FieldDisplayName:
m.ResetDisplayName()
return nil
case user.FieldEmail:
m.ResetEmail()
return nil

View file

@ -261,20 +261,28 @@ func init() {
role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time)
userFields := schema.User{}.Fields()
_ = userFields
// userDescUsername is the schema descriptor for username field.
userDescUsername := userFields[0].Descriptor()
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescDisplayName is the schema descriptor for display_name field.
userDescDisplayName := userFields[1].Descriptor()
// user.DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
user.DisplayNameValidator = userDescDisplayName.Validators[0].(func(string) error)
// userDescEmail is the schema descriptor for email field.
userDescEmail := userFields[0].Descriptor()
userDescEmail := userFields[2].Descriptor()
// user.EmailValidator is a validator for the "email" field. It is called by the builders before save.
user.EmailValidator = userDescEmail.Validators[0].(func(string) error)
// userDescPasswordHash is the schema descriptor for password_hash field.
userDescPasswordHash := userFields[1].Descriptor()
userDescPasswordHash := userFields[3].Descriptor()
// user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error)
// userDescCreatedAt is the schema descriptor for created_at field.
userDescCreatedAt := userFields[3].Descriptor()
userDescCreatedAt := userFields[5].Descriptor()
// user.DefaultCreatedAt holds the default value on creation for the created_at field.
user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time)
// userDescUpdatedAt is the schema descriptor for updated_at field.
userDescUpdatedAt := userFields[4].Descriptor()
userDescUpdatedAt := userFields[6].Descriptor()
// user.DefaultUpdatedAt holds the default value on creation for the updated_at field.
user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time)
// user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.

View file

@ -15,9 +15,14 @@ type User struct {
// Fields of the User.
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("email").
field.String("username").
Unique().
NotEmpty(),
field.String("display_name").
Optional().
MaxLen(64),
field.String("email").
NotEmpty(),
field.String("password_hash").
Sensitive().
NotEmpty(),

View file

@ -17,6 +17,10 @@ type User struct {
config `json:"-"`
// ID of the ent.
ID int `json:"id,omitempty"`
// Username holds the value of the "username" field.
Username string `json:"username,omitempty"`
// DisplayName holds the value of the "display_name" field.
DisplayName string `json:"display_name,omitempty"`
// Email holds the value of the "email" field.
Email string `json:"email,omitempty"`
// PasswordHash holds the value of the "password_hash" field.
@ -80,7 +84,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case user.FieldID:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldStatus:
case user.FieldUsername, user.FieldDisplayName, user.FieldEmail, user.FieldPasswordHash, user.FieldStatus:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@ -105,6 +109,18 @@ func (u *User) assignValues(columns []string, values []any) error {
return fmt.Errorf("unexpected type %T for field id", value)
}
u.ID = int(value.Int64)
case user.FieldUsername:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field username", values[i])
} else if value.Valid {
u.Username = value.String
}
case user.FieldDisplayName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field display_name", values[i])
} else if value.Valid {
u.DisplayName = value.String
}
case user.FieldEmail:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field email", values[i])
@ -186,6 +202,12 @@ func (u *User) String() string {
var builder strings.Builder
builder.WriteString("User(")
builder.WriteString(fmt.Sprintf("id=%v, ", u.ID))
builder.WriteString("username=")
builder.WriteString(u.Username)
builder.WriteString(", ")
builder.WriteString("display_name=")
builder.WriteString(u.DisplayName)
builder.WriteString(", ")
builder.WriteString("email=")
builder.WriteString(u.Email)
builder.WriteString(", ")

View file

@ -15,6 +15,10 @@ const (
Label = "user"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username"
// FieldDisplayName holds the string denoting the display_name field in the database.
FieldDisplayName = "display_name"
// FieldEmail holds the string denoting the email field in the database.
FieldEmail = "email"
// FieldPasswordHash holds the string denoting the password_hash field in the database.
@ -57,6 +61,8 @@ const (
// Columns holds all SQL columns for user fields.
var Columns = []string{
FieldID,
FieldUsername,
FieldDisplayName,
FieldEmail,
FieldPasswordHash,
FieldStatus,
@ -81,6 +87,10 @@ func ValidColumn(column string) bool {
}
var (
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error
// DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
DisplayNameValidator func(string) error
// EmailValidator is a validator for the "email" field. It is called by the builders before save.
EmailValidator func(string) error
// PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
@ -128,6 +138,16 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByUsername orders the results by the username field.
func ByUsername(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsername, opts...).ToFunc()
}
// ByDisplayName orders the results by the display_name field.
func ByDisplayName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDisplayName, opts...).ToFunc()
}
// ByEmail orders the results by the email field.
func ByEmail(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEmail, opts...).ToFunc()

View file

@ -55,6 +55,16 @@ func IDLTE(id int) predicate.User {
return predicate.User(sql.FieldLTE(FieldID, id))
}
// Username applies equality check predicate on the "username" field. It's identical to UsernameEQ.
func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// DisplayName applies equality check predicate on the "display_name" field. It's identical to DisplayNameEQ.
func DisplayName(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// Email applies equality check predicate on the "email" field. It's identical to EmailEQ.
func Email(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v))
@ -75,6 +85,146 @@ func UpdatedAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldUpdatedAt, v))
}
// UsernameEQ applies the EQ predicate on the "username" field.
func UsernameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// UsernameNEQ applies the NEQ predicate on the "username" field.
func UsernameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldUsername, v))
}
// UsernameIn applies the In predicate on the "username" field.
func UsernameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldUsername, vs...))
}
// UsernameNotIn applies the NotIn predicate on the "username" field.
func UsernameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldUsername, vs...))
}
// UsernameGT applies the GT predicate on the "username" field.
func UsernameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldUsername, v))
}
// UsernameGTE applies the GTE predicate on the "username" field.
func UsernameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldUsername, v))
}
// UsernameLT applies the LT predicate on the "username" field.
func UsernameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldUsername, v))
}
// UsernameLTE applies the LTE predicate on the "username" field.
func UsernameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldUsername, v))
}
// UsernameContains applies the Contains predicate on the "username" field.
func UsernameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldUsername, v))
}
// UsernameHasPrefix applies the HasPrefix predicate on the "username" field.
func UsernameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldUsername, v))
}
// UsernameHasSuffix applies the HasSuffix predicate on the "username" field.
func UsernameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldUsername, v))
}
// UsernameEqualFold applies the EqualFold predicate on the "username" field.
func UsernameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldUsername, v))
}
// UsernameContainsFold applies the ContainsFold predicate on the "username" field.
func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
}
// DisplayNameEQ applies the EQ predicate on the "display_name" field.
func DisplayNameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// DisplayNameNEQ applies the NEQ predicate on the "display_name" field.
func DisplayNameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldDisplayName, v))
}
// DisplayNameIn applies the In predicate on the "display_name" field.
func DisplayNameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldDisplayName, vs...))
}
// DisplayNameNotIn applies the NotIn predicate on the "display_name" field.
func DisplayNameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldDisplayName, vs...))
}
// DisplayNameGT applies the GT predicate on the "display_name" field.
func DisplayNameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldDisplayName, v))
}
// DisplayNameGTE applies the GTE predicate on the "display_name" field.
func DisplayNameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldDisplayName, v))
}
// DisplayNameLT applies the LT predicate on the "display_name" field.
func DisplayNameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldDisplayName, v))
}
// DisplayNameLTE applies the LTE predicate on the "display_name" field.
func DisplayNameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldDisplayName, v))
}
// DisplayNameContains applies the Contains predicate on the "display_name" field.
func DisplayNameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldDisplayName, v))
}
// DisplayNameHasPrefix applies the HasPrefix predicate on the "display_name" field.
func DisplayNameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldDisplayName, v))
}
// DisplayNameHasSuffix applies the HasSuffix predicate on the "display_name" field.
func DisplayNameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldDisplayName, v))
}
// DisplayNameIsNil applies the IsNil predicate on the "display_name" field.
func DisplayNameIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldDisplayName))
}
// DisplayNameNotNil applies the NotNil predicate on the "display_name" field.
func DisplayNameNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldDisplayName))
}
// DisplayNameEqualFold applies the EqualFold predicate on the "display_name" field.
func DisplayNameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldDisplayName, v))
}
// DisplayNameContainsFold applies the ContainsFold predicate on the "display_name" field.
func DisplayNameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldDisplayName, v))
}
// EmailEQ applies the EQ predicate on the "email" field.
func EmailEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v))

View file

@ -23,6 +23,26 @@ type UserCreate struct {
hooks []Hook
}
// SetUsername sets the "username" field.
func (uc *UserCreate) SetUsername(s string) *UserCreate {
uc.mutation.SetUsername(s)
return uc
}
// SetDisplayName sets the "display_name" field.
func (uc *UserCreate) SetDisplayName(s string) *UserCreate {
uc.mutation.SetDisplayName(s)
return uc
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uc *UserCreate) SetNillableDisplayName(s *string) *UserCreate {
if s != nil {
uc.SetDisplayName(*s)
}
return uc
}
// SetEmail sets the "email" field.
func (uc *UserCreate) SetEmail(s string) *UserCreate {
uc.mutation.SetEmail(s)
@ -173,6 +193,19 @@ func (uc *UserCreate) defaults() {
// check runs all checks and user-defined validators on the builder.
func (uc *UserCreate) check() error {
if _, ok := uc.mutation.Username(); !ok {
return &ValidationError{Name: "username", err: errors.New(`ent: missing required field "User.username"`)}
}
if v, ok := uc.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uc.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if _, ok := uc.mutation.Email(); !ok {
return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)}
}
@ -229,6 +262,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_node = &User{config: uc.config}
_spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
)
if value, ok := uc.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value
}
if value, ok := uc.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
_node.DisplayName = value
}
if value, ok := uc.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
_node.Email = value

View file

@ -371,12 +371,12 @@ func (uq *UserQuery) WithMedia(opts ...func(*MediaQuery)) *UserQuery {
// Example:
//
// var v []struct {
// Email string `json:"email,omitempty"`
// Username string `json:"username,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.User.Query().
// GroupBy(user.FieldEmail).
// GroupBy(user.FieldUsername).
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
@ -394,11 +394,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
// Example:
//
// var v []struct {
// Email string `json:"email,omitempty"`
// Username string `json:"username,omitempty"`
// }
//
// client.User.Query().
// Select(user.FieldEmail).
// Select(user.FieldUsername).
// Scan(ctx, &v)
func (uq *UserQuery) Select(fields ...string) *UserSelect {
uq.ctx.Fields = append(uq.ctx.Fields, fields...)

View file

@ -31,6 +31,40 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate {
return uu
}
// SetUsername sets the "username" field.
func (uu *UserUpdate) SetUsername(s string) *UserUpdate {
uu.mutation.SetUsername(s)
return uu
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uu *UserUpdate) SetNillableUsername(s *string) *UserUpdate {
if s != nil {
uu.SetUsername(*s)
}
return uu
}
// SetDisplayName sets the "display_name" field.
func (uu *UserUpdate) SetDisplayName(s string) *UserUpdate {
uu.mutation.SetDisplayName(s)
return uu
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uu *UserUpdate) SetNillableDisplayName(s *string) *UserUpdate {
if s != nil {
uu.SetDisplayName(*s)
}
return uu
}
// ClearDisplayName clears the value of the "display_name" field.
func (uu *UserUpdate) ClearDisplayName() *UserUpdate {
uu.mutation.ClearDisplayName()
return uu
}
// SetEmail sets the "email" field.
func (uu *UserUpdate) SetEmail(s string) *UserUpdate {
uu.mutation.SetEmail(s)
@ -244,6 +278,16 @@ func (uu *UserUpdate) defaults() {
// check runs all checks and user-defined validators on the builder.
func (uu *UserUpdate) check() error {
if v, ok := uu.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uu.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uu.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -274,6 +318,15 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
}
}
if value, ok := uu.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uu.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uu.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uu.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
}
@ -444,6 +497,40 @@ type UserUpdateOne struct {
mutation *UserMutation
}
// SetUsername sets the "username" field.
func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne {
uuo.mutation.SetUsername(s)
return uuo
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableUsername(s *string) *UserUpdateOne {
if s != nil {
uuo.SetUsername(*s)
}
return uuo
}
// SetDisplayName sets the "display_name" field.
func (uuo *UserUpdateOne) SetDisplayName(s string) *UserUpdateOne {
uuo.mutation.SetDisplayName(s)
return uuo
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableDisplayName(s *string) *UserUpdateOne {
if s != nil {
uuo.SetDisplayName(*s)
}
return uuo
}
// ClearDisplayName clears the value of the "display_name" field.
func (uuo *UserUpdateOne) ClearDisplayName() *UserUpdateOne {
uuo.mutation.ClearDisplayName()
return uuo
}
// SetEmail sets the "email" field.
func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne {
uuo.mutation.SetEmail(s)
@ -670,6 +757,16 @@ func (uuo *UserUpdateOne) defaults() {
// check runs all checks and user-defined validators on the builder.
func (uuo *UserUpdateOne) check() error {
if v, ok := uuo.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uuo.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uuo.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -717,6 +814,15 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
}
}
}
if value, ok := uuo.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uuo.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uuo.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uuo.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
}

View file

@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.24
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
go.uber.org/mock v0.5.0
golang.org/x/crypto v0.33.0
@ -55,6 +56,7 @@ require (
github.com/goccy/go-json v0.10.5 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/hashicorp/hcl/v2 v2.23.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
@ -65,6 +67,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect

View file

@ -59,6 +59,7 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -130,6 +131,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View file

@ -7,16 +7,18 @@ import (
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor contributor"`
}
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Username string `json:"username" binding:"required,min=3,max=32"`
Password string `json:"password" binding:"required"`
}
@ -31,7 +33,7 @@ func (h *Handler) Register(c *gin.Context) {
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role)
user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
if err != nil {
log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
@ -76,14 +78,20 @@ func (h *Handler) Login(c *gin.Context) {
return
}
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email)
user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password",
})
return
}
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
// 验证密码
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password",
})
return
}
@ -91,7 +99,9 @@ func (h *Handler) Login(c *gin.Context) {
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil {
log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user roles",
})
return
}
@ -111,7 +121,9 @@ func (h *Handler) Login(c *gin.Context) {
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil {
log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to generate token",
})
return
}

View file

@ -3,10 +3,11 @@ package handler
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock"
@ -14,6 +15,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
)
type AuthHandlerTestSuite struct {
@ -54,20 +56,21 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{
name: "成功注册",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT().
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor").
Return(user, nil)
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(&ent.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
},
expectedStatus: http.StatusCreated,
@ -75,6 +78,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{
name: "无效的邮箱格式",
request: RegisterRequest{
Username: "testuser",
Email: "invalid-email",
Password: "password123",
Role: "contributor",
@ -86,6 +90,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{
name: "密码太短",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "short",
Role: "contributor",
@ -97,6 +102,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{
name: "无效的角色",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "invalid-role",
@ -151,91 +157,95 @@ func (s *AuthHandlerTestSuite) TestLogin() {
{
name: "成功登录",
request: LoginRequest{
Email: "test@example.com",
Username: "testuser",
Password: "password123",
},
setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Email: "test@example.com",
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
Return([]*ent.Role{{Name: "admin"}}, nil)
},
expectedStatus: http.StatusOK,
},
{
name: "无效的邮箱格式",
name: "无效的用户名",
request: LoginRequest{
Email: "invalid-email",
Username: "invalid",
Password: "password123",
},
setupMock: func() {},
expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
setupMock: func() {
s.service.EXPECT().
GetUserByUsername(gomock.Any(), "invalid").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
},
{
name: "用户不存在",
request: LoginRequest{
Email: "nonexistent@example.com",
Username: "nonexistent",
Password: "password123",
},
setupMock: func() {
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "nonexistent@example.com").
Return(nil, errors.New("user not found"))
GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials",
expectedError: "Invalid username or password",
},
{
name: "密码错误",
request: LoginRequest{
Email: "test@example.com",
Username: "testuser",
Password: "wrong-password",
},
setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Email: "test@example.com",
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "wrong-password").
Return(false)
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials",
expectedError: "Invalid username or password",
},
{
name: "获取用户角色失败",
request: LoginRequest{
Email: "test@example.com",
Username: "testuser",
Password: "password123",
},
setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{
ID: 1,
Email: "test@example.com",
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}
s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com").
GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return(nil, errors.New("failed to get roles"))
Return(nil, fmt.Errorf("failed to get roles"))
},
expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles",

View file

@ -34,6 +34,18 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
auth.POST("/login", h.Login)
}
// User routes
users := api.Group("/users")
{
users.GET("", h.ListUsers)
users.POST("", h.CreateUser)
users.GET("/:id", h.GetUser)
users.PUT("/:id", h.UpdateUser)
users.DELETE("/:id", h.DeleteUser)
users.GET("/me", h.GetCurrentUser)
users.PUT("/me", h.UpdateCurrentUser)
}
// Category routes
categories := api.Group("/categories")
{

View file

@ -0,0 +1,227 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/types"
)
type UpdateCurrentUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
CurrentPassword string `json:"current_password,omitempty"`
NewPassword string `json:"new_password,omitempty" binding:"omitempty,min=8"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type UpdateUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
Password string `json:"password,omitempty" binding:"omitempty,min=8"`
Role string `json:"role,omitempty" binding:"omitempty,oneof=admin editor"`
Status string `json:"status,omitempty" binding:"omitempty,oneof=active inactive"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
// ListUsers returns a list of users
func (h *Handler) ListUsers(c *gin.Context) {
// Parse query parameters
params := &types.ListUsersParams{
Page: 1,
PerPage: 10,
}
if page := c.Query("page"); page != "" {
if p, err := strconv.Atoi(page); err == nil && p > 0 {
params.Page = p
}
}
if perPage := c.Query("per_page"); perPage != "" {
if pp, err := strconv.Atoi(perPage); err == nil && pp > 0 {
params.PerPage = pp
}
}
params.Sort = c.Query("sort")
params.Role = c.Query("role")
params.Status = c.Query("status")
params.Email = c.Query("email")
// Get users
users, err := h.service.ListUsers(c.Request.Context(), params)
if err != nil {
log.Error().Err(err).Msg("Failed to list users")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
})
}
// CreateUser creates a new user
func (h *Handler) CreateUser(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Email, req.Password, req.Role)
if err != nil {
log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
c.JSON(http.StatusCreated, gin.H{
"data": user,
})
}
// GetUser returns user details
func (h *Handler) GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
user, err := h.service.GetUser(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// UpdateUser updates user information
func (h *Handler) UpdateUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.UpdateUser(c.Request.Context(), id, &types.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Role: req.Role,
Status: req.Status,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// DeleteUser deletes a user
func (h *Handler) DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
if err := h.service.DeleteUser(c.Request.Context(), id); err != nil {
log.Error().Err(err).Msg("Failed to delete user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"})
return
}
c.Status(http.StatusNoContent)
}
// GetCurrentUser returns the current user's information
func (h *Handler) GetCurrentUser(c *gin.Context) {
// 从上下文中获取用户ID由认证中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// 获取用户信息
user, err := h.service.GetUser(c.Request.Context(), userID.(int))
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user information"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// UpdateCurrentUser updates the current user's information
func (h *Handler) UpdateCurrentUser(c *gin.Context) {
// 从上下文中获取用户ID由认证中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
var req UpdateCurrentUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 如果要更新密码,需要验证当前密码
if req.NewPassword != "" {
if req.CurrentPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is required to update password"})
return
}
// 验证当前密码
if err := h.service.VerifyPassword(c.Request.Context(), userID.(int), req.CurrentPassword); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current password"})
return
}
}
// 更新用户信息
user, err := h.service.UpdateUser(c.Request.Context(), userID.(int), &types.UpdateUserInput{
Email: req.Email,
Password: req.NewPassword,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user information"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
@ -78,26 +79,41 @@ func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
// 配置文件日志
if config.EnableFile {
// 确保日志目录存在
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil {
// 验证文件路径
if config.FilePath == "" {
return nil, fmt.Errorf("file path cannot be empty")
}
// 验证路径是否包含无效字符
if strings.ContainsAny(config.FilePath, "\x00") {
return nil, fmt.Errorf("file path contains invalid characters")
}
dir := filepath.Dir(config.FilePath)
// 检查目录是否存在或是否可以创建
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
// 配置日志轮转
// 尝试打开或创建文件,验证路径是否有效且有写入权限
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, fmt.Errorf("failed to open or create log file: %w", err)
}
file.Close()
// 配置文件日志
logWriter = &lumberjack.Logger{
Filename: config.FilePath,
MaxSize: config.Rotation.MaxSize, // MB
MaxAge: config.Rotation.MaxAge, // days
MaxBackups: config.Rotation.MaxBackups, // files
MaxBackups: config.Rotation.MaxBackups, // 文件个数
MaxAge: config.Rotation.MaxAge, // 天数
Compress: config.Rotation.Compress, // 是否压缩
LocalTime: config.Rotation.LocalTime, // 使用本地时间
}
logger := zerolog.New(logWriter).
With().
Timestamp().
Logger()
logger := zerolog.New(logWriter).With().Timestamp().Logger()
fileLogger = &logger
}

View file

@ -219,7 +219,7 @@ func TestAccessLogInvalidConfig(t *testing.T) {
name: "Invalid file path",
config: &types.AccessLogConfig{
EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径
FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
},
expectedError: true,
},

View file

@ -5,6 +5,7 @@ import (
"fmt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/role"
)
@ -38,37 +39,69 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
permissionMap := make(map[string]*ent.Permission)
for resource, actions := range DefaultPermissions {
for _, action := range actions {
permission, err := client.Permission.Create().
SetResource(resource).
SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
Save(ctx)
if err != nil {
return fmt.Errorf("failed creating permission: %w", err)
}
key := fmt.Sprintf("%s:%s", resource, action)
permission, err := client.Permission.Query().
Where(
permission.ResourceEQ(resource),
permission.ActionEQ(action),
).
Only(ctx)
if ent.IsNotFound(err) {
permission, err = client.Permission.Create().
SetResource(resource).
SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
Save(ctx)
if err != nil {
return fmt.Errorf("failed creating permission: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed querying permission: %w", err)
}
permissionMap[key] = permission
}
}
// Create roles with permissions
for roleName, permissions := range DefaultRoles {
roleCreate := client.Role.Create().
SetName(roleName).
SetDescription(fmt.Sprintf("Role for %s users", roleName))
role, err := client.Role.Query().
Where(role.NameEQ(roleName)).
Only(ctx)
if ent.IsNotFound(err) {
roleCreate := client.Role.Create().
SetName(roleName).
SetDescription(fmt.Sprintf("Role for %s users", roleName))
// Add permissions to role
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
roleCreate.AddPermissions(permission)
// Add permissions to role
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
roleCreate.AddPermissions(permission)
}
}
}
}
if _, err := roleCreate.Save(ctx); err != nil {
return fmt.Errorf("failed creating role %s: %w", roleName, err)
if _, err := roleCreate.Save(ctx); err != nil {
return fmt.Errorf("failed creating role %s: %w", roleName, err)
}
} else if err != nil {
return fmt.Errorf("failed querying role: %w", err)
} else {
// Update existing role's permissions
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
err = client.Role.UpdateOne(role).
AddPermissions(permission).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed updating role %s permissions: %w", roleName, err)
}
}
}
}
}
}

View file

@ -64,6 +64,7 @@ func TestAssignRoleToUser(t *testing.T) {
// Create a test user
user, err := client.User.Create().
SetEmail("test@example.com").
SetUsername("testuser").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx)
if err != nil {

View file

@ -1,31 +1,27 @@
package server
import (
"context"
"context"
"entgo.io/ent/dialect/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
"tss-rocks-be/ent"
"tss-rocks-be/internal/config"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
)
// NewEntClient creates a new ent client
func NewEntClient(cfg *config.Config) *ent.Client {
// TODO: Implement database connection based on config
// For now, we'll use SQLite for development
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
if err != nil {
log.Fatal().Err(err).Msg("Failed to connect to database")
}
// 使用配置文件中的数据库设置
client, err := ent.Open(cfg.Database.Driver, cfg.Database.DSN)
if err != nil {
log.Fatal().Err(err).Msg("Failed to connect to database")
}
// Create ent client
client := ent.NewClient(ent.Driver(db))
// Run the auto migration tool
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Run the auto migration tool
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
return client
return client
}

View file

@ -24,7 +24,6 @@ import (
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@ -54,59 +53,74 @@ func NewService(client *ent.Client, storage storage.Storage) Service {
}
// User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) {
// Hash the password
func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(email) {
return nil, fmt.Errorf("invalid email format")
}
// 验证密码长度
if len(password) < 8 {
return nil, fmt.Errorf("password must be at least 8 characters")
}
// 检查用户名是否已存在
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("error checking username: %v", err)
}
if exists {
return nil, fmt.Errorf("username '%s' already exists", username)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
return nil, fmt.Errorf("error hashing password: %v", err)
}
// Add the user role by default
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
// 创建用户
u, err := s.client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
AddRoles(userRole)
SetStatus("active").
Save(ctx)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
return nil, fmt.Errorf("error creating user: %v", err)
}
return user, nil
// 分配角色
err = s.AssignRole(ctx, u.ID, roleStr)
if err != nil {
return nil, fmt.Errorf("error assigning role: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with username '%s' not found", username)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query().
Where(user.EmailEQ(email)).
Only(ctx)
u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email)
return nil, fmt.Errorf("user with email '%s' not found", email)
}
return nil, fmt.Errorf("failed to get user: %w", err)
return nil, fmt.Errorf("error getting user: %v", err)
}
return user, nil
return u, nil
}
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {

View file

@ -103,52 +103,50 @@ func newMockMultipartFile(data []byte) *mockMultipartFile {
func (s *ServiceImplTestSuite) TestCreateUser() {
testCases := []struct {
name string
email string
password string
role string
wantError bool
name string
username string
email string
password string
role string
wantErr bool
}{
{
name: "Valid user creation",
email: "test@example.com",
password: "password123",
role: "admin",
wantError: false,
name: "有效的用户",
username: "testuser",
email: "test@example.com",
password: "password123",
role: "user",
wantErr: false,
},
{
name: "Empty email",
email: "",
password: "password123",
role: "user",
wantError: true,
name: "无效的邮箱",
username: "testuser2",
email: "invalid-email",
password: "password123",
role: "user",
wantErr: true,
},
{
name: "Empty password",
email: "test@example.com",
password: "",
role: "user",
wantError: true,
},
{
name: "Invalid role",
email: "test@example.com",
password: "password123",
role: "invalid_role",
wantError: true,
name: "空密码",
username: "testuser3",
email: "test3@example.com",
password: "",
role: "user",
wantErr: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role)
if tc.wantError {
assert.Error(s.T(), err)
assert.Nil(s.T(), user)
user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
if tc.wantErr {
s.Error(err)
s.Nil(user)
} else {
assert.NoError(s.T(), err)
assert.NotNil(s.T(), user)
assert.Equal(s.T(), tc.email, user.Email)
s.NoError(err)
s.NotNil(user)
s.Equal(tc.email, user.Email)
s.Equal(tc.username, user.Username)
}
})
}
@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() {
password := "password123"
role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role)
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
password := "password123"
role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role)
user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("AssignRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
require.NoError(s.T(), err)
err = s.svc.AssignRole(s.ctx, user.ID, "user")
@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("RemoveRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
require.NoError(s.T(), err)
err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("HasPermission", func() {
s.Run("Admin can create users", func() {
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin")
user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Editor cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor")
user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Editor can create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor")
user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User can read posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("User cannot create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
})
s.Run("Invalid permission format", func() {
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user")
user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
require.NoError(s.T(), err)
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
func (s *ServiceImplTestSuite) TestCategory() {
// Create a test user with admin role for testing
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin")
adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
require.NoError(s.T(), err)
require.NotNil(s.T(), adminUser)
@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() {
ctx := context.Background()
// 创建测试用户,默认会有 "user" 角色
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user")
user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
s.Require().NoError(err)
// 测试新用户有默认的 "user" 角色
@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() {
func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Upload Media", func() {
// Create a user first
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "")
user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
require.NoError(s.T(), err)
require.NotNil(s.T(), user)
@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Delete Media - Unauthorized", func() {
// Create a user
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "")
user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
require.NoError(s.T(), err)
// Mock file content
@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
require.NoError(s.T(), err)
// Try to delete with different user
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "")
anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
require.NoError(s.T(), err)
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)

View file

@ -9,14 +9,22 @@ import (
"tss-rocks-be/ent"
"tss-rocks-be/internal/storage"
"tss-rocks-be/internal/types"
)
// Service interface defines all business logic operations
type Service interface {
// User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error)
CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error)
GetUser(ctx context.Context, id int) (*ent.User, error)
GetUserByUsername(ctx context.Context, username string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool
VerifyPassword(ctx context.Context, userID int, password string) error
UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error)
DeleteUser(ctx context.Context, userID int) error
ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error)
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
// Category operations
CreateCategory(ctx context.Context) (*ent.Category, error)
@ -51,9 +59,8 @@ type Service interface {
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations
InitializeRBAC(ctx context.Context) error
AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
}

View file

@ -0,0 +1,154 @@
package service
import (
"context"
"fmt"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/types"
)
// GetUser gets a user by ID
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
user, err := s.client.User.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// VerifyPassword verifies the user's current password
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
user, err := s.GetUser(ctx, userID)
if err != nil {
return err
}
if !s.ValidatePassword(ctx, user, password) {
return fmt.Errorf("invalid password")
}
return nil
}
// UpdateUser updates user information
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
// Start building the update
update := s.client.User.UpdateOneID(userID)
// Update email if provided
if input.Email != "" {
update.SetEmail(input.Email)
}
// Update display name if provided
if input.DisplayName != "" {
update.SetDisplayName(input.DisplayName)
}
// Update password if provided
if input.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
log.Error().Err(err).Msg("Failed to hash password")
return nil, fmt.Errorf("failed to hash password: %w", err)
}
update.SetPasswordHash(string(hashedPassword))
}
// Update status if provided
if input.Status != "" {
update.SetStatus(user.Status(input.Status))
}
// Execute the update
user, err := update.Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
return nil, fmt.Errorf("failed to update user: %w", err)
}
// Update role if provided
if input.Role != "" {
// Clear existing roles
_, err = user.Update().ClearRoles().Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to clear user roles")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
// Add new role
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to find role")
return nil, fmt.Errorf("failed to find role: %w", err)
}
_, err = user.Update().AddRoles(role).Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to add user role")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
}
return user, nil
}
// DeleteUser deletes a user by ID
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
err := s.client.User.DeleteOneID(userID).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// ListUsers lists users with filters and pagination
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
query := s.client.User.Query()
// Apply filters
if params.Role != "" {
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
}
if params.Status != "" {
query.Where(user.StatusEQ(user.Status(params.Status)))
}
if params.Email != "" {
query.Where(user.EmailContains(params.Email))
}
// Apply pagination
if params.PerPage > 0 {
query.Limit(params.PerPage)
if params.Page > 0 {
query.Offset((params.Page - 1) * params.PerPage)
}
}
// Apply sorting
if params.Sort != "" {
switch params.Sort {
case "email_asc":
query.Order(ent.Asc(user.FieldEmail))
case "email_desc":
query.Order(ent.Desc(user.FieldEmail))
case "created_at_asc":
query.Order(ent.Asc(user.FieldCreatedAt))
case "created_at_desc":
query.Order(ent.Desc(user.FieldCreatedAt))
}
}
// Execute query
users, err := query.All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
return users, nil
}

View file

@ -0,0 +1,28 @@
package types
// UpdateUserInput defines the input for updating a user
type UpdateUserInput struct {
Email string
Password string
Role string
Status string
DisplayName string
}
// ListUsersParams defines the parameters for listing users
type ListUsersParams struct {
Page int
PerPage int
Sort string
Role string
Status string
Email string
}
// CreateUserInput defines the input for creating a user
type CreateUserInput struct {
Email string
Password string
Role string
DisplayName string
}

29
backend/main.go Normal file
View file

@ -0,0 +1,29 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"tss-rocks-be/cmd"
)
var rootCmd = &cobra.Command{
Use: "app",
Short: "TSS backend application",
Long: `TSS (The Starset Society) backend application.
It provides both API endpoints and admin tools for content management.`,
}
func init() {
// 添加子命令
rootCmd.AddCommand(cmd.GetUserCmd())
rootCmd.AddCommand(cmd.GetServerCmd())
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}