[feature/backend] implement /users handler + switch to username + add display name + user management cli

This commit is contained in:
CDN 2025-02-21 04:30:07 +08:00
parent 1d712d4e6c
commit 86ab334bc9
Signed by: CDN
GPG key ID: 0C656827F9F80080
38 changed files with 1851 additions and 506 deletions

View file

@ -60,12 +60,20 @@ User:
type: object type: object
required: required:
- id - id
- email - username
- role - role
- status - status
properties: properties:
id: id:
type: integer type: integer
username:
type: string
minLength: 3
maxLength: 32
display_name:
type: string
maxLength: 64
description: 用户显示名称
email: email:
type: string type: string
format: email format: email
@ -74,11 +82,13 @@ User:
enum: enum:
- admin - admin
- editor - editor
- contributor
status: status:
type: string type: string
enum: enum:
- active - active
- inactive - inactive
- banned
created_at: created_at:
type: string type: string
format: date-time format: date-time

View file

@ -70,6 +70,8 @@ components:
Daily: Daily:
$ref: './components/schemas.yaml#/Daily' $ref: './components/schemas.yaml#/Daily'
paths: paths:
/auth/register:
$ref: './paths/auth.yaml#/register'
/auth/login: /auth/login:
$ref: './paths/auth.yaml#/login' $ref: './paths/auth.yaml#/login'
/auth/logout: /auth/logout:

View file

@ -1,3 +1,63 @@
register:
post:
tags:
- auth
summary: 用户注册
operationId: register
security: [] # 注册接口不需要认证
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- username
- email
- password
- role
properties:
username:
type: string
minLength: 3
maxLength: 32
email:
type: string
format: email
password:
type: string
format: password
minLength: 8
role:
type: string
enum:
- admin
- editor
- contributor
responses:
'200':
description: 注册成功
content:
application/json:
schema:
type: object
required:
- token
- user
properties:
token:
type: string
user:
$ref: '../components/schemas.yaml#/User'
'400':
description: 用户名已存在
content:
application/json:
schema:
$ref: '../components/schemas.yaml#/Error'
'422':
$ref: '../components/responses.yaml#/ValidationError'
login: login:
post: post:
tags: tags:
@ -12,12 +72,13 @@ login:
schema: schema:
type: object type: object
required: required:
- email - username
- password - password
properties: properties:
email: username:
type: string type: string
format: email minLength: 3
maxLength: 32
password: password:
type: string type: string
format: password format: password

3
backend/.gitignore vendored
View file

@ -12,6 +12,9 @@ vendor/
# Build output # Build output
*.exe *.exe
# Config
config/config.yaml
# Database files # Database files
*.db *.db
*.db-journal *.db-journal

View file

@ -5,7 +5,7 @@ RUN apk add --no-cache gcc musl-dev libwebp-dev
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
RUN go build -o tss-rocks-be ./cmd/server RUN go build -o tss-rocks-be
FROM alpine:latest FROM alpine:latest
@ -13,9 +13,13 @@ RUN apk add --no-cache libwebp
RUN adduser -u 1000 -D tss-rocks RUN adduser -u 1000 -D tss-rocks
USER tss-rocks USER tss-rocks
WORKDIR /app WORKDIR /app
# 复制二进制文件和配置
COPY --from=builder /app/tss-rocks-be . COPY --from=builder /app/tss-rocks-be .
COPY --from=builder /app/config/config.yaml ./config/
EXPOSE 8080 EXPOSE 8080
ENV GIN_MODE=release ENV GIN_MODE=release
CMD ["./tss-rocks-be"] # 启动服务器
CMD ["./tss-rocks-be", "server"]

89
backend/cmd/server.go Normal file
View file

@ -0,0 +1,89 @@
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func GetServerCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the server",
Long: `Start the server with the specified configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
// Load configuration
cfg, err := config.Load(configPath)
if err != nil {
return err
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
return nil
},
}
// Add flags
cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file")
return cmd
}

View file

@ -1,79 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func main() {
// Parse command line flags
configPath := flag.String("config", "config/config.yaml", "path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
}

View file

@ -1,127 +0,0 @@
package main
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
)
func TestConfigLoad(t *testing.T) {
// Create a temporary config file for testing
tmpConfig := `
database:
driver: sqlite3
dsn: ":memory:"
server:
port: 8080
host: localhost
storage:
type: local
local:
root_dir: "./testdata"
`
tmpFile, err := os.CreateTemp("", "config-*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString(tmpConfig)
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Test config loading
cfg, err := config.Load(tmpFile.Name())
require.NoError(t, err)
assert.Equal(t, "sqlite3", cfg.Database.Driver)
assert.Equal(t, ":memory:", cfg.Database.DSN)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, "localhost", cfg.Server.Host)
assert.Equal(t, "local", cfg.Storage.Type)
assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir)
}
func TestServerCreation(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 8080,
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
// Create ent client
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
// Test schema creation
err := client.Schema.Create(context.Background())
require.NoError(t, err)
// Test server creation
srv, err := server.New(cfg, client)
require.NoError(t, err)
require.NotNil(t, srv)
}
func TestServerStartAndShutdown(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 0, // Use random available port
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
err := client.Schema.Create(context.Background())
require.NoError(t, err)
srv, err := server.New(cfg, client)
require.NoError(t, err)
// Start server in goroutine
go func() {
err := srv.Start()
if err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Test graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
assert.NoError(t, err)
}

411
backend/cmd/user.go Normal file
View file

@ -0,0 +1,411 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"tss-rocks-be/ent/user"
"tss-rocks-be/ent/role"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
func GetUserCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `Commands for managing users, including create, delete, password, and role management.`,
}
configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file")
cmd.AddCommand(getCreateUserCmd(*configPath))
cmd.AddCommand(getPasswordCmd(*configPath))
cmd.AddCommand(getListUsersCmd(*configPath))
cmd.AddCommand(getDeleteUserCmd(*configPath))
cmd.AddCommand(getRoleCmd(*configPath))
return cmd
}
func getCreateUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "create",
Short: "Create a new user",
Long: `Create a new user with specified username, email and password.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
email, _ := cmd.Flags().GetString("email")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 检查用户名是否已存在
exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking username: %v", err)
}
if exists {
return fmt.Errorf("username '%s' already exists", username)
}
// 检查邮箱是否已存在
exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking email: %v", err)
}
if exists {
return fmt.Errorf("email '%s' already exists", email)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 创建用户
u, err := client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
SetStatus("active").
Save(cmd.Context())
if err != nil {
return fmt.Errorf("error creating user: %v", err)
}
fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username for the new user")
cmd.Flags().String("email", "", "email for the new user")
cmd.Flags().String("password", "", "password for the new user")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("email")
cmd.MarkFlagRequired("password")
return cmd
}
func getPasswordCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "password",
Short: "Change user password",
Long: `Change the password for an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 生成新的密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 更新密码
_, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context())
if err != nil {
return fmt.Errorf("error updating password: %v", err)
}
fmt.Printf("Successfully updated password for user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("password", "", "new password")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("password")
return cmd
}
func getListUsersCmd(configPath string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 获取所有用户
users, err := client.User.Query().All(cmd.Context())
if err != nil {
return fmt.Errorf("error querying users: %v", err)
}
// 创建表格写入器
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT")
// 输出用户信息
for _, u := range users {
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
u.Username,
u.Email,
u.Status,
u.CreatedAt.Format("2006-01-02 15:04:05"))
}
w.Flush()
return nil
},
}
}
func getDeleteUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Long: `Delete an existing user from the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
force, _ := cmd.Flags().GetBool("force")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查用户的角色
hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRoles && !force {
return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username)
}
// 删除用户
err = client.User.DeleteOne(u).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error deleting user: %v", err)
}
fmt.Printf("Successfully deleted user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user to delete")
cmd.Flags().Bool("force", false, "force delete even if user has roles")
// 设置必需的参数
cmd.MarkFlagRequired("username")
return cmd
}
func getRoleCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "role",
Short: "User role management commands",
Long: `Commands for managing user roles, including adding and removing roles.`,
}
cmd.AddCommand(getRoleAddCmd(configPath))
cmd.AddCommand(getRoleRemoveCmd(configPath))
return cmd
}
func getRoleAddCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "add",
Short: "Add a role to a user",
Long: `Add a role to an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否已存在
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRole {
return fmt.Errorf("user '%s' already has role '%s'", username, roleName)
}
// 获取要添加的角色
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 添加角色
err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error adding role: %v", err)
}
fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to add (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}
func getRoleRemoveCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "remove",
Short: "Remove a role from a user",
Long: `Remove a role from an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否存在
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 检查用户是否有该角色
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if !hasRole {
return fmt.Errorf("user '%s' does not have role '%s'", username, roleName)
}
// 移除角色
err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error removing role: %v", err)
}
fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}

View file

@ -1,91 +0,0 @@
database:
driver: sqlite3
dsn: "file:tss-rocks.db?_fk=1&cache=shared"
server:
port: 8080
host: localhost
jwt:
secret: "your-secret-key-here" # 在生产环境中应该使用环境变量
expiration: 24h # token 过期时间
logging:
level: debug # debug, info, warn, error
format: json # json, console
storage:
type: s3 # local or s3
local:
root_dir: "./storage/media"
s3:
region: "us-east-1"
bucket: "your-bucket-name"
access_key_id: "your-access-key-id"
secret_access_key: "your-secret-access-key"
endpoint: "" # Optional, for MinIO or other S3-compatible services
custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
upload:
max_size: 10 # 最大文件大小MB
allowed_types: # 允许的文件类型
- image/jpeg
- image/png
- image/gif
- image/webp
- image/svg+xml
- application/pdf
- application/msword
- application/vnd.openxmlformats-officedocument.wordprocessingml.document
- application/vnd.ms-excel
- application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
- application/zip
- application/x-rar-compressed
- text/plain
- text/csv
allowed_extensions: # 允许的文件扩展名(小写)
- .jpg
- .jpeg
- .png
- .gif
- .webp
- .svg
- .pdf
- .doc
- .docx
- .xls
- .xlsx
- .zip
- .rar
- .txt
- .csv
rate_limit:
# IP限流配置
ip_rate: 50 # 每秒请求数
ip_burst: 25 # 突发请求数
# 路由限流配置
route_rates:
"/api/v1/auth/login":
rate: 5 # 每秒5个请求
burst: 10 # 突发10个请求
"/api/v1/auth/register":
rate: 2 # 每秒2个请求
burst: 5 # 突发5个请求
"/api/v1/media/upload":
rate: 10 # 每秒10个请求
burst: 20 # 突发20个请求
access_log:
enable_console: true # 启用控制台输出
enable_file: true # 启用文件日志
file_path: "./logs/access.log" # 日志文件路径
format: "json" # 日志格式 (json 或 text)
level: "info" # 日志级别
rotation:
max_size: 100 # 每个日志文件的最大大小MB
max_age: 30 # 保留旧日志文件的最大天数
max_backups: 10 # 保留的旧日志文件的最大数量
compress: true # 是否压缩旧日志文件
local_time: true # 使用本地时间作为轮转时间

View file

@ -0,0 +1,36 @@
database:
driver: sqlite3
# SQLite DSN 说明:
# - file:tss.db => 相对路径,在当前目录创建 tss.db
# - cache=shared => 启用共享缓存,提高性能
# - _fk=1 => 启用外键约束
# - mode=rwc => 如果数据库不存在则创建read-write-create
dsn: "file:tss.db?cache=shared&_fk=1&mode=rwc"
server:
port: 8080
host: localhost
jwt:
secret: your-jwt-secret-here # 在生产环境中应该使用环境变量
expiration: 24h
storage:
driver: local
local:
root: storage
base_url: http://localhost:8080/storage
logging:
level: debug
format: console
rate_limit:
enabled: true
requests: 100
duration: 1m
access_log:
enabled: true
format: combined
output: stdout

View file

@ -373,7 +373,9 @@ var (
// UsersColumns holds the columns for the "users" table. // UsersColumns holds the columns for the "users" table.
UsersColumns = []*schema.Column{ UsersColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true}, {Name: "id", Type: field.TypeInt, Increment: true},
{Name: "email", Type: field.TypeString, Unique: true}, {Name: "username", Type: field.TypeString, Unique: true},
{Name: "display_name", Type: field.TypeString, Nullable: true, Size: 64},
{Name: "email", Type: field.TypeString},
{Name: "password_hash", Type: field.TypeString}, {Name: "password_hash", Type: field.TypeString},
{Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"}, {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"},
{Name: "created_at", Type: field.TypeTime}, {Name: "created_at", Type: field.TypeTime},

View file

@ -9321,6 +9321,8 @@ type UserMutation struct {
op Op op Op
typ string typ string
id *int id *int
username *string
display_name *string
email *string email *string
password_hash *string password_hash *string
status *user.Status status *user.Status
@ -9439,6 +9441,91 @@ func (m *UserMutation) IDs(ctx context.Context) ([]int, error) {
} }
} }
// SetUsername sets the "username" field.
func (m *UserMutation) SetUsername(s string) {
m.username = &s
}
// Username returns the value of the "username" field in the mutation.
func (m *UserMutation) Username() (r string, exists bool) {
v := m.username
if v == nil {
return
}
return *v, true
}
// OldUsername returns the old "username" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldUsername(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsername is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsername requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsername: %w", err)
}
return oldValue.Username, nil
}
// ResetUsername resets all changes to the "username" field.
func (m *UserMutation) ResetUsername() {
m.username = nil
}
// SetDisplayName sets the "display_name" field.
func (m *UserMutation) SetDisplayName(s string) {
m.display_name = &s
}
// DisplayName returns the value of the "display_name" field in the mutation.
func (m *UserMutation) DisplayName() (r string, exists bool) {
v := m.display_name
if v == nil {
return
}
return *v, true
}
// OldDisplayName returns the old "display_name" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldDisplayName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldDisplayName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldDisplayName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldDisplayName: %w", err)
}
return oldValue.DisplayName, nil
}
// ClearDisplayName clears the value of the "display_name" field.
func (m *UserMutation) ClearDisplayName() {
m.display_name = nil
m.clearedFields[user.FieldDisplayName] = struct{}{}
}
// DisplayNameCleared returns if the "display_name" field was cleared in this mutation.
func (m *UserMutation) DisplayNameCleared() bool {
_, ok := m.clearedFields[user.FieldDisplayName]
return ok
}
// ResetDisplayName resets all changes to the "display_name" field.
func (m *UserMutation) ResetDisplayName() {
m.display_name = nil
delete(m.clearedFields, user.FieldDisplayName)
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (m *UserMutation) SetEmail(s string) { func (m *UserMutation) SetEmail(s string) {
m.email = &s m.email = &s
@ -9815,7 +9902,13 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UserMutation) Fields() []string { func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 5) fields := make([]string, 0, 7)
if m.username != nil {
fields = append(fields, user.FieldUsername)
}
if m.display_name != nil {
fields = append(fields, user.FieldDisplayName)
}
if m.email != nil { if m.email != nil {
fields = append(fields, user.FieldEmail) fields = append(fields, user.FieldEmail)
} }
@ -9839,6 +9932,10 @@ func (m *UserMutation) Fields() []string {
// schema. // schema.
func (m *UserMutation) Field(name string) (ent.Value, bool) { func (m *UserMutation) Field(name string) (ent.Value, bool) {
switch name { switch name {
case user.FieldUsername:
return m.Username()
case user.FieldDisplayName:
return m.DisplayName()
case user.FieldEmail: case user.FieldEmail:
return m.Email() return m.Email()
case user.FieldPasswordHash: case user.FieldPasswordHash:
@ -9858,6 +9955,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
// database failed. // database failed.
func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name { switch name {
case user.FieldUsername:
return m.OldUsername(ctx)
case user.FieldDisplayName:
return m.OldDisplayName(ctx)
case user.FieldEmail: case user.FieldEmail:
return m.OldEmail(ctx) return m.OldEmail(ctx)
case user.FieldPasswordHash: case user.FieldPasswordHash:
@ -9877,6 +9978,20 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
// type. // type.
func (m *UserMutation) SetField(name string, value ent.Value) error { func (m *UserMutation) SetField(name string, value ent.Value) error {
switch name { switch name {
case user.FieldUsername:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsername(v)
return nil
case user.FieldDisplayName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetDisplayName(v)
return nil
case user.FieldEmail: case user.FieldEmail:
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
@ -9941,7 +10056,11 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this // ClearedFields returns all nullable fields that were cleared during this
// mutation. // mutation.
func (m *UserMutation) ClearedFields() []string { func (m *UserMutation) ClearedFields() []string {
return nil var fields []string
if m.FieldCleared(user.FieldDisplayName) {
fields = append(fields, user.FieldDisplayName)
}
return fields
} }
// FieldCleared returns a boolean indicating if a field with the given name was // FieldCleared returns a boolean indicating if a field with the given name was
@ -9954,6 +10073,11 @@ func (m *UserMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an // ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema. // error if the field is not defined in the schema.
func (m *UserMutation) ClearField(name string) error { func (m *UserMutation) ClearField(name string) error {
switch name {
case user.FieldDisplayName:
m.ClearDisplayName()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name) return fmt.Errorf("unknown User nullable field %s", name)
} }
@ -9961,6 +10085,12 @@ func (m *UserMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema. // It returns an error if the field is not defined in the schema.
func (m *UserMutation) ResetField(name string) error { func (m *UserMutation) ResetField(name string) error {
switch name { switch name {
case user.FieldUsername:
m.ResetUsername()
return nil
case user.FieldDisplayName:
m.ResetDisplayName()
return nil
case user.FieldEmail: case user.FieldEmail:
m.ResetEmail() m.ResetEmail()
return nil return nil

View file

@ -261,20 +261,28 @@ func init() {
role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time) role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time)
userFields := schema.User{}.Fields() userFields := schema.User{}.Fields()
_ = userFields _ = userFields
// userDescUsername is the schema descriptor for username field.
userDescUsername := userFields[0].Descriptor()
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescDisplayName is the schema descriptor for display_name field.
userDescDisplayName := userFields[1].Descriptor()
// user.DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
user.DisplayNameValidator = userDescDisplayName.Validators[0].(func(string) error)
// userDescEmail is the schema descriptor for email field. // userDescEmail is the schema descriptor for email field.
userDescEmail := userFields[0].Descriptor() userDescEmail := userFields[2].Descriptor()
// user.EmailValidator is a validator for the "email" field. It is called by the builders before save. // user.EmailValidator is a validator for the "email" field. It is called by the builders before save.
user.EmailValidator = userDescEmail.Validators[0].(func(string) error) user.EmailValidator = userDescEmail.Validators[0].(func(string) error)
// userDescPasswordHash is the schema descriptor for password_hash field. // userDescPasswordHash is the schema descriptor for password_hash field.
userDescPasswordHash := userFields[1].Descriptor() userDescPasswordHash := userFields[3].Descriptor()
// user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. // user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error) user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error)
// userDescCreatedAt is the schema descriptor for created_at field. // userDescCreatedAt is the schema descriptor for created_at field.
userDescCreatedAt := userFields[3].Descriptor() userDescCreatedAt := userFields[5].Descriptor()
// user.DefaultCreatedAt holds the default value on creation for the created_at field. // user.DefaultCreatedAt holds the default value on creation for the created_at field.
user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time)
// userDescUpdatedAt is the schema descriptor for updated_at field. // userDescUpdatedAt is the schema descriptor for updated_at field.
userDescUpdatedAt := userFields[4].Descriptor() userDescUpdatedAt := userFields[6].Descriptor()
// user.DefaultUpdatedAt holds the default value on creation for the updated_at field. // user.DefaultUpdatedAt holds the default value on creation for the updated_at field.
user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time) user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time)
// user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. // user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.

View file

@ -15,9 +15,14 @@ type User struct {
// Fields of the User. // Fields of the User.
func (User) Fields() []ent.Field { func (User) Fields() []ent.Field {
return []ent.Field{ return []ent.Field{
field.String("email"). field.String("username").
Unique(). Unique().
NotEmpty(), NotEmpty(),
field.String("display_name").
Optional().
MaxLen(64),
field.String("email").
NotEmpty(),
field.String("password_hash"). field.String("password_hash").
Sensitive(). Sensitive().
NotEmpty(), NotEmpty(),

View file

@ -17,6 +17,10 @@ type User struct {
config `json:"-"` config `json:"-"`
// ID of the ent. // ID of the ent.
ID int `json:"id,omitempty"` ID int `json:"id,omitempty"`
// Username holds the value of the "username" field.
Username string `json:"username,omitempty"`
// DisplayName holds the value of the "display_name" field.
DisplayName string `json:"display_name,omitempty"`
// Email holds the value of the "email" field. // Email holds the value of the "email" field.
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// PasswordHash holds the value of the "password_hash" field. // PasswordHash holds the value of the "password_hash" field.
@ -80,7 +84,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case user.FieldID: case user.FieldID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldStatus: case user.FieldUsername, user.FieldDisplayName, user.FieldEmail, user.FieldPasswordHash, user.FieldStatus:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt: case user.FieldCreatedAt, user.FieldUpdatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -105,6 +109,18 @@ func (u *User) assignValues(columns []string, values []any) error {
return fmt.Errorf("unexpected type %T for field id", value) return fmt.Errorf("unexpected type %T for field id", value)
} }
u.ID = int(value.Int64) u.ID = int(value.Int64)
case user.FieldUsername:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field username", values[i])
} else if value.Valid {
u.Username = value.String
}
case user.FieldDisplayName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field display_name", values[i])
} else if value.Valid {
u.DisplayName = value.String
}
case user.FieldEmail: case user.FieldEmail:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field email", values[i]) return fmt.Errorf("unexpected type %T for field email", values[i])
@ -186,6 +202,12 @@ func (u *User) String() string {
var builder strings.Builder var builder strings.Builder
builder.WriteString("User(") builder.WriteString("User(")
builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) builder.WriteString(fmt.Sprintf("id=%v, ", u.ID))
builder.WriteString("username=")
builder.WriteString(u.Username)
builder.WriteString(", ")
builder.WriteString("display_name=")
builder.WriteString(u.DisplayName)
builder.WriteString(", ")
builder.WriteString("email=") builder.WriteString("email=")
builder.WriteString(u.Email) builder.WriteString(u.Email)
builder.WriteString(", ") builder.WriteString(", ")

View file

@ -15,6 +15,10 @@ const (
Label = "user" Label = "user"
// FieldID holds the string denoting the id field in the database. // FieldID holds the string denoting the id field in the database.
FieldID = "id" FieldID = "id"
// FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username"
// FieldDisplayName holds the string denoting the display_name field in the database.
FieldDisplayName = "display_name"
// FieldEmail holds the string denoting the email field in the database. // FieldEmail holds the string denoting the email field in the database.
FieldEmail = "email" FieldEmail = "email"
// FieldPasswordHash holds the string denoting the password_hash field in the database. // FieldPasswordHash holds the string denoting the password_hash field in the database.
@ -57,6 +61,8 @@ const (
// Columns holds all SQL columns for user fields. // Columns holds all SQL columns for user fields.
var Columns = []string{ var Columns = []string{
FieldID, FieldID,
FieldUsername,
FieldDisplayName,
FieldEmail, FieldEmail,
FieldPasswordHash, FieldPasswordHash,
FieldStatus, FieldStatus,
@ -81,6 +87,10 @@ func ValidColumn(column string) bool {
} }
var ( var (
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error
// DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
DisplayNameValidator func(string) error
// EmailValidator is a validator for the "email" field. It is called by the builders before save. // EmailValidator is a validator for the "email" field. It is called by the builders before save.
EmailValidator func(string) error EmailValidator func(string) error
// PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. // PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
@ -128,6 +138,16 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc() return sql.OrderByField(FieldID, opts...).ToFunc()
} }
// ByUsername orders the results by the username field.
func ByUsername(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsername, opts...).ToFunc()
}
// ByDisplayName orders the results by the display_name field.
func ByDisplayName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDisplayName, opts...).ToFunc()
}
// ByEmail orders the results by the email field. // ByEmail orders the results by the email field.
func ByEmail(opts ...sql.OrderTermOption) OrderOption { func ByEmail(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEmail, opts...).ToFunc() return sql.OrderByField(FieldEmail, opts...).ToFunc()

View file

@ -55,6 +55,16 @@ func IDLTE(id int) predicate.User {
return predicate.User(sql.FieldLTE(FieldID, id)) return predicate.User(sql.FieldLTE(FieldID, id))
} }
// Username applies equality check predicate on the "username" field. It's identical to UsernameEQ.
func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// DisplayName applies equality check predicate on the "display_name" field. It's identical to DisplayNameEQ.
func DisplayName(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// Email applies equality check predicate on the "email" field. It's identical to EmailEQ. // Email applies equality check predicate on the "email" field. It's identical to EmailEQ.
func Email(v string) predicate.User { func Email(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v)) return predicate.User(sql.FieldEQ(FieldEmail, v))
@ -75,6 +85,146 @@ func UpdatedAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) return predicate.User(sql.FieldEQ(FieldUpdatedAt, v))
} }
// UsernameEQ applies the EQ predicate on the "username" field.
func UsernameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// UsernameNEQ applies the NEQ predicate on the "username" field.
func UsernameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldUsername, v))
}
// UsernameIn applies the In predicate on the "username" field.
func UsernameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldUsername, vs...))
}
// UsernameNotIn applies the NotIn predicate on the "username" field.
func UsernameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldUsername, vs...))
}
// UsernameGT applies the GT predicate on the "username" field.
func UsernameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldUsername, v))
}
// UsernameGTE applies the GTE predicate on the "username" field.
func UsernameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldUsername, v))
}
// UsernameLT applies the LT predicate on the "username" field.
func UsernameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldUsername, v))
}
// UsernameLTE applies the LTE predicate on the "username" field.
func UsernameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldUsername, v))
}
// UsernameContains applies the Contains predicate on the "username" field.
func UsernameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldUsername, v))
}
// UsernameHasPrefix applies the HasPrefix predicate on the "username" field.
func UsernameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldUsername, v))
}
// UsernameHasSuffix applies the HasSuffix predicate on the "username" field.
func UsernameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldUsername, v))
}
// UsernameEqualFold applies the EqualFold predicate on the "username" field.
func UsernameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldUsername, v))
}
// UsernameContainsFold applies the ContainsFold predicate on the "username" field.
func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
}
// DisplayNameEQ applies the EQ predicate on the "display_name" field.
func DisplayNameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// DisplayNameNEQ applies the NEQ predicate on the "display_name" field.
func DisplayNameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldDisplayName, v))
}
// DisplayNameIn applies the In predicate on the "display_name" field.
func DisplayNameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldDisplayName, vs...))
}
// DisplayNameNotIn applies the NotIn predicate on the "display_name" field.
func DisplayNameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldDisplayName, vs...))
}
// DisplayNameGT applies the GT predicate on the "display_name" field.
func DisplayNameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldDisplayName, v))
}
// DisplayNameGTE applies the GTE predicate on the "display_name" field.
func DisplayNameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldDisplayName, v))
}
// DisplayNameLT applies the LT predicate on the "display_name" field.
func DisplayNameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldDisplayName, v))
}
// DisplayNameLTE applies the LTE predicate on the "display_name" field.
func DisplayNameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldDisplayName, v))
}
// DisplayNameContains applies the Contains predicate on the "display_name" field.
func DisplayNameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldDisplayName, v))
}
// DisplayNameHasPrefix applies the HasPrefix predicate on the "display_name" field.
func DisplayNameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldDisplayName, v))
}
// DisplayNameHasSuffix applies the HasSuffix predicate on the "display_name" field.
func DisplayNameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldDisplayName, v))
}
// DisplayNameIsNil applies the IsNil predicate on the "display_name" field.
func DisplayNameIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldDisplayName))
}
// DisplayNameNotNil applies the NotNil predicate on the "display_name" field.
func DisplayNameNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldDisplayName))
}
// DisplayNameEqualFold applies the EqualFold predicate on the "display_name" field.
func DisplayNameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldDisplayName, v))
}
// DisplayNameContainsFold applies the ContainsFold predicate on the "display_name" field.
func DisplayNameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldDisplayName, v))
}
// EmailEQ applies the EQ predicate on the "email" field. // EmailEQ applies the EQ predicate on the "email" field.
func EmailEQ(v string) predicate.User { func EmailEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v)) return predicate.User(sql.FieldEQ(FieldEmail, v))

View file

@ -23,6 +23,26 @@ type UserCreate struct {
hooks []Hook hooks []Hook
} }
// SetUsername sets the "username" field.
func (uc *UserCreate) SetUsername(s string) *UserCreate {
uc.mutation.SetUsername(s)
return uc
}
// SetDisplayName sets the "display_name" field.
func (uc *UserCreate) SetDisplayName(s string) *UserCreate {
uc.mutation.SetDisplayName(s)
return uc
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uc *UserCreate) SetNillableDisplayName(s *string) *UserCreate {
if s != nil {
uc.SetDisplayName(*s)
}
return uc
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uc *UserCreate) SetEmail(s string) *UserCreate { func (uc *UserCreate) SetEmail(s string) *UserCreate {
uc.mutation.SetEmail(s) uc.mutation.SetEmail(s)
@ -173,6 +193,19 @@ func (uc *UserCreate) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uc *UserCreate) check() error { func (uc *UserCreate) check() error {
if _, ok := uc.mutation.Username(); !ok {
return &ValidationError{Name: "username", err: errors.New(`ent: missing required field "User.username"`)}
}
if v, ok := uc.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uc.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if _, ok := uc.mutation.Email(); !ok { if _, ok := uc.mutation.Email(); !ok {
return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)} return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)}
} }
@ -229,6 +262,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_node = &User{config: uc.config} _node = &User{config: uc.config}
_spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) _spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
) )
if value, ok := uc.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value
}
if value, ok := uc.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
_node.DisplayName = value
}
if value, ok := uc.mutation.Email(); ok { if value, ok := uc.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
_node.Email = value _node.Email = value

View file

@ -371,12 +371,12 @@ func (uq *UserQuery) WithMedia(opts ...func(*MediaQuery)) *UserQuery {
// Example: // Example:
// //
// var v []struct { // var v []struct {
// Email string `json:"email,omitempty"` // Username string `json:"username,omitempty"`
// Count int `json:"count,omitempty"` // Count int `json:"count,omitempty"`
// } // }
// //
// client.User.Query(). // client.User.Query().
// GroupBy(user.FieldEmail). // GroupBy(user.FieldUsername).
// Aggregate(ent.Count()). // Aggregate(ent.Count()).
// Scan(ctx, &v) // Scan(ctx, &v)
func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
@ -394,11 +394,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
// Example: // Example:
// //
// var v []struct { // var v []struct {
// Email string `json:"email,omitempty"` // Username string `json:"username,omitempty"`
// } // }
// //
// client.User.Query(). // client.User.Query().
// Select(user.FieldEmail). // Select(user.FieldUsername).
// Scan(ctx, &v) // Scan(ctx, &v)
func (uq *UserQuery) Select(fields ...string) *UserSelect { func (uq *UserQuery) Select(fields ...string) *UserSelect {
uq.ctx.Fields = append(uq.ctx.Fields, fields...) uq.ctx.Fields = append(uq.ctx.Fields, fields...)

View file

@ -31,6 +31,40 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate {
return uu return uu
} }
// SetUsername sets the "username" field.
func (uu *UserUpdate) SetUsername(s string) *UserUpdate {
uu.mutation.SetUsername(s)
return uu
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uu *UserUpdate) SetNillableUsername(s *string) *UserUpdate {
if s != nil {
uu.SetUsername(*s)
}
return uu
}
// SetDisplayName sets the "display_name" field.
func (uu *UserUpdate) SetDisplayName(s string) *UserUpdate {
uu.mutation.SetDisplayName(s)
return uu
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uu *UserUpdate) SetNillableDisplayName(s *string) *UserUpdate {
if s != nil {
uu.SetDisplayName(*s)
}
return uu
}
// ClearDisplayName clears the value of the "display_name" field.
func (uu *UserUpdate) ClearDisplayName() *UserUpdate {
uu.mutation.ClearDisplayName()
return uu
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uu *UserUpdate) SetEmail(s string) *UserUpdate { func (uu *UserUpdate) SetEmail(s string) *UserUpdate {
uu.mutation.SetEmail(s) uu.mutation.SetEmail(s)
@ -244,6 +278,16 @@ func (uu *UserUpdate) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uu *UserUpdate) check() error { func (uu *UserUpdate) check() error {
if v, ok := uu.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uu.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uu.mutation.Email(); ok { if v, ok := uu.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil { if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -274,6 +318,15 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
} }
} }
} }
if value, ok := uu.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uu.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uu.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uu.mutation.Email(); ok { if value, ok := uu.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
} }
@ -444,6 +497,40 @@ type UserUpdateOne struct {
mutation *UserMutation mutation *UserMutation
} }
// SetUsername sets the "username" field.
func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne {
uuo.mutation.SetUsername(s)
return uuo
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableUsername(s *string) *UserUpdateOne {
if s != nil {
uuo.SetUsername(*s)
}
return uuo
}
// SetDisplayName sets the "display_name" field.
func (uuo *UserUpdateOne) SetDisplayName(s string) *UserUpdateOne {
uuo.mutation.SetDisplayName(s)
return uuo
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableDisplayName(s *string) *UserUpdateOne {
if s != nil {
uuo.SetDisplayName(*s)
}
return uuo
}
// ClearDisplayName clears the value of the "display_name" field.
func (uuo *UserUpdateOne) ClearDisplayName() *UserUpdateOne {
uuo.mutation.ClearDisplayName()
return uuo
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne { func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne {
uuo.mutation.SetEmail(s) uuo.mutation.SetEmail(s)
@ -670,6 +757,16 @@ func (uuo *UserUpdateOne) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uuo *UserUpdateOne) check() error { func (uuo *UserUpdateOne) check() error {
if v, ok := uuo.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uuo.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uuo.mutation.Email(); ok { if v, ok := uuo.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil { if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -717,6 +814,15 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
} }
} }
} }
if value, ok := uuo.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uuo.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uuo.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uuo.mutation.Email(); ok { if value, ok := uuo.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
} }

View file

@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.24 github.com/mattn/go-sqlite3 v1.14.24
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
go.uber.org/mock v0.5.0 go.uber.org/mock v0.5.0
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
@ -55,6 +56,7 @@ require (
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.5 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/hashicorp/hcl/v2 v2.23.0 // indirect github.com/hashicorp/hcl/v2 v2.23.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
@ -65,6 +67,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect

View file

@ -59,6 +59,7 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -130,6 +131,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View file

@ -7,16 +7,18 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
) )
type RegisterRequest struct { type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"` Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor contributor"` Role string `json:"role" binding:"required,oneof=admin editor contributor"`
} }
type LoginRequest struct { type LoginRequest struct {
Email string `json:"email" binding:"required,email"` Username string `json:"username" binding:"required,min=3,max=32"`
Password string `json:"password" binding:"required"` Password string `json:"password" binding:"required"`
} }
@ -31,7 +33,7 @@ func (h *Handler) Register(c *gin.Context) {
return return
} }
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role) user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to create user") log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
@ -76,14 +78,20 @@ func (h *Handler) Login(c *gin.Context) {
return return
} }
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email) user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password",
})
return return
} }
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) { // 验证密码
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid username or password",
})
return return
} }
@ -91,7 +99,9 @@ func (h *Handler) Login(c *gin.Context) {
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get user roles") log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user roles",
})
return return
} }
@ -111,7 +121,9 @@ func (h *Handler) Login(c *gin.Context) {
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to generate token") log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to generate token",
})
return return
} }

View file

@ -3,10 +3,11 @@ package handler
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
@ -14,6 +15,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
) )
type AuthHandlerTestSuite struct { type AuthHandlerTestSuite struct {
@ -54,20 +56,21 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{ {
name: "成功注册", name: "成功注册",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "password123", Password: "password123",
Role: "contributor", Role: "contributor",
}, },
setupMock: func() { setupMock: func() {
user := &ent.User{ s.service.EXPECT().
CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(&ent.User{
ID: 1, ID: 1,
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
} }, nil)
s.service.EXPECT(). s.service.EXPECT().
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor"). GetUserRoles(gomock.Any(), 1).
Return(user, nil)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
}, },
expectedStatus: http.StatusCreated, expectedStatus: http.StatusCreated,
@ -75,6 +78,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{ {
name: "无效的邮箱格式", name: "无效的邮箱格式",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "invalid-email", Email: "invalid-email",
Password: "password123", Password: "password123",
Role: "contributor", Role: "contributor",
@ -86,6 +90,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{ {
name: "密码太短", name: "密码太短",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "short", Password: "short",
Role: "contributor", Role: "contributor",
@ -97,6 +102,7 @@ func (s *AuthHandlerTestSuite) TestRegister() {
{ {
name: "无效的角色", name: "无效的角色",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "password123", Password: "password123",
Role: "invalid-role", Role: "invalid-role",
@ -151,91 +157,95 @@ func (s *AuthHandlerTestSuite) TestLogin() {
{ {
name: "成功登录", name: "成功登录",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{ user := &ent.User{
ID: 1, ID: 1,
Email: "test@example.com", Username: "testuser",
PasswordHash: string(hashedPassword),
} }
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT(). s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID). GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) Return([]*ent.Role{{Name: "admin"}}, nil)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "无效的邮箱格式", name: "无效的用户名",
request: LoginRequest{ request: LoginRequest{
Email: "invalid-email", Username: "invalid",
Password: "password123", Password: "password123",
}, },
setupMock: func() {}, setupMock: func() {
expectedStatus: http.StatusBadRequest, s.service.EXPECT().
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", GetUserByUsername(gomock.Any(), "invalid").
Return(nil, fmt.Errorf("user not found"))
},
expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid username or password",
}, },
{ {
name: "用户不存在", name: "用户不存在",
request: LoginRequest{ request: LoginRequest{
Email: "nonexistent@example.com", Username: "nonexistent",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "nonexistent@example.com"). GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, errors.New("user not found")) Return(nil, fmt.Errorf("user not found"))
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials", expectedError: "Invalid username or password",
}, },
{ {
name: "密码错误", name: "密码错误",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "wrong-password", Password: "wrong-password",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{ user := &ent.User{
ID: 1, ID: 1,
Email: "test@example.com", Username: "testuser",
PasswordHash: string(hashedPassword),
} }
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "wrong-password").
Return(false)
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials", expectedError: "Invalid username or password",
}, },
{ {
name: "获取用户角色失败", name: "获取用户角色失败",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
// 使用 bcrypt 生成正确的密码哈希
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
user := &ent.User{ user := &ent.User{
ID: 1, ID: 1,
Email: "test@example.com", Username: "testuser",
PasswordHash: string(hashedPassword),
} }
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(user, nil)
s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123").
Return(true)
s.service.EXPECT(). s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID). GetUserRoles(gomock.Any(), user.ID).
Return(nil, errors.New("failed to get roles")) Return(nil, fmt.Errorf("failed to get roles"))
}, },
expectedStatus: http.StatusInternalServerError, expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles", expectedError: "Failed to get user roles",

View file

@ -34,6 +34,18 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
auth.POST("/login", h.Login) auth.POST("/login", h.Login)
} }
// User routes
users := api.Group("/users")
{
users.GET("", h.ListUsers)
users.POST("", h.CreateUser)
users.GET("/:id", h.GetUser)
users.PUT("/:id", h.UpdateUser)
users.DELETE("/:id", h.DeleteUser)
users.GET("/me", h.GetCurrentUser)
users.PUT("/me", h.UpdateCurrentUser)
}
// Category routes // Category routes
categories := api.Group("/categories") categories := api.Group("/categories")
{ {

View file

@ -0,0 +1,227 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/types"
)
type UpdateCurrentUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
CurrentPassword string `json:"current_password,omitempty"`
NewPassword string `json:"new_password,omitempty" binding:"omitempty,min=8"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type UpdateUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
Password string `json:"password,omitempty" binding:"omitempty,min=8"`
Role string `json:"role,omitempty" binding:"omitempty,oneof=admin editor"`
Status string `json:"status,omitempty" binding:"omitempty,oneof=active inactive"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
// ListUsers returns a list of users
func (h *Handler) ListUsers(c *gin.Context) {
// Parse query parameters
params := &types.ListUsersParams{
Page: 1,
PerPage: 10,
}
if page := c.Query("page"); page != "" {
if p, err := strconv.Atoi(page); err == nil && p > 0 {
params.Page = p
}
}
if perPage := c.Query("per_page"); perPage != "" {
if pp, err := strconv.Atoi(perPage); err == nil && pp > 0 {
params.PerPage = pp
}
}
params.Sort = c.Query("sort")
params.Role = c.Query("role")
params.Status = c.Query("status")
params.Email = c.Query("email")
// Get users
users, err := h.service.ListUsers(c.Request.Context(), params)
if err != nil {
log.Error().Err(err).Msg("Failed to list users")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
})
}
// CreateUser creates a new user
func (h *Handler) CreateUser(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Email, req.Password, req.Role)
if err != nil {
log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
c.JSON(http.StatusCreated, gin.H{
"data": user,
})
}
// GetUser returns user details
func (h *Handler) GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
user, err := h.service.GetUser(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// UpdateUser updates user information
func (h *Handler) UpdateUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.UpdateUser(c.Request.Context(), id, &types.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Role: req.Role,
Status: req.Status,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// DeleteUser deletes a user
func (h *Handler) DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
if err := h.service.DeleteUser(c.Request.Context(), id); err != nil {
log.Error().Err(err).Msg("Failed to delete user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"})
return
}
c.Status(http.StatusNoContent)
}
// GetCurrentUser returns the current user's information
func (h *Handler) GetCurrentUser(c *gin.Context) {
// 从上下文中获取用户ID由认证中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// 获取用户信息
user, err := h.service.GetUser(c.Request.Context(), userID.(int))
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user information"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// UpdateCurrentUser updates the current user's information
func (h *Handler) UpdateCurrentUser(c *gin.Context) {
// 从上下文中获取用户ID由认证中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
var req UpdateCurrentUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 如果要更新密码,需要验证当前密码
if req.NewPassword != "" {
if req.CurrentPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is required to update password"})
return
}
// 验证当前密码
if err := h.service.VerifyPassword(c.Request.Context(), userID.(int), req.CurrentPassword); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current password"})
return
}
}
// 更新用户信息
user, err := h.service.UpdateUser(c.Request.Context(), userID.(int), &types.UpdateUserInput{
Email: req.Email,
Password: req.NewPassword,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user information"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -78,26 +79,41 @@ func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
// 配置文件日志 // 配置文件日志
if config.EnableFile { if config.EnableFile {
// 确保日志目录存在 // 验证文件路径
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil { if config.FilePath == "" {
return nil, fmt.Errorf("file path cannot be empty")
}
// 验证路径是否包含无效字符
if strings.ContainsAny(config.FilePath, "\x00") {
return nil, fmt.Errorf("file path contains invalid characters")
}
dir := filepath.Dir(config.FilePath)
// 检查目录是否存在或是否可以创建
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err) return nil, fmt.Errorf("failed to create log directory: %w", err)
} }
// 配置日志轮转 // 尝试打开或创建文件,验证路径是否有效且有写入权限
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, fmt.Errorf("failed to open or create log file: %w", err)
}
file.Close()
// 配置文件日志
logWriter = &lumberjack.Logger{ logWriter = &lumberjack.Logger{
Filename: config.FilePath, Filename: config.FilePath,
MaxSize: config.Rotation.MaxSize, // MB MaxSize: config.Rotation.MaxSize, // MB
MaxAge: config.Rotation.MaxAge, // days MaxBackups: config.Rotation.MaxBackups, // 文件个数
MaxBackups: config.Rotation.MaxBackups, // files MaxAge: config.Rotation.MaxAge, // 天数
Compress: config.Rotation.Compress, // 是否压缩 Compress: config.Rotation.Compress, // 是否压缩
LocalTime: config.Rotation.LocalTime, // 使用本地时间 LocalTime: config.Rotation.LocalTime, // 使用本地时间
} }
logger := zerolog.New(logWriter). logger := zerolog.New(logWriter).With().Timestamp().Logger()
With().
Timestamp().
Logger()
fileLogger = &logger fileLogger = &logger
} }

View file

@ -219,7 +219,7 @@ func TestAccessLogInvalidConfig(t *testing.T) {
name: "Invalid file path", name: "Invalid file path",
config: &types.AccessLogConfig{ config: &types.AccessLogConfig{
EnableFile: true, EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径 FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
}, },
expectedError: true, expectedError: true,
}, },

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/role" "tss-rocks-be/ent/role"
) )
@ -38,7 +39,15 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
permissionMap := make(map[string]*ent.Permission) permissionMap := make(map[string]*ent.Permission)
for resource, actions := range DefaultPermissions { for resource, actions := range DefaultPermissions {
for _, action := range actions { for _, action := range actions {
permission, err := client.Permission.Create(). key := fmt.Sprintf("%s:%s", resource, action)
permission, err := client.Permission.Query().
Where(
permission.ResourceEQ(resource),
permission.ActionEQ(action),
).
Only(ctx)
if ent.IsNotFound(err) {
permission, err = client.Permission.Create().
SetResource(resource). SetResource(resource).
SetAction(action). SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)). SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
@ -46,13 +55,19 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
if err != nil { if err != nil {
return fmt.Errorf("failed creating permission: %w", err) return fmt.Errorf("failed creating permission: %w", err)
} }
key := fmt.Sprintf("%s:%s", resource, action) } else if err != nil {
return fmt.Errorf("failed querying permission: %w", err)
}
permissionMap[key] = permission permissionMap[key] = permission
} }
} }
// Create roles with permissions // Create roles with permissions
for roleName, permissions := range DefaultRoles { for roleName, permissions := range DefaultRoles {
role, err := client.Role.Query().
Where(role.NameEQ(roleName)).
Only(ctx)
if ent.IsNotFound(err) {
roleCreate := client.Role.Create(). roleCreate := client.Role.Create().
SetName(roleName). SetName(roleName).
SetDescription(fmt.Sprintf("Role for %s users", roleName)) SetDescription(fmt.Sprintf("Role for %s users", roleName))
@ -70,6 +85,24 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
if _, err := roleCreate.Save(ctx); err != nil { if _, err := roleCreate.Save(ctx); err != nil {
return fmt.Errorf("failed creating role %s: %w", roleName, err) return fmt.Errorf("failed creating role %s: %w", roleName, err)
} }
} else if err != nil {
return fmt.Errorf("failed querying role: %w", err)
} else {
// Update existing role's permissions
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
err = client.Role.UpdateOne(role).
AddPermissions(permission).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed updating role %s permissions: %w", roleName, err)
}
}
}
}
}
} }
return nil return nil

View file

@ -64,6 +64,7 @@ func TestAssignRoleToUser(t *testing.T) {
// Create a test user // Create a test user
user, err := client.User.Create(). user, err := client.User.Create().
SetEmail("test@example.com"). SetEmail("test@example.com").
SetUsername("testuser").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy"). SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx) Save(ctx)
if err != nil { if err != nil {

View file

@ -3,25 +3,21 @@ package server
import ( import (
"context" "context"
"entgo.io/ent/dialect/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
) )
// NewEntClient creates a new ent client // NewEntClient creates a new ent client
func NewEntClient(cfg *config.Config) *ent.Client { func NewEntClient(cfg *config.Config) *ent.Client {
// TODO: Implement database connection based on config // 使用配置文件中的数据库设置
// For now, we'll use SQLite for development client, err := ent.Open(cfg.Database.Driver, cfg.Database.DSN)
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to connect to database") log.Fatal().Err(err).Msg("Failed to connect to database")
} }
// Create ent client
client := ent.NewClient(ent.Driver(db))
// Run the auto migration tool // Run the auto migration tool
if err := client.Schema.Create(context.Background()); err != nil { if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources") log.Fatal().Err(err).Msg("Failed to create schema resources")

View file

@ -24,7 +24,6 @@ import (
"tss-rocks-be/ent/role" "tss-rocks-be/ent/role"
"tss-rocks-be/ent/user" "tss-rocks-be/ent/user"
"tss-rocks-be/internal/storage" "tss-rocks-be/internal/storage"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -54,59 +53,74 @@ func NewService(client *ent.Client, storage storage.Storage) Service {
} }
// User operations // User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) { func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// Hash the password // 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(email) {
return nil, fmt.Errorf("invalid email format")
}
// 验证密码长度
if len(password) < 8 {
return nil, fmt.Errorf("password must be at least 8 characters")
}
// 检查用户名是否已存在
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("error checking username: %v", err)
}
if exists {
return nil, fmt.Errorf("username '%s' already exists", username)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err) return nil, fmt.Errorf("error hashing password: %v", err)
} }
// Add the user role by default // 创建用户
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx) u, err := s.client.User.Create().
if err != nil { SetUsername(username).
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
SetEmail(email). SetEmail(email).
SetPasswordHash(string(hashedPassword)). SetPasswordHash(string(hashedPassword)).
AddRoles(userRole) SetStatus("active").
Save(ctx)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err) return nil, fmt.Errorf("error creating user: %v", err)
} }
return user, nil // 分配角色
err = s.AssignRole(ctx, u.ID, roleStr)
if err != nil {
return nil, fmt.Errorf("error assigning role: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with username '%s' not found", username)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
} }
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) { func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query(). u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
Where(user.EmailEQ(email)).
Only(ctx)
if err != nil { if err != nil {
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email) return nil, fmt.Errorf("user with email '%s' not found", email)
} }
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("error getting user: %v", err)
} }
return user, nil return u, nil
} }
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool { func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {

View file

@ -104,51 +104,49 @@ func newMockMultipartFile(data []byte) *mockMultipartFile {
func (s *ServiceImplTestSuite) TestCreateUser() { func (s *ServiceImplTestSuite) TestCreateUser() {
testCases := []struct { testCases := []struct {
name string name string
username string
email string email string
password string password string
role string role string
wantError bool wantErr bool
}{ }{
{ {
name: "Valid user creation", name: "有效的用户",
username: "testuser",
email: "test@example.com", email: "test@example.com",
password: "password123", password: "password123",
role: "admin",
wantError: false,
},
{
name: "Empty email",
email: "",
password: "password123",
role: "user", role: "user",
wantError: true, wantErr: false,
}, },
{ {
name: "Empty password", name: "无效的邮箱",
email: "test@example.com", username: "testuser2",
email: "invalid-email",
password: "password123",
role: "user",
wantErr: true,
},
{
name: "空密码",
username: "testuser3",
email: "test3@example.com",
password: "", password: "",
role: "user", role: "user",
wantError: true, wantErr: true,
},
{
name: "Invalid role",
email: "test@example.com",
password: "password123",
role: "invalid_role",
wantError: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role) user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
if tc.wantError { if tc.wantErr {
assert.Error(s.T(), err) s.Error(err)
assert.Nil(s.T(), user) s.Nil(user)
} else { } else {
assert.NoError(s.T(), err) s.NoError(err)
assert.NotNil(s.T(), user) s.NotNil(user)
assert.Equal(s.T(), tc.email, user.Email) s.Equal(tc.email, user.Email)
s.Equal(tc.username, user.Username)
} }
}) })
} }
@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() {
password := "password123" password := "password123"
role := "user" role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role) user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
password := "password123" password := "password123"
role := "user" role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role) user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
func (s *ServiceImplTestSuite) TestRBAC() { func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("AssignRole", func() { s.Run("AssignRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.AssignRole(s.ctx, user.ID, "user") err = s.svc.AssignRole(s.ctx, user.ID, "user")
@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("RemoveRole", func() { s.Run("RemoveRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.RemoveRole(s.ctx, user.ID, "admin") err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("HasPermission", func() { s.Run("HasPermission", func() {
s.Run("Admin can create users", func() { s.Run("Admin can create users", func() {
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Editor cannot create users", func() { s.Run("Editor cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor") user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User cannot create users", func() { s.Run("User cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Editor can create posts", func() { s.Run("Editor can create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor") user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User can read posts", func() { s.Run("User can read posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User cannot create posts", func() { s.Run("User cannot create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Invalid permission format", func() { s.Run("Invalid permission format", func() {
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission") _, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
func (s *ServiceImplTestSuite) TestCategory() { func (s *ServiceImplTestSuite) TestCategory() {
// Create a test user with admin role for testing // Create a test user with admin role for testing
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin") adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), adminUser) require.NotNil(s.T(), adminUser)
@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() {
ctx := context.Background() ctx := context.Background()
// 创建测试用户,默认会有 "user" 角色 // 创建测试用户,默认会有 "user" 角色
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user") user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
s.Require().NoError(err) s.Require().NoError(err)
// 测试新用户有默认的 "user" 角色 // 测试新用户有默认的 "user" 角色
@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() {
func (s *ServiceImplTestSuite) TestMedia() { func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Upload Media", func() { s.Run("Upload Media", func() {
// Create a user first // Create a user first
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "") user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Delete Media - Unauthorized", func() { s.Run("Delete Media - Unauthorized", func() {
// Create a user // Create a user
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "") user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
// Mock file content // Mock file content
@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
// Try to delete with different user // Try to delete with different user
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "") anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID) err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)

View file

@ -9,14 +9,22 @@ import (
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/storage" "tss-rocks-be/internal/storage"
"tss-rocks-be/internal/types"
) )
// Service interface defines all business logic operations // Service interface defines all business logic operations
type Service interface { type Service interface {
// User operations // User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error) CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error)
GetUser(ctx context.Context, id int) (*ent.User, error)
GetUserByUsername(ctx context.Context, username string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error) GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool ValidatePassword(ctx context.Context, user *ent.User, password string) bool
VerifyPassword(ctx context.Context, userID int, password string) error
UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error)
DeleteUser(ctx context.Context, userID int) error
ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error)
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
// Category operations // Category operations
CreateCategory(ctx context.Context) (*ent.Category, error) CreateCategory(ctx context.Context) (*ent.Category, error)
@ -51,9 +59,8 @@ type Service interface {
DeleteMedia(ctx context.Context, id int, userID int) error DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations // RBAC operations
InitializeRBAC(ctx context.Context) error
AssignRole(ctx context.Context, userID int, role string) error AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error) HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
} }

View file

@ -0,0 +1,154 @@
package service
import (
"context"
"fmt"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/types"
)
// GetUser gets a user by ID
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
user, err := s.client.User.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// VerifyPassword verifies the user's current password
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
user, err := s.GetUser(ctx, userID)
if err != nil {
return err
}
if !s.ValidatePassword(ctx, user, password) {
return fmt.Errorf("invalid password")
}
return nil
}
// UpdateUser updates user information
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
// Start building the update
update := s.client.User.UpdateOneID(userID)
// Update email if provided
if input.Email != "" {
update.SetEmail(input.Email)
}
// Update display name if provided
if input.DisplayName != "" {
update.SetDisplayName(input.DisplayName)
}
// Update password if provided
if input.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
log.Error().Err(err).Msg("Failed to hash password")
return nil, fmt.Errorf("failed to hash password: %w", err)
}
update.SetPasswordHash(string(hashedPassword))
}
// Update status if provided
if input.Status != "" {
update.SetStatus(user.Status(input.Status))
}
// Execute the update
user, err := update.Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
return nil, fmt.Errorf("failed to update user: %w", err)
}
// Update role if provided
if input.Role != "" {
// Clear existing roles
_, err = user.Update().ClearRoles().Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to clear user roles")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
// Add new role
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to find role")
return nil, fmt.Errorf("failed to find role: %w", err)
}
_, err = user.Update().AddRoles(role).Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to add user role")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
}
return user, nil
}
// DeleteUser deletes a user by ID
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
err := s.client.User.DeleteOneID(userID).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// ListUsers lists users with filters and pagination
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
query := s.client.User.Query()
// Apply filters
if params.Role != "" {
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
}
if params.Status != "" {
query.Where(user.StatusEQ(user.Status(params.Status)))
}
if params.Email != "" {
query.Where(user.EmailContains(params.Email))
}
// Apply pagination
if params.PerPage > 0 {
query.Limit(params.PerPage)
if params.Page > 0 {
query.Offset((params.Page - 1) * params.PerPage)
}
}
// Apply sorting
if params.Sort != "" {
switch params.Sort {
case "email_asc":
query.Order(ent.Asc(user.FieldEmail))
case "email_desc":
query.Order(ent.Desc(user.FieldEmail))
case "created_at_asc":
query.Order(ent.Asc(user.FieldCreatedAt))
case "created_at_desc":
query.Order(ent.Desc(user.FieldCreatedAt))
}
}
// Execute query
users, err := query.All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
return users, nil
}

View file

@ -0,0 +1,28 @@
package types
// UpdateUserInput defines the input for updating a user
type UpdateUserInput struct {
Email string
Password string
Role string
Status string
DisplayName string
}
// ListUsersParams defines the parameters for listing users
type ListUsersParams struct {
Page int
PerPage int
Sort string
Role string
Status string
Email string
}
// CreateUserInput defines the input for creating a user
type CreateUserInput struct {
Email string
Password string
Role string
DisplayName string
}

29
backend/main.go Normal file
View file

@ -0,0 +1,29 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"tss-rocks-be/cmd"
)
var rootCmd = &cobra.Command{
Use: "app",
Short: "TSS backend application",
Long: `TSS (The Starset Society) backend application.
It provides both API endpoints and admin tools for content management.`,
}
func init() {
// 添加子命令
rootCmd.AddCommand(cmd.GetUserCmd())
rootCmd.AddCommand(cmd.GetServerCmd())
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}