package middleware import ( "fmt" "os" "path/filepath" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog" "gopkg.in/natefinch/lumberjack.v2" "tss-rocks-be/internal/types" ) // AccessLogConfig 访问日志配置 type AccessLogConfig struct { // 是否启用控制台输出 EnableConsole bool `yaml:"enable_console"` // 是否启用文件日志 EnableFile bool `yaml:"enable_file"` // 日志文件路径 FilePath string `yaml:"file_path"` // 日志格式 (json 或 text) Format string `yaml:"format"` // 日志级别 Level string `yaml:"level"` // 日志轮转配置 Rotation struct { MaxSize int `yaml:"max_size"` // 每个日志文件的最大大小(MB) MaxAge int `yaml:"max_age"` // 保留旧日志文件的最大天数 MaxBackups int `yaml:"max_backups"` // 保留的旧日志文件的最大数量 Compress bool `yaml:"compress"` // 是否压缩旧日志文件 LocalTime bool `yaml:"local_time"` // 使用本地时间作为轮转时间 } `yaml:"rotation"` } // accessLogger 访问日志记录器 type accessLogger struct { consoleLogger *zerolog.Logger fileLogger *zerolog.Logger logWriter *lumberjack.Logger config *types.AccessLogConfig } // Close 关闭日志文件 func (l *accessLogger) Close() error { if l.logWriter != nil { return l.logWriter.Close() } return nil } // newAccessLogger 创建新的访问日志记录器 func newAccessLogger(config *types.AccessLogConfig) (*accessLogger, error) { var consoleLogger, fileLogger *zerolog.Logger var logWriter *lumberjack.Logger // 设置日志级别 level, err := zerolog.ParseLevel(config.Level) if err != nil { level = zerolog.InfoLevel } zerolog.SetGlobalLevel(level) // 配置控制台日志 if config.EnableConsole { logger := zerolog.New(os.Stdout). With(). Timestamp(). Logger() if config.Format == "text" { logger = logger.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}) } consoleLogger = &logger } // 配置文件日志 if config.EnableFile { // 验证文件路径 if config.FilePath == "" { return nil, fmt.Errorf("file path cannot be empty") } // 验证路径是否包含无效字符 if strings.ContainsAny(config.FilePath, "\x00") { return nil, fmt.Errorf("file path contains invalid characters") } dir := filepath.Dir(config.FilePath) // 检查目录是否存在或是否可以创建 if err := os.MkdirAll(dir, 0755); err != nil { return nil, fmt.Errorf("failed to create log directory: %w", err) } // 尝试打开或创建文件,验证路径是否有效且有写入权限 file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { return nil, fmt.Errorf("failed to open or create log file: %w", err) } file.Close() // 配置文件日志 logWriter = &lumberjack.Logger{ Filename: config.FilePath, MaxSize: config.Rotation.MaxSize, // MB MaxBackups: config.Rotation.MaxBackups, // 文件个数 MaxAge: config.Rotation.MaxAge, // 天数 Compress: config.Rotation.Compress, // 是否压缩 LocalTime: config.Rotation.LocalTime, // 使用本地时间 } logger := zerolog.New(logWriter).With().Timestamp().Logger() fileLogger = &logger } return &accessLogger{ consoleLogger: consoleLogger, fileLogger: fileLogger, logWriter: logWriter, config: config, }, nil } // logEvent 记录日志事件 func (l *accessLogger) logEvent(fields map[string]interface{}, msg string) { if l.consoleLogger != nil { event := l.consoleLogger.Info() for k, v := range fields { event = event.Interface(k, v) } event.Msg(msg) } if l.fileLogger != nil { event := l.fileLogger.Info() for k, v := range fields { event = event.Interface(k, v) } event.Msg(msg) } } // AccessLog 创建访问日志中间件 func AccessLog(config *types.AccessLogConfig) (gin.HandlerFunc, error) { logger, err := newAccessLogger(config) if err != nil { return nil, err } return func(c *gin.Context) { // 用于测试时关闭日志文件 if c == nil { if err := logger.Close(); err != nil { fmt.Printf("Error closing log file: %v\n", err) } return } start := time.Now() requestID := uuid.New().String() path := c.Request.URL.Path query := c.Request.URL.RawQuery // 设置请求ID到上下文 c.Set("request_id", requestID) // 处理请求 c.Next() // 计算处理时间 latency := time.Since(start) // 获取用户ID(如果已认证) var userID interface{} if id, exists := c.Get("user_id"); exists { userID = id } // 准备日志字段 fields := map[string]interface{}{ "request_id": requestID, "method": c.Request.Method, "path": path, "query": query, "ip": c.ClientIP(), "user_agent": c.Request.UserAgent(), "status": c.Writer.Status(), "size": c.Writer.Size(), "latency_ms": latency.Milliseconds(), "component": "access_log", } if userID != nil { fields["user_id"] = userID } // 如果有错误,添加到日志中 if len(c.Errors) > 0 { fields["error"] = c.Errors.String() } // 记录日志 logger.logEvent(fields, fmt.Sprintf("%s %s", c.Request.Method, path)) }, nil }