diff --git a/api/schemas/components/schemas.yaml b/api/schemas/components/schemas.yaml index dba9082..056ef91 100644 --- a/api/schemas/components/schemas.yaml +++ b/api/schemas/components/schemas.yaml @@ -60,12 +60,20 @@ User: type: object required: - id - - email + - username - role - status properties: id: type: integer + username: + type: string + minLength: 3 + maxLength: 32 + display_name: + type: string + maxLength: 64 + description: 用户显示名称 email: type: string format: email @@ -74,11 +82,13 @@ User: enum: - admin - editor + - contributor status: type: string enum: - active - inactive + - banned created_at: type: string format: date-time diff --git a/api/schemas/openapi.yaml b/api/schemas/openapi.yaml index 693fe54..dfcbd21 100644 --- a/api/schemas/openapi.yaml +++ b/api/schemas/openapi.yaml @@ -70,6 +70,8 @@ components: Daily: $ref: './components/schemas.yaml#/Daily' paths: + /auth/register: + $ref: './paths/auth.yaml#/register' /auth/login: $ref: './paths/auth.yaml#/login' /auth/logout: diff --git a/api/schemas/paths/auth.yaml b/api/schemas/paths/auth.yaml index cbe67cb..cc963e1 100644 --- a/api/schemas/paths/auth.yaml +++ b/api/schemas/paths/auth.yaml @@ -1,3 +1,63 @@ +register: + post: + tags: + - auth + summary: 用户注册 + operationId: register + security: [] # 注册接口不需要认证 + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - username + - email + - password + - role + properties: + username: + type: string + minLength: 3 + maxLength: 32 + email: + type: string + format: email + password: + type: string + format: password + minLength: 8 + role: + type: string + enum: + - admin + - editor + - contributor + responses: + '200': + description: 注册成功 + content: + application/json: + schema: + type: object + required: + - token + - user + properties: + token: + type: string + user: + $ref: '../components/schemas.yaml#/User' + '400': + description: 用户名已存在 + content: + application/json: + schema: + $ref: '../components/schemas.yaml#/Error' + '422': + $ref: '../components/responses.yaml#/ValidationError' + login: post: tags: @@ -12,12 +72,13 @@ login: schema: type: object required: - - email + - username - password properties: - email: + username: type: string - format: email + minLength: 3 + maxLength: 32 password: type: string format: password diff --git a/backend/.gitignore b/backend/.gitignore index 1193d52..920e96e 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -12,6 +12,9 @@ vendor/ # Build output *.exe +# Config +config/config.yaml + # Database files *.db *.db-journal diff --git a/backend/Dockerfile b/backend/Dockerfile index 59d675d..c89062d 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -5,7 +5,7 @@ RUN apk add --no-cache gcc musl-dev libwebp-dev COPY go.mod go.sum ./ RUN go mod download COPY . . -RUN go build -o tss-rocks-be ./cmd/server +RUN go build -o tss-rocks-be FROM alpine:latest @@ -13,9 +13,13 @@ RUN apk add --no-cache libwebp RUN adduser -u 1000 -D tss-rocks USER tss-rocks WORKDIR /app + +# 复制二进制文件和配置 COPY --from=builder /app/tss-rocks-be . +COPY --from=builder /app/config/config.yaml ./config/ EXPOSE 8080 ENV GIN_MODE=release -CMD ["./tss-rocks-be"] +# 启动服务器 +CMD ["./tss-rocks-be", "server"] diff --git a/backend/cmd/server.go b/backend/cmd/server.go new file mode 100644 index 0000000..050a061 --- /dev/null +++ b/backend/cmd/server.go @@ -0,0 +1,89 @@ +package cmd + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "tss-rocks-be/internal/config" + "tss-rocks-be/internal/server" + "tss-rocks-be/pkg/logger" +) + +func GetServerCmd() *cobra.Command { + var configPath string + + cmd := &cobra.Command{ + Use: "server", + Short: "Start the server", + Long: `Start the server with the specified configuration.`, + RunE: func(cmd *cobra.Command, args []string) error { + // Load configuration + cfg, err := config.Load(configPath) + if err != nil { + return err + } + + // Setup logger + logger.Setup(cfg) + + // Create a context that can be cancelled + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create ent client + client := server.NewEntClient(cfg) + if client == nil { + log.Fatal().Msg("Failed to create database client") + } + defer client.Close() + + // Run the auto migration tool + if err := client.Schema.Create(ctx); err != nil { + log.Fatal().Err(err).Msg("Failed to create schema resources") + } + + // Create and start server + srv, err := server.New(cfg, client) + if err != nil { + log.Fatal().Err(err).Msg("Failed to create server") + } + + // Start server in a goroutine + go func() { + if err := srv.Start(); err != nil { + log.Fatal().Err(err).Msg("Failed to start server") + } + }() + + // Wait for interrupt signal + sig := <-sigChan + log.Info().Msgf("Received signal: %v", sig) + + // Attempt graceful shutdown with timeout + log.Info().Msg("Shutting down server...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Error().Err(err).Msg("Server forced to shutdown") + os.Exit(1) + } + + return nil + }, + } + + // Add flags + cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file") + + return cmd +} diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go deleted file mode 100644 index f9050cf..0000000 --- a/backend/cmd/server/main.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "os" - "os/signal" - "syscall" - "time" - - "github.com/rs/zerolog/log" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/server" - "tss-rocks-be/pkg/logger" -) - -func main() { - // Parse command line flags - configPath := flag.String("config", "config/config.yaml", "path to config file") - flag.Parse() - - // Load configuration - cfg, err := config.Load(*configPath) - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - // Setup logger - logger.Setup(cfg) - - // Create a context that can be cancelled - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Set up signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // Create ent client - client := server.NewEntClient(cfg) - if client == nil { - log.Fatal().Msg("Failed to create database client") - } - defer client.Close() - - // Run the auto migration tool - if err := client.Schema.Create(ctx); err != nil { - log.Fatal().Err(err).Msg("Failed to create schema resources") - } - - // Create and start server - srv, err := server.New(cfg, client) - if err != nil { - log.Fatal().Err(err).Msg("Failed to create server") - } - - // Start server in a goroutine - go func() { - if err := srv.Start(); err != nil { - log.Fatal().Err(err).Msg("Failed to start server") - } - }() - - // Wait for interrupt signal - sig := <-sigChan - log.Info().Msgf("Received signal: %v", sig) - - // Attempt graceful shutdown with timeout - log.Info().Msg("Shutting down server...") - shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := srv.Shutdown(shutdownCtx); err != nil { - log.Error().Err(err).Msg("Server forced to shutdown") - os.Exit(1) - } -} diff --git a/backend/cmd/server/main_test.go b/backend/cmd/server/main_test.go deleted file mode 100644 index 35eceef..0000000 --- a/backend/cmd/server/main_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package main - -import ( - "context" - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "tss-rocks-be/internal/config" - "tss-rocks-be/internal/server" -) - -func TestConfigLoad(t *testing.T) { - // Create a temporary config file for testing - tmpConfig := ` -database: - driver: sqlite3 - dsn: ":memory:" -server: - port: 8080 - host: localhost -storage: - type: local - local: - root_dir: "./testdata" -` - tmpFile, err := os.CreateTemp("", "config-*.yaml") - require.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - _, err = tmpFile.WriteString(tmpConfig) - require.NoError(t, err) - err = tmpFile.Close() - require.NoError(t, err) - - // Test config loading - cfg, err := config.Load(tmpFile.Name()) - require.NoError(t, err) - assert.Equal(t, "sqlite3", cfg.Database.Driver) - assert.Equal(t, ":memory:", cfg.Database.DSN) - assert.Equal(t, 8080, cfg.Server.Port) - assert.Equal(t, "localhost", cfg.Server.Host) - assert.Equal(t, "local", cfg.Storage.Type) - assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir) -} - -func TestServerCreation(t *testing.T) { - cfg := &config.Config{ - Database: config.DatabaseConfig{ - Driver: "sqlite3", - DSN: ":memory:", - }, - Server: config.ServerConfig{ - Port: 8080, - Host: "localhost", - }, - Storage: config.StorageConfig{ - Type: "local", - Local: config.LocalStorage{ - RootDir: "./testdata", - }, - }, - } - - // Create ent client - client := server.NewEntClient(cfg) - require.NotNil(t, client) - defer client.Close() - - // Test schema creation - err := client.Schema.Create(context.Background()) - require.NoError(t, err) - - // Test server creation - srv, err := server.New(cfg, client) - require.NoError(t, err) - require.NotNil(t, srv) -} - -func TestServerStartAndShutdown(t *testing.T) { - cfg := &config.Config{ - Database: config.DatabaseConfig{ - Driver: "sqlite3", - DSN: ":memory:", - }, - Server: config.ServerConfig{ - Port: 0, // Use random available port - Host: "localhost", - }, - Storage: config.StorageConfig{ - Type: "local", - Local: config.LocalStorage{ - RootDir: "./testdata", - }, - }, - } - - client := server.NewEntClient(cfg) - require.NotNil(t, client) - defer client.Close() - - err := client.Schema.Create(context.Background()) - require.NoError(t, err) - - srv, err := server.New(cfg, client) - require.NoError(t, err) - - // Start server in goroutine - go func() { - err := srv.Start() - if err != nil { - t.Logf("Server stopped: %v", err) - } - }() - - // Give server time to start - time.Sleep(100 * time.Millisecond) - - // Test graceful shutdown - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - err = srv.Shutdown(ctx) - assert.NoError(t, err) -} diff --git a/backend/cmd/user.go b/backend/cmd/user.go new file mode 100644 index 0000000..f68711d --- /dev/null +++ b/backend/cmd/user.go @@ -0,0 +1,411 @@ +package cmd + +import ( + "fmt" + "os" + "text/tabwriter" + + "tss-rocks-be/ent/user" + "tss-rocks-be/ent/role" + "tss-rocks-be/internal/config" + "tss-rocks-be/internal/server" + + "github.com/spf13/cobra" + "golang.org/x/crypto/bcrypt" +) + +func GetUserCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "user", + Short: "User management commands", + Long: `Commands for managing users, including create, delete, password, and role management.`, + } + + configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file") + + cmd.AddCommand(getCreateUserCmd(*configPath)) + cmd.AddCommand(getPasswordCmd(*configPath)) + cmd.AddCommand(getListUsersCmd(*configPath)) + cmd.AddCommand(getDeleteUserCmd(*configPath)) + cmd.AddCommand(getRoleCmd(*configPath)) + + return cmd +} + +func getCreateUserCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "create", + Short: "Create a new user", + Long: `Create a new user with specified username, email and password.`, + RunE: func(cmd *cobra.Command, args []string) error { + username, _ := cmd.Flags().GetString("username") + email, _ := cmd.Flags().GetString("email") + password, _ := cmd.Flags().GetString("password") + + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 检查用户名是否已存在 + exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context()) + if err != nil { + return fmt.Errorf("error checking username: %v", err) + } + if exists { + return fmt.Errorf("username '%s' already exists", username) + } + + // 检查邮箱是否已存在 + exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context()) + if err != nil { + return fmt.Errorf("error checking email: %v", err) + } + if exists { + return fmt.Errorf("email '%s' already exists", email) + } + + // 生成密码哈希 + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("error hashing password: %v", err) + } + + // 创建用户 + u, err := client.User.Create(). + SetUsername(username). + SetEmail(email). + SetPasswordHash(string(hashedPassword)). + SetStatus("active"). + Save(cmd.Context()) + + if err != nil { + return fmt.Errorf("error creating user: %v", err) + } + + fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email) + return nil + }, + } + + // 添加命令行参数 + cmd.Flags().String("username", "", "username for the new user") + cmd.Flags().String("email", "", "email for the new user") + cmd.Flags().String("password", "", "password for the new user") + + // 设置必需的参数 + cmd.MarkFlagRequired("username") + cmd.MarkFlagRequired("email") + cmd.MarkFlagRequired("password") + + return cmd +} + +func getPasswordCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "password", + Short: "Change user password", + Long: `Change the password for an existing user.`, + RunE: func(cmd *cobra.Command, args []string) error { + username, _ := cmd.Flags().GetString("username") + password, _ := cmd.Flags().GetString("password") + + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 查找用户 + u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("user with username '%s' not found", username) + } + + // 生成新的密码哈希 + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("error hashing password: %v", err) + } + + // 更新密码 + _, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context()) + if err != nil { + return fmt.Errorf("error updating password: %v", err) + } + + fmt.Printf("Successfully updated password for user '%s'\n", username) + return nil + }, + } + + // 添加命令行参数 + cmd.Flags().String("username", "", "username of the user") + cmd.Flags().String("password", "", "new password") + + // 设置必需的参数 + cmd.MarkFlagRequired("username") + cmd.MarkFlagRequired("password") + + return cmd +} + +func getListUsersCmd(configPath string) *cobra.Command { + return &cobra.Command{ + Use: "list", + Short: "List all users", + Long: `List all users in the system.`, + RunE: func(cmd *cobra.Command, args []string) error { + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 获取所有用户 + users, err := client.User.Query().All(cmd.Context()) + if err != nil { + return fmt.Errorf("error querying users: %v", err) + } + + // 创建表格写入器 + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT") + + // 输出用户信息 + for _, u := range users { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", + u.Username, + u.Email, + u.Status, + u.CreatedAt.Format("2006-01-02 15:04:05")) + } + w.Flush() + + return nil + }, + } +} + +func getDeleteUserCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "delete", + Short: "Delete a user", + Long: `Delete an existing user from the system.`, + RunE: func(cmd *cobra.Command, args []string) error { + username, _ := cmd.Flags().GetString("username") + force, _ := cmd.Flags().GetBool("force") + + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 查找用户 + u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("user with username '%s' not found", username) + } + + // 检查用户的角色 + hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context()) + if err != nil { + return fmt.Errorf("error checking user roles: %v", err) + } + if hasRoles && !force { + return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username) + } + + // 删除用户 + err = client.User.DeleteOne(u).Exec(cmd.Context()) + if err != nil { + return fmt.Errorf("error deleting user: %v", err) + } + + fmt.Printf("Successfully deleted user '%s'\n", username) + return nil + }, + } + + // 添加命令行参数 + cmd.Flags().String("username", "", "username of the user to delete") + cmd.Flags().Bool("force", false, "force delete even if user has roles") + + // 设置必需的参数 + cmd.MarkFlagRequired("username") + + return cmd +} + +func getRoleCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "role", + Short: "User role management commands", + Long: `Commands for managing user roles, including adding and removing roles.`, + } + + cmd.AddCommand(getRoleAddCmd(configPath)) + cmd.AddCommand(getRoleRemoveCmd(configPath)) + + return cmd +} + +func getRoleAddCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "add", + Short: "Add a role to a user", + Long: `Add a role to an existing user.`, + RunE: func(cmd *cobra.Command, args []string) error { + username, _ := cmd.Flags().GetString("username") + roleName, _ := cmd.Flags().GetString("role") + + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 查找用户 + u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("user with username '%s' not found", username) + } + + // 检查角色是否已存在 + hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context()) + if err != nil { + return fmt.Errorf("error checking user roles: %v", err) + } + if hasRole { + return fmt.Errorf("user '%s' already has role '%s'", username, roleName) + } + + // 获取要添加的角色 + r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("role '%s' not found", roleName) + } + + // 添加角色 + err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context()) + if err != nil { + return fmt.Errorf("error adding role: %v", err) + } + + fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username) + return nil + }, + } + + // 添加命令行参数 + cmd.Flags().String("username", "", "username of the user") + cmd.Flags().String("role", "", "role to add (admin/editor/contributor)") + + // 设置必需的参数 + cmd.MarkFlagRequired("username") + cmd.MarkFlagRequired("role") + + return cmd +} + +func getRoleRemoveCmd(configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "remove", + Short: "Remove a role from a user", + Long: `Remove a role from an existing user.`, + RunE: func(cmd *cobra.Command, args []string) error { + username, _ := cmd.Flags().GetString("username") + roleName, _ := cmd.Flags().GetString("role") + + // 加载配置 + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("error loading config: %v", err) + } + + // 初始化数据库连接 + client := server.NewEntClient(cfg) + if client == nil { + return fmt.Errorf("failed to create database client") + } + defer client.Close() + + // 查找用户 + u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("user with username '%s' not found", username) + } + + // 检查角色是否存在 + r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context()) + if err != nil { + return fmt.Errorf("role '%s' not found", roleName) + } + + // 检查用户是否有该角色 + hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context()) + if err != nil { + return fmt.Errorf("error checking user roles: %v", err) + } + if !hasRole { + return fmt.Errorf("user '%s' does not have role '%s'", username, roleName) + } + + // 移除角色 + err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context()) + if err != nil { + return fmt.Errorf("error removing role: %v", err) + } + + fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username) + return nil + }, + } + + // 添加命令行参数 + cmd.Flags().String("username", "", "username of the user") + cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)") + + // 设置必需的参数 + cmd.MarkFlagRequired("username") + cmd.MarkFlagRequired("role") + + return cmd +} diff --git a/backend/config/config.yaml b/backend/config/config.yaml deleted file mode 100644 index 08f4c31..0000000 --- a/backend/config/config.yaml +++ /dev/null @@ -1,91 +0,0 @@ -database: - driver: sqlite3 - dsn: "file:tss-rocks.db?_fk=1&cache=shared" - -server: - port: 8080 - host: localhost - -jwt: - secret: "your-secret-key-here" # 在生产环境中应该使用环境变量 - expiration: 24h # token 过期时间 - -logging: - level: debug # debug, info, warn, error - format: json # json, console - -storage: - type: s3 # local or s3 - local: - root_dir: "./storage/media" - s3: - region: "us-east-1" - bucket: "your-bucket-name" - access_key_id: "your-access-key-id" - secret_access_key: "your-secret-access-key" - endpoint: "" # Optional, for MinIO or other S3-compatible services - custom_url: "" # Optional, for CDN or custom domain (e.g., https://cdn.example.com/media) - proxy_s3: false # If true, backend will proxy S3 requests instead of redirecting - upload: - max_size: 10 # 最大文件大小(MB) - allowed_types: # 允许的文件类型 - - image/jpeg - - image/png - - image/gif - - image/webp - - image/svg+xml - - application/pdf - - application/msword - - application/vnd.openxmlformats-officedocument.wordprocessingml.document - - application/vnd.ms-excel - - application/vnd.openxmlformats-officedocument.spreadsheetml.sheet - - application/zip - - application/x-rar-compressed - - text/plain - - text/csv - allowed_extensions: # 允许的文件扩展名(小写) - - .jpg - - .jpeg - - .png - - .gif - - .webp - - .svg - - .pdf - - .doc - - .docx - - .xls - - .xlsx - - .zip - - .rar - - .txt - - .csv - -rate_limit: - # IP限流配置 - ip_rate: 50 # 每秒请求数 - ip_burst: 25 # 突发请求数 - - # 路由限流配置 - route_rates: - "/api/v1/auth/login": - rate: 5 # 每秒5个请求 - burst: 10 # 突发10个请求 - "/api/v1/auth/register": - rate: 2 # 每秒2个请求 - burst: 5 # 突发5个请求 - "/api/v1/media/upload": - rate: 10 # 每秒10个请求 - burst: 20 # 突发20个请求 - -access_log: - enable_console: true # 启用控制台输出 - enable_file: true # 启用文件日志 - file_path: "./logs/access.log" # 日志文件路径 - format: "json" # 日志格式 (json 或 text) - level: "info" # 日志级别 - rotation: - max_size: 100 # 每个日志文件的最大大小(MB) - max_age: 30 # 保留旧日志文件的最大天数 - max_backups: 10 # 保留的旧日志文件的最大数量 - compress: true # 是否压缩旧日志文件 - local_time: true # 使用本地时间作为轮转时间 diff --git a/backend/config/config.yaml.example b/backend/config/config.yaml.example new file mode 100644 index 0000000..2b45cd9 --- /dev/null +++ b/backend/config/config.yaml.example @@ -0,0 +1,36 @@ +database: + driver: sqlite3 + # SQLite DSN 说明: + # - file:tss.db => 相对路径,在当前目录创建 tss.db + # - cache=shared => 启用共享缓存,提高性能 + # - _fk=1 => 启用外键约束 + # - mode=rwc => 如果数据库不存在则创建(read-write-create) + dsn: "file:tss.db?cache=shared&_fk=1&mode=rwc" + +server: + port: 8080 + host: localhost + +jwt: + secret: your-jwt-secret-here # 在生产环境中应该使用环境变量 + expiration: 24h + +storage: + driver: local + local: + root: storage + base_url: http://localhost:8080/storage + +logging: + level: debug + format: console + +rate_limit: + enabled: true + requests: 100 + duration: 1m + +access_log: + enabled: true + format: combined + output: stdout diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index bb6da6d..562eed3 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -373,7 +373,9 @@ var ( // UsersColumns holds the columns for the "users" table. UsersColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt, Increment: true}, - {Name: "email", Type: field.TypeString, Unique: true}, + {Name: "username", Type: field.TypeString, Unique: true}, + {Name: "display_name", Type: field.TypeString, Nullable: true, Size: 64}, + {Name: "email", Type: field.TypeString}, {Name: "password_hash", Type: field.TypeString}, {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "inactive", "banned"}, Default: "active"}, {Name: "created_at", Type: field.TypeTime}, diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 26cb34e..b8ccb5b 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -9321,6 +9321,8 @@ type UserMutation struct { op Op typ string id *int + username *string + display_name *string email *string password_hash *string status *user.Status @@ -9439,6 +9441,91 @@ func (m *UserMutation) IDs(ctx context.Context) ([]int, error) { } } +// SetUsername sets the "username" field. +func (m *UserMutation) SetUsername(s string) { + m.username = &s +} + +// Username returns the value of the "username" field in the mutation. +func (m *UserMutation) Username() (r string, exists bool) { + v := m.username + if v == nil { + return + } + return *v, true +} + +// OldUsername returns the old "username" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldUsername(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsername is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsername requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsername: %w", err) + } + return oldValue.Username, nil +} + +// ResetUsername resets all changes to the "username" field. +func (m *UserMutation) ResetUsername() { + m.username = nil +} + +// SetDisplayName sets the "display_name" field. +func (m *UserMutation) SetDisplayName(s string) { + m.display_name = &s +} + +// DisplayName returns the value of the "display_name" field in the mutation. +func (m *UserMutation) DisplayName() (r string, exists bool) { + v := m.display_name + if v == nil { + return + } + return *v, true +} + +// OldDisplayName returns the old "display_name" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDisplayName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDisplayName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDisplayName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDisplayName: %w", err) + } + return oldValue.DisplayName, nil +} + +// ClearDisplayName clears the value of the "display_name" field. +func (m *UserMutation) ClearDisplayName() { + m.display_name = nil + m.clearedFields[user.FieldDisplayName] = struct{}{} +} + +// DisplayNameCleared returns if the "display_name" field was cleared in this mutation. +func (m *UserMutation) DisplayNameCleared() bool { + _, ok := m.clearedFields[user.FieldDisplayName] + return ok +} + +// ResetDisplayName resets all changes to the "display_name" field. +func (m *UserMutation) ResetDisplayName() { + m.display_name = nil + delete(m.clearedFields, user.FieldDisplayName) +} + // SetEmail sets the "email" field. func (m *UserMutation) SetEmail(s string) { m.email = &s @@ -9815,7 +9902,13 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 5) + fields := make([]string, 0, 7) + if m.username != nil { + fields = append(fields, user.FieldUsername) + } + if m.display_name != nil { + fields = append(fields, user.FieldDisplayName) + } if m.email != nil { fields = append(fields, user.FieldEmail) } @@ -9839,6 +9932,10 @@ func (m *UserMutation) Fields() []string { // schema. func (m *UserMutation) Field(name string) (ent.Value, bool) { switch name { + case user.FieldUsername: + return m.Username() + case user.FieldDisplayName: + return m.DisplayName() case user.FieldEmail: return m.Email() case user.FieldPasswordHash: @@ -9858,6 +9955,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { // database failed. func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { + case user.FieldUsername: + return m.OldUsername(ctx) + case user.FieldDisplayName: + return m.OldDisplayName(ctx) case user.FieldEmail: return m.OldEmail(ctx) case user.FieldPasswordHash: @@ -9877,6 +9978,20 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er // type. func (m *UserMutation) SetField(name string, value ent.Value) error { switch name { + case user.FieldUsername: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsername(v) + return nil + case user.FieldDisplayName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDisplayName(v) + return nil case user.FieldEmail: v, ok := value.(string) if !ok { @@ -9941,7 +10056,11 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *UserMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(user.FieldDisplayName) { + fields = append(fields, user.FieldDisplayName) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -9954,6 +10073,11 @@ func (m *UserMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *UserMutation) ClearField(name string) error { + switch name { + case user.FieldDisplayName: + m.ClearDisplayName() + return nil + } return fmt.Errorf("unknown User nullable field %s", name) } @@ -9961,6 +10085,12 @@ func (m *UserMutation) ClearField(name string) error { // It returns an error if the field is not defined in the schema. func (m *UserMutation) ResetField(name string) error { switch name { + case user.FieldUsername: + m.ResetUsername() + return nil + case user.FieldDisplayName: + m.ResetDisplayName() + return nil case user.FieldEmail: m.ResetEmail() return nil diff --git a/backend/ent/runtime.go b/backend/ent/runtime.go index 7b523e1..f16632c 100644 --- a/backend/ent/runtime.go +++ b/backend/ent/runtime.go @@ -261,20 +261,28 @@ func init() { role.UpdateDefaultUpdatedAt = roleDescUpdatedAt.UpdateDefault.(func() time.Time) userFields := schema.User{}.Fields() _ = userFields + // userDescUsername is the schema descriptor for username field. + userDescUsername := userFields[0].Descriptor() + // user.UsernameValidator is a validator for the "username" field. It is called by the builders before save. + user.UsernameValidator = userDescUsername.Validators[0].(func(string) error) + // userDescDisplayName is the schema descriptor for display_name field. + userDescDisplayName := userFields[1].Descriptor() + // user.DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save. + user.DisplayNameValidator = userDescDisplayName.Validators[0].(func(string) error) // userDescEmail is the schema descriptor for email field. - userDescEmail := userFields[0].Descriptor() + userDescEmail := userFields[2].Descriptor() // user.EmailValidator is a validator for the "email" field. It is called by the builders before save. user.EmailValidator = userDescEmail.Validators[0].(func(string) error) // userDescPasswordHash is the schema descriptor for password_hash field. - userDescPasswordHash := userFields[1].Descriptor() + userDescPasswordHash := userFields[3].Descriptor() // user.PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. user.PasswordHashValidator = userDescPasswordHash.Validators[0].(func(string) error) // userDescCreatedAt is the schema descriptor for created_at field. - userDescCreatedAt := userFields[3].Descriptor() + userDescCreatedAt := userFields[5].Descriptor() // user.DefaultCreatedAt holds the default value on creation for the created_at field. user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time) // userDescUpdatedAt is the schema descriptor for updated_at field. - userDescUpdatedAt := userFields[4].Descriptor() + userDescUpdatedAt := userFields[6].Descriptor() // user.DefaultUpdatedAt holds the default value on creation for the updated_at field. user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time) // user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 75384be..558a753 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -15,9 +15,14 @@ type User struct { // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ - field.String("email"). + field.String("username"). Unique(). NotEmpty(), + field.String("display_name"). + Optional(). + MaxLen(64), + field.String("email"). + NotEmpty(), field.String("password_hash"). Sensitive(). NotEmpty(), diff --git a/backend/ent/user.go b/backend/ent/user.go index f67a040..f476629 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -17,6 +17,10 @@ type User struct { config `json:"-"` // ID of the ent. ID int `json:"id,omitempty"` + // Username holds the value of the "username" field. + Username string `json:"username,omitempty"` + // DisplayName holds the value of the "display_name" field. + DisplayName string `json:"display_name,omitempty"` // Email holds the value of the "email" field. Email string `json:"email,omitempty"` // PasswordHash holds the value of the "password_hash" field. @@ -80,7 +84,7 @@ func (*User) scanValues(columns []string) ([]any, error) { switch columns[i] { case user.FieldID: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldStatus: + case user.FieldUsername, user.FieldDisplayName, user.FieldEmail, user.FieldPasswordHash, user.FieldStatus: values[i] = new(sql.NullString) case user.FieldCreatedAt, user.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -105,6 +109,18 @@ func (u *User) assignValues(columns []string, values []any) error { return fmt.Errorf("unexpected type %T for field id", value) } u.ID = int(value.Int64) + case user.FieldUsername: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field username", values[i]) + } else if value.Valid { + u.Username = value.String + } + case user.FieldDisplayName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field display_name", values[i]) + } else if value.Valid { + u.DisplayName = value.String + } case user.FieldEmail: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field email", values[i]) @@ -186,6 +202,12 @@ func (u *User) String() string { var builder strings.Builder builder.WriteString("User(") builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) + builder.WriteString("username=") + builder.WriteString(u.Username) + builder.WriteString(", ") + builder.WriteString("display_name=") + builder.WriteString(u.DisplayName) + builder.WriteString(", ") builder.WriteString("email=") builder.WriteString(u.Email) builder.WriteString(", ") diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 9fe4c09..3ca7ee8 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -15,6 +15,10 @@ const ( Label = "user" // FieldID holds the string denoting the id field in the database. FieldID = "id" + // FieldUsername holds the string denoting the username field in the database. + FieldUsername = "username" + // FieldDisplayName holds the string denoting the display_name field in the database. + FieldDisplayName = "display_name" // FieldEmail holds the string denoting the email field in the database. FieldEmail = "email" // FieldPasswordHash holds the string denoting the password_hash field in the database. @@ -57,6 +61,8 @@ const ( // Columns holds all SQL columns for user fields. var Columns = []string{ FieldID, + FieldUsername, + FieldDisplayName, FieldEmail, FieldPasswordHash, FieldStatus, @@ -81,6 +87,10 @@ func ValidColumn(column string) bool { } var ( + // UsernameValidator is a validator for the "username" field. It is called by the builders before save. + UsernameValidator func(string) error + // DisplayNameValidator is a validator for the "display_name" field. It is called by the builders before save. + DisplayNameValidator func(string) error // EmailValidator is a validator for the "email" field. It is called by the builders before save. EmailValidator func(string) error // PasswordHashValidator is a validator for the "password_hash" field. It is called by the builders before save. @@ -128,6 +138,16 @@ func ByID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldID, opts...).ToFunc() } +// ByUsername orders the results by the username field. +func ByUsername(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsername, opts...).ToFunc() +} + +// ByDisplayName orders the results by the display_name field. +func ByDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDisplayName, opts...).ToFunc() +} + // ByEmail orders the results by the email field. func ByEmail(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldEmail, opts...).ToFunc() diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 040d8a9..f2b3af6 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -55,6 +55,16 @@ func IDLTE(id int) predicate.User { return predicate.User(sql.FieldLTE(FieldID, id)) } +// Username applies equality check predicate on the "username" field. It's identical to UsernameEQ. +func Username(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldUsername, v)) +} + +// DisplayName applies equality check predicate on the "display_name" field. It's identical to DisplayNameEQ. +func DisplayName(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldDisplayName, v)) +} + // Email applies equality check predicate on the "email" field. It's identical to EmailEQ. func Email(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldEmail, v)) @@ -75,6 +85,146 @@ func UpdatedAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) } +// UsernameEQ applies the EQ predicate on the "username" field. +func UsernameEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldUsername, v)) +} + +// UsernameNEQ applies the NEQ predicate on the "username" field. +func UsernameNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldUsername, v)) +} + +// UsernameIn applies the In predicate on the "username" field. +func UsernameIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldUsername, vs...)) +} + +// UsernameNotIn applies the NotIn predicate on the "username" field. +func UsernameNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldUsername, vs...)) +} + +// UsernameGT applies the GT predicate on the "username" field. +func UsernameGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldUsername, v)) +} + +// UsernameGTE applies the GTE predicate on the "username" field. +func UsernameGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldUsername, v)) +} + +// UsernameLT applies the LT predicate on the "username" field. +func UsernameLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldUsername, v)) +} + +// UsernameLTE applies the LTE predicate on the "username" field. +func UsernameLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldUsername, v)) +} + +// UsernameContains applies the Contains predicate on the "username" field. +func UsernameContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldUsername, v)) +} + +// UsernameHasPrefix applies the HasPrefix predicate on the "username" field. +func UsernameHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldUsername, v)) +} + +// UsernameHasSuffix applies the HasSuffix predicate on the "username" field. +func UsernameHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldUsername, v)) +} + +// UsernameEqualFold applies the EqualFold predicate on the "username" field. +func UsernameEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldUsername, v)) +} + +// UsernameContainsFold applies the ContainsFold predicate on the "username" field. +func UsernameContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldUsername, v)) +} + +// DisplayNameEQ applies the EQ predicate on the "display_name" field. +func DisplayNameEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldDisplayName, v)) +} + +// DisplayNameNEQ applies the NEQ predicate on the "display_name" field. +func DisplayNameNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldDisplayName, v)) +} + +// DisplayNameIn applies the In predicate on the "display_name" field. +func DisplayNameIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldDisplayName, vs...)) +} + +// DisplayNameNotIn applies the NotIn predicate on the "display_name" field. +func DisplayNameNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldDisplayName, vs...)) +} + +// DisplayNameGT applies the GT predicate on the "display_name" field. +func DisplayNameGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldDisplayName, v)) +} + +// DisplayNameGTE applies the GTE predicate on the "display_name" field. +func DisplayNameGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldDisplayName, v)) +} + +// DisplayNameLT applies the LT predicate on the "display_name" field. +func DisplayNameLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldDisplayName, v)) +} + +// DisplayNameLTE applies the LTE predicate on the "display_name" field. +func DisplayNameLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldDisplayName, v)) +} + +// DisplayNameContains applies the Contains predicate on the "display_name" field. +func DisplayNameContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldDisplayName, v)) +} + +// DisplayNameHasPrefix applies the HasPrefix predicate on the "display_name" field. +func DisplayNameHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldDisplayName, v)) +} + +// DisplayNameHasSuffix applies the HasSuffix predicate on the "display_name" field. +func DisplayNameHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldDisplayName, v)) +} + +// DisplayNameIsNil applies the IsNil predicate on the "display_name" field. +func DisplayNameIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldDisplayName)) +} + +// DisplayNameNotNil applies the NotNil predicate on the "display_name" field. +func DisplayNameNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldDisplayName)) +} + +// DisplayNameEqualFold applies the EqualFold predicate on the "display_name" field. +func DisplayNameEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldDisplayName, v)) +} + +// DisplayNameContainsFold applies the ContainsFold predicate on the "display_name" field. +func DisplayNameContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldDisplayName, v)) +} + // EmailEQ applies the EQ predicate on the "email" field. func EmailEQ(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldEmail, v)) diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index 49463c8..27cf87b 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -23,6 +23,26 @@ type UserCreate struct { hooks []Hook } +// SetUsername sets the "username" field. +func (uc *UserCreate) SetUsername(s string) *UserCreate { + uc.mutation.SetUsername(s) + return uc +} + +// SetDisplayName sets the "display_name" field. +func (uc *UserCreate) SetDisplayName(s string) *UserCreate { + uc.mutation.SetDisplayName(s) + return uc +} + +// SetNillableDisplayName sets the "display_name" field if the given value is not nil. +func (uc *UserCreate) SetNillableDisplayName(s *string) *UserCreate { + if s != nil { + uc.SetDisplayName(*s) + } + return uc +} + // SetEmail sets the "email" field. func (uc *UserCreate) SetEmail(s string) *UserCreate { uc.mutation.SetEmail(s) @@ -173,6 +193,19 @@ func (uc *UserCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (uc *UserCreate) check() error { + if _, ok := uc.mutation.Username(); !ok { + return &ValidationError{Name: "username", err: errors.New(`ent: missing required field "User.username"`)} + } + if v, ok := uc.mutation.Username(); ok { + if err := user.UsernameValidator(v); err != nil { + return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} + } + } + if v, ok := uc.mutation.DisplayName(); ok { + if err := user.DisplayNameValidator(v); err != nil { + return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)} + } + } if _, ok := uc.mutation.Email(); !ok { return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)} } @@ -229,6 +262,14 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _node = &User{config: uc.config} _spec = sqlgraph.NewCreateSpec(user.Table, sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt)) ) + if value, ok := uc.mutation.Username(); ok { + _spec.SetField(user.FieldUsername, field.TypeString, value) + _node.Username = value + } + if value, ok := uc.mutation.DisplayName(); ok { + _spec.SetField(user.FieldDisplayName, field.TypeString, value) + _node.DisplayName = value + } if value, ok := uc.mutation.Email(); ok { _spec.SetField(user.FieldEmail, field.TypeString, value) _node.Email = value diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 1cd778d..5ba737b 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -371,12 +371,12 @@ func (uq *UserQuery) WithMedia(opts ...func(*MediaQuery)) *UserQuery { // Example: // // var v []struct { -// Email string `json:"email,omitempty"` +// Username string `json:"username,omitempty"` // Count int `json:"count,omitempty"` // } // // client.User.Query(). -// GroupBy(user.FieldEmail). +// GroupBy(user.FieldUsername). // Aggregate(ent.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { @@ -394,11 +394,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Example: // // var v []struct { -// Email string `json:"email,omitempty"` +// Username string `json:"username,omitempty"` // } // // client.User.Query(). -// Select(user.FieldEmail). +// Select(user.FieldUsername). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { uq.ctx.Fields = append(uq.ctx.Fields, fields...) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 2c9c371..ae22d3d 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -31,6 +31,40 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { return uu } +// SetUsername sets the "username" field. +func (uu *UserUpdate) SetUsername(s string) *UserUpdate { + uu.mutation.SetUsername(s) + return uu +} + +// SetNillableUsername sets the "username" field if the given value is not nil. +func (uu *UserUpdate) SetNillableUsername(s *string) *UserUpdate { + if s != nil { + uu.SetUsername(*s) + } + return uu +} + +// SetDisplayName sets the "display_name" field. +func (uu *UserUpdate) SetDisplayName(s string) *UserUpdate { + uu.mutation.SetDisplayName(s) + return uu +} + +// SetNillableDisplayName sets the "display_name" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDisplayName(s *string) *UserUpdate { + if s != nil { + uu.SetDisplayName(*s) + } + return uu +} + +// ClearDisplayName clears the value of the "display_name" field. +func (uu *UserUpdate) ClearDisplayName() *UserUpdate { + uu.mutation.ClearDisplayName() + return uu +} + // SetEmail sets the "email" field. func (uu *UserUpdate) SetEmail(s string) *UserUpdate { uu.mutation.SetEmail(s) @@ -244,6 +278,16 @@ func (uu *UserUpdate) defaults() { // check runs all checks and user-defined validators on the builder. func (uu *UserUpdate) check() error { + if v, ok := uu.mutation.Username(); ok { + if err := user.UsernameValidator(v); err != nil { + return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} + } + } + if v, ok := uu.mutation.DisplayName(); ok { + if err := user.DisplayNameValidator(v); err != nil { + return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)} + } + } if v, ok := uu.mutation.Email(); ok { if err := user.EmailValidator(v); err != nil { return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} @@ -274,6 +318,15 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } + if value, ok := uu.mutation.Username(); ok { + _spec.SetField(user.FieldUsername, field.TypeString, value) + } + if value, ok := uu.mutation.DisplayName(); ok { + _spec.SetField(user.FieldDisplayName, field.TypeString, value) + } + if uu.mutation.DisplayNameCleared() { + _spec.ClearField(user.FieldDisplayName, field.TypeString) + } if value, ok := uu.mutation.Email(); ok { _spec.SetField(user.FieldEmail, field.TypeString, value) } @@ -444,6 +497,40 @@ type UserUpdateOne struct { mutation *UserMutation } +// SetUsername sets the "username" field. +func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne { + uuo.mutation.SetUsername(s) + return uuo +} + +// SetNillableUsername sets the "username" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableUsername(s *string) *UserUpdateOne { + if s != nil { + uuo.SetUsername(*s) + } + return uuo +} + +// SetDisplayName sets the "display_name" field. +func (uuo *UserUpdateOne) SetDisplayName(s string) *UserUpdateOne { + uuo.mutation.SetDisplayName(s) + return uuo +} + +// SetNillableDisplayName sets the "display_name" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDisplayName(s *string) *UserUpdateOne { + if s != nil { + uuo.SetDisplayName(*s) + } + return uuo +} + +// ClearDisplayName clears the value of the "display_name" field. +func (uuo *UserUpdateOne) ClearDisplayName() *UserUpdateOne { + uuo.mutation.ClearDisplayName() + return uuo +} + // SetEmail sets the "email" field. func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne { uuo.mutation.SetEmail(s) @@ -670,6 +757,16 @@ func (uuo *UserUpdateOne) defaults() { // check runs all checks and user-defined validators on the builder. func (uuo *UserUpdateOne) check() error { + if v, ok := uuo.mutation.Username(); ok { + if err := user.UsernameValidator(v); err != nil { + return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} + } + } + if v, ok := uuo.mutation.DisplayName(); ok { + if err := user.DisplayNameValidator(v); err != nil { + return &ValidationError{Name: "display_name", err: fmt.Errorf(`ent: validator failed for field "User.display_name": %w`, err)} + } + } if v, ok := uuo.mutation.Email(); ok { if err := user.EmailValidator(v); err != nil { return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)} @@ -717,6 +814,15 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) } } } + if value, ok := uuo.mutation.Username(); ok { + _spec.SetField(user.FieldUsername, field.TypeString, value) + } + if value, ok := uuo.mutation.DisplayName(); ok { + _spec.SetField(user.FieldDisplayName, field.TypeString, value) + } + if uuo.mutation.DisplayNameCleared() { + _spec.ClearField(user.FieldDisplayName, field.TypeString) + } if value, ok := uuo.mutation.Email(); ok { _spec.SetField(user.FieldEmail, field.TypeString, value) } diff --git a/backend/go.mod b/backend/go.mod index 5d2440a..cc83baa 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -15,6 +15,7 @@ require ( github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.24 github.com/rs/zerolog v1.33.0 + github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.10.0 go.uber.org/mock v0.5.0 golang.org/x/crypto v0.33.0 @@ -55,6 +56,7 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/hashicorp/hcl/v2 v2.23.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect @@ -65,6 +67,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/backend/go.sum b/backend/go.sum index 2c155ee..f3dba0a 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -59,6 +59,7 @@ github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCy github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -130,6 +131,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/backend/internal/handler/auth.go b/backend/internal/handler/auth.go index d3fa471..9503c6b 100644 --- a/backend/internal/handler/auth.go +++ b/backend/internal/handler/auth.go @@ -7,16 +7,18 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" ) type RegisterRequest struct { + Username string `json:"username" binding:"required"` Email string `json:"email" binding:"required,email"` Password string `json:"password" binding:"required,min=8"` Role string `json:"role" binding:"required,oneof=admin editor contributor"` } type LoginRequest struct { - Email string `json:"email" binding:"required,email"` + Username string `json:"username" binding:"required,min=3,max=32"` Password string `json:"password" binding:"required"` } @@ -31,7 +33,7 @@ func (h *Handler) Register(c *gin.Context) { return } - user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Password, req.Role) + user, err := h.service.CreateUser(c.Request.Context(), req.Username, req.Email, req.Password, req.Role) if err != nil { log.Error().Err(err).Msg("Failed to create user") c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) @@ -76,14 +78,20 @@ func (h *Handler) Login(c *gin.Context) { return } - user, err := h.service.GetUserByEmail(c.Request.Context(), req.Email) + user, err := h.service.GetUserByUsername(c.Request.Context(), req.Username) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid username or password", + }) return } - if !h.service.ValidatePassword(c.Request.Context(), user, req.Password) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) + // 验证密码 + err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid username or password", + }) return } @@ -91,7 +99,9 @@ func (h *Handler) Login(c *gin.Context) { roles, err := h.service.GetUserRoles(c.Request.Context(), user.ID) if err != nil { log.Error().Err(err).Msg("Failed to get user roles") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user roles"}) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to get user roles", + }) return } @@ -111,7 +121,9 @@ func (h *Handler) Login(c *gin.Context) { tokenString, err := token.SignedString([]byte(h.cfg.JWT.Secret)) if err != nil { log.Error().Err(err).Msg("Failed to generate token") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate token", + }) return } diff --git a/backend/internal/handler/auth_handler_test.go b/backend/internal/handler/auth_handler_test.go index 0c366f2..dfa4458 100644 --- a/backend/internal/handler/auth_handler_test.go +++ b/backend/internal/handler/auth_handler_test.go @@ -3,10 +3,11 @@ package handler import ( "bytes" "encoding/json" - "errors" + "fmt" "net/http" "net/http/httptest" "testing" + "tss-rocks-be/ent" "tss-rocks-be/internal/config" "tss-rocks-be/internal/service/mock" @@ -14,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" + "golang.org/x/crypto/bcrypt" ) type AuthHandlerTestSuite struct { @@ -54,20 +56,21 @@ func (s *AuthHandlerTestSuite) TestRegister() { { name: "成功注册", request: RegisterRequest{ + Username: "testuser", Email: "test@example.com", Password: "password123", Role: "contributor", }, setupMock: func() { - user := &ent.User{ - ID: 1, - Email: "test@example.com", - } s.service.EXPECT(). - CreateUser(gomock.Any(), "test@example.com", "password123", "contributor"). - Return(user, nil) + CreateUser(gomock.Any(), "testuser", "test@example.com", "password123", "contributor"). + Return(&ent.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + }, nil) s.service.EXPECT(). - GetUserRoles(gomock.Any(), user.ID). + GetUserRoles(gomock.Any(), 1). Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) }, expectedStatus: http.StatusCreated, @@ -75,6 +78,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { { name: "无效的邮箱格式", request: RegisterRequest{ + Username: "testuser", Email: "invalid-email", Password: "password123", Role: "contributor", @@ -86,6 +90,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { { name: "密码太短", request: RegisterRequest{ + Username: "testuser", Email: "test@example.com", Password: "short", Role: "contributor", @@ -97,6 +102,7 @@ func (s *AuthHandlerTestSuite) TestRegister() { { name: "无效的角色", request: RegisterRequest{ + Username: "testuser", Email: "test@example.com", Password: "password123", Role: "invalid-role", @@ -151,91 +157,95 @@ func (s *AuthHandlerTestSuite) TestLogin() { { name: "成功登录", request: LoginRequest{ - Email: "test@example.com", + Username: "testuser", Password: "password123", }, setupMock: func() { + // 使用 bcrypt 生成正确的密码哈希 + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) user := &ent.User{ - ID: 1, - Email: "test@example.com", + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), } s.service.EXPECT(). - GetUserByEmail(gomock.Any(), "test@example.com"). + GetUserByUsername(gomock.Any(), "testuser"). Return(user, nil) - s.service.EXPECT(). - ValidatePassword(gomock.Any(), user, "password123"). - Return(true) s.service.EXPECT(). GetUserRoles(gomock.Any(), user.ID). - Return([]*ent.Role{{ID: 1, Name: "contributor"}}, nil) + Return([]*ent.Role{{Name: "admin"}}, nil) }, expectedStatus: http.StatusOK, }, { - name: "无效的邮箱格式", + name: "无效的用户名", request: LoginRequest{ - Email: "invalid-email", + Username: "invalid", Password: "password123", }, - setupMock: func() {}, - expectedStatus: http.StatusBadRequest, - expectedError: "Key: 'LoginRequest.Email' Error:Field validation for 'Email' failed on the 'email' tag", + setupMock: func() { + s.service.EXPECT(). + GetUserByUsername(gomock.Any(), "invalid"). + Return(nil, fmt.Errorf("user not found")) + }, + expectedStatus: http.StatusUnauthorized, + expectedError: "Invalid username or password", }, { name: "用户不存在", request: LoginRequest{ - Email: "nonexistent@example.com", + Username: "nonexistent", Password: "password123", }, setupMock: func() { s.service.EXPECT(). - GetUserByEmail(gomock.Any(), "nonexistent@example.com"). - Return(nil, errors.New("user not found")) + GetUserByUsername(gomock.Any(), "nonexistent"). + Return(nil, fmt.Errorf("user not found")) }, expectedStatus: http.StatusUnauthorized, - expectedError: "Invalid credentials", + expectedError: "Invalid username or password", }, { name: "密码错误", request: LoginRequest{ - Email: "test@example.com", + Username: "testuser", Password: "wrong-password", }, setupMock: func() { + // 使用 bcrypt 生成正确的密码哈希 + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) user := &ent.User{ - ID: 1, - Email: "test@example.com", + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), } s.service.EXPECT(). - GetUserByEmail(gomock.Any(), "test@example.com"). + GetUserByUsername(gomock.Any(), "testuser"). Return(user, nil) - s.service.EXPECT(). - ValidatePassword(gomock.Any(), user, "wrong-password"). - Return(false) }, expectedStatus: http.StatusUnauthorized, - expectedError: "Invalid credentials", + expectedError: "Invalid username or password", }, { name: "获取用户角色失败", request: LoginRequest{ - Email: "test@example.com", + Username: "testuser", Password: "password123", }, setupMock: func() { + // 使用 bcrypt 生成正确的密码哈希 + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost) user := &ent.User{ - ID: 1, - Email: "test@example.com", + ID: 1, + Username: "testuser", + PasswordHash: string(hashedPassword), } s.service.EXPECT(). - GetUserByEmail(gomock.Any(), "test@example.com"). + GetUserByUsername(gomock.Any(), "testuser"). Return(user, nil) - s.service.EXPECT(). - ValidatePassword(gomock.Any(), user, "password123"). - Return(true) s.service.EXPECT(). GetUserRoles(gomock.Any(), user.ID). - Return(nil, errors.New("failed to get roles")) + Return(nil, fmt.Errorf("failed to get roles")) }, expectedStatus: http.StatusInternalServerError, expectedError: "Failed to get user roles", diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index baa7d8d..4c821aa 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -34,6 +34,18 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) { auth.POST("/login", h.Login) } + // User routes + users := api.Group("/users") + { + users.GET("", h.ListUsers) + users.POST("", h.CreateUser) + users.GET("/:id", h.GetUser) + users.PUT("/:id", h.UpdateUser) + users.DELETE("/:id", h.DeleteUser) + users.GET("/me", h.GetCurrentUser) + users.PUT("/me", h.UpdateCurrentUser) + } + // Category routes categories := api.Group("/categories") { diff --git a/backend/internal/handler/user.go b/backend/internal/handler/user.go new file mode 100644 index 0000000..61637cf --- /dev/null +++ b/backend/internal/handler/user.go @@ -0,0 +1,227 @@ +package handler + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" + "tss-rocks-be/internal/types" +) + +type UpdateCurrentUserRequest struct { + Email string `json:"email,omitempty" binding:"omitempty,email"` + CurrentPassword string `json:"current_password,omitempty"` + NewPassword string `json:"new_password,omitempty" binding:"omitempty,min=8"` + DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"` +} + +type CreateUserRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=8"` + Role string `json:"role" binding:"required,oneof=admin editor"` + DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"` +} + +type UpdateUserRequest struct { + Email string `json:"email,omitempty" binding:"omitempty,email"` + Password string `json:"password,omitempty" binding:"omitempty,min=8"` + Role string `json:"role,omitempty" binding:"omitempty,oneof=admin editor"` + Status string `json:"status,omitempty" binding:"omitempty,oneof=active inactive"` + DisplayName string `json:"display_name,omitempty" binding:"omitempty,max=64"` +} + +// ListUsers returns a list of users +func (h *Handler) ListUsers(c *gin.Context) { + // Parse query parameters + params := &types.ListUsersParams{ + Page: 1, + PerPage: 10, + } + + if page := c.Query("page"); page != "" { + if p, err := strconv.Atoi(page); err == nil && p > 0 { + params.Page = p + } + } + + if perPage := c.Query("per_page"); perPage != "" { + if pp, err := strconv.Atoi(perPage); err == nil && pp > 0 { + params.PerPage = pp + } + } + + params.Sort = c.Query("sort") + params.Role = c.Query("role") + params.Status = c.Query("status") + params.Email = c.Query("email") + + // Get users + users, err := h.service.ListUsers(c.Request.Context(), params) + if err != nil { + log.Error().Err(err).Msg("Failed to list users") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": users, + }) +} + +// CreateUser creates a new user +func (h *Handler) CreateUser(c *gin.Context) { + var req CreateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + user, err := h.service.CreateUser(c.Request.Context(), req.Email, req.Email, req.Password, req.Role) + if err != nil { + log.Error().Err(err).Msg("Failed to create user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + return + } + + c.JSON(http.StatusCreated, gin.H{ + "data": user, + }) +} + +// GetUser returns user details +func (h *Handler) GetUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + user, err := h.service.GetUser(c.Request.Context(), id) + if err != nil { + log.Error().Err(err).Msg("Failed to get user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": user, + }) +} + +// UpdateUser updates user information +func (h *Handler) UpdateUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + var req UpdateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + user, err := h.service.UpdateUser(c.Request.Context(), id, &types.UpdateUserInput{ + Email: req.Email, + Password: req.Password, + Role: req.Role, + Status: req.Status, + DisplayName: req.DisplayName, + }) + if err != nil { + log.Error().Err(err).Msg("Failed to update user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": user, + }) +} + +// DeleteUser deletes a user +func (h *Handler) DeleteUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + if err := h.service.DeleteUser(c.Request.Context(), id); err != nil { + log.Error().Err(err).Msg("Failed to delete user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"}) + return + } + + c.Status(http.StatusNoContent) +} + +// GetCurrentUser returns the current user's information +func (h *Handler) GetCurrentUser(c *gin.Context) { + // 从上下文中获取用户ID(由认证中间件设置) + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + // 获取用户信息 + user, err := h.service.GetUser(c.Request.Context(), userID.(int)) + if err != nil { + log.Error().Err(err).Msg("Failed to get user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user information"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": user, + }) +} + +// UpdateCurrentUser updates the current user's information +func (h *Handler) UpdateCurrentUser(c *gin.Context) { + // 从上下文中获取用户ID(由认证中间件设置) + userID, exists := c.Get("user_id") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + var req UpdateCurrentUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 如果要更新密码,需要验证当前密码 + if req.NewPassword != "" { + if req.CurrentPassword == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Current password is required to update password"}) + return + } + + // 验证当前密码 + if err := h.service.VerifyPassword(c.Request.Context(), userID.(int), req.CurrentPassword); err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current password"}) + return + } + } + + // 更新用户信息 + user, err := h.service.UpdateUser(c.Request.Context(), userID.(int), &types.UpdateUserInput{ + Email: req.Email, + Password: req.NewPassword, + DisplayName: req.DisplayName, + }) + if err != nil { + log.Error().Err(err).Msg("Failed to update user") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user information"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "data": user, + }) +} diff --git a/backend/internal/middleware/accesslog.go b/backend/internal/middleware/accesslog.go index 2495294..32d0f9e 100644 --- a/backend/internal/middleware/accesslog.go +++ b/backend/internal/middleware/accesslog.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" "github.com/gin-gonic/gin" @@ -78,26 +79,41 @@ func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) { // 配置文件日志 if config.EnableFile { - // 确保日志目录存在 - if err := os.MkdirAll(filepath.Dir(config.FilePath), 0755); err != nil { + // 验证文件路径 + if config.FilePath == "" { + return nil, fmt.Errorf("file path cannot be empty") + } + + // 验证路径是否包含无效字符 + if strings.ContainsAny(config.FilePath, "\x00") { + return nil, fmt.Errorf("file path contains invalid characters") + } + + dir := filepath.Dir(config.FilePath) + + // 检查目录是否存在或是否可以创建 + if err := os.MkdirAll(dir, 0755); err != nil { return nil, fmt.Errorf("failed to create log directory: %w", err) } - // 配置日志轮转 + // 尝试打开或创建文件,验证路径是否有效且有写入权限 + file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + return nil, fmt.Errorf("failed to open or create log file: %w", err) + } + file.Close() + + // 配置文件日志 logWriter = &lumberjack.Logger{ Filename: config.FilePath, MaxSize: config.Rotation.MaxSize, // MB - MaxAge: config.Rotation.MaxAge, // days - MaxBackups: config.Rotation.MaxBackups, // files + MaxBackups: config.Rotation.MaxBackups, // 文件个数 + MaxAge: config.Rotation.MaxAge, // 天数 Compress: config.Rotation.Compress, // 是否压缩 LocalTime: config.Rotation.LocalTime, // 使用本地时间 } - logger := zerolog.New(logWriter). - With(). - Timestamp(). - Logger() - + logger := zerolog.New(logWriter).With().Timestamp().Logger() fileLogger = &logger } diff --git a/backend/internal/middleware/accesslog_test.go b/backend/internal/middleware/accesslog_test.go index 085555b..449513a 100644 --- a/backend/internal/middleware/accesslog_test.go +++ b/backend/internal/middleware/accesslog_test.go @@ -219,7 +219,7 @@ func TestAccessLogInvalidConfig(t *testing.T) { name: "Invalid file path", config: &types.AccessLogConfig{ EnableFile: true, - FilePath: "/dev/null/nonexistent/test.log", // 在所有操作系统上都无效的路径 + FilePath: "\x00invalid\x00path", // 使用空字符的路径在所有操作系统上都是无效的 }, expectedError: true, }, diff --git a/backend/internal/rbac/init.go b/backend/internal/rbac/init.go index e7ac706..fcbbb4f 100644 --- a/backend/internal/rbac/init.go +++ b/backend/internal/rbac/init.go @@ -5,6 +5,7 @@ import ( "fmt" "tss-rocks-be/ent" + "tss-rocks-be/ent/permission" "tss-rocks-be/ent/role" ) @@ -38,37 +39,69 @@ func InitializeRBAC(ctx context.Context, client *ent.Client) error { permissionMap := make(map[string]*ent.Permission) for resource, actions := range DefaultPermissions { for _, action := range actions { - permission, err := client.Permission.Create(). - SetResource(resource). - SetAction(action). - SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)). - Save(ctx) - if err != nil { - return fmt.Errorf("failed creating permission: %w", err) - } key := fmt.Sprintf("%s:%s", resource, action) + permission, err := client.Permission.Query(). + Where( + permission.ResourceEQ(resource), + permission.ActionEQ(action), + ). + Only(ctx) + if ent.IsNotFound(err) { + permission, err = client.Permission.Create(). + SetResource(resource). + SetAction(action). + SetDescription(fmt.Sprintf("Permission to %s %s", action, resource)). + Save(ctx) + if err != nil { + return fmt.Errorf("failed creating permission: %w", err) + } + } else if err != nil { + return fmt.Errorf("failed querying permission: %w", err) + } permissionMap[key] = permission } } // Create roles with permissions for roleName, permissions := range DefaultRoles { - roleCreate := client.Role.Create(). - SetName(roleName). - SetDescription(fmt.Sprintf("Role for %s users", roleName)) + role, err := client.Role.Query(). + Where(role.NameEQ(roleName)). + Only(ctx) + if ent.IsNotFound(err) { + roleCreate := client.Role.Create(). + SetName(roleName). + SetDescription(fmt.Sprintf("Role for %s users", roleName)) - // Add permissions to role - for resource, actions := range permissions { - for _, action := range actions { - key := fmt.Sprintf("%s:%s", resource, action) - if permission, exists := permissionMap[key]; exists { - roleCreate.AddPermissions(permission) + // Add permissions to role + for resource, actions := range permissions { + for _, action := range actions { + key := fmt.Sprintf("%s:%s", resource, action) + if permission, exists := permissionMap[key]; exists { + roleCreate.AddPermissions(permission) + } } } - } - if _, err := roleCreate.Save(ctx); err != nil { - return fmt.Errorf("failed creating role %s: %w", roleName, err) + if _, err := roleCreate.Save(ctx); err != nil { + return fmt.Errorf("failed creating role %s: %w", roleName, err) + } + } else if err != nil { + return fmt.Errorf("failed querying role: %w", err) + } else { + // Update existing role's permissions + for resource, actions := range permissions { + for _, action := range actions { + key := fmt.Sprintf("%s:%s", resource, action) + if permission, exists := permissionMap[key]; exists { + err = client.Role.UpdateOne(role). + AddPermissions(permission). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed updating role %s permissions: %w", roleName, err) + } + } + } + } } } diff --git a/backend/internal/rbac/init_test.go b/backend/internal/rbac/init_test.go index cdaaae8..2da0158 100644 --- a/backend/internal/rbac/init_test.go +++ b/backend/internal/rbac/init_test.go @@ -64,6 +64,7 @@ func TestAssignRoleToUser(t *testing.T) { // Create a test user user, err := client.User.Create(). SetEmail("test@example.com"). + SetUsername("testuser"). SetPasswordHash("$2a$10$hzLdXMZEIzgr8eGXL0YoCOIIrQhqEj6N.S3.wY1Jx5.4vWm1ZyHyy"). Save(ctx) if err != nil { diff --git a/backend/internal/server/ent.go b/backend/internal/server/ent.go index 9e289ae..63e444b 100644 --- a/backend/internal/server/ent.go +++ b/backend/internal/server/ent.go @@ -1,31 +1,27 @@ package server import ( - "context" + "context" - "entgo.io/ent/dialect/sql" - _ "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog/log" - "tss-rocks-be/ent" - "tss-rocks-be/internal/config" + "tss-rocks-be/ent" + "tss-rocks-be/internal/config" + + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog/log" ) // NewEntClient creates a new ent client func NewEntClient(cfg *config.Config) *ent.Client { - // TODO: Implement database connection based on config - // For now, we'll use SQLite for development - db, err := sql.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") - if err != nil { - log.Fatal().Err(err).Msg("Failed to connect to database") - } + // 使用配置文件中的数据库设置 + client, err := ent.Open(cfg.Database.Driver, cfg.Database.DSN) + if err != nil { + log.Fatal().Err(err).Msg("Failed to connect to database") + } - // Create ent client - client := ent.NewClient(ent.Driver(db)) + // Run the auto migration tool + if err := client.Schema.Create(context.Background()); err != nil { + log.Fatal().Err(err).Msg("Failed to create schema resources") + } - // Run the auto migration tool - if err := client.Schema.Create(context.Background()); err != nil { - log.Fatal().Err(err).Msg("Failed to create schema resources") - } - - return client + return client } diff --git a/backend/internal/service/impl.go b/backend/internal/service/impl.go index 30f5f57..87e3a8e 100644 --- a/backend/internal/service/impl.go +++ b/backend/internal/service/impl.go @@ -24,7 +24,6 @@ import ( "tss-rocks-be/ent/role" "tss-rocks-be/ent/user" "tss-rocks-be/internal/storage" - "github.com/google/uuid" "golang.org/x/crypto/bcrypt" ) @@ -54,59 +53,74 @@ func NewService(client *ent.Client, storage storage.Storage) Service { } // User operations -func (s *serviceImpl) CreateUser(ctx context.Context, email, password string, roleStr string) (*ent.User, error) { - // Hash the password +func (s *serviceImpl) CreateUser(ctx context.Context, username, email, password string, roleStr string) (*ent.User, error) { + // 验证邮箱格式 + emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + if !emailRegex.MatchString(email) { + return nil, fmt.Errorf("invalid email format") + } + + // 验证密码长度 + if len(password) < 8 { + return nil, fmt.Errorf("password must be at least 8 characters") + } + + // 检查用户名是否已存在 + exists, err := s.client.User.Query().Where(user.Username(username)).Exist(ctx) + if err != nil { + return nil, fmt.Errorf("error checking username: %v", err) + } + if exists { + return nil, fmt.Errorf("username '%s' already exists", username) + } + + // 生成密码哈希 hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return nil, fmt.Errorf("failed to hash password: %w", err) + return nil, fmt.Errorf("error hashing password: %v", err) } - // Add the user role by default - userRole, err := s.client.Role.Query().Where(role.NameEQ("user")).Only(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get user role: %w", err) - } - - // If a specific role is requested and it's not "user", get that role too - var additionalRole *ent.Role - if roleStr != "" && roleStr != "user" { - additionalRole, err = s.client.Role.Query().Where(role.NameEQ(roleStr)).Only(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get role: %w", err) - } - } - - // Create user with password and user role - userCreate := s.client.User.Create(). + // 创建用户 + u, err := s.client.User.Create(). + SetUsername(username). SetEmail(email). SetPasswordHash(string(hashedPassword)). - AddRoles(userRole) + SetStatus("active"). + Save(ctx) - // Add the additional role if specified - if additionalRole != nil { - userCreate.AddRoles(additionalRole) - } - - // Save the user - user, err := userCreate.Save(ctx) if err != nil { - return nil, fmt.Errorf("failed to create user: %w", err) + return nil, fmt.Errorf("error creating user: %v", err) } - return user, nil + // 分配角色 + err = s.AssignRole(ctx, u.ID, roleStr) + if err != nil { + return nil, fmt.Errorf("error assigning role: %v", err) + } + + return u, nil +} + +func (s *serviceImpl) GetUserByUsername(ctx context.Context, username string) (*ent.User, error) { + u, err := s.client.User.Query().Where(user.Username(username)).Only(ctx) + if err != nil { + if ent.IsNotFound(err) { + return nil, fmt.Errorf("user with username '%s' not found", username) + } + return nil, fmt.Errorf("error getting user: %v", err) + } + return u, nil } func (s *serviceImpl) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) { - user, err := s.client.User.Query(). - Where(user.EmailEQ(email)). - Only(ctx) + u, err := s.client.User.Query().Where(user.Email(email)).Only(ctx) if err != nil { if ent.IsNotFound(err) { - return nil, fmt.Errorf("user not found: %s", email) + return nil, fmt.Errorf("user with email '%s' not found", email) } - return nil, fmt.Errorf("failed to get user: %w", err) + return nil, fmt.Errorf("error getting user: %v", err) } - return user, nil + return u, nil } func (s *serviceImpl) ValidatePassword(ctx context.Context, user *ent.User, password string) bool { diff --git a/backend/internal/service/impl_test.go b/backend/internal/service/impl_test.go index 5a00848..abba3cb 100644 --- a/backend/internal/service/impl_test.go +++ b/backend/internal/service/impl_test.go @@ -103,52 +103,50 @@ func newMockMultipartFile(data []byte) *mockMultipartFile { func (s *ServiceImplTestSuite) TestCreateUser() { testCases := []struct { - name string - email string - password string - role string - wantError bool + name string + username string + email string + password string + role string + wantErr bool }{ { - name: "Valid user creation", - email: "test@example.com", - password: "password123", - role: "admin", - wantError: false, + name: "有效的用户", + username: "testuser", + email: "test@example.com", + password: "password123", + role: "user", + wantErr: false, }, { - name: "Empty email", - email: "", - password: "password123", - role: "user", - wantError: true, + name: "无效的邮箱", + username: "testuser2", + email: "invalid-email", + password: "password123", + role: "user", + wantErr: true, }, { - name: "Empty password", - email: "test@example.com", - password: "", - role: "user", - wantError: true, - }, - { - name: "Invalid role", - email: "test@example.com", - password: "password123", - role: "invalid_role", - wantError: true, + name: "空密码", + username: "testuser3", + email: "test3@example.com", + password: "", + role: "user", + wantErr: true, }, } for _, tc := range testCases { s.Run(tc.name, func() { - user, err := s.svc.CreateUser(s.ctx, tc.email, tc.password, tc.role) - if tc.wantError { - assert.Error(s.T(), err) - assert.Nil(s.T(), user) + user, err := s.svc.CreateUser(s.ctx, tc.username, tc.email, tc.password, tc.role) + if tc.wantErr { + s.Error(err) + s.Nil(user) } else { - assert.NoError(s.T(), err) - assert.NotNil(s.T(), user) - assert.Equal(s.T(), tc.email, user.Email) + s.NoError(err) + s.NotNil(user) + s.Equal(tc.email, user.Email) + s.Equal(tc.username, user.Username) } }) } @@ -160,7 +158,7 @@ func (s *ServiceImplTestSuite) TestGetUserByEmail() { password := "password123" role := "user" - user, err := s.svc.CreateUser(s.ctx, email, password, role) + user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role) require.NoError(s.T(), err) require.NotNil(s.T(), user) @@ -184,7 +182,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() { password := "password123" role := "user" - user, err := s.svc.CreateUser(s.ctx, email, password, role) + user, err := s.svc.CreateUser(s.ctx, "testuser", email, password, role) require.NoError(s.T(), err) require.NotNil(s.T(), user) @@ -201,7 +199,7 @@ func (s *ServiceImplTestSuite) TestValidatePassword() { func (s *ServiceImplTestSuite) TestRBAC() { s.Run("AssignRole", func() { - user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password", "admin") + user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password", "admin") require.NoError(s.T(), err) err = s.svc.AssignRole(s.ctx, user.ID, "user") @@ -209,7 +207,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("RemoveRole", func() { - user, err := s.svc.CreateUser(s.ctx, "test2@example.com", "password", "admin") + user, err := s.svc.CreateUser(s.ctx, "testuser2", "test2@example.com", "password", "admin") require.NoError(s.T(), err) err = s.svc.RemoveRole(s.ctx, user.ID, "admin") @@ -218,7 +216,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { s.Run("HasPermission", func() { s.Run("Admin can create users", func() { - user, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password", "admin") + user, err := s.svc.CreateUser(s.ctx, "testuser3", "admin@example.com", "password", "admin") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") @@ -227,7 +225,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("Editor cannot create users", func() { - user, err := s.svc.CreateUser(s.ctx, "editor@example.com", "password", "editor") + user, err := s.svc.CreateUser(s.ctx, "testuser4", "editor@example.com", "password", "editor") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") @@ -236,7 +234,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("User cannot create users", func() { - user, err := s.svc.CreateUser(s.ctx, "user@example.com", "password", "user") + user, err := s.svc.CreateUser(s.ctx, "testuser5", "user@example.com", "password", "user") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "users:create") @@ -245,7 +243,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("Editor can create posts", func() { - user, err := s.svc.CreateUser(s.ctx, "editor2@example.com", "password", "editor") + user, err := s.svc.CreateUser(s.ctx, "testuser6", "editor2@example.com", "password", "editor") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") @@ -254,7 +252,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("User can read posts", func() { - user, err := s.svc.CreateUser(s.ctx, "user2@example.com", "password", "user") + user, err := s.svc.CreateUser(s.ctx, "testuser7", "user2@example.com", "password", "user") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:read") @@ -263,7 +261,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("User cannot create posts", func() { - user, err := s.svc.CreateUser(s.ctx, "user3@example.com", "password", "user") + user, err := s.svc.CreateUser(s.ctx, "testuser8", "user3@example.com", "password", "user") require.NoError(s.T(), err) hasPermission, err := s.svc.HasPermission(s.ctx, user.ID, "posts:create") @@ -272,7 +270,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { }) s.Run("Invalid permission format", func() { - user, err := s.svc.CreateUser(s.ctx, "user4@example.com", "password", "user") + user, err := s.svc.CreateUser(s.ctx, "testuser9", "user4@example.com", "password", "user") require.NoError(s.T(), err) _, err = s.svc.HasPermission(s.ctx, user.ID, "invalid_permission") @@ -284,7 +282,7 @@ func (s *ServiceImplTestSuite) TestRBAC() { func (s *ServiceImplTestSuite) TestCategory() { // Create a test user with admin role for testing - adminUser, err := s.svc.CreateUser(s.ctx, "admin@example.com", "password123", "admin") + adminUser, err := s.svc.CreateUser(s.ctx, "testuser10", "admin@example.com", "password123", "admin") require.NoError(s.T(), err) require.NotNil(s.T(), adminUser) @@ -510,7 +508,7 @@ func (s *ServiceImplTestSuite) TestGetUserRoles() { ctx := context.Background() // 创建测试用户,默认会有 "user" 角色 - user, err := s.svc.CreateUser(ctx, "test@example.com", "password123", "user") + user, err := s.svc.CreateUser(ctx, "testuser", "test@example.com", "password123", "user") s.Require().NoError(err) // 测试新用户有默认的 "user" 角色 @@ -840,7 +838,7 @@ func (s *ServiceImplTestSuite) TestPost() { func (s *ServiceImplTestSuite) TestMedia() { s.Run("Upload Media", func() { // Create a user first - user, err := s.svc.CreateUser(s.ctx, "test@example.com", "password123", "") + user, err := s.svc.CreateUser(s.ctx, "testuser", "test@example.com", "password123", "user") require.NoError(s.T(), err) require.NotNil(s.T(), user) @@ -963,7 +961,7 @@ func (s *ServiceImplTestSuite) TestMedia() { s.Run("Delete Media - Unauthorized", func() { // Create a user - user, err := s.svc.CreateUser(s.ctx, "another@example.com", "password123", "") + user, err := s.svc.CreateUser(s.ctx, "anotheruser", "another@example.com", "password123", "user") require.NoError(s.T(), err) // Mock file content @@ -1010,7 +1008,7 @@ func (s *ServiceImplTestSuite) TestMedia() { require.NoError(s.T(), err) // Try to delete with different user - anotherUser, err := s.svc.CreateUser(s.ctx, "third@example.com", "password123", "") + anotherUser, err := s.svc.CreateUser(s.ctx, "thirduser", "third@example.com", "password123", "user") require.NoError(s.T(), err) err = s.svc.DeleteMedia(s.ctx, media.ID, anotherUser.ID) diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index 3920125..cf1bf8d 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -9,14 +9,22 @@ import ( "tss-rocks-be/ent" "tss-rocks-be/internal/storage" + "tss-rocks-be/internal/types" ) // Service interface defines all business logic operations type Service interface { // User operations - CreateUser(ctx context.Context, email, password string, role string) (*ent.User, error) + CreateUser(ctx context.Context, username, email, password string, role string) (*ent.User, error) + GetUser(ctx context.Context, id int) (*ent.User, error) + GetUserByUsername(ctx context.Context, username string) (*ent.User, error) GetUserByEmail(ctx context.Context, email string) (*ent.User, error) ValidatePassword(ctx context.Context, user *ent.User, password string) bool + VerifyPassword(ctx context.Context, userID int, password string) error + UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) + DeleteUser(ctx context.Context, userID int) error + ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) + GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) // Category operations CreateCategory(ctx context.Context) (*ent.Category, error) @@ -51,9 +59,8 @@ type Service interface { DeleteMedia(ctx context.Context, id int, userID int) error // RBAC operations + InitializeRBAC(ctx context.Context) error AssignRole(ctx context.Context, userID int, role string) error RemoveRole(ctx context.Context, userID int, role string) error - GetUserRoles(ctx context.Context, userID int) ([]*ent.Role, error) HasPermission(ctx context.Context, userID int, permission string) (bool, error) - InitializeRBAC(ctx context.Context) error } diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go new file mode 100644 index 0000000..0ae658b --- /dev/null +++ b/backend/internal/service/user.go @@ -0,0 +1,154 @@ +package service + +import ( + "context" + "fmt" + + "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" + "tss-rocks-be/ent" + "tss-rocks-be/ent/role" + "tss-rocks-be/ent/user" + "tss-rocks-be/internal/types" +) + +// GetUser gets a user by ID +func (s *serviceImpl) GetUser(ctx context.Context, id int) (*ent.User, error) { + user, err := s.client.User.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + return user, nil +} + +// VerifyPassword verifies the user's current password +func (s *serviceImpl) VerifyPassword(ctx context.Context, userID int, password string) error { + user, err := s.GetUser(ctx, userID) + if err != nil { + return err + } + + if !s.ValidatePassword(ctx, user, password) { + return fmt.Errorf("invalid password") + } + + return nil +} + +// UpdateUser updates user information +func (s *serviceImpl) UpdateUser(ctx context.Context, userID int, input *types.UpdateUserInput) (*ent.User, error) { + // Start building the update + update := s.client.User.UpdateOneID(userID) + + // Update email if provided + if input.Email != "" { + update.SetEmail(input.Email) + } + + // Update display name if provided + if input.DisplayName != "" { + update.SetDisplayName(input.DisplayName) + } + + // Update password if provided + if input.Password != "" { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost) + if err != nil { + log.Error().Err(err).Msg("Failed to hash password") + return nil, fmt.Errorf("failed to hash password: %w", err) + } + update.SetPasswordHash(string(hashedPassword)) + } + + // Update status if provided + if input.Status != "" { + update.SetStatus(user.Status(input.Status)) + } + + // Execute the update + user, err := update.Save(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to update user") + return nil, fmt.Errorf("failed to update user: %w", err) + } + + // Update role if provided + if input.Role != "" { + // Clear existing roles + _, err = user.Update().ClearRoles().Save(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to clear user roles") + return nil, fmt.Errorf("failed to update user roles: %w", err) + } + + // Add new role + role, err := s.client.Role.Query().Where(role.NameEQ(input.Role)).Only(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to find role") + return nil, fmt.Errorf("failed to find role: %w", err) + } + + _, err = user.Update().AddRoles(role).Save(ctx) + if err != nil { + log.Error().Err(err).Msg("Failed to add user role") + return nil, fmt.Errorf("failed to update user roles: %w", err) + } + } + + return user, nil +} + +// DeleteUser deletes a user by ID +func (s *serviceImpl) DeleteUser(ctx context.Context, userID int) error { + err := s.client.User.DeleteOneID(userID).Exec(ctx) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + return nil +} + +// ListUsers lists users with filters and pagination +func (s *serviceImpl) ListUsers(ctx context.Context, params *types.ListUsersParams) ([]*ent.User, error) { + query := s.client.User.Query() + + // Apply filters + if params.Role != "" { + query.Where(user.HasRolesWith(role.NameEQ(params.Role))) + } + if params.Status != "" { + query.Where(user.StatusEQ(user.Status(params.Status))) + } + if params.Email != "" { + query.Where(user.EmailContains(params.Email)) + } + + // Apply pagination + if params.PerPage > 0 { + query.Limit(params.PerPage) + if params.Page > 0 { + query.Offset((params.Page - 1) * params.PerPage) + } + } + + // Apply sorting + if params.Sort != "" { + switch params.Sort { + case "email_asc": + query.Order(ent.Asc(user.FieldEmail)) + case "email_desc": + query.Order(ent.Desc(user.FieldEmail)) + case "created_at_asc": + query.Order(ent.Asc(user.FieldCreatedAt)) + case "created_at_desc": + query.Order(ent.Desc(user.FieldCreatedAt)) + } + } + + // Execute query + users, err := query.All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list users: %w", err) + } + + return users, nil +} diff --git a/backend/internal/types/user.go b/backend/internal/types/user.go new file mode 100644 index 0000000..1eef0b6 --- /dev/null +++ b/backend/internal/types/user.go @@ -0,0 +1,28 @@ +package types + +// UpdateUserInput defines the input for updating a user +type UpdateUserInput struct { + Email string + Password string + Role string + Status string + DisplayName string +} + +// ListUsersParams defines the parameters for listing users +type ListUsersParams struct { + Page int + PerPage int + Sort string + Role string + Status string + Email string +} + +// CreateUserInput defines the input for creating a user +type CreateUserInput struct { + Email string + Password string + Role string + DisplayName string +} diff --git a/backend/main.go b/backend/main.go new file mode 100644 index 0000000..3d03640 --- /dev/null +++ b/backend/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + "tss-rocks-be/cmd" +) + +var rootCmd = &cobra.Command{ + Use: "app", + Short: "TSS backend application", + Long: `TSS (The Starset Society) backend application. +It provides both API endpoints and admin tools for content management.`, +} + +func init() { + // 添加子命令 + rootCmd.AddCommand(cmd.GetUserCmd()) + rootCmd.AddCommand(cmd.GetServerCmd()) +} + +func main() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +}