Compare commits

..

8 commits

Author SHA1 Message Date
CDN
1526c27b49
[feature/frontend] admin panel (wip)
Some checks failed
Build Backend / Build Docker Image (push) Failing after 2m38s
Test Backend / test (push) Successful in 4m19s
2025-02-21 07:55:26 +08:00
CDN
34ebb05808
[bugfix/backend] use int64 in userid 2025-02-21 06:01:26 +08:00
CDN
287895347b
[bugfix/backend] use roles in middleware 2025-02-21 05:58:42 +08:00
CDN
e5fc8691bf
[feature/backend] implement /auth/logout handling + overall enhancement 2025-02-21 05:44:18 +08:00
CDN
d8d8e4b0d7
[bugfix/backend] /users/me handling 2025-02-21 04:55:42 +08:00
CDN
823bedd1fa
[feature/backend] implement /users auth middleware 2025-02-21 04:42:53 +08:00
CDN
a853374009
[feature/backend] registration control 2025-02-21 04:38:47 +08:00
CDN
86ab334bc9
[feature/backend] implement /users handler + switch to username + add display name + user management cli 2025-02-21 04:30:07 +08:00
64 changed files with 3686 additions and 680 deletions

View file

@ -60,12 +60,20 @@ User:
type: object type: object
required: required:
- id - id
- email - username
- role - role
- status - status
properties: properties:
id: id:
type: integer type: integer
username:
type: string
minLength: 3
maxLength: 32
display_name:
type: string
maxLength: 64
description: 用户显示名称
email: email:
type: string type: string
format: email format: email
@ -74,11 +82,13 @@ User:
enum: enum:
- admin - admin
- editor - editor
- contributor
status: status:
type: string type: string
enum: enum:
- active - active
- inactive - inactive
- banned
created_at: created_at:
type: string type: string
format: date-time format: date-time

View file

@ -70,6 +70,8 @@ components:
Daily: Daily:
$ref: './components/schemas.yaml#/Daily' $ref: './components/schemas.yaml#/Daily'
paths: paths:
/auth/register:
$ref: './paths/auth.yaml#/register'
/auth/login: /auth/login:
$ref: './paths/auth.yaml#/login' $ref: './paths/auth.yaml#/login'
/auth/logout: /auth/logout:

View file

@ -1,3 +1,63 @@
register:
post:
tags:
- auth
summary: 用户注册
operationId: register
security: [] # 注册接口不需要认证
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- username
- email
- password
- role
properties:
username:
type: string
minLength: 3
maxLength: 32
email:
type: string
format: email
password:
type: string
format: password
minLength: 8
role:
type: string
enum:
- admin
- editor
- contributor
responses:
'200':
description: 注册成功
content:
application/json:
schema:
type: object
required:
- token
- user
properties:
token:
type: string
user:
$ref: '../components/schemas.yaml#/User'
'400':
description: 用户名已存在
content:
application/json:
schema:
$ref: '../components/schemas.yaml#/Error'
'422':
$ref: '../components/responses.yaml#/ValidationError'
login: login:
post: post:
tags: tags:
@ -12,12 +72,13 @@ login:
schema: schema:
type: object type: object
required: required:
- email - username
- password - password
properties: properties:
email: username:
type: string type: string
format: email minLength: 3
maxLength: 32
password: password:
type: string type: string
format: password format: password

3
backend/.gitignore vendored
View file

@ -12,6 +12,9 @@ vendor/
# Build output # Build output
*.exe *.exe
# Config
config/config.yaml
# Database files # Database files
*.db *.db
*.db-journal *.db-journal

View file

@ -5,7 +5,7 @@ RUN apk add --no-cache gcc musl-dev libwebp-dev
COPY go.mod go.sum ./ COPY go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
RUN go build -o tss-rocks-be ./cmd/server RUN go build -o tss-rocks-be
FROM alpine:latest FROM alpine:latest
@ -13,9 +13,13 @@ RUN apk add --no-cache libwebp
RUN adduser -u 1000 -D tss-rocks RUN adduser -u 1000 -D tss-rocks
USER tss-rocks USER tss-rocks
WORKDIR /app WORKDIR /app
# 复制二进制文件和配置
COPY --from=builder /app/tss-rocks-be . COPY --from=builder /app/tss-rocks-be .
COPY --from=builder /app/config/config.yaml ./config/
EXPOSE 8080 EXPOSE 8080
ENV GIN_MODE=release ENV GIN_MODE=release
CMD ["./tss-rocks-be"] # 启动服务器
CMD ["./tss-rocks-be", "server"]

89
backend/cmd/server.go Normal file
View file

@ -0,0 +1,89 @@
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func GetServerCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the server",
Long: `Start the server with the specified configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
// Load configuration
cfg, err := config.Load(configPath)
if err != nil {
return err
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
return nil
},
}
// Add flags
cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file")
return cmd
}

View file

@ -1,79 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func main() {
// Parse command line flags
configPath := flag.String("config", "config/config.yaml", "path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
}

View file

@ -1,127 +0,0 @@
package main
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
)
func TestConfigLoad(t *testing.T) {
// Create a temporary config file for testing
tmpConfig := `
database:
driver: sqlite3
dsn: ":memory:"
server:
port: 8080
host: localhost
storage:
type: local
local:
root_dir: "./testdata"
`
tmpFile, err := os.CreateTemp("", "config-*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString(tmpConfig)
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Test config loading
cfg, err := config.Load(tmpFile.Name())
require.NoError(t, err)
assert.Equal(t, "sqlite3", cfg.Database.Driver)
assert.Equal(t, ":memory:", cfg.Database.DSN)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, "localhost", cfg.Server.Host)
assert.Equal(t, "local", cfg.Storage.Type)
assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir)
}
func TestServerCreation(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 8080,
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
// Create ent client
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
// Test schema creation
err := client.Schema.Create(context.Background())
require.NoError(t, err)
// Test server creation
srv, err := server.New(cfg, client)
require.NoError(t, err)
require.NotNil(t, srv)
}
func TestServerStartAndShutdown(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 0, // Use random available port
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
err := client.Schema.Create(context.Background())
require.NoError(t, err)
srv, err := server.New(cfg, client)
require.NoError(t, err)
// Start server in goroutine
go func() {
err := srv.Start()
if err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Test graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
assert.NoError(t, err)
}

411
backend/cmd/user.go Normal file
View file

@ -0,0 +1,411 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"tss-rocks-be/ent/user"
"tss-rocks-be/ent/role"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
func GetUserCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `Commands for managing users, including create, delete, password, and role management.`,
}
configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file")
cmd.AddCommand(getCreateUserCmd(*configPath))
cmd.AddCommand(getPasswordCmd(*configPath))
cmd.AddCommand(getListUsersCmd(*configPath))
cmd.AddCommand(getDeleteUserCmd(*configPath))
cmd.AddCommand(getRoleCmd(*configPath))
return cmd
}
func getCreateUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "create",
Short: "Create a new user",
Long: `Create a new user with specified username, email and password.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
email, _ := cmd.Flags().GetString("email")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 检查用户名是否已存在
exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking username: %v", err)
}
if exists {
return fmt.Errorf("username '%s' already exists", username)
}
// 检查邮箱是否已存在
exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking email: %v", err)
}
if exists {
return fmt.Errorf("email '%s' already exists", email)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 创建用户
u, err := client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
SetStatus("active").
Save(cmd.Context())
if err != nil {
return fmt.Errorf("error creating user: %v", err)
}
fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username for the new user")
cmd.Flags().String("email", "", "email for the new user")
cmd.Flags().String("password", "", "password for the new user")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("email")
cmd.MarkFlagRequired("password")
return cmd
}
func getPasswordCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "password",
Short: "Change user password",
Long: `Change the password for an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 生成新的密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 更新密码
_, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context())
if err != nil {
return fmt.Errorf("error updating password: %v", err)
}
fmt.Printf("Successfully updated password for user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("password", "", "new password")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("password")
return cmd
}
func getListUsersCmd(configPath string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 获取所有用户
users, err := client.User.Query().All(cmd.Context())
if err != nil {
return fmt.Errorf("error querying users: %v", err)
}
// 创建表格写入器
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT")
// 输出用户信息
for _, u := range users {
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
u.Username,
u.Email,
u.Status,
u.CreatedAt.Format("2006-01-02 15:04:05"))
}
w.Flush()
return nil
},
}
}
func getDeleteUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Long: `Delete an existing user from the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
force, _ := cmd.Flags().GetBool("force")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查用户的角色
hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRoles && !force {
return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username)
}
// 删除用户
err = client.User.DeleteOne(u).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error deleting user: %v", err)
}
fmt.Printf("Successfully deleted user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user to delete")
cmd.Flags().Bool("force", false, "force delete even if user has roles")
// 设置必需的参数
cmd.MarkFlagRequired("username")
return cmd
}
func getRoleCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "role",
Short: "User role management commands",
Long: `Commands for managing user roles, including adding and removing roles.`,
}
cmd.AddCommand(getRoleAddCmd(configPath))
cmd.AddCommand(getRoleRemoveCmd(configPath))
return cmd
}
func getRoleAddCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "add",
Short: "Add a role to a user",
Long: `Add a role to an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否已存在
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRole {
return fmt.Errorf("user '%s' already has role '%s'", username, roleName)
}
// 获取要添加的角色
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 添加角色
err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error adding role: %v", err)
}
fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to add (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}
func getRoleRemoveCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "remove",
Short: "Remove a role from a user",
Long: `Remove a role from an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否存在
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 检查用户是否有该角色
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if !hasRole {
return fmt.Errorf("user '%s' does not have role '%s'", username, roleName)
}
// 移除角色
err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error removing role: %v", err)
}
fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}

View file

@ -1,91 +0,0 @@
database:
driver: sqlite3
dsn: "file:tss-rocks.db?_fk=1&cache=shared"
server:
port: 8080
host: localhost
jwt:
secret: "your-secret-key-here" # 在生产环境中应该使用环境变量
expiration: 24h # token 过期时间
logging:
level: debug # debug, info, warn, error
format: json # json, console
storage:
type: s3 # local or s3
local:
root_dir: "./storage/media"
s3:
region: "us-east-1"
bucket: "your-bucket-name"
access_key_id: "your-access-key-id"
secret_access_key: "your-secret-access-key"
endpoint: "" # Optional, for MinIO or other S3-compatible services
custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media)
proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting
upload:
max_size: 10 # 最大文件大小MB
allowed_types: # 允许的文件类型
- image/jpeg
- image/png
- image/gif
- image/webp
- image/svg+xml
- application/pdf
- application/msword
- application/vnd.openxmlformats-officedocument.wordprocessingml.document
- application/vnd.ms-excel
- application/vnd.openxmlformats-officedocument.spreadsheetml.sheet
- application/zip
- application/x-rar-compressed
- text/plain
- text/csv
allowed_extensions: # 允许的文件扩展名(小写)
- .jpg
- .jpeg
- .png
- .gif
- .webp
- .svg
- .pdf
- .doc
- .docx
- .xls
- .xlsx
- .zip
- .rar
- .txt
- .csv
rate_limit:
# IP限流配置
ip_rate: 50 # 每秒请求数
ip_burst: 25 # 突发请求数
# 路由限流配置
route_rates:
"/api/v1/auth/login":
rate: 5 # 每秒5个请求
burst: 10 # 突发10个请求
"/api/v1/auth/register":
rate: 2 # 每秒2个请求
burst: 5 # 突发5个请求
"/api/v1/media/upload":
rate: 10 # 每秒10个请求
burst: 20 # 突发20个请求
access_log:
enable_console: true # 启用控制台输出
enable_file: true # 启用文件日志
file_path: "./logs/access.log" # 日志文件路径
format: "json" # 日志格式 (json 或 text)
level: "info" # 日志级别
rotation:
max_size: 100 # 每个日志文件的最大大小MB
max_age: 30 # 保留旧日志文件的最大天数
max_backups: 10 # 保留的旧日志文件的最大数量
compress: true # 是否压缩旧日志文件
local_time: true # 使用本地时间作为轮转时间

View file

@ -0,0 +1,41 @@
database:
driver: sqlite3
# SQLite DSN 说明:
# - file:tss.db => 相对路径,在当前目录创建 tss.db
# - cache=shared => 启用共享缓存,提高性能
# - _fk=1 => 启用外键约束
# - mode=rwc => 如果数据库不存在则创建read-write-create
dsn: "file:tss.db?cache=shared&_fk=1&mode=rwc"
server:
port: 8080
host: localhost
jwt:
secret: your-jwt-secret-here # 在生产环境中应该使用环境变量
expiration: 24h
auth:
registration:
enabled: false # 是否允许注册
message: "Registration is currently disabled. Please contact administrator." # 禁用时的提示信息
storage:
driver: local
local:
root: storage
base_url: http://localhost:8080/storage
logging:
level: debug
format: console
rate_limit:
enabled: true
requests: 100
duration: 1m
access_log:
enabled: true
format: combined
output: stdout

View file

@ -373,7 +373,9 @@ var (
// UsersColumns holds the columns for the "users" table. // UsersColumns holds the columns for the "users" table.
UsersColumns = []*schema.Column{ UsersColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt, Increment: true}, {Name: "id", Type: field.TypeInt, Increment: true},
{Name: "email", Type: field.TypeString, Unique: true}, {Name: "username", Type: field.TypeString, Unique: true},
{Name: "display_name", Type: field.TypeString, Nullable: true, Size: 64},
{Name: "email", Type: field.TypeString},
{Name: "password_hash", Type: field.TypeString}, {Name: "password_hash", Type: field.TypeString},
{Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"}, {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"},
{Name: "created_at", Type: field.TypeTime}, {Name: "created_at", Type: field.TypeTime},

View file

@ -9321,6 +9321,8 @@ type UserMutation struct {
op Op op Op
typ string typ string
id *int id *int
username *string
display_name *string
email *string email *string
password_hash *string password_hash *string
status *user.Status status *user.Status
@ -9439,6 +9441,91 @@ func (m *UserMutation) IDs(ctx context.Context) ([]int, error) {
} }
} }
// SetUsername sets the "username" field.
func (m *UserMutation) SetUsername(s string) {
m.username = &s
}
// Username returns the value of the "username" field in the mutation.
func (m *UserMutation) Username() (r string, exists bool) {
v := m.username
if v == nil {
return
}
return *v, true
}
// OldUsername returns the old "username" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldUsername(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUsername is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUsername requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUsername: %w", err)
}
return oldValue.Username, nil
}
// ResetUsername resets all changes to the "username" field.
func (m *UserMutation) ResetUsername() {
m.username = nil
}
// SetDisplayName sets the "display_name" field.
func (m *UserMutation) SetDisplayName(s string) {
m.display_name = &s
}
// DisplayName returns the value of the "display_name" field in the mutation.
func (m *UserMutation) DisplayName() (r string, exists bool) {
v := m.display_name
if v == nil {
return
}
return *v, true
}
// OldDisplayName returns the old "display_name" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldDisplayName(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldDisplayName is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldDisplayName requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldDisplayName: %w", err)
}
return oldValue.DisplayName, nil
}
// ClearDisplayName clears the value of the "display_name" field.
func (m *UserMutation) ClearDisplayName() {
m.display_name = nil
m.clearedFields[user.FieldDisplayName] = struct{}{}
}
// DisplayNameCleared returns if the "display_name" field was cleared in this mutation.
func (m *UserMutation) DisplayNameCleared() bool {
_, ok := m.clearedFields[user.FieldDisplayName]
return ok
}
// ResetDisplayName resets all changes to the "display_name" field.
func (m *UserMutation) ResetDisplayName() {
m.display_name = nil
delete(m.clearedFields, user.FieldDisplayName)
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (m *UserMutation) SetEmail(s string) { func (m *UserMutation) SetEmail(s string) {
m.email = &s m.email = &s
@ -9815,7 +9902,13 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UserMutation) Fields() []string { func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 5) fields := make([]string, 0, 7)
if m.username != nil {
fields = append(fields, user.FieldUsername)
}
if m.display_name != nil {
fields = append(fields, user.FieldDisplayName)
}
if m.email != nil { if m.email != nil {
fields = append(fields, user.FieldEmail) fields = append(fields, user.FieldEmail)
} }
@ -9839,6 +9932,10 @@ func (m *UserMutation) Fields() []string {
// schema. // schema.
func (m *UserMutation) Field(name string) (ent.Value, bool) { func (m *UserMutation) Field(name string) (ent.Value, bool) {
switch name { switch name {
case user.FieldUsername:
return m.Username()
case user.FieldDisplayName:
return m.DisplayName()
case user.FieldEmail: case user.FieldEmail:
return m.Email() return m.Email()
case user.FieldPasswordHash: case user.FieldPasswordHash:
@ -9858,6 +9955,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
// database failed. // database failed.
func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name { switch name {
case user.FieldUsername:
return m.OldUsername(ctx)
case user.FieldDisplayName:
return m.OldDisplayName(ctx)
case user.FieldEmail: case user.FieldEmail:
return m.OldEmail(ctx) return m.OldEmail(ctx)
case user.FieldPasswordHash: case user.FieldPasswordHash:
@ -9877,6 +9978,20 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
// type. // type.
func (m *UserMutation) SetField(name string, value ent.Value) error { func (m *UserMutation) SetField(name string, value ent.Value) error {
switch name { switch name {
case user.FieldUsername:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUsername(v)
return nil
case user.FieldDisplayName:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetDisplayName(v)
return nil
case user.FieldEmail: case user.FieldEmail:
v, ok := value.(string) v, ok := value.(string)
if !ok { if !ok {
@ -9941,7 +10056,11 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this // ClearedFields returns all nullable fields that were cleared during this
// mutation. // mutation.
func (m *UserMutation) ClearedFields() []string { func (m *UserMutation) ClearedFields() []string {
return nil var fields []string
if m.FieldCleared(user.FieldDisplayName) {
fields = append(fields, user.FieldDisplayName)
}
return fields
} }
// FieldCleared returns a boolean indicating if a field with the given name was // FieldCleared returns a boolean indicating if a field with the given name was
@ -9954,6 +10073,11 @@ func (m *UserMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an // ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema. // error if the field is not defined in the schema.
func (m *UserMutation) ClearField(name string) error { func (m *UserMutation) ClearField(name string) error {
switch name {
case user.FieldDisplayName:
m.ClearDisplayName()
return nil
}
return fmt.Errorf("unknown User nullable field %s", name) return fmt.Errorf("unknown User nullable field %s", name)
} }
@ -9961,6 +10085,12 @@ func (m *UserMutation) ClearField(name string) error {
// It returns an error if the field is not defined in the schema. // It returns an error if the field is not defined in the schema.
func (m *UserMutation) ResetField(name string) error { func (m *UserMutation) ResetField(name string) error {
switch name { switch name {
case user.FieldUsername:
m.ResetUsername()
return nil
case user.FieldDisplayName:
m.ResetDisplayName()
return nil
case user.FieldEmail: case user.FieldEmail:
m.ResetEmail() m.ResetEmail()
return nil return nil

View file

@ -261,20 +261,28 @@ func init() {
role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time) role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time)
userFields := schema.User{}.Fields() userFields := schema.User{}.Fields()
_ = userFields _ = userFields
// userDescUsername is the schema descriptor for username field.
userDescUsername := userFields[0].Descriptor()
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescDisplayName is the schema descriptor for display_name field.
userDescDisplayName := userFields[1].Descriptor()
// user.DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
user.DisplayNameValidator = userDescDisplayName.Validators[0].(func(string) error)
// userDescEmail is the schema descriptor for email field. // userDescEmail is the schema descriptor for email field.
userDescEmail := userFields[0].Descriptor() userDescEmail := userFields[2].Descriptor()
// user.EmailValidator is a validator for the "email" field. It is called by the builders before save. // user.EmailValidator is a validator for the "email" field. It is called by the builders before save.
user.EmailValidator = userDescEmail.Validators[0].(func(string) error) user.EmailValidator = userDescEmail.Validators[0].(func(string) error)
// userDescPasswordHash is the schema descriptor for password_hash field. // userDescPasswordHash is the schema descriptor for password_hash field.
userDescPasswordHash := userFields[1].Descriptor() userDescPasswordHash := userFields[3].Descriptor()
// user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. // user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error) user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error)
// userDescCreatedAt is the schema descriptor for created_at field. // userDescCreatedAt is the schema descriptor for created_at field.
userDescCreatedAt := userFields[3].Descriptor() userDescCreatedAt := userFields[5].Descriptor()
// user.DefaultCreatedAt holds the default value on creation for the created_at field. // user.DefaultCreatedAt holds the default value on creation for the created_at field.
user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time)
// userDescUpdatedAt is the schema descriptor for updated_at field. // userDescUpdatedAt is the schema descriptor for updated_at field.
userDescUpdatedAt := userFields[4].Descriptor() userDescUpdatedAt := userFields[6].Descriptor()
// user.DefaultUpdatedAt holds the default value on creation for the updated_at field. // user.DefaultUpdatedAt holds the default value on creation for the updated_at field.
user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time) user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time)
// user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. // user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.

View file

@ -15,9 +15,14 @@ type User struct {
// Fields of the User. // Fields of the User.
func (User) Fields() []ent.Field { func (User) Fields() []ent.Field {
return []ent.Field{ return []ent.Field{
field.String("email"). field.String("username").
Unique(). Unique().
NotEmpty(), NotEmpty(),
field.String("display_name").
Optional().
MaxLen(64),
field.String("email").
NotEmpty(),
field.String("password_hash"). field.String("password_hash").
Sensitive(). Sensitive().
NotEmpty(), NotEmpty(),

View file

@ -17,6 +17,10 @@ type User struct {
config `json:"-"` config `json:"-"`
// ID of the ent. // ID of the ent.
ID int `json:"id,omitempty"` ID int `json:"id,omitempty"`
// Username holds the value of the "username" field.
Username string `json:"username,omitempty"`
// DisplayName holds the value of the "display_name" field.
DisplayName string `json:"display_name,omitempty"`
// Email holds the value of the "email" field. // Email holds the value of the "email" field.
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// PasswordHash holds the value of the "password_hash" field. // PasswordHash holds the value of the "password_hash" field.
@ -80,7 +84,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case user.FieldID: case user.FieldID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldStatus: case user.FieldUsername, user.FieldDisplayName, user.FieldEmail, user.FieldPasswordHash, user.FieldStatus:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt: case user.FieldCreatedAt, user.FieldUpdatedAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -105,6 +109,18 @@ func (u *User) assignValues(columns []string, values []any) error {
return fmt.Errorf("unexpected type %T for field id", value) return fmt.Errorf("unexpected type %T for field id", value)
} }
u.ID = int(value.Int64) u.ID = int(value.Int64)
case user.FieldUsername:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field username", values[i])
} else if value.Valid {
u.Username = value.String
}
case user.FieldDisplayName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field display_name", values[i])
} else if value.Valid {
u.DisplayName = value.String
}
case user.FieldEmail: case user.FieldEmail:
if value, ok := values[i].(*sql.NullString); !ok { if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field email", values[i]) return fmt.Errorf("unexpected type %T for field email", values[i])
@ -186,6 +202,12 @@ func (u *User) String() string {
var builder strings.Builder var builder strings.Builder
builder.WriteString("User(") builder.WriteString("User(")
builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) builder.WriteString(fmt.Sprintf("id=%v, ", u.ID))
builder.WriteString("username=")
builder.WriteString(u.Username)
builder.WriteString(", ")
builder.WriteString("display_name=")
builder.WriteString(u.DisplayName)
builder.WriteString(", ")
builder.WriteString("email=") builder.WriteString("email=")
builder.WriteString(u.Email) builder.WriteString(u.Email)
builder.WriteString(", ") builder.WriteString(", ")

View file

@ -15,6 +15,10 @@ const (
Label = "user" Label = "user"
// FieldID holds the string denoting the id field in the database. // FieldID holds the string denoting the id field in the database.
FieldID = "id" FieldID = "id"
// FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username"
// FieldDisplayName holds the string denoting the display_name field in the database.
FieldDisplayName = "display_name"
// FieldEmail holds the string denoting the email field in the database. // FieldEmail holds the string denoting the email field in the database.
FieldEmail = "email" FieldEmail = "email"
// FieldPasswordHash holds the string denoting the password_hash field in the database. // FieldPasswordHash holds the string denoting the password_hash field in the database.
@ -57,6 +61,8 @@ const (
// Columns holds all SQL columns for user fields. // Columns holds all SQL columns for user fields.
var Columns = []string{ var Columns = []string{
FieldID, FieldID,
FieldUsername,
FieldDisplayName,
FieldEmail, FieldEmail,
FieldPasswordHash, FieldPasswordHash,
FieldStatus, FieldStatus,
@ -81,6 +87,10 @@ func ValidColumn(column string) bool {
} }
var ( var (
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error
// DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save.
DisplayNameValidator func(string) error
// EmailValidator is a validator for the "email" field. It is called by the builders before save. // EmailValidator is a validator for the "email" field. It is called by the builders before save.
EmailValidator func(string) error EmailValidator func(string) error
// PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. // PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save.
@ -128,6 +138,16 @@ func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc() return sql.OrderByField(FieldID, opts...).ToFunc()
} }
// ByUsername orders the results by the username field.
func ByUsername(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUsername, opts...).ToFunc()
}
// ByDisplayName orders the results by the display_name field.
func ByDisplayName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDisplayName, opts...).ToFunc()
}
// ByEmail orders the results by the email field. // ByEmail orders the results by the email field.
func ByEmail(opts ...sql.OrderTermOption) OrderOption { func ByEmail(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEmail, opts...).ToFunc() return sql.OrderByField(FieldEmail, opts...).ToFunc()

View file

@ -55,6 +55,16 @@ func IDLTE(id int) predicate.User {
return predicate.User(sql.FieldLTE(FieldID, id)) return predicate.User(sql.FieldLTE(FieldID, id))
} }
// Username applies equality check predicate on the "username" field. It's identical to UsernameEQ.
func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// DisplayName applies equality check predicate on the "display_name" field. It's identical to DisplayNameEQ.
func DisplayName(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// Email applies equality check predicate on the "email" field. It's identical to EmailEQ. // Email applies equality check predicate on the "email" field. It's identical to EmailEQ.
func Email(v string) predicate.User { func Email(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v)) return predicate.User(sql.FieldEQ(FieldEmail, v))
@ -75,6 +85,146 @@ func UpdatedAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) return predicate.User(sql.FieldEQ(FieldUpdatedAt, v))
} }
// UsernameEQ applies the EQ predicate on the "username" field.
func UsernameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// UsernameNEQ applies the NEQ predicate on the "username" field.
func UsernameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldUsername, v))
}
// UsernameIn applies the In predicate on the "username" field.
func UsernameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldUsername, vs...))
}
// UsernameNotIn applies the NotIn predicate on the "username" field.
func UsernameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldUsername, vs...))
}
// UsernameGT applies the GT predicate on the "username" field.
func UsernameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldUsername, v))
}
// UsernameGTE applies the GTE predicate on the "username" field.
func UsernameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldUsername, v))
}
// UsernameLT applies the LT predicate on the "username" field.
func UsernameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldUsername, v))
}
// UsernameLTE applies the LTE predicate on the "username" field.
func UsernameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldUsername, v))
}
// UsernameContains applies the Contains predicate on the "username" field.
func UsernameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldUsername, v))
}
// UsernameHasPrefix applies the HasPrefix predicate on the "username" field.
func UsernameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldUsername, v))
}
// UsernameHasSuffix applies the HasSuffix predicate on the "username" field.
func UsernameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldUsername, v))
}
// UsernameEqualFold applies the EqualFold predicate on the "username" field.
func UsernameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldUsername, v))
}
// UsernameContainsFold applies the ContainsFold predicate on the "username" field.
func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
}
// DisplayNameEQ applies the EQ predicate on the "display_name" field.
func DisplayNameEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldDisplayName, v))
}
// DisplayNameNEQ applies the NEQ predicate on the "display_name" field.
func DisplayNameNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldDisplayName, v))
}
// DisplayNameIn applies the In predicate on the "display_name" field.
func DisplayNameIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldDisplayName, vs...))
}
// DisplayNameNotIn applies the NotIn predicate on the "display_name" field.
func DisplayNameNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldDisplayName, vs...))
}
// DisplayNameGT applies the GT predicate on the "display_name" field.
func DisplayNameGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldDisplayName, v))
}
// DisplayNameGTE applies the GTE predicate on the "display_name" field.
func DisplayNameGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldDisplayName, v))
}
// DisplayNameLT applies the LT predicate on the "display_name" field.
func DisplayNameLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldDisplayName, v))
}
// DisplayNameLTE applies the LTE predicate on the "display_name" field.
func DisplayNameLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldDisplayName, v))
}
// DisplayNameContains applies the Contains predicate on the "display_name" field.
func DisplayNameContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldDisplayName, v))
}
// DisplayNameHasPrefix applies the HasPrefix predicate on the "display_name" field.
func DisplayNameHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldDisplayName, v))
}
// DisplayNameHasSuffix applies the HasSuffix predicate on the "display_name" field.
func DisplayNameHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldDisplayName, v))
}
// DisplayNameIsNil applies the IsNil predicate on the "display_name" field.
func DisplayNameIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldDisplayName))
}
// DisplayNameNotNil applies the NotNil predicate on the "display_name" field.
func DisplayNameNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldDisplayName))
}
// DisplayNameEqualFold applies the EqualFold predicate on the "display_name" field.
func DisplayNameEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldDisplayName, v))
}
// DisplayNameContainsFold applies the ContainsFold predicate on the "display_name" field.
func DisplayNameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldDisplayName, v))
}
// EmailEQ applies the EQ predicate on the "email" field. // EmailEQ applies the EQ predicate on the "email" field.
func EmailEQ(v string) predicate.User { func EmailEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v)) return predicate.User(sql.FieldEQ(FieldEmail, v))

View file

@ -23,6 +23,26 @@ type UserCreate struct {
hooks []Hook hooks []Hook
} }
// SetUsername sets the "username" field.
func (uc *UserCreate) SetUsername(s string) *UserCreate {
uc.mutation.SetUsername(s)
return uc
}
// SetDisplayName sets the "display_name" field.
func (uc *UserCreate) SetDisplayName(s string) *UserCreate {
uc.mutation.SetDisplayName(s)
return uc
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uc *UserCreate) SetNillableDisplayName(s *string) *UserCreate {
if s != nil {
uc.SetDisplayName(*s)
}
return uc
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uc *UserCreate) SetEmail(s string) *UserCreate { func (uc *UserCreate) SetEmail(s string) *UserCreate {
uc.mutation.SetEmail(s) uc.mutation.SetEmail(s)
@ -173,6 +193,19 @@ func (uc *UserCreate) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uc *UserCreate) check() error { func (uc *UserCreate) check() error {
if _, ok := uc.mutation.Username(); !ok {
return &ValidationError{Name: "username", err: errors.New(`ent: missing required field "User.username"`)}
}
if v, ok := uc.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uc.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if _, ok := uc.mutation.Email(); !ok { if _, ok := uc.mutation.Email(); !ok {
return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)} return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)}
} }
@ -229,6 +262,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_node = &User{config: uc.config} _node = &User{config: uc.config}
_spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) _spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt))
) )
if value, ok := uc.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value
}
if value, ok := uc.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
_node.DisplayName = value
}
if value, ok := uc.mutation.Email(); ok { if value, ok := uc.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
_node.Email = value _node.Email = value

View file

@ -371,12 +371,12 @@ func (uq *UserQuery) WithMedia(opts ...func(*MediaQuery)) *UserQuery {
// Example: // Example:
// //
// var v []struct { // var v []struct {
// Email string `json:"email,omitempty"` // Username string `json:"username,omitempty"`
// Count int `json:"count,omitempty"` // Count int `json:"count,omitempty"`
// } // }
// //
// client.User.Query(). // client.User.Query().
// GroupBy(user.FieldEmail). // GroupBy(user.FieldUsername).
// Aggregate(ent.Count()). // Aggregate(ent.Count()).
// Scan(ctx, &v) // Scan(ctx, &v)
func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
@ -394,11 +394,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy {
// Example: // Example:
// //
// var v []struct { // var v []struct {
// Email string `json:"email,omitempty"` // Username string `json:"username,omitempty"`
// } // }
// //
// client.User.Query(). // client.User.Query().
// Select(user.FieldEmail). // Select(user.FieldUsername).
// Scan(ctx, &v) // Scan(ctx, &v)
func (uq *UserQuery) Select(fields ...string) *UserSelect { func (uq *UserQuery) Select(fields ...string) *UserSelect {
uq.ctx.Fields = append(uq.ctx.Fields, fields...) uq.ctx.Fields = append(uq.ctx.Fields, fields...)

View file

@ -31,6 +31,40 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate {
return uu return uu
} }
// SetUsername sets the "username" field.
func (uu *UserUpdate) SetUsername(s string) *UserUpdate {
uu.mutation.SetUsername(s)
return uu
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uu *UserUpdate) SetNillableUsername(s *string) *UserUpdate {
if s != nil {
uu.SetUsername(*s)
}
return uu
}
// SetDisplayName sets the "display_name" field.
func (uu *UserUpdate) SetDisplayName(s string) *UserUpdate {
uu.mutation.SetDisplayName(s)
return uu
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uu *UserUpdate) SetNillableDisplayName(s *string) *UserUpdate {
if s != nil {
uu.SetDisplayName(*s)
}
return uu
}
// ClearDisplayName clears the value of the "display_name" field.
func (uu *UserUpdate) ClearDisplayName() *UserUpdate {
uu.mutation.ClearDisplayName()
return uu
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uu *UserUpdate) SetEmail(s string) *UserUpdate { func (uu *UserUpdate) SetEmail(s string) *UserUpdate {
uu.mutation.SetEmail(s) uu.mutation.SetEmail(s)
@ -244,6 +278,16 @@ func (uu *UserUpdate) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uu *UserUpdate) check() error { func (uu *UserUpdate) check() error {
if v, ok := uu.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uu.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uu.mutation.Email(); ok { if v, ok := uu.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil { if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -274,6 +318,15 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
} }
} }
} }
if value, ok := uu.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uu.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uu.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uu.mutation.Email(); ok { if value, ok := uu.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
} }
@ -444,6 +497,40 @@ type UserUpdateOne struct {
mutation *UserMutation mutation *UserMutation
} }
// SetUsername sets the "username" field.
func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne {
uuo.mutation.SetUsername(s)
return uuo
}
// SetNillableUsername sets the "username" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableUsername(s *string) *UserUpdateOne {
if s != nil {
uuo.SetUsername(*s)
}
return uuo
}
// SetDisplayName sets the "display_name" field.
func (uuo *UserUpdateOne) SetDisplayName(s string) *UserUpdateOne {
uuo.mutation.SetDisplayName(s)
return uuo
}
// SetNillableDisplayName sets the "display_name" field if the given value is not nil.
func (uuo *UserUpdateOne) SetNillableDisplayName(s *string) *UserUpdateOne {
if s != nil {
uuo.SetDisplayName(*s)
}
return uuo
}
// ClearDisplayName clears the value of the "display_name" field.
func (uuo *UserUpdateOne) ClearDisplayName() *UserUpdateOne {
uuo.mutation.ClearDisplayName()
return uuo
}
// SetEmail sets the "email" field. // SetEmail sets the "email" field.
func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne { func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne {
uuo.mutation.SetEmail(s) uuo.mutation.SetEmail(s)
@ -670,6 +757,16 @@ func (uuo *UserUpdateOne) defaults() {
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
func (uuo *UserUpdateOne) check() error { func (uuo *UserUpdateOne) check() error {
if v, ok := uuo.mutation.Username(); ok {
if err := user.UsernameValidator(v); err != nil {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uuo.mutation.DisplayName(); ok {
if err := user.DisplayNameValidator(v); err != nil {
return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)}
}
}
if v, ok := uuo.mutation.Email(); ok { if v, ok := uuo.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil { if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
@ -717,6 +814,15 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
} }
} }
} }
if value, ok := uuo.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uuo.mutation.DisplayName(); ok {
_spec.SetField(user.FieldDisplayName, field.TypeString, value)
}
if uuo.mutation.DisplayNameCleared() {
_spec.ClearField(user.FieldDisplayName, field.TypeString)
}
if value, ok := uuo.mutation.Email(); ok { if value, ok := uuo.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value) _spec.SetField(user.FieldEmail, field.TypeString, value)
} }

View file

@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.24 github.com/mattn/go-sqlite3 v1.14.24
github.com/rs/zerolog v1.33.0 github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
go.uber.org/mock v0.5.0 go.uber.org/mock v0.5.0
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
@ -55,6 +56,7 @@ require (
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.5 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/hashicorp/hcl/v2 v2.23.0 // indirect github.com/hashicorp/hcl/v2 v2.23.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
@ -65,6 +67,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect

View file

@ -59,6 +59,7 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -130,6 +131,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View file

@ -11,6 +11,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Server ServerConfig `yaml:"server"` Server ServerConfig `yaml:"server"`
JWT JWTConfig `yaml:"jwt"` JWT JWTConfig `yaml:"jwt"`
Auth AuthConfig `yaml:"auth"`
Storage StorageConfig `yaml:"storage"` Storage StorageConfig `yaml:"storage"`
Logging LoggingConfig `yaml:"logging"` Logging LoggingConfig `yaml:"logging"`
RateLimit types.RateLimitConfig `yaml:"rate_limit"` RateLimit types.RateLimitConfig `yaml:"rate_limit"`
@ -32,6 +33,13 @@ type JWTConfig struct {
Expiration string `yaml:"expiration"` Expiration string `yaml:"expiration"`
} }
type AuthConfig struct {
Registration struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
} `yaml:"registration"`
}
type LoggingConfig struct { type LoggingConfig struct {
Level string `yaml:"level"` Level string `yaml:"level"`
Format string `yaml:"format"` Format string `yaml:"format"`

View file

@ -1,119 +1,253 @@
package handler package handler
import ( import (
"net/http" "net/http"
"time" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
) )
type RegisterRequest struct { type RegisterRequest struct {
Email string `json:"email" binding:"required,email"` Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required,min=8"` Email string `json:"email" binding:"required,email"`
Role string `json:"role" binding:"required,oneof=admin editor contributor"` Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor contributor"`
} }
type LoginRequest struct { type LoginRequest struct {
Email string `json:"email" binding:"required,email"` Username string `json:"username" binding:"required,min=3,max=32"`
Password string `json:"password" binding:"required"` Password string `json:"password" binding:"required"`
} }
type AuthResponse struct { type AuthResponse struct {
Token string `json:"token"` Token string `json:"token"`
} }
func (h *Handler) Register(c *gin.Context) { func (h *Handler) Register(c *gin.Context) {
var req RegisterRequest // 检查是否启用注册功能
if err := c.ShouldBindJSON(&req); err != nil { if !h.cfg.Auth.Registration.Enabled {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) message := h.cfg.Auth.Registration.Message
return if message == "" {
} message = "Registration is currently disabled"
}
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"code": "REGISTRATION_DISABLED",
"message": message,
},
})
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role) var req RegisterRequest
if err != nil { if err := c.ShouldBindJSON(&req); err != nil {
log.Error().Err(err).Msg("Failed to create user") c.JSON(http.StatusBadRequest, gin.H{
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) "error": gin.H{
return "code": "INVALID_REQUEST",
} "message": err.Error(),
},
})
return
}
// Get user roles user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role)
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) if err != nil {
if err != nil { log.Error().Err(err).Msg("Failed to create user")
log.Error().Err(err).Msg("Failed to get user roles") c.JSON(http.StatusInternalServerError, gin.H{
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) "error": gin.H{
return "code": "CREATE_USER_FAILED",
} "message": "Failed to create user",
},
})
return
}
// Extract role names for JWT // Get user roles
roleNames := make([]string, len(roles)) roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
for i, r := range roles { if err != nil {
roleNames[i] = r.Name log.Error().Err(err).Msg("Failed to get user roles")
} c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"code": "GET_ROLES_FAILED",
"message": "Failed to get user roles",
},
})
return
}
// Generate JWT token // Extract role names for JWT
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ roleNames := make([]string, len(roles))
"sub": user.ID, for i, r := range roles {
"roles": roleNames, roleNames[i] = r.Name
"exp": time.Now().Add(24 * time.Hour).Unix(), }
})
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) // Generate JWT token
if err != nil { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
log.Error().Err(err).Msg("Failed to generate token") "sub": int64(user.ID), // 将用户 ID 转换为 int64
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) "roles": roleNames,
return "exp": time.Now().Add(24 * time.Hour).Unix(),
} })
c.JSON(http.StatusCreated, AuthResponse{Token: tokenString}) tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil {
log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"code": "GENERATE_TOKEN_FAILED",
"message": "Failed to generate token",
},
})
return
}
c.JSON(http.StatusCreated, AuthResponse{Token: tokenString})
} }
func (h *Handler) Login(c *gin.Context) { func (h *Handler) Login(c *gin.Context) {
var req LoginRequest var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{
return "error": gin.H{
} "code": "INVALID_REQUEST",
"message": err.Error(),
},
})
return
}
user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email) user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) c.JSON(http.StatusUnauthorized, gin.H{
return "error": gin.H{
} "code": "INVALID_CREDENTIALS",
"message": "Invalid username or password",
},
})
return
}
if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) { // 验证密码
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
return if err != nil {
} c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "INVALID_CREDENTIALS",
"message": "Invalid username or password",
},
})
return
}
// Get user roles // Get user roles
roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get user roles") log.Error().Err(err).Msg("Failed to get user roles")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) c.JSON(http.StatusInternalServerError, gin.H{
return "error": gin.H{
} "code": "GET_ROLES_FAILED",
"message": "Failed to get user roles",
},
})
return
}
// Extract role names for JWT // Extract role names for JWT
roleNames := make([]string, len(roles)) roleNames := make([]string, len(roles))
for i, r := range roles { for i, r := range roles {
roleNames[i] = r.Name roleNames[i] = r.Name
} }
// Generate JWT token // Generate JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": user.ID, "sub": int64(user.ID), // 将用户 ID 转换为 int64
"roles": roleNames, "roles": roleNames,
"exp": time.Now().Add(24 * time.Hour).Unix(), "exp": time.Now().Add(24 * time.Hour).Unix(),
}) })
tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret))
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to generate token") log.Error().Err(err).Msg("Failed to generate token")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) c.JSON(http.StatusInternalServerError, gin.H{
return "error": gin.H{
} "code": "GENERATE_TOKEN_FAILED",
"message": "Failed to generate token",
},
})
return
}
c.JSON(http.StatusOK, AuthResponse{Token: tokenString}) c.JSON(http.StatusOK, AuthResponse{Token: tokenString})
}
func (h *Handler) Logout(c *gin.Context) {
// 获取当前用户ID
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "User not authenticated",
},
})
return
}
// 获取 token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "Authorization header is required",
},
})
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "UNAUTHORIZED",
"message": "Authorization header format must be Bearer {token}",
},
})
return
}
// 解析 token 以获取过期时间
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(h.cfg.JWT.Secret), nil
})
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"code": "INVALID_TOKEN",
"message": "Invalid token",
},
})
return
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// 将 token 添加到黑名单
h.service.GetTokenBlacklist().AddToBlacklist(parts[1], claims)
}
// 记录日志
log.Info().
Interface("user_id", userID).
Msg("User logged out")
c.JSON(http.StatusOK, gin.H{
"message": "Successfully logged out",
})
} }

View file

@ -3,10 +3,11 @@ package handler
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
@ -14,6 +15,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/crypto/bcrypt"
) )
type AuthHandlerTestSuite struct { type AuthHandlerTestSuite struct {
@ -31,6 +33,15 @@ func (s *AuthHandlerTestSuite) SetupTest() {
JWT: config.JWTConfig{ JWT: config.JWTConfig{
Secret: "test-secret", Secret: "test-secret",
}, },
Auth: config.AuthConfig{
Registration: struct {
Enabled bool `yaml:"enabled"`
Message string `yaml:"message"`
}{
Enabled: true,
Message: "Registration is disabled",
},
},
}, s.service) }, s.service)
s.router = gin.New() s.router = gin.New()
} }
@ -43,6 +54,13 @@ func TestAuthHandlerSuite(t *testing.T) {
suite.Run(t, new(AuthHandlerTestSuite)) suite.Run(t, new(AuthHandlerTestSuite))
} }
type ErrorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func (s *AuthHandlerTestSuite) TestRegister() { func (s *AuthHandlerTestSuite) TestRegister() {
testCases := []struct { testCases := []struct {
name string name string
@ -50,31 +68,48 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock func() setupMock func()
expectedStatus int expectedStatus int
expectedError string expectedError string
registration bool
}{ }{
{ {
name: "成功注册", name: "成功注册",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "password123", Password: "password123",
Role: "contributor", Role: "contributor",
}, },
setupMock: func() { setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT(). s.service.EXPECT().
CreateUser(gomock.Any(), "test@example.com", "password123", "contributor"). CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
}, nil)
s.service.EXPECT(). s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID). GetUserRoles(gomock.Any(), 1).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
}, },
expectedStatus: http.StatusCreated, expectedStatus: http.StatusCreated,
registration: true,
},
{
name: "注册功能已禁用",
request: RegisterRequest{
Username: "testuser",
Email: "test@example.com",
Password: "password123",
Role: "contributor",
},
setupMock: func() {},
expectedStatus: http.StatusForbidden,
expectedError: "Registration is disabled",
registration: false,
}, },
{ {
name: "无效的邮箱格式", name: "无效的邮箱格式",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "invalid-email", Email: "invalid-email",
Password: "password123", Password: "password123",
Role: "contributor", Role: "contributor",
@ -82,10 +117,12 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", expectedError: "Key: 'RegisterRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag",
registration: true,
}, },
{ {
name: "密码太短", name: "密码太短",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "short", Password: "short",
Role: "contributor", Role: "contributor",
@ -93,10 +130,12 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag", expectedError: "Key: 'RegisterRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag",
registration: true,
}, },
{ {
name: "无效的角色", name: "无效的角色",
request: RegisterRequest{ request: RegisterRequest{
Username: "testuser",
Email: "test@example.com", Email: "test@example.com",
Password: "password123", Password: "password123",
Role: "invalid-role", Role: "invalid-role",
@ -104,11 +143,15 @@ func (s *AuthHandlerTestSuite) TestRegister() {
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag", expectedError: "Key: 'RegisterRequest.Role' Error:Field validation for 'Role' failed on the 'oneof' tag",
registration: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
// 设置注册功能状态
s.handler.cfg.Auth.Registration.Enabled = tc.registration
// 设置 mock // 设置 mock
tc.setupMock() tc.setupMock()
@ -126,10 +169,10 @@ func (s *AuthHandlerTestSuite) TestRegister() {
// 验证响应 // 验证响应
s.Equal(tc.expectedStatus, w.Code) s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" { if tc.expectedError != "" {
var response map[string]string var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err) s.NoError(err)
s.Contains(response["error"], tc.expectedError) s.Contains(response.Error.Message, tc.expectedError)
} else { } else {
var response AuthResponse var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
@ -141,6 +184,8 @@ func (s *AuthHandlerTestSuite) TestRegister() {
} }
func (s *AuthHandlerTestSuite) TestLogin() { func (s *AuthHandlerTestSuite) TestLogin() {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
testCases := []struct { testCases := []struct {
name string name string
request LoginRequest request LoginRequest
@ -151,91 +196,82 @@ func (s *AuthHandlerTestSuite) TestLogin() {
{ {
name: "成功登录", name: "成功登录",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT(). s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123"). GetUserRoles(gomock.Any(), 1).
Return(true)
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "无效的邮箱格式", name: "无效的用户名",
request: LoginRequest{ request: LoginRequest{
Email: "invalid-email", Username: "te",
Password: "password123", Password: "password123",
}, },
setupMock: func() {}, setupMock: func() {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", expectedError: "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'min' tag",
}, },
{ {
name: "用户不存在", name: "用户不存在",
request: LoginRequest{ request: LoginRequest{
Email: "nonexistent@example.com", Username: "nonexistent",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "nonexistent@example.com"). GetUserByUsername(gomock.Any(), "nonexistent").
Return(nil, errors.New("user not found")) Return(nil, fmt.Errorf("user not found"))
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials", expectedError: "Invalid username or password",
}, },
{ {
name: "密码错误", name: "密码错误",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "wrong-password", Password: "wrongpassword",
}, },
setupMock: func() { setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
s.service.EXPECT(). ID: 1,
ValidatePassword(gomock.Any(), user, "wrong-password"). Username: "testuser",
Return(false) PasswordHash: string(hashedPassword),
}, nil)
}, },
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedError: "Invalid credentials", expectedError: "Invalid username or password",
}, },
{ {
name: "获取用户角色失败", name: "获取用户角色失败",
request: LoginRequest{ request: LoginRequest{
Email: "test@example.com", Username: "testuser",
Password: "password123", Password: "password123",
}, },
setupMock: func() { setupMock: func() {
user := &ent.User{
ID: 1,
Email: "test@example.com",
}
s.service.EXPECT(). s.service.EXPECT().
GetUserByEmail(gomock.Any(), "test@example.com"). GetUserByUsername(gomock.Any(), "testuser").
Return(user, nil) Return(&ent.User{
ID: 1,
Username: "testuser",
PasswordHash: string(hashedPassword),
}, nil)
s.service.EXPECT(). s.service.EXPECT().
ValidatePassword(gomock.Any(), user, "password123"). GetUserRoles(gomock.Any(), 1).
Return(true) Return(nil, fmt.Errorf("failed to get roles"))
s.service.EXPECT().
GetUserRoles(gomock.Any(), user.ID).
Return(nil, errors.New("failed to get roles"))
}, },
expectedStatus: http.StatusInternalServerError, expectedStatus: http.StatusInternalServerError,
expectedError: "Failed to get user roles", expectedError: "Failed to get user roles",
@ -261,10 +297,10 @@ func (s *AuthHandlerTestSuite) TestLogin() {
// 验证响应 // 验证响应
s.Equal(tc.expectedStatus, w.Code) s.Equal(tc.expectedStatus, w.Code)
if tc.expectedError != "" { if tc.expectedError != "" {
var response map[string]string var response ErrorResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
s.NoError(err) s.NoError(err)
s.Contains(response["error"], tc.expectedError) s.Contains(response.Error.Message, tc.expectedError)
} else { } else {
var response AuthResponse var response AuthResponse
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)

View file

@ -9,6 +9,7 @@ import (
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/ent/categorycontent" "tss-rocks-be/ent/categorycontent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"tss-rocks-be/internal/types" "tss-rocks-be/internal/types"
@ -59,12 +60,24 @@ type CategoryHandlerTestSuite struct {
func (s *CategoryHandlerTestSuite) SetupTest() { func (s *CategoryHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type ContributorHandlerTestSuite struct {
func (s *ContributorHandlerTestSuite) SetupTest() { func (s *ContributorHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type DailyHandlerTestSuite struct {
func (s *DailyHandlerTestSuite) SetupTest() { func (s *DailyHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/middleware"
"tss-rocks-be/internal/service" "tss-rocks-be/internal/service"
"tss-rocks-be/internal/types" "tss-rocks-be/internal/types"
@ -32,6 +33,19 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
{ {
auth.POST("/register", h.Register) auth.POST("/register", h.Register)
auth.POST("/login", h.Login) auth.POST("/login", h.Login)
auth.POST("/logout", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()), h.Logout)
}
// User routes
users := api.Group("/users", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
{
users.GET("", h.ListUsers)
users.POST("", h.CreateUser)
users.GET("/:id", h.GetUser)
users.PUT("/:id", h.UpdateUser)
users.DELETE("/:id", h.DeleteUser)
users.GET("/me", h.GetCurrentUser)
users.PUT("/me", h.UpdateCurrentUser)
} }
// Category routes // Category routes
@ -71,7 +85,7 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
} }
// Media routes // Media routes
media := api.Group("/media") media := api.Group("/media", middleware.AuthMiddleware(h.cfg.JWT.Secret, h.service.GetTokenBlacklist()))
{ {
media.GET("", h.ListMedia) media.GET("", h.ListMedia)
media.POST("", h.UploadMedia) media.POST("", h.UploadMedia)

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/config" "tss-rocks-be/internal/config"
"tss-rocks-be/internal/service"
"tss-rocks-be/internal/service/mock" "tss-rocks-be/internal/service/mock"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -29,12 +30,24 @@ type PostHandlerTestSuite struct {
func (s *PostHandlerTestSuite) SetupTest() { func (s *PostHandlerTestSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.ctrl = gomock.NewController(s.T())
s.service = mock.NewMockService(s.ctrl) s.service = mock.NewMockService(s.ctrl)
cfg := &config.Config{} cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
},
}
s.handler = NewHandler(cfg, s.service) s.handler = NewHandler(cfg, s.service)
// Setup Gin router // Setup Gin router
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
s.router = gin.New() s.router = gin.New()
// Setup mock for GetTokenBlacklist
tokenBlacklist := &service.TokenBlacklist{}
s.service.EXPECT().
GetTokenBlacklist().
Return(tokenBlacklist).
AnyTimes()
s.handler.RegisterRoutes(s.router) s.handler.RegisterRoutes(s.router)
} }

View file

@ -0,0 +1,255 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/types"
"fmt"
)
type UpdateCurrentUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
CurrentPassword string `json:"current_password,omitempty"`
NewPassword string `json:"new_password,omitempty" binding:"omitempty,min=8"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=8"`
Role string `json:"role" binding:"required,oneof=admin editor"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
type UpdateUserRequest struct {
Email string `json:"email,omitempty" binding:"omitempty,email"`
Password string `json:"password,omitempty" binding:"omitempty,min=8"`
Role string `json:"role,omitempty" binding:"omitempty,oneof=admin editor"`
Status string `json:"status,omitempty" binding:"omitempty,oneof=active inactive"`
DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"`
}
// ListUsers returns a list of users
func (h *Handler) ListUsers(c *gin.Context) {
// Parse query parameters
params := &types.ListUsersParams{
Page: 1,
PerPage: 10,
}
if page := c.Query("page"); page != "" {
if p, err := strconv.Atoi(page); err == nil && p > 0 {
params.Page = p
}
}
if perPage := c.Query("per_page"); perPage != "" {
if pp, err := strconv.Atoi(perPage); err == nil && pp > 0 {
params.PerPage = pp
}
}
params.Sort = c.Query("sort")
params.Role = c.Query("role")
params.Status = c.Query("status")
params.Email = c.Query("email")
// Get users
users, err := h.service.ListUsers(c.Request.Context(), params)
if err != nil {
log.Error().Err(err).Msg("Failed to list users")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": users,
})
}
// CreateUser creates a new user
func (h *Handler) CreateUser(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Email, req.Password, req.Role)
if err != nil {
log.Error().Err(err).Msg("Failed to create user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
c.JSON(http.StatusCreated, gin.H{
"data": user,
})
}
// GetUser returns user details
func (h *Handler) GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
user, err := h.service.GetUser(c.Request.Context(), id)
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// UpdateUser updates user information
func (h *Handler) UpdateUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, err := h.service.UpdateUser(c.Request.Context(), id, &types.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Role: req.Role,
Status: req.Status,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}
// DeleteUser deletes a user
func (h *Handler) DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
if err := h.service.DeleteUser(c.Request.Context(), id); err != nil {
log.Error().Err(err).Msg("Failed to delete user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"})
return
}
c.Status(http.StatusNoContent)
}
// GetCurrentUser returns the current user's information
func (h *Handler) GetCurrentUser(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// 将用户ID转换为int64
var id int64
switch v := userID.(type) {
case int64:
id = v
case int:
id = int64(v)
case float64:
id = int64(v)
case string:
// 尝试将字符串转换为int64
parsedID, err := strconv.ParseInt(v, 10, 64)
if err != nil {
log.Error().
Err(err).
Str("user_id", v).
Msg("Failed to parse user_id string")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user_id format"})
return
}
id = parsedID
default:
log.Error().
Str("type", fmt.Sprintf("%T", userID)).
Interface("value", userID).
Msg("Invalid user_id type")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user_id type"})
return
}
// 获取用户信息
user, err := h.service.GetUser(c.Request.Context(), int(id))
if err != nil {
log.Error().Err(err).Msg("Failed to get user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user information"})
return
}
c.JSON(http.StatusOK, user)
}
// UpdateCurrentUser updates the current user's information
func (h *Handler) UpdateCurrentUser(c *gin.Context) {
// 从上下文中获取用户ID由认证中间件设置
userID, exists := c.Get("user_id")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
var req UpdateCurrentUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 如果要更新密码,需要验证当前密码
if req.NewPassword != "" {
if req.CurrentPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is required to update password"})
return
}
// 验证当前密码
if err := h.service.VerifyPassword(c.Request.Context(), userID.(int), req.CurrentPassword); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current password"})
return
}
}
// 更新用户信息
user, err := h.service.UpdateUser(c.Request.Context(), userID.(int), &types.UpdateUserInput{
Email: req.Email,
Password: req.NewPassword,
DisplayName: req.DisplayName,
})
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user information"})
return
}
c.JSON(http.StatusOK, gin.H{
"data": user,
})
}

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -78,26 +79,41 @@ func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) {
// 配置文件日志 // 配置文件日志
if config.EnableFile { if config.EnableFile {
// 确保日志目录存在 // 验证文件路径
if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil { if config.FilePath == "" {
return nil, fmt.Errorf("file path cannot be empty")
}
// 验证路径是否包含无效字符
if strings.ContainsAny(config.FilePath, "\x00") {
return nil, fmt.Errorf("file path contains invalid characters")
}
dir := filepath.Dir(config.FilePath)
// 检查目录是否存在或是否可以创建
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err) return nil, fmt.Errorf("failed to create log directory: %w", err)
} }
// 配置日志轮转 // 尝试打开或创建文件,验证路径是否有效且有写入权限
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, fmt.Errorf("failed to open or create log file: %w", err)
}
file.Close()
// 配置文件日志
logWriter = &lumberjack.Logger{ logWriter = &lumberjack.Logger{
Filename: config.FilePath, Filename: config.FilePath,
MaxSize: config.Rotation.MaxSize, // MB MaxSize: config.Rotation.MaxSize, // MB
MaxAge: config.Rotation.MaxAge, // days MaxBackups: config.Rotation.MaxBackups, // 文件个数
MaxBackups: config.Rotation.MaxBackups, // files MaxAge: config.Rotation.MaxAge, // 天数
Compress: config.Rotation.Compress, // 是否压缩 Compress: config.Rotation.Compress, // 是否压缩
LocalTime: config.Rotation.LocalTime, // 使用本地时间 LocalTime: config.Rotation.LocalTime, // 使用本地时间
} }
logger := zerolog.New(logWriter). logger := zerolog.New(logWriter).With().Timestamp().Logger()
With().
Timestamp().
Logger()
fileLogger = &logger fileLogger = &logger
} }

View file

@ -219,7 +219,7 @@ func TestAccessLogInvalidConfig(t *testing.T) {
name: "Invalid file path", name: "Invalid file path",
config: &types.AccessLogConfig{ config: &types.AccessLogConfig{
EnableFile: true, EnableFile: true,
FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径 FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的
}, },
expectedError: true, expectedError: true,
}, },

View file

@ -1,16 +1,20 @@
package middleware package middleware
import ( import (
"encoding/json"
"fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tss-rocks-be/internal/service"
) )
// AuthMiddleware creates a middleware for JWT authentication // AuthMiddleware creates a middleware for JWT authentication
func AuthMiddleware(jwtSecret string) gin.HandlerFunc { func AuthMiddleware(jwtSecret string, tokenBlacklist *service.TokenBlacklist) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
@ -26,8 +30,20 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return return
} }
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) { tokenStr := parts[1]
// 检查 token 是否在黑名单中
if tokenBlacklist.IsBlacklisted(tokenStr) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token has been revoked"})
c.Abort()
return
}
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
// 添加调试输出
log.Debug().Str("token", tokenStr).Msg("Parsing token")
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
log.Error().Str("method", token.Method.Alg()).Msg("Invalid signing method")
return nil, jwt.ErrSignatureInvalid return nil, jwt.ErrSignatureInvalid
} }
return []byte(jwtSecret), nil return []byte(jwtSecret), nil
@ -41,9 +57,73 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
} }
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
c.Set("user_id", claims["sub"]) // 添加调试信息
c.Set("user_role", claims["role"]) log.Debug().Interface("claims", claims).Msg("Token claims")
// 获取用户ID
sub, exists := claims["sub"]
if !exists {
log.Error().Msg("Token does not contain sub claim")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format"})
c.Abort()
return
}
// 打印类型信息
log.Debug().
Str("type", fmt.Sprintf("%T", sub)).
Interface("value", sub).
Msg("User ID from token")
// 将用户 ID 转换为字符串
var userIDStr string
switch v := sub.(type) {
case string:
userIDStr = v
case float64:
userIDStr = strconv.FormatFloat(v, 'f', 0, 64)
case json.Number:
userIDStr = v.String()
default:
userIDStr = fmt.Sprintf("%v", v)
}
// 验证用户 ID 是否为有效的数字字符串
_, err := strconv.Atoi(userIDStr)
if err != nil {
log.Error().
Err(err).
Str("user_id", userIDStr).
Msg("Invalid user ID format")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid user ID format"})
c.Abort()
return
}
// 获取用户角色
roles, exists := claims["roles"]
if !exists {
log.Error().Msg("Token does not contain roles claim")
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token format: missing roles"})
c.Abort()
return
}
// 将角色转换为字符串数组
var roleNames []string
if rolesArray, ok := roles.([]interface{}); ok {
for _, r := range rolesArray {
if roleStr, ok := r.(string); ok {
roleNames = append(roleNames, roleStr)
}
}
}
// 设置上下文
c.Set("user_id", userIDStr)
c.Set("user_roles", roleNames) // 存储角色数组
c.Next() c.Next()
return
} else { } else {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
c.Abort() c.Abort()
@ -55,27 +135,37 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
// RoleMiddleware creates a middleware for role-based authorization // RoleMiddleware creates a middleware for role-based authorization
func RoleMiddleware(roles ...string) gin.HandlerFunc { func RoleMiddleware(roles ...string) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
userRole, exists := c.Get("user_role") userRoles, exists := c.Get("user_roles")
if !exists { if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User role not found"}) log.Error().Msg("User roles not found in context")
c.JSON(http.StatusUnauthorized, gin.H{"error": "User roles not found"})
c.Abort() c.Abort()
return return
} }
roleStr, ok := userRole.(string) roleNames, ok := userRoles.([]string)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user role type"}) log.Error().Msg("Invalid user roles type")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user roles type"})
c.Abort() c.Abort()
return return
} }
for _, role := range roles { // 检查用户是否拥有任一所需角色
if role == roleStr { for _, requiredRole := range roles {
c.Next() for _, userRole := range roleNames {
return if requiredRole == userRole {
c.Next()
return
}
} }
} }
log.Warn().
Strs("required_roles", roles).
Strs("user_roles", roleNames).
Msg("Insufficient permissions")
c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"}) c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
c.Abort() c.Abort()
} }

View file

@ -2,6 +2,7 @@ package middleware
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -9,16 +10,22 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"tss-rocks-be/internal/service"
) )
func createTestToken(secret string, claims jwt.MapClaims) string { func createTestToken(secret string, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signedToken, _ := token.SignedString([]byte(secret)) signedToken, err := token.SignedString([]byte(secret))
if err != nil {
panic(fmt.Sprintf("Failed to sign token: %v", err))
}
return signedToken return signedToken
} }
func TestAuthMiddleware(t *testing.T) { func TestAuthMiddleware(t *testing.T) {
jwtSecret := "test-secret" jwtSecret := "test-secret"
tokenBlacklist := service.NewTokenBlacklist()
testCases := []struct { testCases := []struct {
name string name string
@ -27,7 +34,7 @@ func TestAuthMiddleware(t *testing.T) {
expectedBody map[string]string expectedBody map[string]string
checkUserData bool checkUserData bool
expectedUserID string expectedUserID string
expectedRole string expectedRoles []string
}{ }{
{ {
name: "No Authorization header", name: "No Authorization header",
@ -55,25 +62,25 @@ func TestAuthMiddleware(t *testing.T) {
name: "Valid token", name: "Valid token",
setupAuth: func(req *http.Request) { setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"sub": "user123", "sub": "123",
"role": "user", "roles": []string{"admin", "editor"},
"exp": time.Now().Add(time.Hour).Unix(), "exp": time.Now().Add(time.Hour).Unix(),
} }
token := createTestToken(jwtSecret, claims) token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
checkUserData: true, checkUserData: true,
expectedUserID: "user123", expectedUserID: "123",
expectedRole: "user", expectedRoles: []string{"admin", "editor"},
}, },
{ {
name: "Expired token", name: "Expired token",
setupAuth: func(req *http.Request) { setupAuth: func(req *http.Request) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"sub": "user123", "sub": "123",
"role": "user", "roles": []string{"user"},
"exp": time.Now().Add(-time.Hour).Unix(), "exp": time.Now().Add(-time.Hour).Unix(),
} }
token := createTestToken(jwtSecret, claims) token := createTestToken(jwtSecret, claims)
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
@ -89,18 +96,22 @@ func TestAuthMiddleware(t *testing.T) {
router := gin.New() router := gin.New()
// 添加认证中间件 // 添加认证中间件
router.Use(AuthMiddleware(jwtSecret)) router.Use(func(c *gin.Context) {
// 设置日志级别为 debug
gin.SetMode(gin.DebugMode)
c.Next()
}, AuthMiddleware(jwtSecret, tokenBlacklist))
// 测试路由 // 测试路由
router.GET("/test", func(c *gin.Context) { router.GET("/test", func(c *gin.Context) {
if tc.checkUserData { if tc.checkUserData {
userID, exists := c.Get("user_id") userID, exists := c.Get("user_id")
assert.True(t, exists) assert.True(t, exists, "user_id should exist in context")
assert.Equal(t, tc.expectedUserID, userID) assert.Equal(t, tc.expectedUserID, userID, "user_id should match")
role, exists := c.Get("user_role") roles, exists := c.Get("user_roles")
assert.True(t, exists) assert.True(t, exists, "user_roles should exist in context")
assert.Equal(t, tc.expectedRole, role) assert.Equal(t, tc.expectedRoles, roles, "user_roles should match")
} }
c.Status(http.StatusOK) c.Status(http.StatusOK)
}) })
@ -114,13 +125,13 @@ func TestAuthMiddleware(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
// 验证响应 // 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code) assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil { if tc.expectedBody != nil {
var response map[string]string var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response) err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err) assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response) assert.Equal(t, tc.expectedBody, response, "Response body should match")
} }
}) })
} }
@ -135,27 +146,27 @@ func TestRoleMiddleware(t *testing.T) {
expectedBody map[string]string expectedBody map[string]string
}{ }{
{ {
name: "No user role", name: "No user roles",
setupContext: func(c *gin.Context) { setupContext: func(c *gin.Context) {
// 不设置用户角色 // 不设置用户角色
}, },
allowedRoles: []string{"admin"}, allowedRoles: []string{"admin"},
expectedStatus: http.StatusUnauthorized, expectedStatus: http.StatusUnauthorized,
expectedBody: map[string]string{"error": "User role not found"}, expectedBody: map[string]string{"error": "User roles not found"},
}, },
{ {
name: "Invalid role type", name: "Invalid roles type",
setupContext: func(c *gin.Context) { setupContext: func(c *gin.Context) {
c.Set("user_role", 123) // 设置错误类型的角色 c.Set("user_roles", 123) // 设置错误类型的角色
}, },
allowedRoles: []string{"admin"}, allowedRoles: []string{"admin"},
expectedStatus: http.StatusInternalServerError, expectedStatus: http.StatusInternalServerError,
expectedBody: map[string]string{"error": "Invalid user role type"}, expectedBody: map[string]string{"error": "Invalid user roles type"},
}, },
{ {
name: "Insufficient permissions", name: "Insufficient permissions",
setupContext: func(c *gin.Context) { setupContext: func(c *gin.Context) {
c.Set("user_role", "user") c.Set("user_roles", []string{"user"})
}, },
allowedRoles: []string{"admin"}, allowedRoles: []string{"admin"},
expectedStatus: http.StatusForbidden, expectedStatus: http.StatusForbidden,
@ -164,7 +175,7 @@ func TestRoleMiddleware(t *testing.T) {
{ {
name: "Allowed role", name: "Allowed role",
setupContext: func(c *gin.Context) { setupContext: func(c *gin.Context) {
c.Set("user_role", "admin") c.Set("user_roles", []string{"admin"})
}, },
allowedRoles: []string{"admin"}, allowedRoles: []string{"admin"},
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
@ -172,7 +183,7 @@ func TestRoleMiddleware(t *testing.T) {
{ {
name: "One of multiple allowed roles", name: "One of multiple allowed roles",
setupContext: func(c *gin.Context) { setupContext: func(c *gin.Context) {
c.Set("user_role", "editor") c.Set("user_roles", []string{"user", "editor"})
}, },
allowedRoles: []string{"admin", "editor", "moderator"}, allowedRoles: []string{"admin", "editor", "moderator"},
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
@ -188,8 +199,7 @@ func TestRoleMiddleware(t *testing.T) {
router.Use(func(c *gin.Context) { router.Use(func(c *gin.Context) {
tc.setupContext(c) tc.setupContext(c)
c.Next() c.Next()
}) }, RoleMiddleware(tc.allowedRoles...))
router.Use(RoleMiddleware(tc.allowedRoles...))
// 测试路由 // 测试路由
router.GET("/test", func(c *gin.Context) { router.GET("/test", func(c *gin.Context) {
@ -204,13 +214,13 @@ func TestRoleMiddleware(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
// 验证响应 // 验证响应
assert.Equal(t, tc.expectedStatus, rec.Code) assert.Equal(t, tc.expectedStatus, rec.Code, "HTTP status code should match")
if tc.expectedBody != nil { if tc.expectedBody != nil {
var response map[string]string var response map[string]string
err := json.NewDecoder(rec.Body).Decode(&response) err := json.NewDecoder(rec.Body).Decode(&response)
assert.NoError(t, err) assert.NoError(t, err, "Response body should be valid JSON")
assert.Equal(t, tc.expectedBody, response) assert.Equal(t, tc.expectedBody, response, "Response body should match")
} }
}) })
} }

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/ent/permission"
"tss-rocks-be/ent/role" "tss-rocks-be/ent/role"
) )
@ -38,37 +39,69 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error {
permissionMap := make(map[string]*ent.Permission) permissionMap := make(map[string]*ent.Permission)
for resource, actions := range DefaultPermissions { for resource, actions := range DefaultPermissions {
for _, action := range actions { for _, action := range actions {
permission, err := client.Permission.Create().
SetResource(resource).
SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
Save(ctx)
if err != nil {
return fmt.Errorf("failed creating permission: %w", err)
}
key := fmt.Sprintf("%s:%s", resource, action) key := fmt.Sprintf("%s:%s", resource, action)
permission, err := client.Permission.Query().
Where(
permission.ResourceEQ(resource),
permission.ActionEQ(action),
).
Only(ctx)
if ent.IsNotFound(err) {
permission, err = client.Permission.Create().
SetResource(resource).
SetAction(action).
SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)).
Save(ctx)
if err != nil {
return fmt.Errorf("failed creating permission: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed querying permission: %w", err)
}
permissionMap[key] = permission permissionMap[key] = permission
} }
} }
// Create roles with permissions // Create roles with permissions
for roleName, permissions := range DefaultRoles { for roleName, permissions := range DefaultRoles {
roleCreate := client.Role.Create(). role, err := client.Role.Query().
SetName(roleName). Where(role.NameEQ(roleName)).
SetDescription(fmt.Sprintf("Role for %s users", roleName)) Only(ctx)
if ent.IsNotFound(err) {
roleCreate := client.Role.Create().
SetName(roleName).
SetDescription(fmt.Sprintf("Role for %s users", roleName))
// Add permissions to role // Add permissions to role
for resource, actions := range permissions { for resource, actions := range permissions {
for _, action := range actions { for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action) key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists { if permission, exists := permissionMap[key]; exists {
roleCreate.AddPermissions(permission) roleCreate.AddPermissions(permission)
}
} }
} }
}
if _, err := roleCreate.Save(ctx); err != nil { if _, err := roleCreate.Save(ctx); err != nil {
return fmt.Errorf("failed creating role %s: %w", roleName, err) return fmt.Errorf("failed creating role %s: %w", roleName, err)
}
} else if err != nil {
return fmt.Errorf("failed querying role: %w", err)
} else {
// Update existing role's permissions
for resource, actions := range permissions {
for _, action := range actions {
key := fmt.Sprintf("%s:%s", resource, action)
if permission, exists := permissionMap[key]; exists {
err = client.Role.UpdateOne(role).
AddPermissions(permission).
Exec(ctx)
if err != nil {
return fmt.Errorf("failed updating role %s permissions: %w", roleName, err)
}
}
}
}
} }
} }

View file

@ -64,6 +64,7 @@ func TestAssignRoleToUser(t *testing.T) {
// Create a test user // Create a test user
user, err := client.User.Create(). user, err := client.User.Create().
SetEmail("test@example.com"). SetEmail("test@example.com").
SetUsername("testuser").
SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy"). SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy").
Save(ctx) Save(ctx)
if err != nil { if err != nil {

View file

@ -1,31 +1,27 @@
package server package server
import ( import (
"context" "context"
"entgo.io/ent/dialect/sql" "tss-rocks-be/ent"
_ "github.com/mattn/go-sqlite3" "tss-rocks-be/internal/config"
"github.com/rs/zerolog/log"
"tss-rocks-be/ent" _ "github.com/mattn/go-sqlite3"
"tss-rocks-be/internal/config" "github.com/rs/zerolog/log"
) )
// NewEntClient creates a new ent client // NewEntClient creates a new ent client
func NewEntClient(cfg *config.Config) *ent.Client { func NewEntClient(cfg *config.Config) *ent.Client {
// TODO: Implement database connection based on config // 使用配置文件中的数据库设置
// For now, we'll use SQLite for development client, err := ent.Open(cfg.Database.Driver, cfg.Database.DSN)
db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil {
if err != nil { log.Fatal().Err(err).Msg("Failed to connect to database")
log.Fatal().Err(err).Msg("Failed to connect to database") }
}
// Create ent client // Run the auto migration tool
client := ent.NewClient(ent.Driver(db)) if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Run the auto migration tool return client
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
return client
} }

View file

@ -0,0 +1,57 @@
package service
import (
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/zerolog/log"
)
// TokenBlacklist 用于存储已失效的 token
type TokenBlacklist struct {
tokens sync.Map
}
// NewTokenBlacklist 创建一个新的 token 黑名单
func NewTokenBlacklist() *TokenBlacklist {
bl := &TokenBlacklist{}
// 启动定期清理过期 token 的 goroutine
go bl.cleanupExpiredTokens()
return bl
}
// AddToBlacklist 将 token 添加到黑名单
func (bl *TokenBlacklist) AddToBlacklist(tokenStr string, claims jwt.MapClaims) {
// 获取 token 的过期时间
exp, ok := claims["exp"].(float64)
if !ok {
log.Error().Msg("Failed to get token expiration time")
return
}
// 存储 token 和其过期时间
bl.tokens.Store(tokenStr, time.Unix(int64(exp), 0))
}
// IsBlacklisted 检查 token 是否在黑名单中
func (bl *TokenBlacklist) IsBlacklisted(tokenStr string) bool {
_, exists := bl.tokens.Load(tokenStr)
return exists
}
// cleanupExpiredTokens 定期清理过期的 token
func (bl *TokenBlacklist) cleanupExpiredTokens() {
ticker := time.NewTicker(1 * time.Hour)
for range ticker.C {
now := time.Now()
bl.tokens.Range(func(key, value interface{}) bool {
if expTime, ok := value.(time.Time); ok {
if now.After(expTime) {
bl.tokens.Delete(key)
}
}
return true
})
}
}

View file

@ -18,6 +18,7 @@ import (
"tss-rocks-be/ent/contributorsociallink" "tss-rocks-be/ent/contributorsociallink"
"tss-rocks-be/ent/daily" "tss-rocks-be/ent/daily"
"tss-rocks-be/ent/dailycontent" "tss-rocks-be/ent/dailycontent"
"tss-rocks-be/ent/media"
"tss-rocks-be/ent/permission" "tss-rocks-be/ent/permission"
"tss-rocks-be/ent/post" "tss-rocks-be/ent/post"
"tss-rocks-be/ent/postcontent" "tss-rocks-be/ent/postcontent"
@ -41,72 +42,94 @@ var openFile func(fh *multipart.FileHeader) (multipart.File, error) = func(fh *m
} }
type serviceImpl struct { type serviceImpl struct {
client *ent.Client client *ent.Client
storage storage.Storage storage storage.Storage
tokenBlacklist *TokenBlacklist
} }
// NewService creates a new Service instance // NewService creates a new Service instance
func NewService(client *ent.Client, storage storage.Storage) Service { func NewService(client *ent.Client, storage storage.Storage) Service {
return &serviceImpl{ return &serviceImpl{
client: client, client: client,
storage: storage, storage: storage,
tokenBlacklist: NewTokenBlacklist(),
} }
} }
// GetTokenBlacklist returns the token blacklist
func (s *serviceImpl) GetTokenBlacklist() *TokenBlacklist {
return s.tokenBlacklist
}
// User operations // User operations
func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) { func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) {
// Hash the password // 验证邮箱格式
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if !emailRegex.MatchString(email) {
return nil, fmt.Errorf("invalid email format")
}
// 验证密码长度
if len(password) < 8 {
return nil, fmt.Errorf("password must be at least 8 characters")
}
// 检查用户名是否已存在
exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("error checking username: %v", err)
}
if exists {
return nil, fmt.Errorf("username '%s' already exists", username)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err) return nil, fmt.Errorf("error hashing password: %v", err)
} }
// Add the user role by default // 创建用户
userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx) u, err := s.client.User.Create().
if err != nil { SetUsername(username).
return nil, fmt.Errorf("failed to get user role: %w", err)
}
// If a specific role is requested and it's not "user", get that role too
var additionalRole *ent.Role
if roleStr != "" && roleStr != "user" {
additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get role: %w", err)
}
}
// Create user with password and user role
userCreate := s.client.User.Create().
SetEmail(email). SetEmail(email).
SetPasswordHash(string(hashedPassword)). SetPasswordHash(string(hashedPassword)).
AddRoles(userRole) SetStatus("active").
Save(ctx)
// Add the additional role if specified
if additionalRole != nil {
userCreate.AddRoles(additionalRole)
}
// Save the user
user, err := userCreate.Save(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err) return nil, fmt.Errorf("error creating user: %v", err)
} }
return user, nil // 分配角色
err = s.AssignRole(ctx, u.ID, roleStr)
if err != nil {
return nil, fmt.Errorf("error assigning role: %v", err)
}
return u, nil
}
func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) {
u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, fmt.Errorf("user with username '%s' not found", username)
}
return nil, fmt.Errorf("error getting user: %v", err)
}
return u, nil
} }
func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) { func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) {
user, err := s.client.User.Query(). u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx)
Where(user.EmailEQ(email)).
Only(ctx)
if err != nil { if err != nil {
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
return nil, fmt.Errorf("user not found: %s", email) return nil, fmt.Errorf("user with email '%s' not found", email)
} }
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("error getting user: %v", err)
} }
return user, nil return u, nil
} }
func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool { func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool {
@ -437,12 +460,14 @@ func (s *serviceImpl) DeleteMedia(ctx context.Context, id int, userID int) error
} }
// Check ownership // Check ownership
if media.CreatedBy != strconv.Itoa(userID) { isOwner := media.CreatedBy == strconv.Itoa(userID)
if !isOwner {
return ErrUnauthorized return ErrUnauthorized
} }
// Delete from storage // Delete from storage
if err := s.storage.Delete(ctx, media.StorageID); err != nil { err = s.storage.Delete(ctx, media.StorageID)
if err != nil {
return err return err
} }
@ -890,3 +915,138 @@ func (s *serviceImpl) HasPermission(ctx context.Context, userID int, permission
return false, nil return false, nil
} }
func (s *serviceImpl) Delete(ctx context.Context, id int, currentUserID int) error {
// Check if the entity exists and get its type
var entityExists bool
var err error
// Try to find the entity in different tables
if entityExists, err = s.client.User.Query().Where(user.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete users
hasPermission, err := s.HasPermission(ctx, currentUserID, "users:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
// Cannot delete yourself
if id == currentUserID {
return fmt.Errorf("cannot delete your own account")
}
return s.client.User.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Post.Query().Where(post.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete posts
hasPermission, err := s.HasPermission(ctx, currentUserID, "posts:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the author of the post
isAuthor, err := s.client.Post.Query().
Where(post.ID(id)).
QueryContributors().
QueryContributor().
QueryUser().
Where(user.ID(currentUserID)).
Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check post author: %v", err)
}
if !isAuthor {
return ErrUnauthorized
}
}
return s.client.Post.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Category.Query().Where(category.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete categories
hasPermission, err := s.HasPermission(ctx, currentUserID, "categories:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Category.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Contributor.Query().Where(contributor.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete contributors
hasPermission, err := s.HasPermission(ctx, currentUserID, "contributors:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
return s.client.Contributor.DeleteOneID(id).Exec(ctx)
}
if entityExists, err = s.client.Media.Query().Where(media.ID(id)).Exist(ctx); err == nil && entityExists {
// Check if user has permission to delete media
hasPermission, err := s.HasPermission(ctx, currentUserID, "media:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
// Check if the user is the uploader of the media
mediaItem, err := s.client.Media.Query().
Where(media.ID(id)).
Only(ctx)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
isOwner := mediaItem.CreatedBy == strconv.Itoa(currentUserID)
if !isOwner {
return ErrUnauthorized
}
}
// Get media item for path
mediaItem, err := s.client.Media.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get media: %v", err)
}
// Delete from storage first
if err := s.storage.Delete(ctx, mediaItem.StorageID); err != nil {
return fmt.Errorf("failed to delete media file: %v", err)
}
// Then delete from database
return s.client.Media.DeleteOneID(id).Exec(ctx)
}
return fmt.Errorf("entity with id %d not found or delete operation not supported for this entity type", id)
}
func (s *serviceImpl) DeleteDaily(ctx context.Context, id string, currentUserID int) error {
// Check if user has permission to delete daily content
hasPermission, err := s.HasPermission(ctx, currentUserID, "daily:delete")
if err != nil {
return fmt.Errorf("failed to check permission: %v", err)
}
if !hasPermission {
return ErrUnauthorized
}
exists, err := s.client.Daily.Query().Where(daily.ID(id)).Exist(ctx)
if err != nil {
return fmt.Errorf("failed to check daily existence: %v", err)
}
if !exists {
return fmt.Errorf("daily with id %s not found", id)
}
return s.client.Daily.DeleteOneID(id).Exec(ctx)
}

View file

@ -103,52 +103,50 @@ func newMockMultipartFile(data []byte) *mockMultipartFile {
func (s *ServiceImplTestSuite) TestCreateUser() { func (s *ServiceImplTestSuite) TestCreateUser() {
testCases := []struct { testCases := []struct {
name string name string
email string username string
password string email string
role string password string
wantError bool role string
wantErr bool
}{ }{
{ {
name: "Valid user creation", name: "有效的用户",
email: "test@example.com", username: "testuser",
password: "password123", email: "test@example.com",
role: "admin", password: "password123",
wantError: false, role: "user",
wantErr: false,
}, },
{ {
name: "Empty email", name: "无效的邮箱",
email: "", username: "testuser2",
password: "password123", email: "invalid-email",
role: "user", password: "password123",
wantError: true, role: "user",
wantErr: true,
}, },
{ {
name: "Empty password", name: "空密码",
email: "test@example.com", username: "testuser3",
password: "", email: "test3@example.com",
role: "user", password: "",
wantError: true, role: "user",
}, wantErr: true,
{
name: "Invalid role",
email: "test@example.com",
password: "password123",
role: "invalid_role",
wantError: true,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
s.Run(tc.name, func() { s.Run(tc.name, func() {
user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role) user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role)
if tc.wantError { if tc.wantErr {
assert.Error(s.T(), err) s.Error(err)
assert.Nil(s.T(), user) s.Nil(user)
} else { } else {
assert.NoError(s.T(), err) s.NoError(err)
assert.NotNil(s.T(), user) s.NotNil(user)
assert.Equal(s.T(), tc.email, user.Email) s.Equal(tc.email, user.Email)
s.Equal(tc.username, user.Username)
} }
}) })
} }
@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() {
password := "password123" password := "password123"
role := "user" role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role) user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
password := "password123" password := "password123"
role := "user" role := "user"
user, err := s.svc.CreateUser(s.ctx, email, password, role) user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role)
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() {
func (s *ServiceImplTestSuite) TestRBAC() { func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("AssignRole", func() { s.Run("AssignRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.AssignRole(s.ctx, user.ID, "user") err = s.svc.AssignRole(s.ctx, user.ID, "user")
@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("RemoveRole", func() { s.Run("RemoveRole", func() {
user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.RemoveRole(s.ctx, user.ID, "admin") err = s.svc.RemoveRole(s.ctx, user.ID, "admin")
@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
s.Run("HasPermission", func() { s.Run("HasPermission", func() {
s.Run("Admin can create users", func() { s.Run("Admin can create users", func() {
user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin") user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Editor cannot create users", func() { s.Run("Editor cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor") user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User cannot create users", func() { s.Run("User cannot create users", func() {
user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create")
@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Editor can create posts", func() { s.Run("Editor can create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor") user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User can read posts", func() { s.Run("User can read posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read")
@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("User cannot create posts", func() { s.Run("User cannot create posts", func() {
user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create")
@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
}) })
s.Run("Invalid permission format", func() { s.Run("Invalid permission format", func() {
user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user") user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
_, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission") _, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission")
@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() {
func (s *ServiceImplTestSuite) TestCategory() { func (s *ServiceImplTestSuite) TestCategory() {
// Create a test user with admin role for testing // Create a test user with admin role for testing
adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin") adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin")
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), adminUser) require.NotNil(s.T(), adminUser)
@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() {
ctx := context.Background() ctx := context.Background()
// 创建测试用户,默认会有 "user" 角色 // 创建测试用户,默认会有 "user" 角色
user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user") user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user")
s.Require().NoError(err) s.Require().NoError(err)
// 测试新用户有默认的 "user" 角色 // 测试新用户有默认的 "user" 角色
@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() {
func (s *ServiceImplTestSuite) TestMedia() { func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Upload Media", func() { s.Run("Upload Media", func() {
// Create a user first // Create a user first
user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "") user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
require.NotNil(s.T(), user) require.NotNil(s.T(), user)
@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
s.Run("Delete Media - Unauthorized", func() { s.Run("Delete Media - Unauthorized", func() {
// Create a user // Create a user
user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "") user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
// Mock file content // Mock file content
@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() {
require.NoError(s.T(), err) require.NoError(s.T(), err)
// Try to delete with different user // Try to delete with different user
anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "") anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID) err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID)

View file

@ -9,14 +9,22 @@ import (
"tss-rocks-be/ent" "tss-rocks-be/ent"
"tss-rocks-be/internal/storage" "tss-rocks-be/internal/storage"
"tss-rocks-be/internal/types"
) )
// Service interface defines all business logic operations // Service interface defines all business logic operations
type Service interface { type Service interface {
// User operations // User operations
CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error) CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error)
GetUser(ctx context.Context, id int) (*ent.User, error)
GetUserByUsername(ctx context.Context, username string) (*ent.User, error)
GetUserByEmail(ctx context.Context, email string) (*ent.User, error) GetUserByEmail(ctx context.Context, email string) (*ent.User, error)
ValidatePassword(ctx context.Context, user *ent.User, password string) bool ValidatePassword(ctx context.Context, user *ent.User, password string) bool
VerifyPassword(ctx context.Context, userID int, password string) error
UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error)
DeleteUser(ctx context.Context, userID int) error
ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error)
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
// Category operations // Category operations
CreateCategory(ctx context.Context) (*ent.Category, error) CreateCategory(ctx context.Context) (*ent.Category, error)
@ -31,6 +39,13 @@ type Service interface {
GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error) GetPostBySlug(ctx context.Context, langCode, slug string) (*ent.Post, error)
ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error) ListPosts(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Post, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// Contributor operations // Contributor operations
CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error) CreateContributor(ctx context.Context, name string, avatarURL, bio *string) (*ent.Contributor, error)
AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error) AddContributorSocialLink(ctx context.Context, contributorID int, linkType, name, value string) (*ent.ContributorSocialLink, error)
@ -43,17 +58,16 @@ type Service interface {
GetDailyByID(ctx context.Context, id string) (*ent.Daily, error) GetDailyByID(ctx context.Context, id string) (*ent.Daily, error)
ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error) ListDailies(ctx context.Context, langCode string, categoryID *int, limit, offset int) ([]*ent.Daily, error)
// Media operations
ListMedia(ctx context.Context, limit, offset int) ([]*ent.Media, error)
Upload(ctx context.Context, file *multipart.FileHeader, userID int) (*ent.Media, error)
GetMedia(ctx context.Context, id int) (*ent.Media, error)
GetFile(ctx context.Context, id int) (io.ReadCloser, *storage.FileInfo, error)
DeleteMedia(ctx context.Context, id int, userID int) error
// RBAC operations // RBAC operations
InitializeRBAC(ctx context.Context) error
AssignRole(ctx context.Context, userID int, role string) error AssignRole(ctx context.Context, userID int, role string) error
RemoveRole(ctx context.Context, userID int, role string) error RemoveRole(ctx context.Context, userID int, role string) error
GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error)
HasPermission(ctx context.Context, userID int, permission string) (bool, error) HasPermission(ctx context.Context, userID int, permission string) (bool, error)
InitializeRBAC(ctx context.Context) error
// Token blacklist
GetTokenBlacklist() *TokenBlacklist
// Generic operations
Delete(ctx context.Context, id int, currentUserID int) error
DeleteDaily(ctx context.Context, id string, currentUserID int) error
} }

View file

@ -0,0 +1,154 @@
package service
import (
"context"
"fmt"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"tss-rocks-be/ent"
"tss-rocks-be/ent/role"
"tss-rocks-be/ent/user"
"tss-rocks-be/internal/types"
)
// GetUser gets a user by ID
func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) {
user, err := s.client.User.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// VerifyPassword verifies the user's current password
func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error {
user, err := s.GetUser(ctx, userID)
if err != nil {
return err
}
if !s.ValidatePassword(ctx, user, password) {
return fmt.Errorf("invalid password")
}
return nil
}
// UpdateUser updates user information
func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) {
// Start building the update
update := s.client.User.UpdateOneID(userID)
// Update email if provided
if input.Email != "" {
update.SetEmail(input.Email)
}
// Update display name if provided
if input.DisplayName != "" {
update.SetDisplayName(input.DisplayName)
}
// Update password if provided
if input.Password != "" {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
log.Error().Err(err).Msg("Failed to hash password")
return nil, fmt.Errorf("failed to hash password: %w", err)
}
update.SetPasswordHash(string(hashedPassword))
}
// Update status if provided
if input.Status != "" {
update.SetStatus(user.Status(input.Status))
}
// Execute the update
user, err := update.Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to update user")
return nil, fmt.Errorf("failed to update user: %w", err)
}
// Update role if provided
if input.Role != "" {
// Clear existing roles
_, err = user.Update().ClearRoles().Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to clear user roles")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
// Add new role
role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to find role")
return nil, fmt.Errorf("failed to find role: %w", err)
}
_, err = user.Update().AddRoles(role).Save(ctx)
if err != nil {
log.Error().Err(err).Msg("Failed to add user role")
return nil, fmt.Errorf("failed to update user roles: %w", err)
}
}
return user, nil
}
// DeleteUser deletes a user by ID
func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error {
err := s.client.User.DeleteOneID(userID).Exec(ctx)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}
// ListUsers lists users with filters and pagination
func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) {
query := s.client.User.Query()
// Apply filters
if params.Role != "" {
query.Where(user.HasRolesWith(role.NameEQ(params.Role)))
}
if params.Status != "" {
query.Where(user.StatusEQ(user.Status(params.Status)))
}
if params.Email != "" {
query.Where(user.EmailContains(params.Email))
}
// Apply pagination
if params.PerPage > 0 {
query.Limit(params.PerPage)
if params.Page > 0 {
query.Offset((params.Page - 1) * params.PerPage)
}
}
// Apply sorting
if params.Sort != "" {
switch params.Sort {
case "email_asc":
query.Order(ent.Asc(user.FieldEmail))
case "email_desc":
query.Order(ent.Desc(user.FieldEmail))
case "created_at_asc":
query.Order(ent.Asc(user.FieldCreatedAt))
case "created_at_desc":
query.Order(ent.Desc(user.FieldCreatedAt))
}
}
// Execute query
users, err := query.All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
return users, nil
}

View file

@ -0,0 +1,28 @@
package types
// UpdateUserInput defines the input for updating a user
type UpdateUserInput struct {
Email string
Password string
Role string
Status string
DisplayName string
}
// ListUsersParams defines the parameters for listing users
type ListUsersParams struct {
Page int
PerPage int
Sort string
Role string
Status string
Email string
}
// CreateUserInput defines the input for creating a user
type CreateUserInput struct {
Email string
Password string
Role string
DisplayName string
}

29
backend/main.go Normal file
View file

@ -0,0 +1,29 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"tss-rocks-be/cmd"
)
var rootCmd = &cobra.Command{
Use: "app",
Short: "TSS backend application",
Long: `TSS (The Starset Society) backend application.
It provides both API endpoints and admin tools for content management.`,
}
func init() {
// 添加子命令
rootCmd.AddCommand(cmd.GetUserCmd())
rootCmd.AddCommand(cmd.GetServerCmd())
}
func main() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

View file

@ -20,5 +20,42 @@
}, },
"footer": { "footer": {
"copyright": "TSS Rocks. All rights reserved." "copyright": "TSS Rocks. All rights reserved."
},
"admin": {
"common": {
"search": "Search",
"create": "Create",
"edit": "Edit",
"delete": "Delete",
"status": "Status",
"actions": "Actions",
"published": "Published",
"draft": "Draft",
"author": "Author",
"date": "Date",
"title": "Title",
"description": "Description",
"name": "Name",
"email": "Email",
"role": "Role",
"bio": "Bio",
"articles": "Articles",
"lastLogin": "Last Login",
"joinDate": "Join Date",
"username": "Username",
"logout": "Logout"
},
"nav": {
"dashboard": "Dashboard",
"posts": "Posts",
"categories": "Categories",
"users": "Users",
"contributors": "Contributors"
},
"roles": {
"admin": "Administrator",
"user": "User",
"contributor": "Contributor"
}
} }
} }

View file

@ -19,6 +19,43 @@
"system": "跟随系统" "system": "跟随系统"
}, },
"footer": { "footer": {
"copyright": "TSS.Rocks. 版权所有。" "copyright": "TSS Rocks. 保留所有权利。"
},
"admin": {
"common": {
"search": "搜索",
"create": "新建",
"edit": "编辑",
"delete": "删除",
"status": "状态",
"actions": "操作",
"published": "已发布",
"draft": "草稿",
"author": "作者",
"date": "日期",
"title": "标题",
"description": "描述",
"name": "名称",
"email": "邮箱",
"role": "角色",
"bio": "简介",
"articles": "文章数",
"lastLogin": "最后登录",
"joinDate": "加入时间",
"username": "用户名",
"logout": "退出"
},
"nav": {
"dashboard": "仪表盘",
"posts": "文章",
"categories": "分类",
"users": "用户",
"contributors": "作者"
},
"roles": {
"admin": "管理员",
"user": "普通用户",
"contributor": "作者"
}
} }
} }

View file

@ -19,6 +19,43 @@
"system": "跟隨系統" "system": "跟隨系統"
}, },
"footer": { "footer": {
"copyright": "TSS.Rocks. 版權所有。" "copyright": "TSS Rocks. 保留所有權利。"
},
"admin": {
"common": {
"search": "搜尋",
"create": "新建",
"edit": "編輯",
"delete": "刪除",
"status": "狀態",
"actions": "操作",
"published": "已發布",
"draft": "草稿",
"author": "作者",
"date": "日期",
"title": "標題",
"description": "描述",
"name": "名稱",
"email": "郵箱",
"role": "角色",
"bio": "簡介",
"articles": "文章數",
"lastLogin": "最後登入",
"joinDate": "加入時間",
"username": "用戶名",
"logout": "退出"
},
"nav": {
"dashboard": "儀表板",
"posts": "文章",
"categories": "分類",
"users": "用戶",
"contributors": "作者"
},
"roles": {
"admin": "管理員",
"user": "普通用戶",
"contributor": "作者"
}
} }
} }

View file

@ -10,10 +10,11 @@
"preview": "vite preview" "preview": "vite preview"
}, },
"dependencies": { "dependencies": {
"@tss-rocks/api": "workspace:*",
"@headlessui/react": "^2.2.0", "@headlessui/react": "^2.2.0",
"@tss-rocks/api": "workspace:*",
"@types/markdown-it": "^14.1.2", "@types/markdown-it": "^14.1.2",
"i18next": "^24.2.2", "i18next": "^24.2.2",
"i18next-browser-languagedetector": "^8.0.4",
"lucide-react": "^0.474.0", "lucide-react": "^0.474.0",
"markdown-it": "^14.1.0", "markdown-it": "^14.1.0",
"react": "^19.0.0", "react": "^19.0.0",
@ -27,6 +28,7 @@
"@tailwindcss/postcss": "^4.0.3", "@tailwindcss/postcss": "^4.0.3",
"@tailwindcss/typography": "^0.5.16", "@tailwindcss/typography": "^0.5.16",
"@tailwindcss/vite": "^4.0.3", "@tailwindcss/vite": "^4.0.3",
"@types/node": "^22.13.4",
"@types/react": "^19.0.8", "@types/react": "^19.0.8",
"@types/react-dom": "^19.0.3", "@types/react-dom": "^19.0.3",
"@vitejs/plugin-react": "^4.3.1", "@vitejs/plugin-react": "^4.3.1",

View file

@ -3,35 +3,61 @@ import { BrowserRouter, Routes, Route } from 'react-router-dom';
import { Header } from './components/layout/Header'; import { Header } from './components/layout/Header';
import { Suspense, lazy } from 'react'; import { Suspense, lazy } from 'react';
import Footer from './components/Footer'; import Footer from './components/Footer';
import { AuthProvider } from './contexts/AuthContext';
// Lazy load pages // Lazy load pages
const Home = lazy(() => import('./pages/Home')); const Home = lazy(() => import('./pages/Home'));
const Daily = lazy(() => import('./pages/Daily')); const Daily = lazy(() => import('./pages/Daily'));
const Article = lazy(() => import('./pages/Article')); const Article = lazy(() => import('./pages/Article'));
const AdminLayout = lazy(() => import('./pages/admin/layout/AdminLayout'));
// 管理页面懒加载
const PostsManagement = lazy(() => import('./pages/admin/posts/PostsManagement'));
const CategoriesManagement = lazy(() => import('./pages/admin/categories/CategoriesManagement'));
const UsersManagement = lazy(() => import('./pages/admin/users/UsersManagement'));
const ContributorsManagement = lazy(() => import('./pages/admin/contributors/ContributorsManagement'));
function App() { function App() {
return ( return (
<BrowserRouter> <AuthProvider>
<div className="flex flex-col min-h-screen bg-white dark:bg-neutral-900 text-gray-900 dark:text-gray-100"> <BrowserRouter>
<Header /> <div className="flex flex-col min-h-screen bg-white dark:bg-neutral-900 text-gray-900 dark:text-gray-100">
{/* 页眉分隔线 */}
<div className="w-[95%] mx-auto">
<div className="border-t-2 border-gray-900 dark:border-gray-100 w-full mb-2" />
</div>
<main className="flex-1 w-[95%] mx-auto py-8">
<Suspense fallback={<div>Loading...</div>}> <Suspense fallback={<div>Loading...</div>}>
<Routes> <Routes>
<Route path="/" element={<Home />} /> {/* Admin routes */}
<Route path="/daily" element={<Daily />} /> <Route path="/admin" element={<AdminLayout />}>
<Route path="/posts/:articleId" element={<Article />} /> <Route path="posts" element={<PostsManagement />} />
<Route path="categories" element={<CategoriesManagement />} />
<Route path="users" element={<UsersManagement />} />
<Route path="contributors" element={<ContributorsManagement />} />
</Route>
{/* Public routes */}
<Route
path="/"
element={
<>
<Header />
{/* 页眉分隔线 */}
<div className="w-[95%] mx-auto">
<div className="border-t-2 border-gray-900 dark:border-gray-100 w-full mb-2" />
</div>
<main className="flex-1 w-[95%] mx-auto py-8">
<Routes>
<Route index element={<Home />} />
<Route path="/daily" element={<Daily />} />
<Route path="/posts/:articleId" element={<Article />} />
</Routes>
</main>
<Footer />
</>
}
/>
</Routes> </Routes>
</Suspense> </Suspense>
</main> </div>
</BrowserRouter>
<Footer /> </AuthProvider>
</div>
</BrowserRouter>
); );
} }

View file

@ -0,0 +1,132 @@
import { ReactNode } from 'react';
import { useTranslation } from 'react-i18next';
interface Column<T> {
key: keyof T;
title: string;
render?: (value: T[keyof T], item: T) => ReactNode;
}
interface TableProps<T> {
columns: Column<T>[];
data: T[];
loading?: boolean;
onEdit?: (item: T) => void;
onDelete?: (item: T) => void;
}
const Table = <T extends Record<string, any>>({
columns,
data,
loading = false,
onEdit,
onDelete,
}: TableProps<T>) => {
const { t } = useTranslation();
if (loading) {
return (
<div className="p-6">
{/* Loading header */}
<div className="flex border-b border-slate-200 dark:border-slate-700 pb-4">
{columns.map((_, index) => (
<div
key={index}
className="flex-1 h-6 bg-slate-200 dark:bg-slate-700 rounded animate-pulse"
style={{ marginRight: index !== columns.length - 1 ? '2rem' : 0 }}
/>
))}
</div>
{/* Loading rows */}
<div className="space-y-4 pt-4">
{[1, 2, 3].map((row) => (
<div key={row} className="flex items-center">
{columns.map((_, index) => (
<div
key={index}
className="flex-1"
style={{ marginRight: index !== columns.length - 1 ? '2rem' : 0 }}
>
<div className="h-4 bg-slate-200 dark:bg-slate-700 rounded animate-pulse"
style={{ width: index === columns.length - 1 ? '30%' : '60%' }}
/>
</div>
))}
</div>
))}
</div>
</div>
);
}
return (
<div className="overflow-hidden">
<table className="min-w-full">
<thead>
<tr className="border-b border-slate-200 dark:border-slate-700">
{columns.map((column) => (
<th
key={String(column.key)}
scope="col"
className="px-6 py-4 text-left text-sm font-bold text-slate-900 dark:text-white uppercase tracking-wider"
>
{column.title}
</th>
))}
{(onEdit || onDelete) && (
<th scope="col" className="relative px-6 py-4">
<span className="sr-only">{t('admin.common.actions')}</span>
</th>
)}
</tr>
</thead>
<tbody className="divide-y divide-slate-200 dark:divide-slate-700">
{data.map((item, itemIdx) => (
<tr
key={itemIdx}
className="transition-colors hover:bg-slate-50 dark:hover:bg-slate-800/50"
>
{columns.map((column) => (
<td
key={String(column.key)}
className="px-6 py-4 text-sm text-slate-800 dark:text-white"
>
{column.render
? column.render(item[column.key], item)
: String(item[column.key])}
</td>
))}
{(onEdit || onDelete) && (
<td className="px-6 py-4 text-right text-sm space-x-2">
{onEdit && (
<button
onClick={() => onEdit(item)}
className="text-slate-600 dark:text-white hover:text-slate-900 dark:hover:text-slate-200 px-2 py-1 rounded-md transition-colors"
>
{t('admin.common.edit')}
</button>
)}
{onDelete && (
<button
onClick={() => onDelete(item)}
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
>
{t('admin.common.delete')}
</button>
)}
</td>
)}
</tr>
))}
</tbody>
</table>
{data.length === 0 && !loading && (
<div className="text-center py-8">
<p className="text-slate-500 dark:text-white">{t('admin.common.noData')}</p>
</div>
)}
</div>
);
};
export default Table;

View file

@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
import { FiSun, FiMoon, FiSearch, FiGlobe, FiMonitor } from 'react-icons/fi'; import { FiSun, FiMoon, FiSearch, FiGlobe, FiMonitor } from 'react-icons/fi';
import { Menu } from '@headlessui/react'; import { Menu } from '@headlessui/react';
import { useTheme } from '../../hooks/useTheme'; import { useTheme } from '../../hooks/useTheme';
import type { Theme } from '../../hooks/useTheme';
const LANGUAGES = [ const LANGUAGES = [
{ code: 'en', nativeName: 'English' }, { code: 'en', nativeName: 'English' },

View file

@ -0,0 +1,82 @@
import { createContext, useContext, useState, useCallback, ReactNode } from 'react';
import { AuthContextType, AuthState, User } from '../types/auth';
const API_URL = import.meta.env.VITE_API_URL || '/api';
const initialState: AuthState = {
user: null,
token: null,
loading: false,
error: null,
};
const AuthContext = createContext<AuthContextType | null>(null);
export const useAuth = () => {
const context = useContext(AuthContext);
if (!context) {
throw new Error('useAuth must be used within an AuthProvider');
}
return context;
};
interface AuthProviderProps {
children: ReactNode;
}
export const AuthProvider = ({ children }: AuthProviderProps) => {
const [state, setState] = useState<AuthState>(initialState);
const login = useCallback(async (email: string, password: string) => {
setState(prev => ({ ...prev, loading: true, error: null }));
try {
const response = await fetch(`${API_URL}/auth/login`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ email, password }),
});
if (!response.ok) {
throw new Error('登录失败');
}
const data = await response.json();
setState(prev => ({
...prev,
user: data.user,
token: data.token,
loading: false,
}));
localStorage.setItem('token', data.token);
} catch (error) {
setState(prev => ({
...prev,
loading: false,
error: error instanceof Error ? error.message : '登录失败',
}));
}
}, []);
const logout = useCallback(() => {
localStorage.removeItem('token');
setState(initialState);
}, []);
const clearError = useCallback(() => {
setState(prev => ({ ...prev, error: null }));
}, []);
const value = {
...state,
login,
logout,
clearError,
};
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
};
export default AuthContext;

View file

@ -1,6 +1,6 @@
import { useState, useEffect } from 'react'; import { useState, useEffect } from 'react';
type Theme = 'light' | 'dark' | 'system'; export type Theme = 'light' | 'dark' | 'system';
export function useTheme() { export function useTheme() {
const [theme, setTheme] = useState<Theme>( const [theme, setTheme] = useState<Theme>(

View file

@ -1,11 +1,27 @@
import i18n from 'i18next'; import i18n from 'i18next';
import { initReactI18next } from 'react-i18next'; import { initReactI18next } from 'react-i18next';
import LanguageDetector from 'i18next-browser-languagedetector';
import en from '../data/i18n/en.json'; import en from '../data/i18n/en.json';
import zhHans from '../data/i18n/zh-Hans.json'; import zhHans from '../data/i18n/zh-Hans.json';
import zhHant from '../data/i18n/zh-Hant.json'; import zhHant from '../data/i18n/zh-Hant.json';
// 获取浏览器默认语言
const getBrowserLanguage = () => {
const lang = navigator.language;
if (lang.startsWith('zh')) {
return lang === 'zh-TW' || lang === 'zh-HK' ? 'zh-Hant' : 'zh-Hans';
}
return 'en';
};
// 获取存储的语言设置
const getStoredLanguage = () => {
return localStorage.getItem('language') || getBrowserLanguage();
};
i18n i18n
.use(LanguageDetector)
.use(initReactI18next) .use(initReactI18next)
.init({ .init({
resources: { resources: {
@ -13,11 +29,21 @@ i18n
'zh-Hans': { translation: zhHans }, 'zh-Hans': { translation: zhHans },
'zh-Hant': { translation: zhHant }, 'zh-Hant': { translation: zhHant },
}, },
lng: 'zh-Hans', lng: getStoredLanguage(),
fallbackLng: 'en', fallbackLng: 'en',
detection: {
order: ['localStorage', 'navigator'],
lookupLocalStorage: 'language',
caches: ['localStorage'],
},
interpolation: { interpolation: {
escapeValue: false, escapeValue: false,
}, },
}); });
// 监听语言变化并保存到 localStorage
i18n.on('languageChanged', (lng) => {
localStorage.setItem('language', lng);
});
export default i18n; export default i18n;

View file

@ -0,0 +1,129 @@
import { FC, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
import Table from '../../../components/admin/Table';
interface Category {
id: string;
name: string;
description: string;
}
const CategoriesManagement: FC = () => {
const [searchTerm, setSearchTerm] = useState('');
const [loading, setLoading] = useState(false);
const { t } = useTranslation();
// 这里后续会通过 API 获取数据
const categories: Category[] = [
{
id: '1',
name: '新闻',
description: '新闻分类',
},
];
const handleEdit = (category: Category) => {
console.log('Edit category:', category);
};
const handleDelete = (category: Category) => {
console.log('Delete category:', category);
};
const columns = [
{
key: 'name',
title: t('admin.common.name'),
},
{
key: 'description',
title: t('admin.common.description'),
},
];
return (
<div className="divide-y divide-stone-200 dark:divide-stone-700">
<div className="p-6">
<div className="flex items-center justify-between">
<div className="relative">
<input
type="text"
placeholder={t('admin.common.search')}
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="pl-10 pr-4 py-2 border-b-2 border-slate-200 dark:border-slate-700 focus:border-slate-900 dark:focus:border-slate-500 bg-transparent outline-none w-64 text-slate-800 dark:text-white placeholder-slate-400 dark:placeholder-slate-500"
/>
<RiSearchLine className="absolute left-0 top-1/2 -translate-y-1/2 text-slate-400 dark:text-slate-500 text-xl" />
</div>
<button className="flex items-center gap-2 px-6 py-2 bg-blue-600 hover:bg-blue-700 dark:bg-blue-500 dark:hover:bg-blue-600 text-white transition-colors rounded-md">
<RiAddLine className="text-xl" />
<span>{t('admin.common.create')}</span>
</button>
</div>
</div>
<div>
<Table
columns={columns}
data={categories}
loading={loading}
onEdit={handleEdit}
onDelete={handleDelete}
>
{({ data, onEdit, onDelete }) => (
<table className="w-full">
<thead>
<tr className="border-b border-stone-200 dark:border-stone-700">
{columns.map((column, index) => (
<th
key={column.key}
className={`text-left py-4 px-6 text-stone-600 dark:text-stone-400 font-medium ${
index === 0 ? 'rounded-tl-md' : ''
} ${index === columns.length - 1 ? 'rounded-tr-md' : ''}`}
>
{column.title}
</th>
))}
<th className="text-right py-4 px-6 text-stone-600 dark:text-stone-400 font-medium rounded-tr-md">{t('admin.common.actions')}</th>
</tr>
</thead>
<tbody>
{data.map((category) => (
<tr key={category.id} className="border-b border-stone-200 dark:border-stone-700 last:border-0">
{columns.map((column) => (
<td
key={column.key}
className={`py-4 px-6 ${
column.key === 'description' ? 'text-left' : 'text-left'
} text-stone-800 dark:text-stone-200`}
>
{category[column.key]}
</td>
))}
<td className="py-4 px-6 text-right">
<button
className="text-stone-600 dark:text-stone-400 hover:text-stone-900 dark:hover:text-stone-200 px-2 py-1 rounded-md transition-colors"
onClick={() => onEdit(category)}
>
{t('admin.common.edit')}
</button>
<button
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
onClick={() => onDelete(category)}
>
{t('admin.common.delete')}
</button>
</td>
</tr>
))}
</tbody>
</table>
)}
</Table>
</div>
</div>
);
};
export default CategoriesManagement;

View file

@ -0,0 +1,138 @@
import { FC, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
import Table from '../../../components/admin/Table';
interface Contributor {
id: string;
name: string;
bio: string;
articles: number;
joinDate: string;
}
const ContributorsManagement: FC = () => {
const [searchTerm, setSearchTerm] = useState('');
const [loading, setLoading] = useState(false);
const { t } = useTranslation();
// 这里后续会通过 API 获取数据
const contributors: Contributor[] = [
{
id: '1',
name: '李四',
bio: '这是李四的简介',
articles: 5,
joinDate: '2024-02-20',
},
];
const handleEdit = (contributor: Contributor) => {
console.log('Edit contributor:', contributor);
};
const handleDelete = (contributor: Contributor) => {
console.log('Delete contributor:', contributor);
};
const columns = [
{
key: 'name',
title: t('admin.common.name'),
},
{
key: 'bio',
title: t('admin.common.bio'),
},
{
key: 'articles',
title: t('admin.common.articles'),
render: (value: number) => (
<span className="px-3 py-1 bg-stone-100 dark:bg-stone-700 text-stone-600 dark:text-stone-300 rounded-md text-sm">
{value}
</span>
),
},
{
key: 'joinDate',
title: t('admin.common.joinDate'),
},
];
return (
<div className="divide-y divide-stone-200 dark:divide-stone-700">
<div className="p-6">
<div className="flex items-center justify-between">
<div className="relative">
<input
type="text"
placeholder={t('admin.common.search')}
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="pl-10 pr-4 py-2 border-b-2 border-slate-200 dark:border-slate-700 focus:border-slate-900 dark:focus:border-slate-500 bg-transparent outline-none w-64 text-slate-800 dark:text-white placeholder-slate-400 dark:placeholder-slate-500"
/>
<RiSearchLine className="absolute left-0 top-1/2 -translate-y-1/2 text-slate-400 dark:text-slate-500 text-xl" />
</div>
<button className="flex items-center gap-2 px-6 py-2 bg-blue-600 hover:bg-blue-700 dark:bg-blue-500 dark:hover:bg-blue-600 text-white transition-colors rounded-md">
<RiAddLine className="text-xl" />
<span>{t('admin.common.create')}</span>
</button>
</div>
</div>
<div>
<Table
columns={columns}
data={contributors}
loading={loading}
onEdit={handleEdit}
onDelete={handleDelete}
>
{({ data, onEdit, onDelete }) => (
<table className="w-full">
<thead>
<tr className="border-b border-stone-200 dark:border-stone-700">
<th className="text-left py-4 px-6 text-stone-600 dark:text-stone-400 font-medium rounded-tl-md">{t('admin.common.name')}</th>
<th className="text-left py-4 px-6 text-stone-600 dark:text-stone-400 font-medium">{t('admin.common.bio')}</th>
<th className="text-left py-4 px-6 text-stone-600 dark:text-stone-400 font-medium">{t('admin.common.articles')}</th>
<th className="text-left py-4 px-6 text-stone-600 dark:text-stone-400 font-medium">{t('admin.common.joinDate')}</th>
<th className="text-right py-4 px-6 text-stone-600 dark:text-stone-400 font-medium rounded-tr-md">{t('admin.common.actions')}</th>
</tr>
</thead>
<tbody>
{data.map((contributor) => (
<tr key={contributor.id} className="border-b border-stone-200 dark:border-stone-700 last:border-0">
<td className="py-4 px-6 text-stone-800 dark:text-stone-200">{contributor.name}</td>
<td className="py-4 px-6 text-stone-800 dark:text-stone-200">{contributor.bio}</td>
<td className="py-4 px-6">
<span className="px-3 py-1 bg-stone-100 dark:bg-stone-700 text-stone-600 dark:text-stone-300 rounded-md text-sm">
{contributor.articles}
</span>
</td>
<td className="py-4 px-6 text-stone-800 dark:text-stone-200">{contributor.joinDate}</td>
<td className="py-4 px-6 text-right">
<button
className="text-stone-600 dark:text-stone-400 hover:text-stone-900 dark:hover:text-stone-200 px-2 py-1 rounded-md transition-colors"
onClick={() => onEdit(contributor)}
>
{t('admin.common.edit')}
</button>
<button
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
onClick={() => onDelete(contributor)}
>
{t('admin.common.delete')}
</button>
</td>
</tr>
))}
</tbody>
</table>
)}
</Table>
</div>
</div>
);
};
export default ContributorsManagement;

View file

@ -0,0 +1,168 @@
import { FC } from 'react';
import { Link, Outlet, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import {
RiFileTextLine,
RiFolderLine,
RiUserLine,
RiTeamLine,
RiLogoutBoxRLine,
RiSunLine,
RiMoonLine,
RiComputerLine,
RiGlobalLine
} from 'react-icons/ri';
import { useTheme } from '../../../hooks/useTheme';
import type { Theme } from '../../../hooks/useTheme';
interface AdminLayoutProps {}
const menuItems = [
{ path: '/admin/posts', icon: RiFileTextLine, label: 'admin.nav.posts' },
{ path: '/admin/categories', icon: RiFolderLine, label: 'admin.nav.categories' },
{ path: '/admin/users', icon: RiUserLine, label: 'admin.nav.users' },
{ path: '/admin/contributors', icon: RiTeamLine, label: 'admin.nav.contributors' },
];
const themeOptions = [
{ value: 'light' as const, icon: RiSunLine, label: 'theme.light' },
{ value: 'dark' as const, icon: RiMoonLine, label: 'theme.dark' },
{ value: 'system' as const, icon: RiComputerLine, label: 'theme.system' }
];
const languageOptions = [
{ value: 'en', label: 'English' },
{ value: 'zh-Hans', label: '简体中文' },
{ value: 'zh-Hant', label: '繁體中文' }
];
type LanguageMap = {
'en': 'zh-Hans';
'zh-Hans': 'zh-Hant';
'zh-Hant': 'en';
};
const languageMap: LanguageMap = {
'en': 'zh-Hans',
'zh-Hans': 'zh-Hant',
'zh-Hant': 'en'
};
const AdminLayout: FC<AdminLayoutProps> = () => {
const location = useLocation();
const { t, i18n } = useTranslation();
const { theme, setTheme } = useTheme();
return (
<div className="min-h-screen bg-slate-100 dark:bg-slate-900 py-6 flex">
{/* Background Overlay */}
<div className="fixed inset-0 bg-gradient-to-br from-slate-200 to-slate-300 dark:from-slate-800 dark:to-slate-900 backdrop-blur-xl -z-10" />
<div className="w-full max-w-[98%] mx-auto flex gap-4">
{/* Sidebar */}
<aside className="w-64 bg-white/80 dark:bg-slate-800/80 backdrop-blur-lg rounded-lg shadow-lg flex flex-col">
<div className="h-16 px-6 border-b border-slate-200/80 dark:border-slate-700/80 flex items-center justify-center">
<h1 className="text-2xl font-bold text-slate-800 dark:text-white tracking-tight">
TSS Rocks
</h1>
</div>
<nav className="flex-1 p-4">
{menuItems.map((item) => {
const Icon = item.icon;
const isActive = location.pathname === item.path;
return (
<Link
key={item.path}
to={item.path}
className={`flex items-center gap-3 px-4 py-3 mb-2 rounded-lg transition-colors ${
isActive
? 'bg-slate-900 dark:bg-slate-700 text-white'
: 'text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-slate-700/50'
}`}
>
<Icon className="text-xl" />
<span className="tracking-wide">{t(item.label)}</span>
</Link>
);
})}
</nav>
<div className="p-4 space-y-2">
{/* Language and Logout Buttons */}
<div className="flex gap-2">
<button
onClick={() => {
const currentLang = i18n.language as keyof LanguageMap;
const nextLang = languageMap[currentLang] || 'en';
i18n.changeLanguage(nextLang);
}}
className="flex-1 flex items-center justify-center gap-2 px-3 py-2 rounded-md text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-slate-700/50 transition-colors whitespace-nowrap min-w-[100px]"
title={t('admin.common.switchLanguage')}
>
<RiGlobalLine className="text-xl flex-shrink-0" />
<span className="text-sm truncate">{languageOptions.find(lang => lang.value === i18n.language)?.label || 'English'}</span>
</button>
<button
className="flex-1 flex items-center justify-center gap-2 px-3 py-2 rounded-md text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-slate-700/50 transition-colors whitespace-nowrap"
onClick={() => {/* TODO: Implement logout */}}
>
<RiLogoutBoxRLine className="text-xl flex-shrink-0" />
<span className="text-sm">{t('admin.common.logout')}</span>
</button>
</div>
{/* Theme Buttons */}
<div className="border-t border-slate-200/80 dark:border-slate-700/80 pt-2">
<div className="flex items-center gap-1">
{themeOptions.map((option) => {
const Icon = option.icon;
return (
<button
key={option.value}
onClick={() => setTheme(option.value)}
className={`flex-1 p-2 rounded-md transition-colors ${
theme === option.value
? 'bg-slate-900 dark:bg-slate-700 text-white'
: 'text-slate-600 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-slate-700/50'
}`}
title={t(option.label)}
>
<Icon className="text-xl mx-auto" />
</button>
);
})}
</div>
</div>
</div>
</aside>
{/* Main Content */}
<main className="flex-1 flex flex-col rounded-lg overflow-hidden bg-white/80 dark:bg-slate-800/80 backdrop-blur-lg shadow-lg">
<header className="border-b border-slate-200/80 dark:border-slate-700/80">
<div className="h-16 px-8 flex items-center justify-between">
<div>
<h2 className="text-2xl font-bold text-slate-800 dark:text-white">
{t(menuItems.find(item => item.path === location.pathname)?.label || 'admin.nav.dashboard')}
</h2>
</div>
<div className="flex items-center gap-3">
<div className="w-10 h-10 bg-slate-200 dark:bg-slate-700 rounded-full flex items-center justify-center text-slate-600 dark:text-slate-300">
<span className="text-lg">A</span>
</div>
<div>
<div className="text-slate-800 dark:text-white"></div>
<div className="text-sm text-slate-500 dark:text-slate-400">Administrator</div>
</div>
</div>
</div>
</header>
<div className="flex-1 p-6">
<div className="h-full bg-white dark:bg-slate-800 rounded-lg shadow-sm border border-slate-200/60 dark:border-slate-700/60">
<Outlet />
</div>
</div>
</main>
</div>
</div>
);
};
export default AdminLayout;

View file

@ -0,0 +1,148 @@
import { FC, useState } from 'react';
import { useTranslation } from 'react-i18next';
import Table from '../../../components/admin/Table';
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
interface Post {
id: string;
title: string;
author: string;
status: 'draft' | 'published';
publishDate: string;
}
const PostsManagement: FC = () => {
const [searchTerm, setSearchTerm] = useState('');
const [loading, setLoading] = useState(false);
const { t } = useTranslation();
// 这里后续会通过 API 获取数据
const posts: Post[] = [
{
id: '1',
title: '示例文章标题',
author: '张三',
status: 'published',
publishDate: '2024-02-20',
},
];
const handleEdit = (post: Post) => {
console.log('Edit post:', post);
};
const handleDelete = (post: Post) => {
console.log('Delete post:', post);
};
const columns = [
{
key: 'title',
title: t('admin.common.title'),
},
{
key: 'author',
title: t('admin.common.author'),
},
{
key: 'status',
title: t('admin.common.status'),
render: (value: Post['status']) => (
<span
className={`inline-block px-3 py-1 text-xs rounded-md ${
value === 'published'
? 'bg-slate-900 dark:bg-slate-700 text-white'
: 'bg-slate-200 dark:bg-slate-600 text-slate-700 dark:text-slate-200'
}`}
>
{t(`admin.common.${value}`)}
</span>
),
},
{
key: 'publishDate',
title: t('admin.common.date'),
},
];
return (
<div className="divide-y divide-slate-200 dark:divide-slate-700">
<div className="p-6">
<div className="flex items-center justify-between">
<div className="relative">
<input
type="text"
placeholder={t('admin.common.search')}
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="pl-10 pr-4 py-2 border-b-2 border-slate-200 dark:border-slate-700 focus:border-slate-900 dark:focus:border-slate-500 bg-transparent outline-none w-64 text-slate-800 dark:text-white placeholder-slate-400 dark:placeholder-slate-500"
/>
<RiSearchLine className="absolute left-0 top-1/2 -translate-y-1/2 text-slate-400 dark:text-slate-500 text-xl" />
</div>
<button className="flex items-center gap-2 px-6 py-2 bg-blue-600 hover:bg-blue-700 dark:bg-blue-500 dark:hover:bg-blue-600 text-white transition-colors rounded-md">
<RiAddLine className="text-xl" />
<span>{t('admin.common.create')}</span>
</button>
</div>
</div>
<div>
<Table
columns={columns}
data={posts}
loading={loading}
onEdit={handleEdit}
onDelete={handleDelete}
>
{({ data, onEdit, onDelete }) => (
<table className="w-full">
<thead>
<tr className="border-b border-slate-200 dark:border-slate-700">
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tl-md">{t('admin.common.title')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.author')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.status')}</th>
<th className="text-left py-4 px-6 text-slate-600 dark:text-slate-400 font-medium">{t('admin.common.date')}</th>
<th className="text-right py-4 px-6 text-slate-600 dark:text-slate-400 font-medium rounded-tr-md">{t('admin.common.actions')}</th>
</tr>
</thead>
<tbody>
{data.map((post) => (
<tr key={post.id} className="border-b border-slate-200 dark:border-slate-700 last:border-0">
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.title}</td>
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.author}</td>
<td className="py-4 px-6">
<span className={`px-3 py-1 text-sm rounded-md ${
post.status === 'published'
? 'bg-slate-900 dark:bg-slate-700 text-white'
: 'bg-slate-200 dark:bg-slate-600 text-slate-700 dark:text-slate-200'
}`}>
{t(`admin.common.${post.status}`)}
</span>
</td>
<td className="py-4 px-6 text-slate-800 dark:text-slate-200">{post.publishDate}</td>
<td className="py-4 px-6 text-right">
<button
className="text-slate-600 dark:text-slate-400 hover:text-slate-900 dark:hover:text-slate-200 px-2 py-1 rounded-md transition-colors"
onClick={() => onEdit(post)}
>
{t('admin.common.edit')}
</button>
<button
className="text-red-600 dark:text-red-400 hover:text-red-700 dark:hover:text-red-300 px-2 py-1 rounded-md transition-colors"
onClick={() => onDelete(post)}
>
{t('admin.common.delete')}
</button>
</td>
</tr>
))}
</tbody>
</table>
)}
</Table>
</div>
</div>
);
};
export default PostsManagement;

View file

@ -0,0 +1,100 @@
import { FC, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { RiAddLine, RiSearchLine } from 'react-icons/ri';
import Table from '../../../components/admin/Table';
interface User {
id: string;
username: string;
email: string;
role: 'admin' | 'user';
lastLogin: string;
}
const UsersManagement: FC = () => {
const [searchTerm, setSearchTerm] = useState('');
const [loading, setLoading] = useState(true);
const { t } = useTranslation();
// 这里后续会通过 API 获取数据
const users: User[] = [
{
id: '1',
username: '张三',
email: 'zhangsan@example.com',
role: 'admin',
lastLogin: '2024-02-20',
},
];
const handleEdit = (user: User) => {
console.log('Edit user:', user);
};
const handleDelete = (user: User) => {
console.log('Delete user:', user);
};
const columns = [
{
key: 'username',
title: t('admin.common.username'),
},
{
key: 'role',
title: t('admin.common.role'),
render: (value: User['role']) => (
<span
className={`inline-block px-3 py-1 text-xs rounded-md ${
value === 'admin'
? 'bg-slate-900 dark:bg-slate-700 text-white'
: 'bg-slate-200 dark:bg-slate-600 text-slate-700 dark:text-slate-200'
}`}
>
{t(`admin.roles.${value}`)}
</span>
),
},
{
key: 'joinDate',
title: t('admin.common.joinDate'),
},
{
key: 'lastLogin',
title: t('admin.common.lastLogin'),
},
];
return (
<div className="divide-y divide-slate-200 dark:divide-slate-700">
<div className="p-6">
<div className="flex items-center justify-between">
<div className="relative">
<input
type="text"
placeholder={t('admin.common.search')}
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="pl-10 pr-4 py-2 border-b-2 border-slate-200 dark:border-slate-700 focus:border-slate-900 dark:focus:border-slate-500 bg-transparent outline-none w-64 text-slate-800 dark:text-white placeholder-slate-400 dark:placeholder-slate-500"
/>
<RiSearchLine className="absolute left-0 top-1/2 -translate-y-1/2 text-slate-400 dark:text-slate-500 text-xl" />
</div>
<button className="flex items-center gap-2 px-6 py-2 bg-blue-600 hover:bg-blue-700 dark:bg-blue-500 dark:hover:bg-blue-600 text-white transition-colors rounded-md">
<RiAddLine className="text-xl" />
<span>{t('admin.common.create')}</span>
</button>
</div>
</div>
<Table
columns={columns}
data={users}
loading={loading}
onEdit={handleEdit}
onDelete={handleDelete}
/>
</div>
);
};
export default UsersManagement;

13
frontend/src/router.tsx Normal file
View file

@ -0,0 +1,13 @@
import { createBrowserRouter } from 'react-router-dom';
import AdminLayout from './pages/admin/layout/AdminLayout';
import Dashboard from './pages/admin/Dashboard';
const router = createBrowserRouter([
{
path: '/admin',
element: <AdminLayout><Dashboard /></AdminLayout>,
},
// Add more routes here as we develop them
]);
export default router;

View file

@ -0,0 +1,19 @@
export interface User {
id: string;
username: string;
email: string;
role: 'admin' | 'contributor' | 'user';
}
export interface AuthState {
user: User | null;
token: string | null;
loading: boolean;
error: string | null;
}
export interface AuthContextType extends AuthState {
login: (email: string, password: string) => Promise<void>;
logout: () => void;
clearError: () => void;
}

13
pnpm-lock.yaml generated
View file

@ -28,6 +28,9 @@ importers:
i18next: i18next:
specifier: ^24.2.2 specifier: ^24.2.2
version: 24.2.2(typescript@5.7.3) version: 24.2.2(typescript@5.7.3)
i18next-browser-languagedetector:
specifier: ^8.0.4
version: 8.0.4
lucide-react: lucide-react:
specifier: ^0.474.0 specifier: ^0.474.0
version: 0.474.0(react@19.0.0) version: 0.474.0(react@19.0.0)
@ -62,6 +65,9 @@ importers:
'@tailwindcss/vite': '@tailwindcss/vite':
specifier: ^4.0.3 specifier: ^4.0.3
version: 4.0.7(vite@6.1.1(@types/node@22.13.4)(jiti@2.4.2)(lightningcss@1.29.1)) version: 4.0.7(vite@6.1.1(@types/node@22.13.4)(jiti@2.4.2)(lightningcss@1.29.1))
'@types/node':
specifier: ^22.13.4
version: 22.13.4
'@types/react': '@types/react':
specifier: ^19.0.8 specifier: ^19.0.8
version: 19.0.10 version: 19.0.10
@ -1358,6 +1364,9 @@ packages:
resolution: {integrity: sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==} resolution: {integrity: sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==}
engines: {node: '>= 14'} engines: {node: '>= 14'}
i18next-browser-languagedetector@8.0.4:
resolution: {integrity: sha512-f3frU3pIxD50/Tz20zx9TD9HobKYg47fmAETb117GKGPrhwcSSPJDoCposXlVycVebQ9GQohC3Efbpq7/nnJ5w==}
i18next@24.2.2: i18next@24.2.2:
resolution: {integrity: sha512-NE6i86lBCKRYZa5TaUDkU5S4HFgLIEJRLr3Whf2psgaxBleQ2LC1YW1Vc+SCgkAW7VEzndT6al6+CzegSUHcTQ==} resolution: {integrity: sha512-NE6i86lBCKRYZa5TaUDkU5S4HFgLIEJRLr3Whf2psgaxBleQ2LC1YW1Vc+SCgkAW7VEzndT6al6+CzegSUHcTQ==}
peerDependencies: peerDependencies:
@ -3466,6 +3475,10 @@ snapshots:
transitivePeerDependencies: transitivePeerDependencies:
- supports-color - supports-color
i18next-browser-languagedetector@8.0.4:
dependencies:
'@babel/runtime': 7.26.9
i18next@24.2.2(typescript@5.7.3): i18next@24.2.2(typescript@5.7.3):
dependencies: dependencies:
'@babel/runtime': 7.26.9 '@babel/runtime': 7.26.9