[feature/backend] implement /users handler + switch to username + add display name + user management cli

This commit is contained in:
CDN 2025-02-21 04:30:07 +08:00
parent 1d712d4e6c
commit 86ab334bc9
Signed by: CDN
GPG key ID: 0C656827F9F80080
38 changed files with 1851 additions and 506 deletions

89
backend/cmd/server.go Normal file
View file

@ -0,0 +1,89 @@
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func GetServerCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the server",
Long: `Start the server with the specified configuration.`,
RunE: func(cmd *cobra.Command, args []string) error {
// Load configuration
cfg, err := config.Load(configPath)
if err != nil {
return err
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
return nil
},
}
// Add flags
cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file")
return cmd
}

View file

@ -1,79 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/rs/zerolog/log"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"tss-rocks-be/pkg/logger"
)
func main() {
// Parse command line flags
configPath := flag.String("config", "config/config.yaml", "path to config file")
flag.Parse()
// Load configuration
cfg, err := config.Load(*configPath)
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
// Setup logger
logger.Setup(cfg)
// Create a context that can be cancelled
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create ent client
client := server.NewEntClient(cfg)
if client == nil {
log.Fatal().Msg("Failed to create database client")
}
defer client.Close()
// Run the auto migration tool
if err := client.Schema.Create(ctx); err != nil {
log.Fatal().Err(err).Msg("Failed to create schema resources")
}
// Create and start server
srv, err := server.New(cfg, client)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create server")
}
// Start server in a goroutine
go func() {
if err := srv.Start(); err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
}()
// Wait for interrupt signal
sig := <-sigChan
log.Info().Msgf("Received signal: %v", sig)
// Attempt graceful shutdown with timeout
log.Info().Msg("Shutting down server...")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Error().Err(err).Msg("Server forced to shutdown")
os.Exit(1)
}
}

View file

@ -1,127 +0,0 @@
package main
import (
"context"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
)
func TestConfigLoad(t *testing.T) {
// Create a temporary config file for testing
tmpConfig := `
database:
driver: sqlite3
dsn: ":memory:"
server:
port: 8080
host: localhost
storage:
type: local
local:
root_dir: "./testdata"
`
tmpFile, err := os.CreateTemp("", "config-*.yaml")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString(tmpConfig)
require.NoError(t, err)
err = tmpFile.Close()
require.NoError(t, err)
// Test config loading
cfg, err := config.Load(tmpFile.Name())
require.NoError(t, err)
assert.Equal(t, "sqlite3", cfg.Database.Driver)
assert.Equal(t, ":memory:", cfg.Database.DSN)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, "localhost", cfg.Server.Host)
assert.Equal(t, "local", cfg.Storage.Type)
assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir)
}
func TestServerCreation(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 8080,
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
// Create ent client
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
// Test schema creation
err := client.Schema.Create(context.Background())
require.NoError(t, err)
// Test server creation
srv, err := server.New(cfg, client)
require.NoError(t, err)
require.NotNil(t, srv)
}
func TestServerStartAndShutdown(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
Driver: "sqlite3",
DSN: ":memory:",
},
Server: config.ServerConfig{
Port: 0, // Use random available port
Host: "localhost",
},
Storage: config.StorageConfig{
Type: "local",
Local: config.LocalStorage{
RootDir: "./testdata",
},
},
}
client := server.NewEntClient(cfg)
require.NotNil(t, client)
defer client.Close()
err := client.Schema.Create(context.Background())
require.NoError(t, err)
srv, err := server.New(cfg, client)
require.NoError(t, err)
// Start server in goroutine
go func() {
err := srv.Start()
if err != nil {
t.Logf("Server stopped: %v", err)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Test graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
assert.NoError(t, err)
}

411
backend/cmd/user.go Normal file
View file

@ -0,0 +1,411 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"tss-rocks-be/ent/user"
"tss-rocks-be/ent/role"
"tss-rocks-be/internal/config"
"tss-rocks-be/internal/server"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
)
func GetUserCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `Commands for managing users, including create, delete, password, and role management.`,
}
configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file")
cmd.AddCommand(getCreateUserCmd(*configPath))
cmd.AddCommand(getPasswordCmd(*configPath))
cmd.AddCommand(getListUsersCmd(*configPath))
cmd.AddCommand(getDeleteUserCmd(*configPath))
cmd.AddCommand(getRoleCmd(*configPath))
return cmd
}
func getCreateUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "create",
Short: "Create a new user",
Long: `Create a new user with specified username, email and password.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
email, _ := cmd.Flags().GetString("email")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 检查用户名是否已存在
exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking username: %v", err)
}
if exists {
return fmt.Errorf("username '%s' already exists", username)
}
// 检查邮箱是否已存在
exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking email: %v", err)
}
if exists {
return fmt.Errorf("email '%s' already exists", email)
}
// 生成密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 创建用户
u, err := client.User.Create().
SetUsername(username).
SetEmail(email).
SetPasswordHash(string(hashedPassword)).
SetStatus("active").
Save(cmd.Context())
if err != nil {
return fmt.Errorf("error creating user: %v", err)
}
fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username for the new user")
cmd.Flags().String("email", "", "email for the new user")
cmd.Flags().String("password", "", "password for the new user")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("email")
cmd.MarkFlagRequired("password")
return cmd
}
func getPasswordCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "password",
Short: "Change user password",
Long: `Change the password for an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
password, _ := cmd.Flags().GetString("password")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 生成新的密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("error hashing password: %v", err)
}
// 更新密码
_, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context())
if err != nil {
return fmt.Errorf("error updating password: %v", err)
}
fmt.Printf("Successfully updated password for user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("password", "", "new password")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("password")
return cmd
}
func getListUsersCmd(configPath string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 获取所有用户
users, err := client.User.Query().All(cmd.Context())
if err != nil {
return fmt.Errorf("error querying users: %v", err)
}
// 创建表格写入器
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT")
// 输出用户信息
for _, u := range users {
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
u.Username,
u.Email,
u.Status,
u.CreatedAt.Format("2006-01-02 15:04:05"))
}
w.Flush()
return nil
},
}
}
func getDeleteUserCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Long: `Delete an existing user from the system.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
force, _ := cmd.Flags().GetBool("force")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查用户的角色
hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRoles && !force {
return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username)
}
// 删除用户
err = client.User.DeleteOne(u).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error deleting user: %v", err)
}
fmt.Printf("Successfully deleted user '%s'\n", username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user to delete")
cmd.Flags().Bool("force", false, "force delete even if user has roles")
// 设置必需的参数
cmd.MarkFlagRequired("username")
return cmd
}
func getRoleCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "role",
Short: "User role management commands",
Long: `Commands for managing user roles, including adding and removing roles.`,
}
cmd.AddCommand(getRoleAddCmd(configPath))
cmd.AddCommand(getRoleRemoveCmd(configPath))
return cmd
}
func getRoleAddCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "add",
Short: "Add a role to a user",
Long: `Add a role to an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否已存在
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if hasRole {
return fmt.Errorf("user '%s' already has role '%s'", username, roleName)
}
// 获取要添加的角色
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 添加角色
err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error adding role: %v", err)
}
fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to add (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}
func getRoleRemoveCmd(configPath string) *cobra.Command {
cmd := &cobra.Command{
Use: "remove",
Short: "Remove a role from a user",
Long: `Remove a role from an existing user.`,
RunE: func(cmd *cobra.Command, args []string) error {
username, _ := cmd.Flags().GetString("username")
roleName, _ := cmd.Flags().GetString("role")
// 加载配置
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
// 初始化数据库连接
client := server.NewEntClient(cfg)
if client == nil {
return fmt.Errorf("failed to create database client")
}
defer client.Close()
// 查找用户
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("user with username '%s' not found", username)
}
// 检查角色是否存在
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
if err != nil {
return fmt.Errorf("role '%s' not found", roleName)
}
// 检查用户是否有该角色
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
if err != nil {
return fmt.Errorf("error checking user roles: %v", err)
}
if !hasRole {
return fmt.Errorf("user '%s' does not have role '%s'", username, roleName)
}
// 移除角色
err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context())
if err != nil {
return fmt.Errorf("error removing role: %v", err)
}
fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username)
return nil
},
}
// 添加命令行参数
cmd.Flags().String("username", "", "username of the user")
cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)")
// 设置必需的参数
cmd.MarkFlagRequired("username")
cmd.MarkFlagRequired("role")
return cmd
}