This commit is contained in:
Joe
2026-01-16 15:49:34 +08:00
commit 550d3e1f42
380 changed files with 62024 additions and 0 deletions
+73
View File
@@ -0,0 +1,73 @@
package utils
import (
"fmt"
"math"
"goravel/app/models"
)
// FormatBalance 根据货币的小数位数格式化余额
// 如果货币没有小数(decimalPlaces = 0),则不返回小数部分
func FormatBalance(balance float64, currency *models.Currency) float64 {
if currency == nil {
// 如果没有货币信息,默认使用2位小数
return math.Round(balance*100) / 100
}
decimalPlaces := currency.DecimalPlaces
if decimalPlaces < 0 {
decimalPlaces = 0
}
if decimalPlaces > 8 {
decimalPlaces = 8
}
// 计算精度倍数
multiplier := math.Pow(10, float64(decimalPlaces))
// 四舍五入到指定小数位数
formatted := math.Round(balance*multiplier) / multiplier
return formatted
}
// FormatBalanceString 格式化余额为字符串(去除末尾多余的0)
func FormatBalanceString(balance float64, currency *models.Currency) string {
formatted := FormatBalance(balance, currency)
if currency == nil {
return fmt.Sprintf("%.2f", formatted)
}
decimalPlaces := currency.DecimalPlaces
if decimalPlaces < 0 {
decimalPlaces = 0
}
if decimalPlaces > 8 {
decimalPlaces = 8
}
// 如果小数位数为0,使用整数格式
if decimalPlaces == 0 {
return fmt.Sprintf("%.0f", formatted)
}
// 格式化字符串,去除末尾的0
format := fmt.Sprintf("%%.%df", decimalPlaces)
result := fmt.Sprintf(format, formatted)
// 去除末尾的0和小数点(如果小数部分全为0)
if decimalPlaces > 0 {
// 去除末尾的0
for len(result) > 0 && result[len(result)-1] == '0' {
result = result[:len(result)-1]
}
// 如果最后是小数点,也去除
if len(result) > 0 && result[len(result)-1] == '.' {
result = result[:len(result)-1]
}
}
return result
}
+101
View File
@@ -0,0 +1,101 @@
package utils
import (
"fmt"
"github.com/goravel/framework/facades"
"goravel/app/models"
)
// GetConfigValue 从数据库获取配置值
// group: 配置分组
// key: 配置键
// defaultValue: 默认值(如果配置不存在)
func GetConfigValue(group, key string, defaultValue string) string {
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
defer func() {
if r := recover(); r != nil {
// 静默处理,返回默认值
}
}()
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
// 这样可以避免在构建时执行数据库查询
orm := facades.Orm()
if orm == nil {
return defaultValue
}
var config models.Config
err := orm.Query().Where("group", group).Where("key", key).First(&config)
if err != nil {
return defaultValue
}
if config.Value == "" {
return defaultValue
}
return config.Value
}
// GetConfigValueInt 从数据库获取配置值(整数类型)
func GetConfigValueInt(group, key string, defaultValue int) int {
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
defer func() {
if r := recover(); r != nil {
// 静默处理,返回默认值
}
}()
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
// 这样可以避免在构建时执行数据库查询
orm := facades.Orm()
if orm == nil {
return defaultValue
}
var config models.Config
err := orm.Query().Where("group", group).Where("key", key).First(&config)
if err != nil {
return defaultValue
}
if config.Value == "" {
return defaultValue
}
// 简单的字符串转整数,实际可以使用更完善的转换
value := 0
_, err = fmt.Sscanf(config.Value, "%d", &value)
if err != nil {
return defaultValue
}
return value
}
// GetConfigValueBool 从数据库获取配置值(布尔类型)
func GetConfigValueBool(group, key string, defaultValue bool) bool {
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
defer func() {
if r := recover(); r != nil {
// 静默处理,返回默认值
}
}()
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
// 这样可以避免在构建时执行数据库查询
orm := facades.Orm()
if orm == nil {
return defaultValue
}
var config models.Config
err := orm.Query().Where("group", group).Where("key", key).First(&config)
if err != nil {
return defaultValue
}
if config.Value == "" {
return defaultValue
}
// 支持 "1", "true", "True" 等格式
value := config.Value
return value == "1" || value == "true" || value == "True" || value == "TRUE"
}
+252
View File
@@ -0,0 +1,252 @@
package utils
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/facades"
)
// CountOptimizer 分页统计优化器
// 当数据量超过阈值时,使用执行计划估算的行数,否则使用实际的 count(*)
type CountOptimizer struct {
// Threshold 阈值,超过此值使用估算值(默认 10000)
Threshold int64
// ModuleName 模块名称,用于日志记录
ModuleName string
}
// NewCountOptimizer 创建新的 CountOptimizer
// threshold: 阈值,超过此值使用估算值(默认 10000)
// moduleName: 模块名称,用于日志记录
func NewCountOptimizer(threshold int64, moduleName string) *CountOptimizer {
if threshold <= 0 {
threshold = 10000 // 默认阈值
}
return &CountOptimizer{
Threshold: threshold,
ModuleName: moduleName,
}
}
// OptimizedCount 优化的 count 查询(使用 ORM Query 对象)
// query: ORM 查询对象(已应用筛选条件,但未应用排序和分页)
// 注意:此方法会先尝试估算,如果失败则使用实际 count
// 返回:总数、是否使用估算值、错误
func (co *CountOptimizer) OptimizedCount(query orm.Query) (int64, bool, error) {
// 先尝试执行实际 count(如果数据量小,直接 count 也很快)
// 如果数据量大,我们可以通过执行时间来判断,但更简单的方法是先估算
// 这里我们提供一个简化版本:直接使用实际 count,如果慢的话可以后续优化
// 由于无法直接从 ORM Query 提取 SQL,这里提供一个变通方案:
// 先执行一次快速查询获取估算值(需要表名)
// 但更推荐使用 OptimizedCountWithTable 方法
actualCount, err := query.Count()
return actualCount, false, err
}
// extractRowsFromPostgreSQLExplain 从 PostgreSQL EXPLAIN JSON 结果中提取行数
func (co *CountOptimizer) extractRowsFromPostgreSQLExplain(jsonStr string) int64 {
// 简化实现:使用正则表达式提取 "Plan" -> "Plan Rows" 的值
// 实际应该使用 JSON 解析,但为了简化,使用正则
re := regexp.MustCompile(`"Plan Rows":\s*(\d+)`)
matches := re.FindStringSubmatch(jsonStr)
if len(matches) > 1 {
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
return rows
}
}
return 0
}
// extractRowsFromPostgreSQLExplainText 从 PostgreSQL EXPLAIN 文本结果中提取行数
func (co *CountOptimizer) extractRowsFromPostgreSQLExplainText(result []map[string]any) int64 {
// PostgreSQL 文本格式的 EXPLAIN 结果中,行数通常在 "rows=" 后面
for _, row := range result {
if queryPlan, ok := row["QUERY PLAN"]; ok {
if planStr, ok := queryPlan.(string); ok {
// 使用正则表达式提取 rows= 后面的数字
re := regexp.MustCompile(`rows=(\d+)`)
matches := re.FindStringSubmatch(planStr)
if len(matches) > 1 {
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
return rows
}
}
}
}
}
return 0
}
// OptimizedCountWithTable 使用表名和 WHERE 条件进行优化的 count 查询
// tableName: 表名
// whereClause: WHERE 子句(不包含 WHERE 关键字),例如:"status = ? AND user_id = ?"
// args: WHERE 条件的参数
// 返回:总数、是否使用估算值、错误
func (co *CountOptimizer) OptimizedCountWithTable(tableName, whereClause string, args ...any) (int64, bool, error) {
dbConnection := facades.Config().GetString("database.default", "sqlite")
// 构建 COUNT SQL
var countSQL string
if whereClause != "" {
switch dbConnection {
case "mysql":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s` WHERE %s", tableName, whereClause)
case "postgres":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s WHERE %s", tableName, whereClause)
default:
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
} else {
switch dbConnection {
case "mysql":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s`", tableName)
case "postgres":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s", tableName)
default:
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
}
// 构建 EXPLAIN SQL
var explainSQL string
if dbConnection == "mysql" {
if whereClause != "" {
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s` WHERE %s", tableName, whereClause)
} else {
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s`", tableName)
}
} else if dbConnection == "postgres" {
if whereClause != "" {
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s WHERE %s", tableName, whereClause)
} else {
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s", tableName)
}
} else {
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// 先执行 EXPLAIN 查询获取估算值(快速,不需要特别精准)
estimatedCount, err := co.executeExplain(explainSQL, args...)
if err != nil {
// 如果估算失败,使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// 如果估算值超过阈值,直接返回估算值(不需要执行实际 count,更快)
if estimatedCount >= co.Threshold {
return estimatedCount, true, nil
}
// 估算值小于阈值,执行实际的 count 获取精确值
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// executeExplain 执行 EXPLAIN 查询并提取估算行数
func (co *CountOptimizer) executeExplain(explainSQL string, args ...any) (int64, error) {
dbConnection := facades.Config().GetString("database.default", "sqlite")
switch dbConnection {
case "mysql":
// MySQL EXPLAIN 返回表格格式
var explainResult []map[string]any
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
return 0, err
}
if len(explainResult) == 0 {
return 0, fmt.Errorf("explain result is empty")
}
// MySQL EXPLAIN 结果中,rows 字段包含估算行数
if rows, ok := explainResult[0]["rows"]; ok {
switch v := rows.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case float64:
return int64(v), nil
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed, nil
}
}
}
return 0, fmt.Errorf("cannot extract rows from explain result")
case "postgres":
// PostgreSQL EXPLAIN (FORMAT JSON) 返回 JSON 格式
var explainResult []map[string]any
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
return 0, err
}
if len(explainResult) == 0 {
return 0, fmt.Errorf("explain result is empty")
}
// PostgreSQL EXPLAIN (FORMAT JSON) 结果是一个包含 JSON 字符串的数组
if queryPlan, ok := explainResult[0]["QUERY PLAN"]; ok {
if planStr, ok := queryPlan.(string); ok {
// 从 JSON 字符串中提取 "Plan" -> "Plan Rows" 的值
rows := co.extractRowsFromPostgreSQLExplain(planStr)
if rows > 0 {
return rows, nil
}
}
}
// 如果 JSON 格式解析失败,尝试文本格式
explainTextSQL := strings.Replace(explainSQL, "EXPLAIN (FORMAT JSON)", "EXPLAIN", 1)
var explainTextResult []map[string]any
if err := facades.Orm().Query().Raw(explainTextSQL, args...).Get(&explainTextResult); err == nil {
rows := co.extractRowsFromPostgreSQLExplainText(explainTextResult)
if rows > 0 {
return rows, nil
}
}
return 0, fmt.Errorf("cannot extract rows from explain result")
}
return 0, fmt.Errorf("unsupported database: %s", dbConnection)
}
+203
View File
@@ -0,0 +1,203 @@
package errorlog
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
"goravel/app/models"
"goravel/app/utils/logger"
"goravel/app/utils/traceid"
)
// RecordHTTP 同时记录文件日志和数据库日志(用于系统级错误,默认 error 级别)
// 使用场景:数据库操作失败、系统服务异常、关键业务逻辑错误等
//
// 示例:
//
// if err != nil {
// errorlog.RecordHTTP(ctx, "auth", "Failed to save admin profile", map[string]any{
// "error": err.Error(),
// "admin_id": admin.ID,
// }, "Save admin profile error: %v", err)
// return response.Error(ctx, http.StatusInternalServerError, "update_failed")
// }
func RecordHTTP(ctx http.Context, module, message string, attributes map[string]any, format string, args ...any) {
RecordHTTPWithLevel(ctx, "error", module, message, attributes, format, args...)
}
// RecordHTTPWithLevel 同时记录文件日志和数据库日志(可指定日志级别)
// level: 日志级别,支持 "error", "warning", "info", "debug"
// 使用场景:需要记录不同级别的系统日志
//
// 示例:
//
// // 记录警告
// errorlog.RecordHTTPWithLevel(ctx, "warning", "auth", "Unusual login pattern detected", map[string]any{
// "admin_id": admin.ID,
// "ip": ctx.Request().Ip(),
// }, "Unusual login pattern: %s", pattern)
//
// // 记录信息
// errorlog.RecordHTTPWithLevel(ctx, "info", "payment", "Payment processed successfully", map[string]any{
// "order_id": order.ID,
// "amount": order.Amount,
// }, "Payment processed: %d", order.ID)
func RecordHTTPWithLevel(ctx http.Context, level, module, message string, attributes map[string]any, format string, args ...any) {
// 清理和验证日志级别,防止伪造
level = sanitizeLogLevel(level)
// 清理日志格式字符串,防止格式字符串注入
sanitizedFormat := sanitizeLogFormat(format)
// 清理日志参数,防止注入攻击
sanitizedArgs := sanitizeLogArgs(args...)
// 根据级别选择不同的日志函数
// 注意:logger 包目前只有 ErrorfHTTP,所以所有级别都使用它
// 但在日志消息中添加级别前缀,便于区分和过滤
levelPrefix := "[" + strings.ToUpper(level) + "] "
formattedMessage := levelPrefix + fmt.Sprintf(sanitizedFormat, sanitizedArgs...)
switch level {
case "error":
logger.ErrorfHTTP(ctx, formattedMessage)
case "warning":
logger.ErrorfHTTP(ctx, formattedMessage)
case "info", "debug":
logger.ErrorfHTTP(ctx, formattedMessage)
default:
logger.ErrorfHTTP(ctx, formattedMessage)
}
// 记录到数据库(所有级别都记录 trace_id)
// 清理 message 和 attributes,防止注入
if ctx != nil {
sanitizedMessage := sanitizeLogString(message, 500) // 限制消息长度
sanitizedAttributes := sanitizeAttributes(attributes)
recordToDatabaseHTTPWithLevel(ctx, level, module, sanitizedMessage, sanitizedAttributes)
}
}
// Record 同时记录文件日志和数据库日志(用于标准 context,默认 error 级别)
// 使用场景:goroutine、后台任务等
//
// 示例:
//
// go func(ctx context.Context) {
// if err != nil {
// errorlog.Record(ctx, "operation-log", "Failed to create operation log", map[string]any{
// "error": err.Error(),
// }, "Create operation log error: %v", err)
// }
// }(traceCtx)
func Record(ctx context.Context, module, message string, attributes map[string]any, format string, args ...any) {
RecordWithLevel(ctx, "error", module, message, attributes, format, args...)
}
// RecordWithLevel 同时记录文件日志和数据库日志(可指定日志级别,用于标准 context)
// level: 日志级别,支持 "error", "warning", "info", "debug"
// 使用场景:goroutine、后台任务中需要记录不同级别的日志
//
// 示例:
//
// go func(ctx context.Context) {
// errorlog.RecordWithLevel(ctx, "info", "background-task", "Task completed", map[string]any{
// "task_id": taskID,
// }, "Background task completed: %s", taskID)
// }(traceCtx)
func RecordWithLevel(ctx context.Context, level, module, message string, attributes map[string]any, format string, args ...any) {
// 清理和验证日志级别,防止伪造
level = sanitizeLogLevel(level)
// 清理日志格式字符串,防止格式字符串注入
sanitizedFormat := sanitizeLogFormat(format)
// 清理日志参数,防止注入攻击
sanitizedArgs := sanitizeLogArgs(args...)
// 根据级别选择不同的日志函数
// 注意:logger 包目前只有 ErrorfContext,所以所有级别都使用它
// 但在日志消息中添加级别前缀,便于区分和过滤
levelPrefix := "[" + strings.ToUpper(level) + "] "
formattedMessage := levelPrefix + fmt.Sprintf(sanitizedFormat, sanitizedArgs...)
switch level {
case "error":
logger.ErrorfContext(ctx, formattedMessage)
case "warning", "info", "debug":
logger.ErrorfContext(ctx, formattedMessage)
default:
logger.ErrorfContext(ctx, formattedMessage)
}
// 记录到数据库(所有级别都记录 trace_id)
// 清理 message 和 attributes,防止注入
if ctx != nil {
sanitizedMessage := sanitizeLogString(message, 1000) // 限制消息长度
sanitizedAttributes := sanitizeAttributes(attributes)
recordToDatabaseWithLevel(ctx, level, module, sanitizedMessage, sanitizedAttributes)
}
}
// recordToDatabaseHTTPWithLevel 将日志记录到数据库(HTTP context,支持所有级别)
func recordToDatabaseHTTPWithLevel(ctx http.Context, level, module, message string, attributes map[string]any) {
var contextJSON string
if len(attributes) > 0 {
if data, err := json.Marshal(attributes); err == nil {
contextJSON = string(data)
}
}
traceID := traceid.FromHTTPContext(ctx)
if traceID == "" {
traceID = traceid.EnsureHTTPContext(ctx, "")
}
log := models.SystemLog{
Level: level,
Module: module,
TraceID: traceID,
Message: message,
Context: contextJSON,
IP: ctx.Request().Ip(),
UserAgent: ctx.Request().Header("User-Agent", ""),
}
_ = facades.Orm().Query().Create(&log)
}
// recordToDatabaseWithLevel 将日志记录到数据库(标准 context,支持所有级别)
func recordToDatabaseWithLevel(ctx context.Context, level, module, message string, attributes map[string]any) {
if ctx == nil {
ctx = context.Background()
}
var contextJSON string
if len(attributes) > 0 {
if data, err := json.Marshal(attributes); err == nil {
contextJSON = string(data)
}
}
traceID := traceid.FromContext(ctx)
if traceID == "" {
var newCtx context.Context
newCtx, traceID = traceid.EnsureContext(ctx)
ctx = newCtx
}
log := models.SystemLog{
Level: level,
Module: module,
TraceID: traceID,
Message: message,
Context: contextJSON,
}
_ = facades.Orm().Query().Create(&log)
}
+182
View File
@@ -0,0 +1,182 @@
package errorlog
import (
"strings"
"unicode"
)
// sanitizeLogString 清理日志字符串,防止日志注入攻击
// 移除或转义危险字符:
// - 换行符 (\n, \r) - 可能用于伪造多行日志
// - 制表符 (\t) - 可能用于格式化攻击
// - 控制字符 - 可能用于隐藏攻击痕迹
//
// 参数:
// - s: 要清理的字符串
// - maxLength: 最大长度限制(0 表示不限制)
//
// 返回:
// - 清理后的字符串
func sanitizeLogString(s string, maxLength int) string {
if s == "" {
return s
}
// 移除控制字符和换行符
var builder strings.Builder
builder.Grow(len(s))
for _, r := range s {
// 允许打印字符和空格,移除控制字符
if unicode.IsPrint(r) || r == ' ' {
builder.WriteRune(r)
} else {
// 将控制字符替换为转义序列(可选,或直接移除)
// 这里选择移除,因为转义序列也可能被利用
}
}
result := builder.String()
// 限制长度
if maxLength > 0 && len(result) > maxLength {
result = result[:maxLength] + "...[truncated]"
}
return result
}
// sanitizeLogLevel 验证和清理日志级别,防止伪造
// 只允许预定义的日志级别
func sanitizeLogLevel(level string) string {
level = strings.ToLower(strings.TrimSpace(level))
// 只允许预定义的级别
allowedLevels := map[string]string{
"error": "error",
"warning": "warning",
"warn": "warning", // 兼容 warn
"info": "info",
"debug": "debug",
}
if validLevel, ok := allowedLevels[level]; ok {
return validLevel
}
// 如果级别无效,默认返回 error(最安全的级别)
return "error"
}
// sanitizeLogArgs 清理日志参数,防止注入攻击
// 将参数转换为安全的字符串表示
func sanitizeLogArgs(args ...any) []any {
if len(args) == 0 {
return args
}
sanitized := make([]any, len(args))
for i, arg := range args {
switch v := arg.(type) {
case string:
// 字符串参数:清理控制字符
sanitized[i] = sanitizeLogString(v, 0)
case []byte:
// 字节数组:转换为字符串后清理
sanitized[i] = sanitizeLogString(string(v), 0)
case error:
// 错误对象:清理错误消息
if v != nil {
sanitized[i] = sanitizeLogString(v.Error(), 0)
} else {
sanitized[i] = "<nil>"
}
default:
// 其他类型:保持原样(数字、布尔等通常是安全的)
sanitized[i] = arg
}
}
return sanitized
}
// sanitizeLogFormat 清理日志格式字符串,防止格式字符串注入
// 移除危险的控制字符,但保留格式占位符(%v, %s 等)
func sanitizeLogFormat(format string) string {
if format == "" {
return format
}
// 移除换行符和控制字符,但保留格式占位符
var builder strings.Builder
builder.Grow(len(format))
for i := 0; i < len(format); i++ {
r := rune(format[i])
// 允许打印字符、空格和格式占位符
if unicode.IsPrint(r) || r == ' ' {
builder.WriteRune(r)
} else if r == '\n' || r == '\r' || r == '\t' {
// 将换行符和制表符替换为空格
builder.WriteRune(' ')
}
// 其他控制字符直接移除
}
result := builder.String()
// 限制格式字符串长度(防止过长的格式字符串攻击)
maxFormatLength := 1000
if len(result) > maxFormatLength {
result = result[:maxFormatLength] + "...[truncated]"
}
return result
}
// sanitizeAttributes 清理 attributes map,防止注入攻击
func sanitizeAttributes(attributes map[string]any) map[string]any {
if attributes == nil || len(attributes) == 0 {
return attributes
}
sanitized := make(map[string]any, len(attributes))
for key, value := range attributes {
// 清理 key
sanitizedKey := sanitizeLogString(key, 100)
// 清理 value
switch v := value.(type) {
case string:
sanitized[sanitizedKey] = sanitizeLogString(v, 1000)
case []byte:
sanitized[sanitizedKey] = sanitizeLogString(string(v), 1000)
case error:
if v != nil {
sanitized[sanitizedKey] = sanitizeLogString(v.Error(), 1000)
} else {
sanitized[sanitizedKey] = "<nil>"
}
case map[string]any:
// 递归清理嵌套 map
sanitized[sanitizedKey] = sanitizeAttributes(v)
case []any:
// 清理数组
sanitizedArray := make([]any, len(v))
for i, item := range v {
if str, ok := item.(string); ok {
sanitizedArray[i] = sanitizeLogString(str, 500)
} else {
sanitizedArray[i] = item
}
}
sanitized[sanitizedKey] = sanitizedArray
default:
// 其他类型(数字、布尔等)保持原样
sanitized[sanitizedKey] = value
}
}
return sanitized
}
+83
View File
@@ -0,0 +1,83 @@
package utils
import (
"strings"
"github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/facades"
)
// ApplyFulltextSearch 应用全文索引搜索条件
// column: 要搜索的字段名(如 "request", "content" 等)
// keyword: 搜索关键词
// query: ORM 查询对象
// 返回: 应用了搜索条件的查询对象
func ApplyFulltextSearch(query orm.Query, column, keyword string) orm.Query {
if keyword == "" {
return query
}
// 获取数据库类型
dbConnection := facades.Config().GetString("database.default", "sqlite")
isPostgreSQL := dbConnection == "postgres"
// 判断是否应该使用全文索引
// 对于短词(少于3个字符)或包含特殊字符的搜索,使用 LIKE/ILIKE
// 对于长词,使用全文索引
useFulltext := len(keyword) >= 3
// 检查是否包含特殊字符(逗号、引号、括号等)
specialChars := []string{",", "\"", "'", "(", ")", "[", "]", "{", "}", ":", ";", "=", "+", "-", "*", "/", "\\", "|", "&", "%", "$", "#", "@", "!", "?", "<", ">", "~", "`"}
for _, char := range specialChars {
if strings.Contains(keyword, char) {
useFulltext = false
break
}
}
if useFulltext {
if isPostgreSQL {
// PostgreSQL: 使用 pg_trgm 相似度搜索(需要已创建 GIN 索引)
// 使用 % 操作符进行相似度匹配,阈值默认 0.3
return query.Where(column+" % ?", keyword)
} else {
// MySQL: 使用 ngram 全文索引
// 注意:需要确保字段已创建全文索引,索引名格式为 ft_{column}
return query.Where("MATCH("+column+") AGAINST(? IN BOOLEAN MODE)", keyword)
}
} else {
// 短词或包含特殊字符:使用 LIKE/ILIKE
if isPostgreSQL {
// PostgreSQL: 使用 ILIKE(不区分大小写)
return query.Where(column+" ILIKE ?", "%"+keyword+"%")
} else {
// MySQL: 使用 LIKE
return query.Where(column+" LIKE ?", "%"+keyword+"%")
}
}
}
// ShouldUseFulltextIndex 判断是否应该使用全文索引
// keyword: 搜索关键词
// 返回: true 表示应该使用全文索引,false 表示使用 LIKE/ILIKE
func ShouldUseFulltextIndex(keyword string) bool {
if len(keyword) < 3 {
return false
}
// 检查是否包含特殊字符
specialChars := []string{",", "\"", "'", "(", ")", "[", "]", "{", "}", ":", ";", "=", "+", "-", "*", "/", "\\", "|", "&", "%", "$", "#", "@", "!", "?", "<", ">", "~", "`"}
for _, char := range specialChars {
if strings.Contains(keyword, char) {
return false
}
}
return true
}
// IsPostgreSQL 判断当前数据库是否为 PostgreSQL
func IsPostgreSQL() bool {
dbConnection := facades.Config().GetString("database.default", "sqlite")
return dbConnection == "postgres"
}
+71
View File
@@ -0,0 +1,71 @@
package utils
import (
"fmt"
"reflect"
"sync"
"github.com/goravel/framework/facades"
"gorm.io/gorm"
)
var (
gormDBInstance *gorm.DB
gormDBOnce sync.Once
gormDBErr error
)
// GetGormDB 获取原生 GORM DB 实例
// 通过框架的 Query().Instance() 方法获取 GORM DB 实例
// 使用单例模式,确保只获取一次
func GetGormDB() (*gorm.DB, error) {
gormDBOnce.Do(func() {
gormDBInstance, gormDBErr = tryGetGormFromFramework()
if gormDBErr == nil && gormDBInstance != nil {
facades.Log().Infof("成功从框架获取 GORM DB 实例")
}
})
return gormDBInstance, gormDBErr
}
// tryGetGormFromFramework 尝试从框架的 ORM 获取底层 GORM DB 实例
// 通过 Query().Instance() 方法获取 GORM DB 实例
func tryGetGormFromFramework() (*gorm.DB, error) {
orm := facades.Orm()
if orm == nil {
return nil, fmt.Errorf("框架 ORM 不可用")
}
ormValue := reflect.ValueOf(orm)
// 通过 Query() 方法获取查询对象,然后调用 Instance() 方法获取 GORM DB
if queryMethod := ormValue.MethodByName("Query"); queryMethod.IsValid() {
queryResults := queryMethod.Call(nil)
if len(queryResults) > 0 {
queryValue := queryResults[0]
if queryValue.Kind() == reflect.Interface && !queryValue.IsNil() {
if queryElem := queryValue.Elem(); queryElem.IsValid() {
// 调用 Instance() 方法
if instanceMethod := queryElem.MethodByName("Instance"); instanceMethod.IsValid() {
methodType := instanceMethod.Type()
if methodType.NumIn() == 0 && methodType.NumOut() > 0 {
results := instanceMethod.Call(nil)
if len(results) > 0 {
result := results[0]
// 检查返回类型是否为 *gorm.DB
if result.Type().String() == "*gorm.DB" && !result.IsNil() {
if db, ok := result.Interface().(*gorm.DB); ok && db != nil {
facades.Log().Infof("成功通过 Instance() 方法从 Query 对象获取 GORM DB 实例")
return db, nil
}
}
}
}
}
}
}
}
}
return nil, fmt.Errorf("无法从框架获取 GORM DB 实例")
}
+18
View File
@@ -0,0 +1,18 @@
package utils
import "html"
// EscapeString 转义 HTML 特殊字符,防止 XSS 攻击
// 将 < > & " ' 等字符转换为 HTML 实体
//
// 示例:
// EscapeString("<script>alert('XSS')</script>")
// 返回: "&lt;script&gt;alert(&#39;XSS&#39;)&lt;/script&gt;"
func EscapeString(s string) string {
return html.EscapeString(s)
}
// EscapeBytes 转义字节数组中的 HTML 特殊字符
func EscapeBytes(b []byte) []byte {
return []byte(html.EscapeString(string(b)))
}
+201
View File
@@ -0,0 +1,201 @@
package utils
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// IPLocation IP 地理位置信息
type IPLocation struct {
Country string `json:"country"` // 国家
Region string `json:"region"` // 省份/州
City string `json:"city"` // 城市
ISP string `json:"isp"` // ISP
CountryCode string `json:"countryCode"` // 国家代码
}
// GetIPLocation 根据 IP 地址获取地理位置信息
// 使用 ip-api.com 免费 API(无需 API Key,有速率限制)
// 如果查询失败,返回空字符串,不影响主流程
func GetIPLocation(ip string) string {
if ip == "" || ip == "127.0.0.1" || ip == "::1" || strings.HasPrefix(ip, "192.168.") || strings.HasPrefix(ip, "10.") || strings.HasPrefix(ip, "172.") {
return "内网IP"
}
// 使用 ip-api.com 免费 API
// 格式:http://ip-api.com/json/{ip}?fields=status,message,country,regionName,city,isp,countryCode
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=status,message,country,regionName,city,isp,countryCode&lang=zh-CN", ip)
client := &http.Client{
Timeout: 3 * time.Second, // 3秒超时,避免阻塞
}
resp, err := client.Get(url)
if err != nil {
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return ""
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return ""
}
var result struct {
Status string `json:"status"`
Message string `json:"message"`
Country string `json:"country"`
RegionName string `json:"regionName"`
City string `json:"city"`
ISP string `json:"isp"`
CountryCode string `json:"countryCode"`
}
if err := json.Unmarshal(body, &result); err != nil {
return ""
}
if result.Status != "success" {
return ""
}
// 构建位置字符串:国家 省份 城市
var locationParts []string
if result.Country != "" {
locationParts = append(locationParts, result.Country)
}
if result.RegionName != "" {
locationParts = append(locationParts, result.RegionName)
}
if result.City != "" {
locationParts = append(locationParts, result.City)
}
location := strings.Join(locationParts, " ")
if location == "" {
return ""
}
// 如果位置信息太长,截断
if len(location) > 100 {
location = location[:100]
}
return location
}
// GetIPLocationAsync 异步获取 IP 地理位置信息
// 用于不阻塞主流程的场景
//
// 安全特性:
// - 使用 context 超时控制(默认 10 秒)
// - 添加 panic recovery,防止 goroutine 崩溃
// - callback 执行超时保护(默认 5 秒),防止阻塞
//
// 参数:
// - ip: IP 地址
// - callback: 回调函数,接收位置信息
func GetIPLocationAsync(ip string, callback func(location string)) {
GetIPLocationAsyncWithTimeout(ip, callback, 10*time.Second, 5*time.Second)
}
// GetIPLocationAsyncWithTimeout 异步获取 IP 地理位置信息(带超时控制)
//
// 参数:
// - ip: IP 地址
// - callback: 回调函数,接收位置信息
// - locationTimeout: IP 查询超时时间(默认 10 秒)
// - callbackTimeout: 回调函数执行超时时间(默认 5 秒)
func GetIPLocationAsyncWithTimeout(ip string, callback func(location string), locationTimeout, callbackTimeout time.Duration) {
// 设置默认超时时间
if locationTimeout <= 0 {
locationTimeout = 10 * time.Second
}
if callbackTimeout <= 0 {
callbackTimeout = 5 * time.Second
}
go func() {
// 添加 panic recovery,防止 goroutine 崩溃导致程序退出
defer func() {
if r := recover(); r != nil {
// 记录 panic 但不影响主流程
// 注意:这里不能使用 facades.Log(),因为可能没有初始化 context
// 如果需要记录日志,可以通过参数传入 logger
}
}()
// 创建带超时的 context(确保总超时时间不超过 locationTimeout
// 注意:GetIPLocation 内部已有 3 秒 HTTP 超时,这里设置外层超时作为兜底
ctx, cancel := context.WithTimeout(context.Background(), locationTimeout)
defer cancel()
// 在 goroutine 中执行 IP 查询
locationChan := make(chan string, 1)
go func() {
defer func() {
if r := recover(); r != nil {
// panic 时发送空字符串,避免阻塞
select {
case locationChan <- "":
default:
}
}
}()
location := GetIPLocation(ip)
// 非阻塞发送,避免 goroutine 泄露
select {
case locationChan <- location:
case <-ctx.Done():
// 如果已经超时,直接返回,不发送结果
return
}
}()
// 等待结果或超时
var location string
select {
case location = <-locationChan:
// 成功获取位置信息
case <-ctx.Done():
// 超时,返回空字符串
// 注意:内部的 goroutine 可能仍在执行,但由于 HTTP client 有 3 秒超时,会很快结束
location = ""
}
// 执行回调(带超时保护)
if callback != nil {
callbackCtx, callbackCancel := context.WithTimeout(context.Background(), callbackTimeout)
defer callbackCancel()
callbackDone := make(chan struct{}, 1)
go func() {
defer func() {
if r := recover(); r != nil {
// callback 中发生 panic,静默处理
}
callbackDone <- struct{}{}
}()
callback(location)
}()
// 等待 callback 完成或超时
select {
case <-callbackDone:
// callback 正常完成
case <-callbackCtx.Done():
// callback 执行超时,不等待(防止阻塞)
// 注意:callback 仍在执行,但不会阻塞当前 goroutine
}
}
}()
}
+163
View File
@@ -0,0 +1,163 @@
package utils
import (
"fmt"
"net"
"strings"
"github.com/goravel/framework/support/str"
apperrors "goravel/app/errors"
)
// IsIPInBlacklist 检查IP是否在黑名单中
// ip: 要检查的IP地址
// blacklistIPs: 黑名单IP字符串,支持:
// - 单个IP: 192.168.1.1
// - 多个IP(逗号分隔): 192.168.1.1,192.168.1.2
// - CIDR格式: 192.168.0.0/24
// - IP范围: 192.168.1.1-192.168.1.100
func IsIPInBlacklist(ip string, blacklistIPs string) bool {
if blacklistIPs == "" {
return false
}
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 分割多个IP(支持逗号分隔)
ipList := strings.SplitSeq(blacklistIPs, ",")
for blacklistIP := range ipList {
blacklistIP = str.Of(blacklistIP).Trim().String()
if str.Of(blacklistIP).IsEmpty() {
continue
}
// 检查单个IP
if blacklistIP == ip {
return true
}
// 检查CIDR格式 (192.168.0.0/24)
if str.Of(blacklistIP).Contains("/") {
_, ipNet, err := net.ParseCIDR(blacklistIP)
if err == nil && ipNet.Contains(parsedIP) {
return true
}
}
// 检查IP范围格式 (192.168.1.1-192.168.1.100)
if str.Of(blacklistIP).Contains("-") {
parts := str.Of(blacklistIP).Split("-")
if len(parts) == 2 {
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
if startIP != nil && endIP != nil {
if isIPInRange(parsedIP, startIP, endIP) {
return true
}
}
}
}
}
return false
}
// isIPInRange 检查IP是否在指定范围内
func isIPInRange(ip, startIP, endIP net.IP) bool {
ipBytes := ip.To4()
startBytes := startIP.To4()
endBytes := endIP.To4()
if ipBytes == nil || startBytes == nil || endBytes == nil {
return false
}
// 将IP地址转换为32位整数进行比较
ipInt := uint32(ipBytes[0])<<24 | uint32(ipBytes[1])<<16 | uint32(ipBytes[2])<<8 | uint32(ipBytes[3])
startInt := uint32(startBytes[0])<<24 | uint32(startBytes[1])<<16 | uint32(startBytes[2])<<8 | uint32(startBytes[3])
endInt := uint32(endBytes[0])<<24 | uint32(endBytes[1])<<16 | uint32(endBytes[2])<<8 | uint32(endBytes[3])
return ipInt >= startInt && ipInt <= endInt
}
// ValidateBlacklistIP 验证黑名单IP格式
// 返回业务错误类型,如果格式正确返回 nil
func ValidateBlacklistIP(ipStr string) error {
if ipStr == "" {
return apperrors.ErrIPAddressRequired
}
ipList := strings.SplitSeq(ipStr, ",")
for ip := range ipList {
ip = str.Of(ip).Trim().String()
if str.Of(ip).IsEmpty() {
continue
}
// 检查CIDR格式
if str.Of(ip).Contains("/") {
_, _, err := net.ParseCIDR(ip)
if err != nil {
return apperrors.ErrInvalidCIDRFormat.WithMessage(fmt.Sprintf("CIDR格式错误: %s", ip))
}
continue
}
// 检查IP范围格式
if str.Of(ip).Contains("-") {
parts := str.Of(ip).Split("-")
if len(parts) != 2 {
return apperrors.ErrInvalidIPRangeFormat.WithMessage(fmt.Sprintf("IP范围格式错误: %s (格式应为: 192.168.1.1-192.168.1.100)", ip))
}
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
if startIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("起始IP格式错误: %s", parts[0]))
}
if endIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("结束IP格式错误: %s", parts[1]))
}
// 验证范围是否有效(结束IP应该大于等于起始IP)
if !isIPInRange(endIP, startIP, endIP) && !endIP.Equal(startIP) {
// 简单比较:检查结束IP是否大于起始IP
startBytes := startIP.To4()
endBytes := endIP.To4()
if startBytes != nil && endBytes != nil {
for i := range 4 {
if endBytes[i] < startBytes[i] {
return apperrors.ErrInvalidIPRangeOrder.WithMessage(fmt.Sprintf("结束IP必须大于等于起始IP: %s", ip))
}
if endBytes[i] > startBytes[i] {
break
}
}
}
}
continue
}
// 检查单个IP格式
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("IP地址格式错误: %s", ip))
}
}
return nil
}
// FormatIPRange 格式化IP范围显示
func FormatIPRange(ipStr string) string {
if str.Of(ipStr).Contains("-") {
parts := str.Of(ipStr).Split("-")
if len(parts) == 2 {
return fmt.Sprintf("%s ~ %s", str.Of(parts[0]).Trim().String(), str.Of(parts[1]).Trim().String())
}
}
return ipStr
}
+76
View File
@@ -0,0 +1,76 @@
package utils
import (
"strings"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
)
// ParseAcceptLanguage 解析 Accept-Language 请求头
// 格式: "zh-CN,zh;q=0.9,en;q=0.8" 或 "en-US,en;q=0.9"
// 返回: 语言代码("cn" 或 "en"),如果无法解析则返回空字符串
func ParseAcceptLanguage(acceptLanguage string) string {
if acceptLanguage == "" {
return ""
}
// 分割语言列表
languages := strings.Split(acceptLanguage, ",")
if len(languages) == 0 {
return ""
}
// 取第一个语言
firstLang := strings.TrimSpace(languages[0])
// 移除质量值(如果有)
if idx := strings.Index(firstLang, ";"); idx != -1 {
firstLang = firstLang[:idx]
}
// 转换为小写并提取语言代码
firstLang = strings.ToLower(strings.TrimSpace(firstLang))
// 处理语言代码(如 zh-CN -> cn, en-US -> en
if strings.HasPrefix(firstLang, "zh") {
return "cn"
}
if strings.HasPrefix(firstLang, "en") {
return "en"
}
return ""
}
// GetCurrentLanguage 获取当前请求的语言(从 HTTP Context
// 优先从请求头 Accept-Language 获取,其次从查询参数获取,最后使用默认语言
func GetCurrentLanguage(ctx http.Context) string {
// 优先从请求头 Accept-Language 获取语言
acceptLanguage := ctx.Request().Header("Accept-Language", "")
lang := ParseAcceptLanguage(acceptLanguage)
// 如果请求头没有,尝试从查询参数获取
if lang == "" {
lang = ctx.Request().Input("lang")
}
// 如果都没有,使用默认语言
if lang == "" {
lang = facades.Config().GetString("app.locale")
}
// 验证并规范化语言代码
return NormalizeLanguage(lang)
}
// NormalizeLanguage 验证并规范化语言代码
// 只支持 "cn" 和 "en",其他值会返回默认语言
func NormalizeLanguage(lang string) string {
// 验证语言是否支持(只支持 cn 和 en)
if lang != "cn" && lang != "en" {
return facades.Config().GetString("app.locale", "cn")
}
return lang
}
+205
View File
@@ -0,0 +1,205 @@
package utils
import (
"context"
"fmt"
"time"
"github.com/goravel/framework/facades"
"github.com/redis/go-redis/v9"
)
// LockResult 锁操作结果
type LockResult struct {
Acquired bool // 是否成功获取锁
Error error // 错误信息
Client *redis.Client // Redis 客户端(如果使用 Redis
}
// TryAcquireLock 尝试获取分布式锁(原子操作)
// lockKey: 锁的键名
// lockValue: 锁的值(用于标识锁的拥有者)
// ttl: 锁的过期时间
// 返回 LockResult,包含是否成功获取锁和错误信息
func TryAcquireLock(lockKey, lockValue string, ttl time.Duration) *LockResult {
result := &LockResult{
Acquired: false,
}
// 优先使用 Redis SETNX 原子操作
redisClient, err := GetRedisClient("default")
if err != nil {
facades.Log().Warningf("获取 Redis 客户端失败,降级到缓存锁: %v", err)
// 降级到普通缓存检查(不保证原子性,但至少提供基本保护)
return tryAcquireLockWithCache(lockKey, lockValue, ttl)
}
// 使用 Redis SETNX 原子操作(SET key value NX EX seconds
// SetNX 会自动设置过期时间,如果键已存在且未过期,返回 false
ctx := context.Background()
acquired, err := redisClient.SetNX(ctx, lockKey, lockValue, ttl).Result()
if err != nil {
facades.Log().Errorf("Redis 获取锁失败: key=%s, error=%v", lockKey, err)
result.Error = fmt.Errorf("获取锁失败: %v", err)
// 注意:使用公共 Redis 客户端,不需要手动关闭
return result
}
if !acquired {
// 锁已被占用,检查锁的剩余过期时间
ttlResult, _ := redisClient.TTL(ctx, lockKey).Result()
if ttlResult <= 0 {
// 锁已过期,删除它并重试
redisClient.Del(ctx, lockKey)
// 重试一次
acquired, err = redisClient.SetNX(ctx, lockKey, lockValue, ttl).Result()
if err != nil {
facades.Log().Errorf("Redis 重试获取锁失败: key=%s, error=%v", lockKey, err)
result.Error = fmt.Errorf("获取锁失败: %v", err)
// 注意:使用公共 Redis 客户端,不需要手动关闭
return result
}
}
}
result.Acquired = acquired
if acquired {
result.Client = redisClient // 只有获取成功才保留客户端
} else {
result.Client = nil
}
return result
}
// ReleaseLock 释放锁
// 注意:只有锁的拥有者才能释放锁(通过 lockValue 验证)
func ReleaseLock(lockKey, lockValue string, client *redis.Client) error {
if client == nil {
// 如果没有 Redis 客户端,使用缓存删除
_ = facades.Cache().Forget(lockKey)
return nil
}
ctx := context.Background()
// 使用 Lua 脚本确保只有锁的拥有者才能释放锁
script := `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`
_, err := client.Eval(ctx, script, []string{lockKey}, lockValue).Result()
if err != nil {
// 如果 Lua 脚本执行失败,尝试直接删除(降级处理)
return client.Del(ctx, lockKey).Err()
}
return nil
}
// CloseLockClient 关闭锁的 Redis 客户端
// 注意:由于现在使用公共 Redis 客户端池,通常不需要手动关闭
// 此函数保留用于兼容性,但不会真正关闭客户端(客户端由连接池管理)
func CloseLockClient(client *redis.Client) {
// 使用公共 Redis 客户端池,不需要手动关闭
// 客户端由连接池统一管理
}
// tryAcquireLockWithCache 使用缓存实现锁(降级方案,不保证原子性)
// 注意:由于缓存操作的检查-设置不是原子的,在高并发下可能仍有竞态条件
// 但至少可以提供基本的保护
func tryAcquireLockWithCache(lockKey, lockValue string, ttl time.Duration) *LockResult {
result := &LockResult{
Acquired: false,
}
// 检查是否已有锁
var cachedValue string
cacheErr := facades.Cache().Get(lockKey, &cachedValue)
// 如果 Get 返回 nil(成功)且值不为空,说明锁已存在
if cacheErr == nil && cachedValue != "" {
// 锁已存在
result.Error = fmt.Errorf("锁已被占用")
return result
}
// 设置锁(使用 Put,如果键已存在会覆盖,但至少可以防止大部分重复请求)
if err := facades.Cache().Put(lockKey, lockValue, ttl); err != nil {
facades.Log().Errorf("设置缓存锁失败: key=%s, error=%v", lockKey, err)
result.Error = fmt.Errorf("设置锁失败: %v", err)
return result
}
// 再次检查,确保锁设置成功(双重检查,减少竞态条件的影响)
// 短暂延迟,让其他并发请求有机会检测到锁
time.Sleep(50 * time.Millisecond)
var verifyValue string
if facades.Cache().Get(lockKey, &verifyValue) == nil {
if verifyValue == lockValue {
// 锁设置成功且值匹配
result.Acquired = true
return result
}
// 值不匹配,说明被其他请求覆盖了(竞态条件)
result.Error = fmt.Errorf("锁已被占用")
return result
}
// 锁设置后无法验证,可能是缓存问题
result.Error = fmt.Errorf("设置锁失败")
return result
}
// LockGuard 锁保护器,自动管理锁的生命周期
type LockGuard struct {
lockKey string
lockValue string
client *redis.Client
acquired bool
}
// AcquireLock 获取锁
// lockKey: 锁的键名(会自动添加用户ID前缀,格式:lockKey:userID
// userID: 用户ID(用于生成唯一的锁键)
// ttl: 锁的过期时间
// 返回 LockGuard 和错误,如果锁已被占用,返回 ErrLockAcquired 错误
func AcquireLock(lockKey string, userID uint, ttl time.Duration) (*LockGuard, error) {
// 生成完整的锁键和值
fullLockKey := fmt.Sprintf("%s:%d", lockKey, userID)
lockValue := fmt.Sprintf("%d_%d", userID, time.Now().Unix())
// 尝试获取锁
lockResult := TryAcquireLock(fullLockKey, lockValue, ttl)
if lockResult.Error != nil {
return nil, lockResult.Error
}
if !lockResult.Acquired {
return nil, fmt.Errorf("锁已被占用")
}
return &LockGuard{
lockKey: fullLockKey,
lockValue: lockValue,
client: lockResult.Client,
acquired: true,
}, nil
}
// Release 释放锁
func (g *LockGuard) Release() {
if !g.acquired {
return
}
if g.client != nil {
if err := ReleaseLock(g.lockKey, g.lockValue, g.client); err != nil {
facades.Log().Errorf("释放 Redis 锁失败: key=%s, error=%v", g.lockKey, err)
}
CloseLockClient(g.client)
g.client = nil // 防止重复关闭
} else {
// 使用缓存时,直接删除
_ = facades.Cache().Forget(g.lockKey)
}
g.acquired = false
}
+85
View File
@@ -0,0 +1,85 @@
package logger
import (
"context"
"fmt"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
"goravel/app/utils/traceid"
)
// DebugfHTTP logs a debug message and automatically attaches trace_id from the http context.
// Debug messages are only shown when APP_DEBUG=true
func DebugfHTTP(ctx http.Context, format string, args ...any) {
if ctx == nil {
facades.Log().Debugf(format, args...)
return
}
trace := traceid.FromHTTPContext(ctx)
facades.Log().Debugf(prependTrace(trace, format), args...)
}
// Debugf logs a debug message without any context.
// Debug messages are only shown when APP_DEBUG=true
func Debugf(format string, args ...any) {
facades.Log().Debugf(format, args...)
}
// InfofHTTP logs an info message and automatically attaches trace_id from the http context.
func InfofHTTP(ctx http.Context, format string, args ...any) {
if ctx == nil {
facades.Log().Infof(format, args...)
return
}
trace := traceid.FromHTTPContext(ctx)
facades.Log().Infof(prependTrace(trace, format), args...)
}
// WarnfHTTP logs a warning and automatically attaches trace_id from the http context.
func WarnfHTTP(ctx http.Context, format string, args ...any) {
if ctx == nil {
facades.Log().Warningf(format, args...)
return
}
trace := traceid.FromHTTPContext(ctx)
facades.Log().Warningf(prependTrace(trace, format), args...)
}
// ErrorfHTTP logs an error and automatically attaches trace_id from the http context.
func ErrorfHTTP(ctx http.Context, format string, args ...any) {
if ctx == nil {
facades.Log().Errorf(format, args...)
return
}
trace := traceid.FromHTTPContext(ctx)
facades.Log().Errorf(prependTrace(trace, format), args...)
}
// ErrorfContext logs an error with a standard context's trace id (if available).
func ErrorfContext(ctx context.Context, format string, args ...any) {
if ctx == nil {
facades.Log().Errorf(format, args...)
return
}
trace := traceid.FromContext(ctx)
facades.Log().Errorf(prependTrace(trace, format), args...)
}
// Errorf logs an error without any context (fallback).
func Errorf(format string, args ...any) {
facades.Log().Errorf(format, args...)
}
func prependTrace(traceID, format string) string {
if traceID == "" {
return format
}
return fmt.Sprintf("[trace_id=%s] %s", traceID, format)
}
+294
View File
@@ -0,0 +1,294 @@
package utils
import (
"strings"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
"github.com/goravel/framework/support/str"
"goravel/app/models"
)
// GetOperationTitleFromContext 从 context 中获取操作标题
// 优先使用权限标识(permission_slug),如果没有权限标识,则根据路径和方法生成默认标题
func GetOperationTitleFromContext(ctx http.Context) string {
if ctx == nil {
return "operation.unknown"
}
// 优先从 context 中获取权限标识(由权限中间件设置)
permissionSlugValue := ctx.Value("permission_slug")
if permissionSlugValue != nil {
if permissionSlug, ok := permissionSlugValue.(string); ok && permissionSlug != "" {
// 直接返回权限标识,前端多语言文件中已有对应翻译
return permissionSlug
}
}
// 如果没有权限标识,尝试从权限表中查询匹配的权限
method := ctx.Request().Method()
path := ctx.Request().Path()
// 先从权限表中查询匹配的权限
permissionSlug := findPermissionSlugFromDB(method, path)
if permissionSlug != "" {
return permissionSlug
}
// 如果权限表中没有找到,根据路径和方法生成默认标题
defaultTitle := generateDefaultTitle(method, path)
if defaultTitle != "" {
return defaultTitle
}
// 无法生成标题时,返回未知操作
return "operation.unknown"
}
// generateDefaultTitle 根据方法和路径生成默认操作标题
func generateDefaultTitle(method, path string) string {
pathStr := str.Of(path)
// 分片上传相关(与权限配置中的 slug 保持一致)
if pathStr.Contains("/attachments/chunk") {
if method == "POST" || method == "GET" {
// 权限配置中的 slug 是 attachment.chunk
return "attachment.chunk"
}
}
// 附件上传
if pathStr.Contains("/attachments/upload") && method == "POST" {
return "attachment.upload"
}
// 附件删除
if pathStr.Contains("/attachments/") && pathStr.EndsWith("/batch-delete") && method == "POST" {
return "attachment.batch_delete"
}
if pathStr.Contains("/attachments/") && method == "DELETE" {
return "attachment.destroy"
}
// 附件更新显示名称
if pathStr.Contains("/attachments/") && pathStr.EndsWith("/display-name") && method == "PUT" {
return "attachment.update_display_name"
}
// 导出下载
if pathStr.Contains("/exports/") && pathStr.EndsWith("/download") && method == "GET" {
return "export.download"
}
// 订单导入
if pathStr.Contains("/orders/import") && method == "POST" {
return "order.import"
}
// 订单导出
if pathStr.Contains("/orders/export") && method == "POST" {
return "order.export"
}
// 管理员解绑谷歌验证码
if pathStr.Contains("/admins/") && pathStr.EndsWith("/unbind-google-auth") && method == "POST" {
return "admin.unbind_google_auth"
}
// 更新个人资料
if pathStr.EndsWith("/profile") && (method == "PUT" || method == "PATCH") {
return "profile.update"
}
// 修改密码
if pathStr.EndsWith("/password") && (method == "PUT" || method == "PATCH") {
return "password.update"
}
// 批量删除(通用模式)
if pathStr.EndsWith("/batch-delete") && method == "POST" {
parts := pathStr.ChopStart("/api/admin/").Split("/")
if len(parts) > 0 {
module := str.Of(parts[0]).Replace("-", "_").String()
return str.Of(module).Append(".batch_delete").String()
}
}
// 清理操作(通用模式)
if pathStr.EndsWith("/clean") && method == "POST" {
parts := pathStr.ChopStart("/api/admin/").Split("/")
if len(parts) > 0 {
module := str.Of(parts[0]).Replace("-", "_").String()
return str.Of(module).Append(".clean").String()
}
}
// 标准 CRUD 操作(通用模式)
parts := pathStr.ChopStart("/api/admin/").Split("/")
if len(parts) >= 1 {
module := str.Of(parts[0]).Replace("-", "_").String()
switch method {
case "POST":
// 创建操作
if len(parts) == 1 || (len(parts) == 2 && parts[1] != "batch-delete" && parts[1] != "clean") {
return str.Of(module).Append(".store").String()
}
case "PUT", "PATCH":
// 更新操作
if len(parts) >= 2 {
return str.Of(module).Append(".update").String()
}
case "DELETE":
// 删除操作
if len(parts) >= 2 {
return str.Of(module).Append(".destroy").String()
}
}
}
return ""
}
// findPermissionSlugFromDB 从权限表中查询匹配的权限标识
// 优先匹配精确路径,然后匹配通配符路径
func findPermissionSlugFromDB(method, path string) string {
var permissions []models.Permission
// 查询所有启用的权限,方法匹配或方法为空
query := facades.Orm().Query().Model(&models.Permission{}).
Where("status", 1).
Where("(method = ? OR method = '')", method)
if err := query.Find(&permissions); err != nil {
// 查询失败时返回空,使用默认逻辑
return ""
}
// 优先匹配精确路径
for _, perm := range permissions {
if perm.Path == path {
return perm.Slug
}
}
// 然后匹配通配符路径
for _, perm := range permissions {
if perm.Path != "" && perm.Path != path {
if matchPermissionPath(perm.Path, path) {
return perm.Slug
}
}
}
return ""
}
// matchPermissionPath 路径匹配,支持通配符(与权限中间件中的 matchPath 逻辑一致)
func matchPermissionPath(pattern, path string) bool {
if pattern == path {
return true
}
// 如果模式不包含通配符,直接返回 false
if !containsChar(pattern, '*') {
return false
}
// 将模式按 * 分割成多个部分
parts := splitPatternString(pattern)
if len(parts) == 0 {
return false
}
// 如果模式以 * 开头,需要特殊处理
if pattern[0] == '*' {
// 检查路径是否以模式的剩余部分结尾
if len(parts) > 1 {
suffix := parts[1]
return len(path) >= len(suffix) && path[len(path)-len(suffix):] == suffix
}
return true
}
// 如果模式以 * 结尾
if pattern[len(pattern)-1] == '*' {
prefix := pattern[:len(pattern)-1]
if len(path) >= len(prefix) {
pathPrefix := path[:len(prefix)]
if pathPrefix == prefix {
// 如果前缀以 / 结尾,路径必须比前缀长(即后面还有内容)
if len(prefix) > 0 && prefix[len(prefix)-1] == '/' {
return len(path) > len(prefix)
}
// 如果前缀不以 / 结尾,路径可以等于或长于前缀
return true
}
}
return false
}
// 处理中间有通配符的情况,如 /api/admin/attachments/*/display-name
// 过滤掉 "*" 标记,只保留实际的部分
var actualParts []string
for _, part := range parts {
if part != "*" {
actualParts = append(actualParts, part)
}
}
if len(actualParts) < 2 {
return false
}
// 检查路径是否以第一部分开头
firstPart := actualParts[0]
if len(path) < len(firstPart) || path[:len(firstPart)] != firstPart {
return false
}
// 检查路径是否以最后一部分结尾
lastPart := actualParts[len(actualParts)-1]
if len(path) < len(lastPart) || path[len(path)-len(lastPart):] != lastPart {
return false
}
// 检查中间部分是否存在(通配符匹配任意内容)
remainingPath := path[len(firstPart) : len(path)-len(lastPart)]
// 确保中间部分不为空(至少有一个字符,通常是数字ID)
return len(remainingPath) > 0
}
// containsChar 检查字符串是否包含指定字符
func containsChar(s string, c byte) bool {
for i := range s {
if s[i] == c {
return true
}
}
return false
}
// splitPatternString 按 * 分割模式字符串
func splitPatternString(pattern string) []string {
var parts []string
var current strings.Builder
for i := range pattern {
if pattern[i] == '*' {
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
parts = append(parts, "*")
} else {
current.WriteByte(pattern[i])
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
+139
View File
@@ -0,0 +1,139 @@
package utils
import (
"context"
"fmt"
"sync"
"time"
"github.com/goravel/framework/facades"
"github.com/redis/go-redis/v9"
)
var (
// redisClients 缓存不同连接名的 Redis 客户端
redisClients sync.Map
// redisMutex 用于保护客户端创建过程
redisMutex sync.Mutex
)
// GetRedisClient 获取 Redis 客户端(使用连接池,支持多连接)
// connectionName: Redis 连接名称,默认为 "default"
// 返回缓存的 Redis 客户端,如果不存在则创建并缓存
func GetRedisClient(connectionName string) (*redis.Client, error) {
if connectionName == "" {
connectionName = "default"
}
// 先从缓存中获取
if client, ok := redisClients.Load(connectionName); ok {
redisClient := client.(*redis.Client)
// 测试连接是否有效
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
err := redisClient.Ping(ctx).Err()
cancel()
if err == nil {
return redisClient, nil
}
// 连接失效,从缓存中移除
redisClients.Delete(connectionName)
}
// 使用互斥锁确保只创建一个客户端
redisMutex.Lock()
defer redisMutex.Unlock()
// 双重检查,防止并发创建
if client, ok := redisClients.Load(connectionName); ok {
return client.(*redis.Client), nil
}
// 创建新的 Redis 客户端
client, err := createRedisClient(connectionName)
if err != nil {
return nil, err
}
// 缓存客户端
redisClients.Store(connectionName, client)
return client, nil
}
// createRedisClient 创建 Redis 客户端
func createRedisClient(connectionName string) (*redis.Client, error) {
// 获取 Redis 配置
host := facades.Config().GetString(fmt.Sprintf("database.redis.%s.host", connectionName), "")
if host == "" {
// 尝试使用 default 连接
host = facades.Config().GetString("database.redis.default.host", "127.0.0.1")
}
port := facades.Config().GetInt(fmt.Sprintf("database.redis.%s.port", connectionName), 0)
if port == 0 {
port = facades.Config().GetInt("database.redis.default.port", 6379)
}
password := facades.Config().GetString(fmt.Sprintf("database.redis.%s.password", connectionName), "")
if password == "" {
password = facades.Config().GetString("database.redis.default.password", "")
}
db := facades.Config().GetInt(fmt.Sprintf("database.redis.%s.database", connectionName), -1)
if db == -1 {
db = facades.Config().GetInt("database.redis.default.database", 0)
}
// 创建 Redis 客户端(使用连接池配置)
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", host, port),
Password: password,
DB: db,
PoolSize: 10, // 连接池大小
MinIdleConns: 5, // 最小空闲连接数
MaxRetries: 3, // 最大重试次数
})
// 测试连接(设置超时)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
client.Close() // 连接失败,关闭客户端
return nil, fmt.Errorf("Redis 连接失败 [%s]: %v", connectionName, err)
}
return client, nil
}
// CloseRedisClient 关闭指定的 Redis 客户端并从缓存中移除
func CloseRedisClient(connectionName string) error {
if connectionName == "" {
connectionName = "default"
}
if client, ok := redisClients.LoadAndDelete(connectionName); ok {
redisClient := client.(*redis.Client)
return redisClient.Close()
}
return nil
}
// CloseAllRedisClients 关闭所有缓存的 Redis 客户端
func CloseAllRedisClients() error {
var errs []error
redisClients.Range(func(key, value any) bool {
client := value.(*redis.Client)
if err := client.Close(); err != nil {
errs = append(errs, err)
}
redisClients.Delete(key)
return true
})
if len(errs) > 0 {
return fmt.Errorf("关闭 Redis 客户端时发生错误: %v", errs)
}
return nil
}
+49
View File
@@ -0,0 +1,49 @@
package utils
import (
"github.com/goravel/framework/facades"
"github.com/goravel/framework/support/str"
)
// IsSensitiveField 检查字段名是否是敏感字段
func IsSensitiveField(fieldName string) bool {
keyLower := str.Of(fieldName).Lower().String()
// 获取配置的敏感字段列表
sensitiveFieldsInterface := facades.Config().Get("operation_log.sensitive_fields", []string{})
if sensitiveFields, ok := sensitiveFieldsInterface.([]any); ok {
for _, fieldInterface := range sensitiveFields {
if field, ok := fieldInterface.(string); ok {
if keyLower == str.Of(field).Lower().String() {
return true
}
}
}
} else if sensitiveFields, ok := sensitiveFieldsInterface.([]string); ok {
for _, field := range sensitiveFields {
if keyLower == str.Of(field).Lower().String() {
return true
}
}
}
// 检查是否包含敏感关键词
sensitiveKeywordsInterface := facades.Config().Get("operation_log.sensitive_keywords", []string{})
if sensitiveKeywords, ok := sensitiveKeywordsInterface.([]any); ok {
for _, keywordInterface := range sensitiveKeywords {
if keyword, ok := keywordInterface.(string); ok {
if str.Of(keyLower).Contains(str.Of(keyword).Lower().String()) {
return true
}
}
}
} else if sensitiveKeywords, ok := sensitiveKeywordsInterface.([]string); ok {
for _, keyword := range sensitiveKeywords {
if str.Of(keyLower).Contains(str.Of(keyword).Lower().String()) {
return true
}
}
}
return false
}
+285
View File
@@ -0,0 +1,285 @@
package utils
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/goravel/framework/facades"
"goravel/app/constants"
"goravel/app/utils/errorlog"
)
// GetShardingTableName 根据时间获取分表名称
// baseTableName: 基础表名,如 "orders"
// orderTime: 订单时间
// 返回: 分表名称,如 "orders_202501"
func GetShardingTableName(baseTableName string, orderTime time.Time) string {
return fmt.Sprintf("%s_%s", baseTableName, orderTime.Format("200601"))
}
// GetUserBalanceLogsShardingTableName 根据 user_id 获取用户余额变动记录分表名称
// userID: 用户ID
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
// 分表逻辑:user_id % UserBalanceLogsShards
// 注意:此函数为特定表实现,如需为其他表实现哈希分表,请使用 GetHashShardingTableName
func GetUserBalanceLogsShardingTableName(userID uint) string {
return GetHashShardingTableName("user_balance_logs", userID, constants.UserBalanceLogsShards)
}
// HashShardingConfig 哈希分表配置
type HashShardingConfig struct {
BaseTableName string // 基础表名,如 "user_balance_logs"
NumberOfShards int // 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
}
// 预定义的哈希分表配置
var (
// UserBalanceLogsShardingConfig 用户余额变动记录分表配置
UserBalanceLogsShardingConfig = HashShardingConfig{
BaseTableName: "user_balance_logs",
NumberOfShards: constants.UserBalanceLogsShards,
}
)
// GetHashShardingTableName 通用的哈希分表名称生成函数
// baseTableName: 基础表名,如 "user_balance_logs"
// shardingKey: 分表键值(uint 类型),如 user_id
// numberOfShards: 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
// 分表逻辑:shardingKey % numberOfShards
//
// 使用示例:
//
// // 为 user_balance_logs 表分表(4个分表)
// tableName := GetHashShardingTableName("user_balance_logs", userID, 4)
//
// // 为新的表 example_table 分表(8个分表)
// tableName := GetHashShardingTableName("example_table", entityID, 8)
func GetHashShardingTableName(baseTableName string, shardingKey uint, numberOfShards int) string {
if numberOfShards <= 0 {
// 如果分表数量无效,返回基础表名(不分表)
return baseTableName
}
shardIndex := int(shardingKey) % numberOfShards
return fmt.Sprintf("%s_%d", baseTableName, shardIndex)
}
// GetHashShardingTableNameByConfig 通过配置获取哈希分表名称
func GetHashShardingTableNameByConfig(config HashShardingConfig, shardingKey uint) string {
return GetHashShardingTableName(config.BaseTableName, shardingKey, config.NumberOfShards)
}
// GetAllHashShardingTableNames 获取所有哈希分表名称列表
// 用于批量操作所有分表(如迁移、统计等)
func GetAllHashShardingTableNames(config HashShardingConfig) []string {
tableNames := make([]string, config.NumberOfShards)
for i := 0; i < config.NumberOfShards; i++ {
tableNames[i] = fmt.Sprintf("%s_%d", config.BaseTableName, i)
}
return tableNames
}
// GetHashShardIndex 获取分表索引
func GetHashShardIndex(shardingKey uint, numberOfShards int) int {
if numberOfShards <= 0 {
return 0
}
return int(shardingKey) % numberOfShards
}
// GetShardingTableNames 获取时间范围内的所有分表名称
// baseTableName: 基础表名
// startTime: 开始时间
// endTime: 结束时间
// 返回: 分表名称列表
func GetShardingTableNames(baseTableName string, startTime, endTime time.Time) []string {
var tableNames []string
// 如果结束时间是零值,使用当前时间
if endTime.IsZero() {
endTime = time.Now().UTC()
}
// 确保开始时间不晚于结束时间
if startTime.After(endTime) {
return tableNames
}
// 从开始时间到结束时间,按月遍历
current := time.Date(startTime.Year(), startTime.Month(), 1, 0, 0, 0, 0, startTime.Location())
end := time.Date(endTime.Year(), endTime.Month(), 1, 0, 0, 0, 0, endTime.Location())
for !current.After(end) {
tableNames = append(tableNames, GetShardingTableName(baseTableName, current))
current = current.AddDate(0, 1, 0) // 加一个月
}
return tableNames
}
// DefaultMaxTimeRangeMonths 默认最大时间范围(月数)
// 可以通过配置覆盖,用于限制查询时间范围,避免跨太多分表
const DefaultMaxTimeRangeMonths = 3
// TimeRangeError 时间范围验证错误
type TimeRangeError struct {
Key string
Params map[string]any
}
func (e *TimeRangeError) Error() string {
// 返回翻译键,由调用方进行翻译
return e.Key
}
// ValidateTimeRange 验证时间范围是否超过指定月数
// startTime: 开始时间
// endTime: 结束时间
// maxMonths: 最大允许的月数,如果为0则使用默认值 DefaultMaxTimeRangeMonths
// 返回: 是否有效,错误信息(错误信息包含翻译键和参数)
func ValidateTimeRange(startTime, endTime time.Time, maxMonths ...int) (bool, error) {
// 检查开始时间是否晚于结束时间(忽略零值时间)
if !startTime.IsZero() && !endTime.IsZero() && startTime.After(endTime) {
return false, &TimeRangeError{
Key: "start_time_after_end_time",
Params: nil,
}
}
// 确定最大月数
maxMonthsValue := DefaultMaxTimeRangeMonths
if len(maxMonths) > 0 && maxMonths[0] > 0 {
maxMonthsValue = maxMonths[0]
}
// 检查时间范围是否超过指定月数(忽略零值结束时间)
if !endTime.IsZero() {
maxTimeLater := startTime.AddDate(0, maxMonthsValue, 0)
if endTime.After(maxTimeLater) {
return false, &TimeRangeError{
Key: "time_range_exceeded",
Params: map[string]any{
"months": maxMonthsValue,
},
}
}
}
return true, nil
}
// GetAllExistingShardingTables 获取数据库中所有已存在的分表名称
// baseTableName: 基础表名,如 "orders" 或 "order_details"
// 返回: 已存在的分表名称列表
func GetAllExistingShardingTables(baseTableName string) ([]string, error) {
var tableNames []string
// 获取当前数据库名
dbName := facades.Config().GetString("database.connections.mysql.database")
if dbName == "" {
dbName = facades.Config().GetString("database.connections.postgresql.database")
}
// 构建表名匹配模式:orders_YYYYMM 或 order_details_YYYYMM
pattern := fmt.Sprintf("%s_%%", baseTableName)
// 查询所有匹配的表名
query := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = ?
AND table_name LIKE ?
ORDER BY table_name
`
// 执行查询,使用 Scan 获取结果
var rows []map[string]any
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
"pattern": pattern,
"error": err.Error(),
}, "查询分表失败: %v", err)
return nil, fmt.Errorf("查询分表失败: %v", err)
}
// 验证表名格式(确保是有效的分表名称,格式为 baseTableName_YYYYMM
// 从 pattern 中提取基础表名(移除末尾的 %)
baseTableNameFromPattern := strings.TrimSuffix(pattern, "_%")
patternRegex := regexp.MustCompile(fmt.Sprintf("^%s_\\d{6}$", regexp.QuoteMeta(baseTableNameFromPattern)))
for _, row := range rows {
// 尝试不同的字段名格式(MySQL 可能返回不同的大小写)
var tableName string
var ok bool
// 尝试 table_name (小写)
if tableName, ok = row["table_name"].(string); !ok {
// 尝试 TABLE_NAME (大写)
if tableName, ok = row["TABLE_NAME"].(string); !ok {
// 尝试遍历所有键
for key, value := range row {
if (key == "table_name" || key == "TABLE_NAME" || key == "Table_Name") && value != nil {
if str, ok := value.(string); ok {
tableName = str
ok = true
break
}
}
}
}
}
if ok && tableName != "" {
// 验证表名格式
if patternRegex.MatchString(tableName) {
tableNames = append(tableNames, tableName)
}
}
}
return tableNames, nil
}
// GetAllExistingShardingTablesByPattern 通过表名模式获取所有已存在的分表
// 这是一个更通用的方法,可以通过自定义模式匹配
// pattern: 表名匹配模式,如 "orders_%" 或 "order_details_%"
func GetAllExistingShardingTablesByPattern(pattern string) ([]string, error) {
var tableNames []string
// 获取当前数据库名
dbName := facades.Config().GetString("database.connections.mysql.database")
if dbName == "" {
dbName = facades.Config().GetString("database.connections.postgresql.database")
}
// 查询所有匹配的表名
query := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = ?
AND table_name LIKE ?
ORDER BY table_name
`
// 执行查询,使用 Scan 获取结果
var rows []map[string]any
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
"pattern": pattern,
"error": err.Error(),
}, "查询分表失败: %v", err)
return nil, fmt.Errorf("查询分表失败: %v", err)
}
for _, row := range rows {
if tableName, ok := row["table_name"].(string); ok {
tableNames = append(tableNames, tableName)
}
}
return tableNames, nil
}
+81
View File
@@ -0,0 +1,81 @@
package utils
import (
"fmt"
"strings"
"time"
"github.com/oklog/ulid/v2"
)
// ShardingNoConfig 分表单号配置
type ShardingNoConfig struct {
Prefix string // 单号前缀,如 "ORD"、"PAY"
DateFormat string // 日期格式,如 "200601"(年月)、"20060102"(年月日)
}
// 预定义配置
var (
// OrderNoConfig 订单号配置(ORD + YYYYMM + ULID
OrderNoConfig = ShardingNoConfig{
Prefix: "ORD",
DateFormat: "200601", // YYYYMM
}
// PaymentNoConfig 支付单号配置(PAY + YYYYMMDD + ULID
PaymentNoConfig = ShardingNoConfig{
Prefix: "PAY",
DateFormat: "20060102", // YYYYMMDD
}
)
// GenerateShardingNo 生成分表单号
// 格式:前缀 + 日期 + ULID
// 例如:ORD20250101ARZ3S0K5M2X9P4Q6R8T1V3W5Y7Z9
func GenerateShardingNo(config ShardingNoConfig) string {
now := time.Now().UTC()
dateStr := now.Format(config.DateFormat)
ulidStr := ulid.Make().String()
return fmt.Sprintf("%s%s%s", config.Prefix, dateStr, ulidStr)
}
// ParseShardingNoDate 从分表单号解析日期
// 返回:解析的时间和是否成功
func ParseShardingNoDate(no string, config ShardingNoConfig) (time.Time, bool) {
prefixLen := len(config.Prefix)
dateLen := len(config.DateFormat)
minLen := prefixLen + dateLen + 1 // 前缀 + 日期 + 至少1位ULID
// 验证长度和前缀
if len(no) < minLen || !strings.HasPrefix(no, config.Prefix) {
return time.Time{}, false
}
// 提取日期部分
dateStr := no[prefixLen : prefixLen+dateLen]
// 验证日期格式(简单验证:都是数字)
for _, c := range dateStr {
if c < '0' || c > '9' {
return time.Time{}, false
}
}
// 解析日期
parsedTime, err := time.Parse(config.DateFormat, dateStr)
if err != nil {
return time.Time{}, false
}
return parsedTime, true
}
// ParseShardingNoYearMonth 从分表单号解析年月字符串(用于分表定位)
// 返回:年月字符串(如 "202501")和是否成功
func ParseShardingNoYearMonth(no string, config ShardingNoConfig) (string, bool) {
parsedTime, ok := ParseShardingNoDate(no, config)
if !ok {
return "", false
}
return parsedTime.Format("200601"), true
}
+91
View File
@@ -0,0 +1,91 @@
package utils
import "time"
// 时间格式常量
const (
// DateTimeFormat 标准日期时间格式 (YYYY-MM-DD HH:mm:ss)
DateTimeFormat = "2006-01-02 15:04:05"
// DateFormat 标准日期格式 (YYYY-MM-DD)
DateFormat = "2006-01-02"
// DateTimeFormatT 带T的日期时间格式 (YYYY-MM-DDTHH:mm:ss)
DateTimeFormatT = "2006-01-02T15:04:05"
// DateTimeFormatMs 带毫秒的日期时间格式
DateTimeFormatMs = "2006-01-02 15:04:05.000"
// DateTimeFormatTZ 带时区的日期时间格式
DateTimeFormatTZ = "2006-01-02T15:04:05.000Z07:00"
// YearMonthFormat 年月格式 (YYYYMM)
YearMonthFormat = "200601"
)
// ParseDateTime 解析标准格式的日期时间字符串 (2006-01-02 15:04:05)
func ParseDateTime(s string) (time.Time, error) {
return time.Parse(DateTimeFormat, s)
}
// ParseDateTimeUTC 解析标准格式的日期时间字符串并转换为 UTC
func ParseDateTimeUTC(s string) (time.Time, error) {
t, err := time.Parse(DateTimeFormat, s)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
// ParseDate 解析标准日期格式字符串 (2006-01-02)
func ParseDate(s string) (time.Time, error) {
return time.Parse(DateFormat, s)
}
// ParseDateUTC 解析标准日期格式字符串并转换为 UTC
func ParseDateUTC(s string) (time.Time, error) {
t, err := time.Parse(DateFormat, s)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
// FormatDateTime 格式化时间为标准日期时间字符串 (2006-01-02 15:04:05)
func FormatDateTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(DateTimeFormat)
}
// FormatDate 格式化时间为标准日期字符串 (2006-01-02)
func FormatDate(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(DateFormat)
}
// FormatDateTimePtr 格式化时间指针为标准日期时间字符串
func FormatDateTimePtr(t *time.Time) string {
if t == nil || t.IsZero() {
return ""
}
return t.Format(DateTimeFormat)
}
// ParseDateTimeInLocation 在指定时区解析日期时间字符串
func ParseDateTimeInLocation(s string, loc *time.Location) (time.Time, error) {
return time.ParseInLocation(DateTimeFormat, s, loc)
}
// FormatYearMonth 格式化时间为年月字符串 (200601)
func FormatYearMonth(t time.Time) string {
return t.Format(YearMonthFormat)
}
// ParseYearMonth 解析年月格式字符串 (200601)
func ParseYearMonth(s string) (time.Time, error) {
return time.Parse(YearMonthFormat, s)
}
+118
View File
@@ -0,0 +1,118 @@
package traceid
import (
"context"
"strings"
"github.com/goravel/framework/contracts/http"
"github.com/oklog/ulid/v2"
)
type contextKey string
const (
ContextKey contextKey = "trace_id"
headerKey string = "X-Trace-Id"
requestIDHeader string = "X-Request-Id"
)
// Generate returns a new trace id using ULID to ensure it sorts well and is URL safe.
func Generate() string {
return strings.ToLower(ulid.Make().String())
}
// EnsureHTTPContext stores a trace id on the http.Context (and returns it).
func EnsureHTTPContext(ctx http.Context, preferred string) string {
if ctx == nil {
return Generate()
}
traceID := preferred
if traceID == "" {
traceID = ctx.Request().Header(headerKey, "")
}
if traceID == "" {
traceID = ctx.Request().Header(requestIDHeader, "")
}
if traceID == "" {
traceID = Generate()
}
ctx.WithValue(string(ContextKey), traceID)
return traceID
}
// FromHTTPContext retrieves the stored trace id from http.Context.
func FromHTTPContext(ctx http.Context) string {
if ctx == nil {
return ""
}
if value := ctx.Value(string(ContextKey)); value != nil {
if traceID, ok := value.(string); ok {
return traceID
}
}
return ""
}
// StoreHTTPContext stores an existing trace id into the http context.
func StoreHTTPContext(ctx http.Context, traceID string) {
if ctx == nil || traceID == "" {
return
}
ctx.WithValue(string(ContextKey), traceID)
}
// EnsureContext ensures a standard context carries a trace id and returns both.
func EnsureContext(ctx context.Context) (context.Context, string) {
traceID := FromContext(ctx)
if traceID == "" {
traceID = Generate()
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, ContextKey, traceID), traceID
}
// WithTrace assigns a trace id into the provided context (or background if nil).
func WithTrace(ctx context.Context, traceID string) context.Context {
if ctx == nil {
ctx = context.Background()
}
if traceID == "" {
traceID = Generate()
}
return context.WithValue(ctx, ContextKey, traceID)
}
// FromContext reads the trace id from a standard context.
func FromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if traceID, ok := ctx.Value(ContextKey).(string); ok {
return traceID
}
return ""
}
// DeriveContextFromHTTP builds a standard context containing the http trace id.
func DeriveContextFromHTTP(ctx http.Context) context.Context {
traceID := FromHTTPContext(ctx)
if traceID == "" {
traceID = Generate()
StoreHTTPContext(ctx, traceID)
}
return context.WithValue(context.Background(), ContextKey, traceID)
}
// HeaderName exposes the header used to propagate the trace id.
func HeaderName() string {
return headerKey
}
// RequestHeaderFallback exposes secondary header name for incoming trace ids.
func RequestHeaderFallback() string {
return requestIDHeader
}
+122
View File
@@ -0,0 +1,122 @@
package utils
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/goravel/framework/facades"
"goravel/lang"
)
// loadLangMessages 加载语言文件的 messages 对象(内部辅助函数)
// 优先从文件系统读取,失败时从 embed FS 读取(支持生产环境)
// 返回 messages map 和是否成功加载
func loadLangMessages(langCode string) (map[string]any, bool) {
var langData []byte
var err error
// 获取语言文件路径(使用框架配置的路径)
langPath := facades.Config().GetString("app.lang_path", "lang")
langFile := filepath.Join(langPath, fmt.Sprintf("%s.json", langCode))
// 优先尝试从文件系统读取
langData, err = os.ReadFile(langFile)
if err != nil {
// 文件系统读取失败,尝试从 embed FS 读取
embedFile := fmt.Sprintf("%s.json", langCode)
langData, err = lang.FS.ReadFile(embedFile)
if err != nil {
facades.Log().Debugf("读取语言文件失败: %s, embed: %s, error=%v", langFile, embedFile, err)
return nil, false
}
}
// 解析 JSON
var langMap map[string]any
if err := json.Unmarshal(langData, &langMap); err != nil {
facades.Log().Debugf("解析语言文件失败: lang=%s, error=%v", langCode, err)
return nil, false
}
// 获取 messages 对象(使用类型辅助函数)
messages, ok := GetMap(langMap, "messages")
if !ok {
facades.Log().Debugf("语言文件中没有 messages 对象: lang=%s", langCode)
return nil, false
}
return messages, true
}
// TranslateHeaders 翻译表头(直接读取语言文件)
// 这是一个通用函数,可以被任何需要翻译表头的 job 使用
//
// 参数:
// - headerKeys: 需要翻译的键列表(如 ["export_header_id", "export_header_order_no"]
// - lang: 语言代码(如 "cn" 或 "en"
//
// 返回:
// - []string: 翻译后的表头列表,如果翻译失败则返回原始键
//
// 示例:
//
// headerKeys := []string{"export_header_id", "export_header_order_no"}
// headers := utils.TranslateHeaders(headerKeys, "cn")
func TranslateHeaders(headerKeys []string, lang string) []string {
headers := make([]string, len(headerKeys))
// 加载语言文件的 messages 对象
messages, ok := loadLangMessages(lang)
if !ok {
// 如果读取失败,使用原始键
facades.Log().Warningf("加载语言文件失败: lang=%s, 使用原始键", lang)
return append([]string(nil), headerKeys...)
}
// 翻译每个键(使用泛型辅助函数)
for i, key := range headerKeys {
fullKey := "messages." + key
if value, ok := GetString(messages, key); ok && value != "" {
headers[i] = value
} else {
// 如果翻译失败,使用原始键
headers[i] = key
facades.Log().Debugf("翻译键未找到: %s (语言: %s)", fullKey, lang)
}
}
return headers
}
// TranslateKey 翻译单个键(从语言文件读取,支持文件系统和embed FS)
// 这是一个通用函数,可以被任何需要翻译单个键的地方使用(包括Job等无http.Context的场景)
//
// 参数:
// - key: 需要翻译的键(如 "export_order_status_pending"
// - lang: 语言代码(如 "cn" 或 "en"
// - defaultValue: 如果翻译失败时返回的默认值(通常为原始键)
//
// 返回:
// - string: 翻译后的文本,如果翻译失败则返回 defaultValue
//
// 示例:
//
// statusText := utils.TranslateKey("export_order_status_pending", "cn", "pending")
func TranslateKey(key, lang, defaultValue string) string {
// 加载语言文件的 messages 对象(支持文件系统和embed FS)
messages, ok := loadLangMessages(lang)
if !ok {
return defaultValue
}
// 尝试获取翻译值
if value, ok := GetString(messages, key); ok && value != "" {
return value
}
// 如果翻译失败,返回默认值
return defaultValue
}
+226
View File
@@ -0,0 +1,226 @@
package utils
import (
"fmt"
"reflect"
)
// GetValue 从 map[string]any 中安全地获取指定类型的值
// 支持多种类型转换,如果转换失败返回零值和 false
func GetValue[T any](m map[string]any, key string) (T, bool) {
var zero T
val, ok := m[key]
if !ok {
return zero, false
}
// 尝试直接类型断言
if v, ok := val.(T); ok {
return v, true
}
// 对于数字类型,尝试从其他数字类型转换
return convertNumeric[T](val)
}
// GetUint 从 map[string]any 中获取 uint 值(支持多种数字类型转换)
func GetUint(m map[string]any, key string) (uint, bool) {
return GetValue[uint](m, key)
}
// GetFloat64 从 map[string]any 中获取 float64 值(支持多种数字类型转换)
func GetFloat64(m map[string]any, key string) (float64, bool) {
return GetValue[float64](m, key)
}
// GetString 从 map[string]any 中获取 string 值
func GetString(m map[string]any, key string) (string, bool) {
val, ok := m[key]
if !ok {
return "", false
}
if v, ok := val.(string); ok {
return v, true
}
return "", false
}
// GetMap 从 map[string]any 中获取 map[string]any 值
func GetMap(m map[string]any, key string) (map[string]any, bool) {
val, ok := m[key]
if !ok {
return nil, false
}
if v, ok := val.(map[string]any); ok {
return v, true
}
return nil, false
}
// convertNumeric 将值转换为数字类型(支持多种数字类型)
func convertNumeric[T any](val any) (T, bool) {
var zero T
switch v := val.(type) {
case float64:
return convertFromFloat64[T](v)
case int:
return convertFromInt[T](v)
case uint:
return convertFromUint[T](v)
case int64:
return convertFromInt64[T](v)
case uint64:
return convertFromUint64[T](v)
default:
return zero, false
}
}
// convertFromFloat64 从 float64 转换
func convertFromFloat64[T any](v float64) (T, bool) {
var zero T
switch any(zero).(type) {
case uint:
return any(uint(v)).(T), true
case int:
return any(int(v)).(T), true
case float64:
return any(v).(T), true
default:
return zero, false
}
}
// convertFromInt 从 int 转换
func convertFromInt[T any](v int) (T, bool) {
var zero T
switch any(zero).(type) {
case uint:
return any(uint(v)).(T), true
case int:
return any(v).(T), true
case float64:
return any(float64(v)).(T), true
default:
return zero, false
}
}
// convertFromUint 从 uint 转换
func convertFromUint[T any](v uint) (T, bool) {
var zero T
switch any(zero).(type) {
case uint:
return any(v).(T), true
case int:
return any(int(v)).(T), true
case float64:
return any(float64(v)).(T), true
default:
return zero, false
}
}
// convertFromInt64 从 int64 转换
func convertFromInt64[T any](v int64) (T, bool) {
var zero T
switch any(zero).(type) {
case uint:
return any(uint(v)).(T), true
case int:
return any(int(v)).(T), true
case float64:
return any(float64(v)).(T), true
default:
return zero, false
}
}
// convertFromUint64 从 uint64 转换
func convertFromUint64[T any](v uint64) (T, bool) {
var zero T
switch any(zero).(type) {
case uint:
return any(uint(v)).(T), true
case int:
return any(int(v)).(T), true
case float64:
return any(float64(v)).(T), true
default:
return zero, false
}
}
// MustGetValue 从 map[string]any 中获取值,如果不存在或类型不匹配则 panic
// 仅在确定值存在且类型正确时使用
func MustGetValue[T any](m map[string]any, key string) T {
val, ok := GetValue[T](m, key)
if !ok {
panic(fmt.Sprintf("key %s not found or type mismatch in map", key))
}
return val
}
// FillFiltersFromMap 从 map[string]any 填充 Filters 结构体
// 支持 string, uint, float64 类型,使用字段名的 snake_case 作为 map 的 key
// 示例:
//
// filters := services.OrderFilters{}
// utils.FillFiltersFromMap(m, &filters)
func FillFiltersFromMap(m map[string]any, filtersPtr any) {
v := reflect.ValueOf(filtersPtr)
if v.Kind() != reflect.Ptr || v.IsNil() {
return
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanSet() {
continue
}
structField := t.Field(i)
// 获取 json tag 或使用 snake_case 字段名
key := structField.Tag.Get("json")
if key == "" || key == "-" {
key = toSnakeCase(structField.Name)
}
switch field.Kind() {
case reflect.String:
if val, ok := GetString(m, key); ok {
field.SetString(val)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if val, ok := GetUint(m, key); ok {
field.SetUint(uint64(val))
}
case reflect.Float64, reflect.Float32:
if val, ok := GetFloat64(m, key); ok {
field.SetFloat(val)
}
}
}
}
// toSnakeCase 将 PascalCase/camelCase 转换为 snake_case
func toSnakeCase(s string) string {
var result []byte
for i, c := range s {
if c >= 'A' && c <= 'Z' {
if i > 0 {
result = append(result, '_')
}
result = append(result, byte(c+'a'-'A'))
} else {
result = append(result, byte(c))
}
}
return string(result)
}
+77
View File
@@ -0,0 +1,77 @@
package utils
import (
"time"
"github.com/goravel/framework/support/str"
"github.com/oklog/ulid/v2"
)
// GenerateULID 生成 ULID(默认使用当前时间,与 Laravel 行为一致)
// 返回: ULID 字符串
func GenerateULID() string {
entropy := ulid.DefaultEntropy()
return ulid.MustNew(ulid.Timestamp(time.Now()), entropy).String()
}
// GenerateULIDWithTime 使用指定时间生成 ULID
// t: 指定的时间
// 返回: ULID 字符串
func GenerateULIDWithTime(t time.Time) string {
entropy := ulid.DefaultEntropy()
return ulid.MustNew(ulid.Timestamp(t), entropy).String()
}
// ParseULID 解析 ULID 字符串
// ulidStr: ULID 字符串
// 返回: ULID 对象和错误
func ParseULID(ulidStr string) (ulid.ULID, error) {
return ulid.Parse(ulidStr)
}
// ParseULIDTime 从 ULID 解析出时间
// ulidStr: ULID 字符串
// 返回: 时间对象和错误
func ParseULIDTime(ulidStr string) (time.Time, error) {
id, err := ulid.Parse(ulidStr)
if err != nil {
return time.Time{}, err
}
return ulid.Time(id.Time()), nil
}
// ParseULIDTimeString 从 ULID 解析出时间字符串
// ulidStr: ULID 字符串
// format: 时间格式,如 DateTimeFormat
// 返回: 格式化的时间字符串和错误
func ParseULIDTimeString(ulidStr string, format string) (string, error) {
t, err := ParseULIDTime(ulidStr)
if err != nil {
return "", err
}
if format == "" {
format = DateTimeFormat
}
return t.Format(format), nil
}
// IsValidULID 验证 ULID 字符串是否有效
// 使用 Goravel 框架的字符串库进行验证(与框架的 IsUlid 方法一致)
// 参考: https://www.goravel.dev/zh_CN/digging-deeper/strings.html#isulid
// ulidStr: ULID 字符串
// 返回: 是否有效
func IsValidULID(ulidStr string) bool {
// 使用 Goravel 框架的字符串库验证(框架只提供验证功能,不提供生成功能)
return str.Of(ulidStr).IsUlid()
}
// GetULIDTimestamp 获取 ULID 的时间戳(毫秒)
// ulidStr: ULID 字符串
// 返回: Unix 时间戳(毫秒)和错误
func GetULIDTimestamp(ulidStr string) (int64, error) {
id, err := ulid.Parse(ulidStr)
if err != nil {
return 0, err
}
return int64(id.Time()), nil
}