[feature/backend] implement /users handler + switch to username + add display name + user management cli
This commit is contained in:
parent
1d712d4e6c
commit
86ab334bc9
38 changed files with 1851 additions and 506 deletions
89
backend/cmd/server.go
Normal file
89
backend/cmd/server.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/cobra"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/server"
|
||||
"tss-rocks-be/pkg/logger"
|
||||
)
|
||||
|
||||
func GetServerCmd() *cobra.Command {
|
||||
var configPath string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Start the server",
|
||||
Long: `Start the server with the specified configuration.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Load configuration
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setup logger
|
||||
logger.Setup(cfg)
|
||||
|
||||
// Create a context that can be cancelled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Set up signal handling
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create ent client
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
log.Fatal().Msg("Failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(ctx); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create schema resources")
|
||||
}
|
||||
|
||||
// Create and start server
|
||||
srv, err := server.New(cfg, client)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create server")
|
||||
}
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
sig := <-sigChan
|
||||
log.Info().Msgf("Received signal: %v", sig)
|
||||
|
||||
// Attempt graceful shutdown with timeout
|
||||
log.Info().Msg("Shutting down server...")
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Server forced to shutdown")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Add flags
|
||||
cmd.Flags().StringVarP(&configPath, "config", "c", "config/config.yaml", "path to config file")
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/server"
|
||||
"tss-rocks-be/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
configPath := flag.String("config", "config/config.yaml", "path to config file")
|
||||
flag.Parse()
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Setup logger
|
||||
logger.Setup(cfg)
|
||||
|
||||
// Create a context that can be cancelled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Set up signal handling
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Create ent client
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
log.Fatal().Msg("Failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Run the auto migration tool
|
||||
if err := client.Schema.Create(ctx); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create schema resources")
|
||||
}
|
||||
|
||||
// Create and start server
|
||||
srv, err := server.New(cfg, client)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create server")
|
||||
}
|
||||
|
||||
// Start server in a goroutine
|
||||
go func() {
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal
|
||||
sig := <-sigChan
|
||||
log.Info().Msgf("Received signal: %v", sig)
|
||||
|
||||
// Attempt graceful shutdown with timeout
|
||||
log.Info().Msg("Shutting down server...")
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Error().Err(err).Msg("Server forced to shutdown")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/server"
|
||||
)
|
||||
|
||||
func TestConfigLoad(t *testing.T) {
|
||||
// Create a temporary config file for testing
|
||||
tmpConfig := `
|
||||
database:
|
||||
driver: sqlite3
|
||||
dsn: ":memory:"
|
||||
server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
storage:
|
||||
type: local
|
||||
local:
|
||||
root_dir: "./testdata"
|
||||
`
|
||||
tmpFile, err := os.CreateTemp("", "config-*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.WriteString(tmpConfig)
|
||||
require.NoError(t, err)
|
||||
err = tmpFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test config loading
|
||||
cfg, err := config.Load(tmpFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sqlite3", cfg.Database.Driver)
|
||||
assert.Equal(t, ":memory:", cfg.Database.DSN)
|
||||
assert.Equal(t, 8080, cfg.Server.Port)
|
||||
assert.Equal(t, "localhost", cfg.Server.Host)
|
||||
assert.Equal(t, "local", cfg.Storage.Type)
|
||||
assert.Equal(t, "./testdata", cfg.Storage.Local.RootDir)
|
||||
}
|
||||
|
||||
func TestServerCreation(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Driver: "sqlite3",
|
||||
DSN: ":memory:",
|
||||
},
|
||||
Server: config.ServerConfig{
|
||||
Port: 8080,
|
||||
Host: "localhost",
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "./testdata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create ent client
|
||||
client := server.NewEntClient(cfg)
|
||||
require.NotNil(t, client)
|
||||
defer client.Close()
|
||||
|
||||
// Test schema creation
|
||||
err := client.Schema.Create(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test server creation
|
||||
srv, err := server.New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, srv)
|
||||
}
|
||||
|
||||
func TestServerStartAndShutdown(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Driver: "sqlite3",
|
||||
DSN: ":memory:",
|
||||
},
|
||||
Server: config.ServerConfig{
|
||||
Port: 0, // Use random available port
|
||||
Host: "localhost",
|
||||
},
|
||||
Storage: config.StorageConfig{
|
||||
Type: "local",
|
||||
Local: config.LocalStorage{
|
||||
RootDir: "./testdata",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := server.NewEntClient(cfg)
|
||||
require.NotNil(t, client)
|
||||
defer client.Close()
|
||||
|
||||
err := client.Schema.Create(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
srv, err := server.New(cfg, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
err := srv.Start()
|
||||
if err != nil {
|
||||
t.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Test graceful shutdown
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = srv.Shutdown(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
411
backend/cmd/user.go
Normal file
411
backend/cmd/user.go
Normal file
|
@ -0,0 +1,411 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"tss-rocks-be/ent/user"
|
||||
"tss-rocks-be/ent/role"
|
||||
"tss-rocks-be/internal/config"
|
||||
"tss-rocks-be/internal/server"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func GetUserCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "User management commands",
|
||||
Long: `Commands for managing users, including create, delete, password, and role management.`,
|
||||
}
|
||||
|
||||
configPath := cmd.PersistentFlags().String("config", "config/config.yaml", "Path to config file")
|
||||
|
||||
cmd.AddCommand(getCreateUserCmd(*configPath))
|
||||
cmd.AddCommand(getPasswordCmd(*configPath))
|
||||
cmd.AddCommand(getListUsersCmd(*configPath))
|
||||
cmd.AddCommand(getDeleteUserCmd(*configPath))
|
||||
cmd.AddCommand(getRoleCmd(*configPath))
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getCreateUserCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a new user",
|
||||
Long: `Create a new user with specified username, email and password.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username, _ := cmd.Flags().GetString("username")
|
||||
email, _ := cmd.Flags().GetString("email")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 检查用户名是否已存在
|
||||
exists, err := client.User.Query().Where(user.Username(username)).Exist(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking username: %v", err)
|
||||
}
|
||||
if exists {
|
||||
return fmt.Errorf("username '%s' already exists", username)
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
exists, err = client.User.Query().Where(user.Email(email)).Exist(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking email: %v", err)
|
||||
}
|
||||
if exists {
|
||||
return fmt.Errorf("email '%s' already exists", email)
|
||||
}
|
||||
|
||||
// 生成密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error hashing password: %v", err)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
u, err := client.User.Create().
|
||||
SetUsername(username).
|
||||
SetEmail(email).
|
||||
SetPasswordHash(string(hashedPassword)).
|
||||
SetStatus("active").
|
||||
Save(cmd.Context())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating user: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully created user '%s' with email '%s'\n", u.Username, u.Email)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 添加命令行参数
|
||||
cmd.Flags().String("username", "", "username for the new user")
|
||||
cmd.Flags().String("email", "", "email for the new user")
|
||||
cmd.Flags().String("password", "", "password for the new user")
|
||||
|
||||
// 设置必需的参数
|
||||
cmd.MarkFlagRequired("username")
|
||||
cmd.MarkFlagRequired("email")
|
||||
cmd.MarkFlagRequired("password")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getPasswordCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "password",
|
||||
Short: "Change user password",
|
||||
Long: `Change the password for an existing user.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username, _ := cmd.Flags().GetString("username")
|
||||
password, _ := cmd.Flags().GetString("password")
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 查找用户
|
||||
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
// 生成新的密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error hashing password: %v", err)
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
_, err = u.Update().SetPasswordHash(string(hashedPassword)).Save(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating password: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully updated password for user '%s'\n", username)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 添加命令行参数
|
||||
cmd.Flags().String("username", "", "username of the user")
|
||||
cmd.Flags().String("password", "", "new password")
|
||||
|
||||
// 设置必需的参数
|
||||
cmd.MarkFlagRequired("username")
|
||||
cmd.MarkFlagRequired("password")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getListUsersCmd(configPath string) *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all users",
|
||||
Long: `List all users in the system.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 获取所有用户
|
||||
users, err := client.User.Query().All(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error querying users: %v", err)
|
||||
}
|
||||
|
||||
// 创建表格写入器
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "USERNAME\tEMAIL\tSTATUS\tCREATED AT")
|
||||
|
||||
// 输出用户信息
|
||||
for _, u := range users {
|
||||
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
|
||||
u.Username,
|
||||
u.Email,
|
||||
u.Status,
|
||||
u.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
w.Flush()
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getDeleteUserCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "delete",
|
||||
Short: "Delete a user",
|
||||
Long: `Delete an existing user from the system.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username, _ := cmd.Flags().GetString("username")
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 查找用户
|
||||
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
// 检查用户的角色
|
||||
hasRoles, err := client.User.QueryRoles(u).Exist(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking user roles: %v", err)
|
||||
}
|
||||
if hasRoles && !force {
|
||||
return fmt.Errorf("user '%s' has roles assigned. Use --force to override", username)
|
||||
}
|
||||
|
||||
// 删除用户
|
||||
err = client.User.DeleteOne(u).Exec(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting user: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully deleted user '%s'\n", username)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 添加命令行参数
|
||||
cmd.Flags().String("username", "", "username of the user to delete")
|
||||
cmd.Flags().Bool("force", false, "force delete even if user has roles")
|
||||
|
||||
// 设置必需的参数
|
||||
cmd.MarkFlagRequired("username")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getRoleCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "role",
|
||||
Short: "User role management commands",
|
||||
Long: `Commands for managing user roles, including adding and removing roles.`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(getRoleAddCmd(configPath))
|
||||
cmd.AddCommand(getRoleRemoveCmd(configPath))
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getRoleAddCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "add",
|
||||
Short: "Add a role to a user",
|
||||
Long: `Add a role to an existing user.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username, _ := cmd.Flags().GetString("username")
|
||||
roleName, _ := cmd.Flags().GetString("role")
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 查找用户
|
||||
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
// 检查角色是否已存在
|
||||
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking user roles: %v", err)
|
||||
}
|
||||
if hasRole {
|
||||
return fmt.Errorf("user '%s' already has role '%s'", username, roleName)
|
||||
}
|
||||
|
||||
// 获取要添加的角色
|
||||
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("role '%s' not found", roleName)
|
||||
}
|
||||
|
||||
// 添加角色
|
||||
err = client.User.UpdateOne(u).AddRoles(r).Exec(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding role: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully added role '%s' to user '%s'\n", roleName, username)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 添加命令行参数
|
||||
cmd.Flags().String("username", "", "username of the user")
|
||||
cmd.Flags().String("role", "", "role to add (admin/editor/contributor)")
|
||||
|
||||
// 设置必需的参数
|
||||
cmd.MarkFlagRequired("username")
|
||||
cmd.MarkFlagRequired("role")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func getRoleRemoveCmd(configPath string) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "remove",
|
||||
Short: "Remove a role from a user",
|
||||
Long: `Remove a role from an existing user.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username, _ := cmd.Flags().GetString("username")
|
||||
roleName, _ := cmd.Flags().GetString("role")
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 初始化数据库连接
|
||||
client := server.NewEntClient(cfg)
|
||||
if client == nil {
|
||||
return fmt.Errorf("failed to create database client")
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 查找用户
|
||||
u, err := client.User.Query().Where(user.Username(username)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("user with username '%s' not found", username)
|
||||
}
|
||||
|
||||
// 检查角色是否存在
|
||||
r, err := client.Role.Query().Where(role.Name(roleName)).Only(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("role '%s' not found", roleName)
|
||||
}
|
||||
|
||||
// 检查用户是否有该角色
|
||||
hasRole, err := client.User.QueryRoles(u).Where(role.Name(roleName)).Exist(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking user roles: %v", err)
|
||||
}
|
||||
if !hasRole {
|
||||
return fmt.Errorf("user '%s' does not have role '%s'", username, roleName)
|
||||
}
|
||||
|
||||
// 移除角色
|
||||
err = client.User.UpdateOne(u).RemoveRoles(r).Exec(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error removing role: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully removed role '%s' from user '%s'\n", roleName, username)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 添加命令行参数
|
||||
cmd.Flags().String("username", "", "username of the user")
|
||||
cmd.Flags().String("role", "", "role to remove (admin/editor/contributor)")
|
||||
|
||||
// 设置必需的参数
|
||||
cmd.MarkFlagRequired("username")
|
||||
cmd.MarkFlagRequired("role")
|
||||
|
||||
return cmd
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue