init
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"goravel/app/models"
|
||||
)
|
||||
|
||||
// FormatBalance 根据货币的小数位数格式化余额
|
||||
// 如果货币没有小数(decimalPlaces = 0),则不返回小数部分
|
||||
func FormatBalance(balance float64, currency *models.Currency) float64 {
|
||||
if currency == nil {
|
||||
// 如果没有货币信息,默认使用2位小数
|
||||
return math.Round(balance*100) / 100
|
||||
}
|
||||
|
||||
decimalPlaces := currency.DecimalPlaces
|
||||
if decimalPlaces < 0 {
|
||||
decimalPlaces = 0
|
||||
}
|
||||
if decimalPlaces > 8 {
|
||||
decimalPlaces = 8
|
||||
}
|
||||
|
||||
// 计算精度倍数
|
||||
multiplier := math.Pow(10, float64(decimalPlaces))
|
||||
|
||||
// 四舍五入到指定小数位数
|
||||
formatted := math.Round(balance*multiplier) / multiplier
|
||||
|
||||
return formatted
|
||||
}
|
||||
|
||||
// FormatBalanceString 格式化余额为字符串(去除末尾多余的0)
|
||||
func FormatBalanceString(balance float64, currency *models.Currency) string {
|
||||
formatted := FormatBalance(balance, currency)
|
||||
|
||||
if currency == nil {
|
||||
return fmt.Sprintf("%.2f", formatted)
|
||||
}
|
||||
|
||||
decimalPlaces := currency.DecimalPlaces
|
||||
if decimalPlaces < 0 {
|
||||
decimalPlaces = 0
|
||||
}
|
||||
if decimalPlaces > 8 {
|
||||
decimalPlaces = 8
|
||||
}
|
||||
|
||||
// 如果小数位数为0,使用整数格式
|
||||
if decimalPlaces == 0 {
|
||||
return fmt.Sprintf("%.0f", formatted)
|
||||
}
|
||||
|
||||
// 格式化字符串,去除末尾的0
|
||||
format := fmt.Sprintf("%%.%df", decimalPlaces)
|
||||
result := fmt.Sprintf(format, formatted)
|
||||
|
||||
// 去除末尾的0和小数点(如果小数部分全为0)
|
||||
if decimalPlaces > 0 {
|
||||
// 去除末尾的0
|
||||
for len(result) > 0 && result[len(result)-1] == '0' {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
// 如果最后是小数点,也去除
|
||||
if len(result) > 0 && result[len(result)-1] == '.' {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
|
||||
"goravel/app/models"
|
||||
)
|
||||
|
||||
// GetConfigValue 从数据库获取配置值
|
||||
// group: 配置分组
|
||||
// key: 配置键
|
||||
// defaultValue: 默认值(如果配置不存在)
|
||||
func GetConfigValue(group, key string, defaultValue string) string {
|
||||
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// 静默处理,返回默认值
|
||||
}
|
||||
}()
|
||||
|
||||
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
|
||||
// 这样可以避免在构建时执行数据库查询
|
||||
orm := facades.Orm()
|
||||
if orm == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
var config models.Config
|
||||
err := orm.Query().Where("group", group).Where("key", key).First(&config)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
if config.Value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return config.Value
|
||||
}
|
||||
|
||||
// GetConfigValueInt 从数据库获取配置值(整数类型)
|
||||
func GetConfigValueInt(group, key string, defaultValue int) int {
|
||||
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// 静默处理,返回默认值
|
||||
}
|
||||
}()
|
||||
|
||||
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
|
||||
// 这样可以避免在构建时执行数据库查询
|
||||
orm := facades.Orm()
|
||||
if orm == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
var config models.Config
|
||||
err := orm.Query().Where("group", group).Where("key", key).First(&config)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
if config.Value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
// 简单的字符串转整数,实际可以使用更完善的转换
|
||||
value := 0
|
||||
_, err = fmt.Sscanf(config.Value, "%d", &value)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// GetConfigValueBool 从数据库获取配置值(布尔类型)
|
||||
func GetConfigValueBool(group, key string, defaultValue bool) bool {
|
||||
// 使用 recover 来捕获可能的 panic(例如在构建时数据库不可用)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// 静默处理,返回默认值
|
||||
}
|
||||
}()
|
||||
|
||||
// 尝试检查数据库连接是否可用,如果不可用则直接返回默认值
|
||||
// 这样可以避免在构建时执行数据库查询
|
||||
orm := facades.Orm()
|
||||
if orm == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
var config models.Config
|
||||
err := orm.Query().Where("group", group).Where("key", key).First(&config)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
if config.Value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
// 支持 "1", "true", "True" 等格式
|
||||
value := config.Value
|
||||
return value == "1" || value == "true" || value == "True" || value == "TRUE"
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/database/orm"
|
||||
"github.com/goravel/framework/facades"
|
||||
)
|
||||
|
||||
// CountOptimizer 分页统计优化器
|
||||
// 当数据量超过阈值时,使用执行计划估算的行数,否则使用实际的 count(*)
|
||||
type CountOptimizer struct {
|
||||
// Threshold 阈值,超过此值使用估算值(默认 10000)
|
||||
Threshold int64
|
||||
// ModuleName 模块名称,用于日志记录
|
||||
ModuleName string
|
||||
}
|
||||
|
||||
// NewCountOptimizer 创建新的 CountOptimizer
|
||||
// threshold: 阈值,超过此值使用估算值(默认 10000)
|
||||
// moduleName: 模块名称,用于日志记录
|
||||
func NewCountOptimizer(threshold int64, moduleName string) *CountOptimizer {
|
||||
if threshold <= 0 {
|
||||
threshold = 10000 // 默认阈值
|
||||
}
|
||||
return &CountOptimizer{
|
||||
Threshold: threshold,
|
||||
ModuleName: moduleName,
|
||||
}
|
||||
}
|
||||
|
||||
// OptimizedCount 优化的 count 查询(使用 ORM Query 对象)
|
||||
// query: ORM 查询对象(已应用筛选条件,但未应用排序和分页)
|
||||
// 注意:此方法会先尝试估算,如果失败则使用实际 count
|
||||
// 返回:总数、是否使用估算值、错误
|
||||
func (co *CountOptimizer) OptimizedCount(query orm.Query) (int64, bool, error) {
|
||||
// 先尝试执行实际 count(如果数据量小,直接 count 也很快)
|
||||
// 如果数据量大,我们可以通过执行时间来判断,但更简单的方法是先估算
|
||||
// 这里我们提供一个简化版本:直接使用实际 count,如果慢的话可以后续优化
|
||||
|
||||
// 由于无法直接从 ORM Query 提取 SQL,这里提供一个变通方案:
|
||||
// 先执行一次快速查询获取估算值(需要表名)
|
||||
// 但更推荐使用 OptimizedCountWithTable 方法
|
||||
|
||||
actualCount, err := query.Count()
|
||||
return actualCount, false, err
|
||||
}
|
||||
|
||||
// extractRowsFromPostgreSQLExplain 从 PostgreSQL EXPLAIN JSON 结果中提取行数
|
||||
func (co *CountOptimizer) extractRowsFromPostgreSQLExplain(jsonStr string) int64 {
|
||||
// 简化实现:使用正则表达式提取 "Plan" -> "Plan Rows" 的值
|
||||
// 实际应该使用 JSON 解析,但为了简化,使用正则
|
||||
re := regexp.MustCompile(`"Plan Rows":\s*(\d+)`)
|
||||
matches := re.FindStringSubmatch(jsonStr)
|
||||
if len(matches) > 1 {
|
||||
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
|
||||
return rows
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// extractRowsFromPostgreSQLExplainText 从 PostgreSQL EXPLAIN 文本结果中提取行数
|
||||
func (co *CountOptimizer) extractRowsFromPostgreSQLExplainText(result []map[string]any) int64 {
|
||||
// PostgreSQL 文本格式的 EXPLAIN 结果中,行数通常在 "rows=" 后面
|
||||
for _, row := range result {
|
||||
if queryPlan, ok := row["QUERY PLAN"]; ok {
|
||||
if planStr, ok := queryPlan.(string); ok {
|
||||
// 使用正则表达式提取 rows= 后面的数字
|
||||
re := regexp.MustCompile(`rows=(\d+)`)
|
||||
matches := re.FindStringSubmatch(planStr)
|
||||
if len(matches) > 1 {
|
||||
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
|
||||
return rows
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// OptimizedCountWithTable 使用表名和 WHERE 条件进行优化的 count 查询
|
||||
// tableName: 表名
|
||||
// whereClause: WHERE 子句(不包含 WHERE 关键字),例如:"status = ? AND user_id = ?"
|
||||
// args: WHERE 条件的参数
|
||||
// 返回:总数、是否使用估算值、错误
|
||||
func (co *CountOptimizer) OptimizedCountWithTable(tableName, whereClause string, args ...any) (int64, bool, error) {
|
||||
dbConnection := facades.Config().GetString("database.default", "sqlite")
|
||||
|
||||
// 构建 COUNT SQL
|
||||
var countSQL string
|
||||
if whereClause != "" {
|
||||
switch dbConnection {
|
||||
case "mysql":
|
||||
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s` WHERE %s", tableName, whereClause)
|
||||
case "postgres":
|
||||
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s WHERE %s", tableName, whereClause)
|
||||
default:
|
||||
// 其他数据库不支持估算,直接使用实际 count
|
||||
var result struct {
|
||||
Cnt int64
|
||||
}
|
||||
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return result.Cnt, false, nil
|
||||
}
|
||||
} else {
|
||||
switch dbConnection {
|
||||
case "mysql":
|
||||
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s`", tableName)
|
||||
case "postgres":
|
||||
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s", tableName)
|
||||
default:
|
||||
// 其他数据库不支持估算,直接使用实际 count
|
||||
var result struct {
|
||||
Cnt int64
|
||||
}
|
||||
if err := facades.Orm().Query().Raw(countSQL).Scan(&result); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return result.Cnt, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 EXPLAIN SQL
|
||||
var explainSQL string
|
||||
if dbConnection == "mysql" {
|
||||
if whereClause != "" {
|
||||
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s` WHERE %s", tableName, whereClause)
|
||||
} else {
|
||||
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s`", tableName)
|
||||
}
|
||||
} else if dbConnection == "postgres" {
|
||||
if whereClause != "" {
|
||||
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s WHERE %s", tableName, whereClause)
|
||||
} else {
|
||||
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s", tableName)
|
||||
}
|
||||
} else {
|
||||
// 其他数据库不支持估算,直接使用实际 count
|
||||
var result struct {
|
||||
Cnt int64
|
||||
}
|
||||
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return result.Cnt, false, nil
|
||||
}
|
||||
|
||||
// 先执行 EXPLAIN 查询获取估算值(快速,不需要特别精准)
|
||||
estimatedCount, err := co.executeExplain(explainSQL, args...)
|
||||
if err != nil {
|
||||
// 如果估算失败,使用实际 count
|
||||
var result struct {
|
||||
Cnt int64
|
||||
}
|
||||
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return result.Cnt, false, nil
|
||||
}
|
||||
|
||||
// 如果估算值超过阈值,直接返回估算值(不需要执行实际 count,更快)
|
||||
if estimatedCount >= co.Threshold {
|
||||
return estimatedCount, true, nil
|
||||
}
|
||||
|
||||
// 估算值小于阈值,执行实际的 count 获取精确值
|
||||
var result struct {
|
||||
Cnt int64
|
||||
}
|
||||
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return result.Cnt, false, nil
|
||||
}
|
||||
|
||||
// executeExplain 执行 EXPLAIN 查询并提取估算行数
|
||||
func (co *CountOptimizer) executeExplain(explainSQL string, args ...any) (int64, error) {
|
||||
dbConnection := facades.Config().GetString("database.default", "sqlite")
|
||||
|
||||
switch dbConnection {
|
||||
case "mysql":
|
||||
// MySQL EXPLAIN 返回表格格式
|
||||
var explainResult []map[string]any
|
||||
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(explainResult) == 0 {
|
||||
return 0, fmt.Errorf("explain result is empty")
|
||||
}
|
||||
|
||||
// MySQL EXPLAIN 结果中,rows 字段包含估算行数
|
||||
if rows, ok := explainResult[0]["rows"]; ok {
|
||||
switch v := rows.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case float64:
|
||||
return int64(v), nil
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("cannot extract rows from explain result")
|
||||
|
||||
case "postgres":
|
||||
// PostgreSQL EXPLAIN (FORMAT JSON) 返回 JSON 格式
|
||||
var explainResult []map[string]any
|
||||
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(explainResult) == 0 {
|
||||
return 0, fmt.Errorf("explain result is empty")
|
||||
}
|
||||
|
||||
// PostgreSQL EXPLAIN (FORMAT JSON) 结果是一个包含 JSON 字符串的数组
|
||||
if queryPlan, ok := explainResult[0]["QUERY PLAN"]; ok {
|
||||
if planStr, ok := queryPlan.(string); ok {
|
||||
// 从 JSON 字符串中提取 "Plan" -> "Plan Rows" 的值
|
||||
rows := co.extractRowsFromPostgreSQLExplain(planStr)
|
||||
if rows > 0 {
|
||||
return rows, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 JSON 格式解析失败,尝试文本格式
|
||||
explainTextSQL := strings.Replace(explainSQL, "EXPLAIN (FORMAT JSON)", "EXPLAIN", 1)
|
||||
var explainTextResult []map[string]any
|
||||
if err := facades.Orm().Query().Raw(explainTextSQL, args...).Get(&explainTextResult); err == nil {
|
||||
rows := co.extractRowsFromPostgreSQLExplainText(explainTextResult)
|
||||
if rows > 0 {
|
||||
return rows, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("cannot extract rows from explain result")
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("unsupported database: %s", dbConnection)
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package errorlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/http"
|
||||
"github.com/goravel/framework/facades"
|
||||
|
||||
"goravel/app/models"
|
||||
"goravel/app/utils/logger"
|
||||
"goravel/app/utils/traceid"
|
||||
)
|
||||
|
||||
// RecordHTTP 同时记录文件日志和数据库日志(用于系统级错误,默认 error 级别)
|
||||
// 使用场景:数据库操作失败、系统服务异常、关键业务逻辑错误等
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// if err != nil {
|
||||
// errorlog.RecordHTTP(ctx, "auth", "Failed to save admin profile", map[string]any{
|
||||
// "error": err.Error(),
|
||||
// "admin_id": admin.ID,
|
||||
// }, "Save admin profile error: %v", err)
|
||||
// return response.Error(ctx, http.StatusInternalServerError, "update_failed")
|
||||
// }
|
||||
func RecordHTTP(ctx http.Context, module, message string, attributes map[string]any, format string, args ...any) {
|
||||
RecordHTTPWithLevel(ctx, "error", module, message, attributes, format, args...)
|
||||
}
|
||||
|
||||
// RecordHTTPWithLevel 同时记录文件日志和数据库日志(可指定日志级别)
|
||||
// level: 日志级别,支持 "error", "warning", "info", "debug"
|
||||
// 使用场景:需要记录不同级别的系统日志
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// // 记录警告
|
||||
// errorlog.RecordHTTPWithLevel(ctx, "warning", "auth", "Unusual login pattern detected", map[string]any{
|
||||
// "admin_id": admin.ID,
|
||||
// "ip": ctx.Request().Ip(),
|
||||
// }, "Unusual login pattern: %s", pattern)
|
||||
//
|
||||
// // 记录信息
|
||||
// errorlog.RecordHTTPWithLevel(ctx, "info", "payment", "Payment processed successfully", map[string]any{
|
||||
// "order_id": order.ID,
|
||||
// "amount": order.Amount,
|
||||
// }, "Payment processed: %d", order.ID)
|
||||
func RecordHTTPWithLevel(ctx http.Context, level, module, message string, attributes map[string]any, format string, args ...any) {
|
||||
// 清理和验证日志级别,防止伪造
|
||||
level = sanitizeLogLevel(level)
|
||||
|
||||
// 清理日志格式字符串,防止格式字符串注入
|
||||
sanitizedFormat := sanitizeLogFormat(format)
|
||||
|
||||
// 清理日志参数,防止注入攻击
|
||||
sanitizedArgs := sanitizeLogArgs(args...)
|
||||
|
||||
// 根据级别选择不同的日志函数
|
||||
// 注意:logger 包目前只有 ErrorfHTTP,所以所有级别都使用它
|
||||
// 但在日志消息中添加级别前缀,便于区分和过滤
|
||||
levelPrefix := "[" + strings.ToUpper(level) + "] "
|
||||
formattedMessage := levelPrefix + fmt.Sprintf(sanitizedFormat, sanitizedArgs...)
|
||||
|
||||
switch level {
|
||||
case "error":
|
||||
logger.ErrorfHTTP(ctx, formattedMessage)
|
||||
case "warning":
|
||||
logger.ErrorfHTTP(ctx, formattedMessage)
|
||||
case "info", "debug":
|
||||
logger.ErrorfHTTP(ctx, formattedMessage)
|
||||
default:
|
||||
logger.ErrorfHTTP(ctx, formattedMessage)
|
||||
}
|
||||
|
||||
// 记录到数据库(所有级别都记录 trace_id)
|
||||
// 清理 message 和 attributes,防止注入
|
||||
if ctx != nil {
|
||||
sanitizedMessage := sanitizeLogString(message, 500) // 限制消息长度
|
||||
sanitizedAttributes := sanitizeAttributes(attributes)
|
||||
recordToDatabaseHTTPWithLevel(ctx, level, module, sanitizedMessage, sanitizedAttributes)
|
||||
}
|
||||
}
|
||||
|
||||
// Record 同时记录文件日志和数据库日志(用于标准 context,默认 error 级别)
|
||||
// 使用场景:goroutine、后台任务等
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// go func(ctx context.Context) {
|
||||
// if err != nil {
|
||||
// errorlog.Record(ctx, "operation-log", "Failed to create operation log", map[string]any{
|
||||
// "error": err.Error(),
|
||||
// }, "Create operation log error: %v", err)
|
||||
// }
|
||||
// }(traceCtx)
|
||||
func Record(ctx context.Context, module, message string, attributes map[string]any, format string, args ...any) {
|
||||
RecordWithLevel(ctx, "error", module, message, attributes, format, args...)
|
||||
}
|
||||
|
||||
// RecordWithLevel 同时记录文件日志和数据库日志(可指定日志级别,用于标准 context)
|
||||
// level: 日志级别,支持 "error", "warning", "info", "debug"
|
||||
// 使用场景:goroutine、后台任务中需要记录不同级别的日志
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// go func(ctx context.Context) {
|
||||
// errorlog.RecordWithLevel(ctx, "info", "background-task", "Task completed", map[string]any{
|
||||
// "task_id": taskID,
|
||||
// }, "Background task completed: %s", taskID)
|
||||
// }(traceCtx)
|
||||
func RecordWithLevel(ctx context.Context, level, module, message string, attributes map[string]any, format string, args ...any) {
|
||||
// 清理和验证日志级别,防止伪造
|
||||
level = sanitizeLogLevel(level)
|
||||
|
||||
// 清理日志格式字符串,防止格式字符串注入
|
||||
sanitizedFormat := sanitizeLogFormat(format)
|
||||
|
||||
// 清理日志参数,防止注入攻击
|
||||
sanitizedArgs := sanitizeLogArgs(args...)
|
||||
|
||||
// 根据级别选择不同的日志函数
|
||||
// 注意:logger 包目前只有 ErrorfContext,所以所有级别都使用它
|
||||
// 但在日志消息中添加级别前缀,便于区分和过滤
|
||||
levelPrefix := "[" + strings.ToUpper(level) + "] "
|
||||
formattedMessage := levelPrefix + fmt.Sprintf(sanitizedFormat, sanitizedArgs...)
|
||||
|
||||
switch level {
|
||||
case "error":
|
||||
logger.ErrorfContext(ctx, formattedMessage)
|
||||
case "warning", "info", "debug":
|
||||
logger.ErrorfContext(ctx, formattedMessage)
|
||||
default:
|
||||
logger.ErrorfContext(ctx, formattedMessage)
|
||||
}
|
||||
|
||||
// 记录到数据库(所有级别都记录 trace_id)
|
||||
// 清理 message 和 attributes,防止注入
|
||||
if ctx != nil {
|
||||
sanitizedMessage := sanitizeLogString(message, 1000) // 限制消息长度
|
||||
sanitizedAttributes := sanitizeAttributes(attributes)
|
||||
recordToDatabaseWithLevel(ctx, level, module, sanitizedMessage, sanitizedAttributes)
|
||||
}
|
||||
}
|
||||
|
||||
// recordToDatabaseHTTPWithLevel 将日志记录到数据库(HTTP context,支持所有级别)
|
||||
func recordToDatabaseHTTPWithLevel(ctx http.Context, level, module, message string, attributes map[string]any) {
|
||||
var contextJSON string
|
||||
if len(attributes) > 0 {
|
||||
if data, err := json.Marshal(attributes); err == nil {
|
||||
contextJSON = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
traceID := traceid.FromHTTPContext(ctx)
|
||||
if traceID == "" {
|
||||
traceID = traceid.EnsureHTTPContext(ctx, "")
|
||||
}
|
||||
|
||||
log := models.SystemLog{
|
||||
Level: level,
|
||||
Module: module,
|
||||
TraceID: traceID,
|
||||
Message: message,
|
||||
Context: contextJSON,
|
||||
IP: ctx.Request().Ip(),
|
||||
UserAgent: ctx.Request().Header("User-Agent", ""),
|
||||
}
|
||||
|
||||
_ = facades.Orm().Query().Create(&log)
|
||||
}
|
||||
|
||||
// recordToDatabaseWithLevel 将日志记录到数据库(标准 context,支持所有级别)
|
||||
func recordToDatabaseWithLevel(ctx context.Context, level, module, message string, attributes map[string]any) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var contextJSON string
|
||||
if len(attributes) > 0 {
|
||||
if data, err := json.Marshal(attributes); err == nil {
|
||||
contextJSON = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
traceID := traceid.FromContext(ctx)
|
||||
if traceID == "" {
|
||||
var newCtx context.Context
|
||||
newCtx, traceID = traceid.EnsureContext(ctx)
|
||||
ctx = newCtx
|
||||
}
|
||||
|
||||
log := models.SystemLog{
|
||||
Level: level,
|
||||
Module: module,
|
||||
TraceID: traceID,
|
||||
Message: message,
|
||||
Context: contextJSON,
|
||||
}
|
||||
|
||||
_ = facades.Orm().Query().Create(&log)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package errorlog
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// sanitizeLogString 清理日志字符串,防止日志注入攻击
|
||||
// 移除或转义危险字符:
|
||||
// - 换行符 (\n, \r) - 可能用于伪造多行日志
|
||||
// - 制表符 (\t) - 可能用于格式化攻击
|
||||
// - 控制字符 - 可能用于隐藏攻击痕迹
|
||||
//
|
||||
// 参数:
|
||||
// - s: 要清理的字符串
|
||||
// - maxLength: 最大长度限制(0 表示不限制)
|
||||
//
|
||||
// 返回:
|
||||
// - 清理后的字符串
|
||||
func sanitizeLogString(s string, maxLength int) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
|
||||
// 移除控制字符和换行符
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(s))
|
||||
|
||||
for _, r := range s {
|
||||
// 允许打印字符和空格,移除控制字符
|
||||
if unicode.IsPrint(r) || r == ' ' {
|
||||
builder.WriteRune(r)
|
||||
} else {
|
||||
// 将控制字符替换为转义序列(可选,或直接移除)
|
||||
// 这里选择移除,因为转义序列也可能被利用
|
||||
}
|
||||
}
|
||||
|
||||
result := builder.String()
|
||||
|
||||
// 限制长度
|
||||
if maxLength > 0 && len(result) > maxLength {
|
||||
result = result[:maxLength] + "...[truncated]"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeLogLevel 验证和清理日志级别,防止伪造
|
||||
// 只允许预定义的日志级别
|
||||
func sanitizeLogLevel(level string) string {
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
|
||||
// 只允许预定义的级别
|
||||
allowedLevels := map[string]string{
|
||||
"error": "error",
|
||||
"warning": "warning",
|
||||
"warn": "warning", // 兼容 warn
|
||||
"info": "info",
|
||||
"debug": "debug",
|
||||
}
|
||||
|
||||
if validLevel, ok := allowedLevels[level]; ok {
|
||||
return validLevel
|
||||
}
|
||||
|
||||
// 如果级别无效,默认返回 error(最安全的级别)
|
||||
return "error"
|
||||
}
|
||||
|
||||
// sanitizeLogArgs 清理日志参数,防止注入攻击
|
||||
// 将参数转换为安全的字符串表示
|
||||
func sanitizeLogArgs(args ...any) []any {
|
||||
if len(args) == 0 {
|
||||
return args
|
||||
}
|
||||
|
||||
sanitized := make([]any, len(args))
|
||||
for i, arg := range args {
|
||||
switch v := arg.(type) {
|
||||
case string:
|
||||
// 字符串参数:清理控制字符
|
||||
sanitized[i] = sanitizeLogString(v, 0)
|
||||
case []byte:
|
||||
// 字节数组:转换为字符串后清理
|
||||
sanitized[i] = sanitizeLogString(string(v), 0)
|
||||
case error:
|
||||
// 错误对象:清理错误消息
|
||||
if v != nil {
|
||||
sanitized[i] = sanitizeLogString(v.Error(), 0)
|
||||
} else {
|
||||
sanitized[i] = "<nil>"
|
||||
}
|
||||
default:
|
||||
// 其他类型:保持原样(数字、布尔等通常是安全的)
|
||||
sanitized[i] = arg
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// sanitizeLogFormat 清理日志格式字符串,防止格式字符串注入
|
||||
// 移除危险的控制字符,但保留格式占位符(%v, %s 等)
|
||||
func sanitizeLogFormat(format string) string {
|
||||
if format == "" {
|
||||
return format
|
||||
}
|
||||
|
||||
// 移除换行符和控制字符,但保留格式占位符
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(format))
|
||||
|
||||
for i := 0; i < len(format); i++ {
|
||||
r := rune(format[i])
|
||||
|
||||
// 允许打印字符、空格和格式占位符
|
||||
if unicode.IsPrint(r) || r == ' ' {
|
||||
builder.WriteRune(r)
|
||||
} else if r == '\n' || r == '\r' || r == '\t' {
|
||||
// 将换行符和制表符替换为空格
|
||||
builder.WriteRune(' ')
|
||||
}
|
||||
// 其他控制字符直接移除
|
||||
}
|
||||
|
||||
result := builder.String()
|
||||
|
||||
// 限制格式字符串长度(防止过长的格式字符串攻击)
|
||||
maxFormatLength := 1000
|
||||
if len(result) > maxFormatLength {
|
||||
result = result[:maxFormatLength] + "...[truncated]"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeAttributes 清理 attributes map,防止注入攻击
|
||||
func sanitizeAttributes(attributes map[string]any) map[string]any {
|
||||
if attributes == nil || len(attributes) == 0 {
|
||||
return attributes
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(attributes))
|
||||
for key, value := range attributes {
|
||||
// 清理 key
|
||||
sanitizedKey := sanitizeLogString(key, 100)
|
||||
|
||||
// 清理 value
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
sanitized[sanitizedKey] = sanitizeLogString(v, 1000)
|
||||
case []byte:
|
||||
sanitized[sanitizedKey] = sanitizeLogString(string(v), 1000)
|
||||
case error:
|
||||
if v != nil {
|
||||
sanitized[sanitizedKey] = sanitizeLogString(v.Error(), 1000)
|
||||
} else {
|
||||
sanitized[sanitizedKey] = "<nil>"
|
||||
}
|
||||
case map[string]any:
|
||||
// 递归清理嵌套 map
|
||||
sanitized[sanitizedKey] = sanitizeAttributes(v)
|
||||
case []any:
|
||||
// 清理数组
|
||||
sanitizedArray := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
sanitizedArray[i] = sanitizeLogString(str, 500)
|
||||
} else {
|
||||
sanitizedArray[i] = item
|
||||
}
|
||||
}
|
||||
sanitized[sanitizedKey] = sanitizedArray
|
||||
default:
|
||||
// 其他类型(数字、布尔等)保持原样
|
||||
sanitized[sanitizedKey] = value
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/database/orm"
|
||||
"github.com/goravel/framework/facades"
|
||||
)
|
||||
|
||||
// ApplyFulltextSearch 应用全文索引搜索条件
|
||||
// column: 要搜索的字段名(如 "request", "content" 等)
|
||||
// keyword: 搜索关键词
|
||||
// query: ORM 查询对象
|
||||
// 返回: 应用了搜索条件的查询对象
|
||||
func ApplyFulltextSearch(query orm.Query, column, keyword string) orm.Query {
|
||||
if keyword == "" {
|
||||
return query
|
||||
}
|
||||
|
||||
// 获取数据库类型
|
||||
dbConnection := facades.Config().GetString("database.default", "sqlite")
|
||||
isPostgreSQL := dbConnection == "postgres"
|
||||
|
||||
// 判断是否应该使用全文索引
|
||||
// 对于短词(少于3个字符)或包含特殊字符的搜索,使用 LIKE/ILIKE
|
||||
// 对于长词,使用全文索引
|
||||
useFulltext := len(keyword) >= 3
|
||||
|
||||
// 检查是否包含特殊字符(逗号、引号、括号等)
|
||||
specialChars := []string{",", "\"", "'", "(", ")", "[", "]", "{", "}", ":", ";", "=", "+", "-", "*", "/", "\\", "|", "&", "%", "$", "#", "@", "!", "?", "<", ">", "~", "`"}
|
||||
for _, char := range specialChars {
|
||||
if strings.Contains(keyword, char) {
|
||||
useFulltext = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if useFulltext {
|
||||
if isPostgreSQL {
|
||||
// PostgreSQL: 使用 pg_trgm 相似度搜索(需要已创建 GIN 索引)
|
||||
// 使用 % 操作符进行相似度匹配,阈值默认 0.3
|
||||
return query.Where(column+" % ?", keyword)
|
||||
} else {
|
||||
// MySQL: 使用 ngram 全文索引
|
||||
// 注意:需要确保字段已创建全文索引,索引名格式为 ft_{column}
|
||||
return query.Where("MATCH("+column+") AGAINST(? IN BOOLEAN MODE)", keyword)
|
||||
}
|
||||
} else {
|
||||
// 短词或包含特殊字符:使用 LIKE/ILIKE
|
||||
if isPostgreSQL {
|
||||
// PostgreSQL: 使用 ILIKE(不区分大小写)
|
||||
return query.Where(column+" ILIKE ?", "%"+keyword+"%")
|
||||
} else {
|
||||
// MySQL: 使用 LIKE
|
||||
return query.Where(column+" LIKE ?", "%"+keyword+"%")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldUseFulltextIndex 判断是否应该使用全文索引
|
||||
// keyword: 搜索关键词
|
||||
// 返回: true 表示应该使用全文索引,false 表示使用 LIKE/ILIKE
|
||||
func ShouldUseFulltextIndex(keyword string) bool {
|
||||
if len(keyword) < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否包含特殊字符
|
||||
specialChars := []string{",", "\"", "'", "(", ")", "[", "]", "{", "}", ":", ";", "=", "+", "-", "*", "/", "\\", "|", "&", "%", "$", "#", "@", "!", "?", "<", ">", "~", "`"}
|
||||
for _, char := range specialChars {
|
||||
if strings.Contains(keyword, char) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsPostgreSQL 判断当前数据库是否为 PostgreSQL
|
||||
func IsPostgreSQL() bool {
|
||||
dbConnection := facades.Config().GetString("database.default", "sqlite")
|
||||
return dbConnection == "postgres"
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
gormDBInstance *gorm.DB
|
||||
gormDBOnce sync.Once
|
||||
gormDBErr error
|
||||
)
|
||||
|
||||
// GetGormDB 获取原生 GORM DB 实例
|
||||
// 通过框架的 Query().Instance() 方法获取 GORM DB 实例
|
||||
// 使用单例模式,确保只获取一次
|
||||
func GetGormDB() (*gorm.DB, error) {
|
||||
gormDBOnce.Do(func() {
|
||||
gormDBInstance, gormDBErr = tryGetGormFromFramework()
|
||||
if gormDBErr == nil && gormDBInstance != nil {
|
||||
facades.Log().Infof("成功从框架获取 GORM DB 实例")
|
||||
}
|
||||
})
|
||||
return gormDBInstance, gormDBErr
|
||||
}
|
||||
|
||||
// tryGetGormFromFramework 尝试从框架的 ORM 获取底层 GORM DB 实例
|
||||
// 通过 Query().Instance() 方法获取 GORM DB 实例
|
||||
func tryGetGormFromFramework() (*gorm.DB, error) {
|
||||
orm := facades.Orm()
|
||||
if orm == nil {
|
||||
return nil, fmt.Errorf("框架 ORM 不可用")
|
||||
}
|
||||
|
||||
ormValue := reflect.ValueOf(orm)
|
||||
|
||||
// 通过 Query() 方法获取查询对象,然后调用 Instance() 方法获取 GORM DB
|
||||
if queryMethod := ormValue.MethodByName("Query"); queryMethod.IsValid() {
|
||||
queryResults := queryMethod.Call(nil)
|
||||
if len(queryResults) > 0 {
|
||||
queryValue := queryResults[0]
|
||||
if queryValue.Kind() == reflect.Interface && !queryValue.IsNil() {
|
||||
if queryElem := queryValue.Elem(); queryElem.IsValid() {
|
||||
// 调用 Instance() 方法
|
||||
if instanceMethod := queryElem.MethodByName("Instance"); instanceMethod.IsValid() {
|
||||
methodType := instanceMethod.Type()
|
||||
if methodType.NumIn() == 0 && methodType.NumOut() > 0 {
|
||||
results := instanceMethod.Call(nil)
|
||||
if len(results) > 0 {
|
||||
result := results[0]
|
||||
// 检查返回类型是否为 *gorm.DB
|
||||
if result.Type().String() == "*gorm.DB" && !result.IsNil() {
|
||||
if db, ok := result.Interface().(*gorm.DB); ok && db != nil {
|
||||
facades.Log().Infof("成功通过 Instance() 方法从 Query 对象获取 GORM DB 实例")
|
||||
return db, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("无法从框架获取 GORM DB 实例")
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package utils
|
||||
|
||||
import "html"
|
||||
|
||||
// EscapeString 转义 HTML 特殊字符,防止 XSS 攻击
|
||||
// 将 < > & " ' 等字符转换为 HTML 实体
|
||||
//
|
||||
// 示例:
|
||||
// EscapeString("<script>alert('XSS')</script>")
|
||||
// 返回: "<script>alert('XSS')</script>"
|
||||
func EscapeString(s string) string {
|
||||
return html.EscapeString(s)
|
||||
}
|
||||
|
||||
// EscapeBytes 转义字节数组中的 HTML 特殊字符
|
||||
func EscapeBytes(b []byte) []byte {
|
||||
return []byte(html.EscapeString(string(b)))
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPLocation IP 地理位置信息
|
||||
type IPLocation struct {
|
||||
Country string `json:"country"` // 国家
|
||||
Region string `json:"region"` // 省份/州
|
||||
City string `json:"city"` // 城市
|
||||
ISP string `json:"isp"` // ISP
|
||||
CountryCode string `json:"countryCode"` // 国家代码
|
||||
}
|
||||
|
||||
// GetIPLocation 根据 IP 地址获取地理位置信息
|
||||
// 使用 ip-api.com 免费 API(无需 API Key,有速率限制)
|
||||
// 如果查询失败,返回空字符串,不影响主流程
|
||||
func GetIPLocation(ip string) string {
|
||||
if ip == "" || ip == "127.0.0.1" || ip == "::1" || strings.HasPrefix(ip, "192.168.") || strings.HasPrefix(ip, "10.") || strings.HasPrefix(ip, "172.") {
|
||||
return "内网IP"
|
||||
}
|
||||
|
||||
// 使用 ip-api.com 免费 API
|
||||
// 格式:http://ip-api.com/json/{ip}?fields=status,message,country,regionName,city,isp,countryCode
|
||||
url := fmt.Sprintf("http://ip-api.com/json/%s?fields=status,message,country,regionName,city,isp,countryCode&lang=zh-CN", ip)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 3 * time.Second, // 3秒超时,避免阻塞
|
||||
}
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Country string `json:"country"`
|
||||
RegionName string `json:"regionName"`
|
||||
City string `json:"city"`
|
||||
ISP string `json:"isp"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if result.Status != "success" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 构建位置字符串:国家 省份 城市
|
||||
var locationParts []string
|
||||
if result.Country != "" {
|
||||
locationParts = append(locationParts, result.Country)
|
||||
}
|
||||
if result.RegionName != "" {
|
||||
locationParts = append(locationParts, result.RegionName)
|
||||
}
|
||||
if result.City != "" {
|
||||
locationParts = append(locationParts, result.City)
|
||||
}
|
||||
|
||||
location := strings.Join(locationParts, " ")
|
||||
if location == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 如果位置信息太长,截断
|
||||
if len(location) > 100 {
|
||||
location = location[:100]
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
// GetIPLocationAsync 异步获取 IP 地理位置信息
|
||||
// 用于不阻塞主流程的场景
|
||||
//
|
||||
// 安全特性:
|
||||
// - 使用 context 超时控制(默认 10 秒)
|
||||
// - 添加 panic recovery,防止 goroutine 崩溃
|
||||
// - callback 执行超时保护(默认 5 秒),防止阻塞
|
||||
//
|
||||
// 参数:
|
||||
// - ip: IP 地址
|
||||
// - callback: 回调函数,接收位置信息
|
||||
func GetIPLocationAsync(ip string, callback func(location string)) {
|
||||
GetIPLocationAsyncWithTimeout(ip, callback, 10*time.Second, 5*time.Second)
|
||||
}
|
||||
|
||||
// GetIPLocationAsyncWithTimeout 异步获取 IP 地理位置信息(带超时控制)
|
||||
//
|
||||
// 参数:
|
||||
// - ip: IP 地址
|
||||
// - callback: 回调函数,接收位置信息
|
||||
// - locationTimeout: IP 查询超时时间(默认 10 秒)
|
||||
// - callbackTimeout: 回调函数执行超时时间(默认 5 秒)
|
||||
func GetIPLocationAsyncWithTimeout(ip string, callback func(location string), locationTimeout, callbackTimeout time.Duration) {
|
||||
// 设置默认超时时间
|
||||
if locationTimeout <= 0 {
|
||||
locationTimeout = 10 * time.Second
|
||||
}
|
||||
if callbackTimeout <= 0 {
|
||||
callbackTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
go func() {
|
||||
// 添加 panic recovery,防止 goroutine 崩溃导致程序退出
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// 记录 panic 但不影响主流程
|
||||
// 注意:这里不能使用 facades.Log(),因为可能没有初始化 context
|
||||
// 如果需要记录日志,可以通过参数传入 logger
|
||||
}
|
||||
}()
|
||||
|
||||
// 创建带超时的 context(确保总超时时间不超过 locationTimeout)
|
||||
// 注意:GetIPLocation 内部已有 3 秒 HTTP 超时,这里设置外层超时作为兜底
|
||||
ctx, cancel := context.WithTimeout(context.Background(), locationTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 在 goroutine 中执行 IP 查询
|
||||
locationChan := make(chan string, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// panic 时发送空字符串,避免阻塞
|
||||
select {
|
||||
case locationChan <- "":
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
location := GetIPLocation(ip)
|
||||
// 非阻塞发送,避免 goroutine 泄露
|
||||
select {
|
||||
case locationChan <- location:
|
||||
case <-ctx.Done():
|
||||
// 如果已经超时,直接返回,不发送结果
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待结果或超时
|
||||
var location string
|
||||
select {
|
||||
case location = <-locationChan:
|
||||
// 成功获取位置信息
|
||||
case <-ctx.Done():
|
||||
// 超时,返回空字符串
|
||||
// 注意:内部的 goroutine 可能仍在执行,但由于 HTTP client 有 3 秒超时,会很快结束
|
||||
location = ""
|
||||
}
|
||||
|
||||
// 执行回调(带超时保护)
|
||||
if callback != nil {
|
||||
callbackCtx, callbackCancel := context.WithTimeout(context.Background(), callbackTimeout)
|
||||
defer callbackCancel()
|
||||
|
||||
callbackDone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// callback 中发生 panic,静默处理
|
||||
}
|
||||
callbackDone <- struct{}{}
|
||||
}()
|
||||
callback(location)
|
||||
}()
|
||||
|
||||
// 等待 callback 完成或超时
|
||||
select {
|
||||
case <-callbackDone:
|
||||
// callback 正常完成
|
||||
case <-callbackCtx.Done():
|
||||
// callback 执行超时,不等待(防止阻塞)
|
||||
// 注意:callback 仍在执行,但不会阻塞当前 goroutine
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/support/str"
|
||||
|
||||
apperrors "goravel/app/errors"
|
||||
)
|
||||
|
||||
// IsIPInBlacklist 检查IP是否在黑名单中
|
||||
// ip: 要检查的IP地址
|
||||
// blacklistIPs: 黑名单IP字符串,支持:
|
||||
// - 单个IP: 192.168.1.1
|
||||
// - 多个IP(逗号分隔): 192.168.1.1,192.168.1.2
|
||||
// - CIDR格式: 192.168.0.0/24
|
||||
// - IP范围: 192.168.1.1-192.168.1.100
|
||||
func IsIPInBlacklist(ip string, blacklistIPs string) bool {
|
||||
if blacklistIPs == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 解析IP地址
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 分割多个IP(支持逗号分隔)
|
||||
ipList := strings.SplitSeq(blacklistIPs, ",")
|
||||
for blacklistIP := range ipList {
|
||||
blacklistIP = str.Of(blacklistIP).Trim().String()
|
||||
if str.Of(blacklistIP).IsEmpty() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查单个IP
|
||||
if blacklistIP == ip {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查CIDR格式 (192.168.0.0/24)
|
||||
if str.Of(blacklistIP).Contains("/") {
|
||||
_, ipNet, err := net.ParseCIDR(blacklistIP)
|
||||
if err == nil && ipNet.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IP范围格式 (192.168.1.1-192.168.1.100)
|
||||
if str.Of(blacklistIP).Contains("-") {
|
||||
parts := str.Of(blacklistIP).Split("-")
|
||||
if len(parts) == 2 {
|
||||
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
|
||||
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
|
||||
if startIP != nil && endIP != nil {
|
||||
if isIPInRange(parsedIP, startIP, endIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isIPInRange 检查IP是否在指定范围内
|
||||
func isIPInRange(ip, startIP, endIP net.IP) bool {
|
||||
ipBytes := ip.To4()
|
||||
startBytes := startIP.To4()
|
||||
endBytes := endIP.To4()
|
||||
|
||||
if ipBytes == nil || startBytes == nil || endBytes == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 将IP地址转换为32位整数进行比较
|
||||
ipInt := uint32(ipBytes[0])<<24 | uint32(ipBytes[1])<<16 | uint32(ipBytes[2])<<8 | uint32(ipBytes[3])
|
||||
startInt := uint32(startBytes[0])<<24 | uint32(startBytes[1])<<16 | uint32(startBytes[2])<<8 | uint32(startBytes[3])
|
||||
endInt := uint32(endBytes[0])<<24 | uint32(endBytes[1])<<16 | uint32(endBytes[2])<<8 | uint32(endBytes[3])
|
||||
|
||||
return ipInt >= startInt && ipInt <= endInt
|
||||
}
|
||||
|
||||
// ValidateBlacklistIP 验证黑名单IP格式
|
||||
// 返回业务错误类型,如果格式正确返回 nil
|
||||
func ValidateBlacklistIP(ipStr string) error {
|
||||
if ipStr == "" {
|
||||
return apperrors.ErrIPAddressRequired
|
||||
}
|
||||
|
||||
ipList := strings.SplitSeq(ipStr, ",")
|
||||
for ip := range ipList {
|
||||
ip = str.Of(ip).Trim().String()
|
||||
if str.Of(ip).IsEmpty() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查CIDR格式
|
||||
if str.Of(ip).Contains("/") {
|
||||
_, _, err := net.ParseCIDR(ip)
|
||||
if err != nil {
|
||||
return apperrors.ErrInvalidCIDRFormat.WithMessage(fmt.Sprintf("CIDR格式错误: %s", ip))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查IP范围格式
|
||||
if str.Of(ip).Contains("-") {
|
||||
parts := str.Of(ip).Split("-")
|
||||
if len(parts) != 2 {
|
||||
return apperrors.ErrInvalidIPRangeFormat.WithMessage(fmt.Sprintf("IP范围格式错误: %s (格式应为: 192.168.1.1-192.168.1.100)", ip))
|
||||
}
|
||||
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
|
||||
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
|
||||
if startIP == nil {
|
||||
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("起始IP格式错误: %s", parts[0]))
|
||||
}
|
||||
if endIP == nil {
|
||||
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("结束IP格式错误: %s", parts[1]))
|
||||
}
|
||||
// 验证范围是否有效(结束IP应该大于等于起始IP)
|
||||
if !isIPInRange(endIP, startIP, endIP) && !endIP.Equal(startIP) {
|
||||
// 简单比较:检查结束IP是否大于起始IP
|
||||
startBytes := startIP.To4()
|
||||
endBytes := endIP.To4()
|
||||
if startBytes != nil && endBytes != nil {
|
||||
for i := range 4 {
|
||||
if endBytes[i] < startBytes[i] {
|
||||
return apperrors.ErrInvalidIPRangeOrder.WithMessage(fmt.Sprintf("结束IP必须大于等于起始IP: %s", ip))
|
||||
}
|
||||
if endBytes[i] > startBytes[i] {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查单个IP格式
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("IP地址格式错误: %s", ip))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FormatIPRange 格式化IP范围显示
|
||||
func FormatIPRange(ipStr string) string {
|
||||
if str.Of(ipStr).Contains("-") {
|
||||
parts := str.Of(ipStr).Split("-")
|
||||
if len(parts) == 2 {
|
||||
return fmt.Sprintf("%s ~ %s", str.Of(parts[0]).Trim().String(), str.Of(parts[1]).Trim().String())
|
||||
}
|
||||
}
|
||||
return ipStr
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/http"
|
||||
"github.com/goravel/framework/facades"
|
||||
)
|
||||
|
||||
// ParseAcceptLanguage 解析 Accept-Language 请求头
|
||||
// 格式: "zh-CN,zh;q=0.9,en;q=0.8" 或 "en-US,en;q=0.9"
|
||||
// 返回: 语言代码("cn" 或 "en"),如果无法解析则返回空字符串
|
||||
func ParseAcceptLanguage(acceptLanguage string) string {
|
||||
if acceptLanguage == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 分割语言列表
|
||||
languages := strings.Split(acceptLanguage, ",")
|
||||
if len(languages) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 取第一个语言
|
||||
firstLang := strings.TrimSpace(languages[0])
|
||||
|
||||
// 移除质量值(如果有)
|
||||
if idx := strings.Index(firstLang, ";"); idx != -1 {
|
||||
firstLang = firstLang[:idx]
|
||||
}
|
||||
|
||||
// 转换为小写并提取语言代码
|
||||
firstLang = strings.ToLower(strings.TrimSpace(firstLang))
|
||||
|
||||
// 处理语言代码(如 zh-CN -> cn, en-US -> en)
|
||||
if strings.HasPrefix(firstLang, "zh") {
|
||||
return "cn"
|
||||
}
|
||||
if strings.HasPrefix(firstLang, "en") {
|
||||
return "en"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCurrentLanguage 获取当前请求的语言(从 HTTP Context)
|
||||
// 优先从请求头 Accept-Language 获取,其次从查询参数获取,最后使用默认语言
|
||||
func GetCurrentLanguage(ctx http.Context) string {
|
||||
// 优先从请求头 Accept-Language 获取语言
|
||||
acceptLanguage := ctx.Request().Header("Accept-Language", "")
|
||||
lang := ParseAcceptLanguage(acceptLanguage)
|
||||
|
||||
// 如果请求头没有,尝试从查询参数获取
|
||||
if lang == "" {
|
||||
lang = ctx.Request().Input("lang")
|
||||
}
|
||||
|
||||
// 如果都没有,使用默认语言
|
||||
if lang == "" {
|
||||
lang = facades.Config().GetString("app.locale")
|
||||
}
|
||||
|
||||
// 验证并规范化语言代码
|
||||
return NormalizeLanguage(lang)
|
||||
}
|
||||
|
||||
// NormalizeLanguage 验证并规范化语言代码
|
||||
// 只支持 "cn" 和 "en",其他值会返回默认语言
|
||||
func NormalizeLanguage(lang string) string {
|
||||
// 验证语言是否支持(只支持 cn 和 en)
|
||||
if lang != "cn" && lang != "en" {
|
||||
return facades.Config().GetString("app.locale", "cn")
|
||||
}
|
||||
return lang
|
||||
}
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// LockResult 锁操作结果
|
||||
type LockResult struct {
|
||||
Acquired bool // 是否成功获取锁
|
||||
Error error // 错误信息
|
||||
Client *redis.Client // Redis 客户端(如果使用 Redis)
|
||||
}
|
||||
|
||||
// TryAcquireLock 尝试获取分布式锁(原子操作)
|
||||
// lockKey: 锁的键名
|
||||
// lockValue: 锁的值(用于标识锁的拥有者)
|
||||
// ttl: 锁的过期时间
|
||||
// 返回 LockResult,包含是否成功获取锁和错误信息
|
||||
func TryAcquireLock(lockKey, lockValue string, ttl time.Duration) *LockResult {
|
||||
result := &LockResult{
|
||||
Acquired: false,
|
||||
}
|
||||
|
||||
// 优先使用 Redis SETNX 原子操作
|
||||
redisClient, err := GetRedisClient("default")
|
||||
if err != nil {
|
||||
facades.Log().Warningf("获取 Redis 客户端失败,降级到缓存锁: %v", err)
|
||||
// 降级到普通缓存检查(不保证原子性,但至少提供基本保护)
|
||||
return tryAcquireLockWithCache(lockKey, lockValue, ttl)
|
||||
}
|
||||
|
||||
// 使用 Redis SETNX 原子操作(SET key value NX EX seconds)
|
||||
// SetNX 会自动设置过期时间,如果键已存在且未过期,返回 false
|
||||
ctx := context.Background()
|
||||
acquired, err := redisClient.SetNX(ctx, lockKey, lockValue, ttl).Result()
|
||||
if err != nil {
|
||||
facades.Log().Errorf("Redis 获取锁失败: key=%s, error=%v", lockKey, err)
|
||||
result.Error = fmt.Errorf("获取锁失败: %v", err)
|
||||
// 注意:使用公共 Redis 客户端,不需要手动关闭
|
||||
return result
|
||||
}
|
||||
|
||||
if !acquired {
|
||||
// 锁已被占用,检查锁的剩余过期时间
|
||||
ttlResult, _ := redisClient.TTL(ctx, lockKey).Result()
|
||||
if ttlResult <= 0 {
|
||||
// 锁已过期,删除它并重试
|
||||
redisClient.Del(ctx, lockKey)
|
||||
// 重试一次
|
||||
acquired, err = redisClient.SetNX(ctx, lockKey, lockValue, ttl).Result()
|
||||
if err != nil {
|
||||
facades.Log().Errorf("Redis 重试获取锁失败: key=%s, error=%v", lockKey, err)
|
||||
result.Error = fmt.Errorf("获取锁失败: %v", err)
|
||||
// 注意:使用公共 Redis 客户端,不需要手动关闭
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.Acquired = acquired
|
||||
if acquired {
|
||||
result.Client = redisClient // 只有获取成功才保留客户端
|
||||
} else {
|
||||
result.Client = nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ReleaseLock 释放锁
|
||||
// 注意:只有锁的拥有者才能释放锁(通过 lockValue 验证)
|
||||
func ReleaseLock(lockKey, lockValue string, client *redis.Client) error {
|
||||
if client == nil {
|
||||
// 如果没有 Redis 客户端,使用缓存删除
|
||||
_ = facades.Cache().Forget(lockKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
// 使用 Lua 脚本确保只有锁的拥有者才能释放锁
|
||||
script := `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`
|
||||
_, err := client.Eval(ctx, script, []string{lockKey}, lockValue).Result()
|
||||
if err != nil {
|
||||
// 如果 Lua 脚本执行失败,尝试直接删除(降级处理)
|
||||
return client.Del(ctx, lockKey).Err()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseLockClient 关闭锁的 Redis 客户端
|
||||
// 注意:由于现在使用公共 Redis 客户端池,通常不需要手动关闭
|
||||
// 此函数保留用于兼容性,但不会真正关闭客户端(客户端由连接池管理)
|
||||
func CloseLockClient(client *redis.Client) {
|
||||
// 使用公共 Redis 客户端池,不需要手动关闭
|
||||
// 客户端由连接池统一管理
|
||||
}
|
||||
|
||||
// tryAcquireLockWithCache 使用缓存实现锁(降级方案,不保证原子性)
|
||||
// 注意:由于缓存操作的检查-设置不是原子的,在高并发下可能仍有竞态条件
|
||||
// 但至少可以提供基本的保护
|
||||
func tryAcquireLockWithCache(lockKey, lockValue string, ttl time.Duration) *LockResult {
|
||||
result := &LockResult{
|
||||
Acquired: false,
|
||||
}
|
||||
|
||||
// 检查是否已有锁
|
||||
var cachedValue string
|
||||
cacheErr := facades.Cache().Get(lockKey, &cachedValue)
|
||||
// 如果 Get 返回 nil(成功)且值不为空,说明锁已存在
|
||||
if cacheErr == nil && cachedValue != "" {
|
||||
// 锁已存在
|
||||
result.Error = fmt.Errorf("锁已被占用")
|
||||
return result
|
||||
}
|
||||
|
||||
// 设置锁(使用 Put,如果键已存在会覆盖,但至少可以防止大部分重复请求)
|
||||
if err := facades.Cache().Put(lockKey, lockValue, ttl); err != nil {
|
||||
facades.Log().Errorf("设置缓存锁失败: key=%s, error=%v", lockKey, err)
|
||||
result.Error = fmt.Errorf("设置锁失败: %v", err)
|
||||
return result
|
||||
}
|
||||
|
||||
// 再次检查,确保锁设置成功(双重检查,减少竞态条件的影响)
|
||||
// 短暂延迟,让其他并发请求有机会检测到锁
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
var verifyValue string
|
||||
if facades.Cache().Get(lockKey, &verifyValue) == nil {
|
||||
if verifyValue == lockValue {
|
||||
// 锁设置成功且值匹配
|
||||
result.Acquired = true
|
||||
return result
|
||||
}
|
||||
// 值不匹配,说明被其他请求覆盖了(竞态条件)
|
||||
result.Error = fmt.Errorf("锁已被占用")
|
||||
return result
|
||||
}
|
||||
|
||||
// 锁设置后无法验证,可能是缓存问题
|
||||
result.Error = fmt.Errorf("设置锁失败")
|
||||
return result
|
||||
}
|
||||
|
||||
// LockGuard 锁保护器,自动管理锁的生命周期
|
||||
type LockGuard struct {
|
||||
lockKey string
|
||||
lockValue string
|
||||
client *redis.Client
|
||||
acquired bool
|
||||
}
|
||||
|
||||
// AcquireLock 获取锁
|
||||
// lockKey: 锁的键名(会自动添加用户ID前缀,格式:lockKey:userID)
|
||||
// userID: 用户ID(用于生成唯一的锁键)
|
||||
// ttl: 锁的过期时间
|
||||
// 返回 LockGuard 和错误,如果锁已被占用,返回 ErrLockAcquired 错误
|
||||
func AcquireLock(lockKey string, userID uint, ttl time.Duration) (*LockGuard, error) {
|
||||
// 生成完整的锁键和值
|
||||
fullLockKey := fmt.Sprintf("%s:%d", lockKey, userID)
|
||||
lockValue := fmt.Sprintf("%d_%d", userID, time.Now().Unix())
|
||||
|
||||
// 尝试获取锁
|
||||
lockResult := TryAcquireLock(fullLockKey, lockValue, ttl)
|
||||
if lockResult.Error != nil {
|
||||
return nil, lockResult.Error
|
||||
}
|
||||
if !lockResult.Acquired {
|
||||
return nil, fmt.Errorf("锁已被占用")
|
||||
}
|
||||
|
||||
return &LockGuard{
|
||||
lockKey: fullLockKey,
|
||||
lockValue: lockValue,
|
||||
client: lockResult.Client,
|
||||
acquired: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Release 释放锁
|
||||
func (g *LockGuard) Release() {
|
||||
if !g.acquired {
|
||||
return
|
||||
}
|
||||
|
||||
if g.client != nil {
|
||||
if err := ReleaseLock(g.lockKey, g.lockValue, g.client); err != nil {
|
||||
facades.Log().Errorf("释放 Redis 锁失败: key=%s, error=%v", g.lockKey, err)
|
||||
}
|
||||
CloseLockClient(g.client)
|
||||
g.client = nil // 防止重复关闭
|
||||
} else {
|
||||
// 使用缓存时,直接删除
|
||||
_ = facades.Cache().Forget(g.lockKey)
|
||||
}
|
||||
g.acquired = false
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goravel/framework/contracts/http"
|
||||
"github.com/goravel/framework/facades"
|
||||
|
||||
"goravel/app/utils/traceid"
|
||||
)
|
||||
|
||||
// DebugfHTTP logs a debug message and automatically attaches trace_id from the http context.
|
||||
// Debug messages are only shown when APP_DEBUG=true
|
||||
func DebugfHTTP(ctx http.Context, format string, args ...any) {
|
||||
if ctx == nil {
|
||||
facades.Log().Debugf(format, args...)
|
||||
return
|
||||
}
|
||||
|
||||
trace := traceid.FromHTTPContext(ctx)
|
||||
facades.Log().Debugf(prependTrace(trace, format), args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message without any context.
|
||||
// Debug messages are only shown when APP_DEBUG=true
|
||||
func Debugf(format string, args ...any) {
|
||||
facades.Log().Debugf(format, args...)
|
||||
}
|
||||
|
||||
// InfofHTTP logs an info message and automatically attaches trace_id from the http context.
|
||||
func InfofHTTP(ctx http.Context, format string, args ...any) {
|
||||
if ctx == nil {
|
||||
facades.Log().Infof(format, args...)
|
||||
return
|
||||
}
|
||||
|
||||
trace := traceid.FromHTTPContext(ctx)
|
||||
facades.Log().Infof(prependTrace(trace, format), args...)
|
||||
}
|
||||
|
||||
// WarnfHTTP logs a warning and automatically attaches trace_id from the http context.
|
||||
func WarnfHTTP(ctx http.Context, format string, args ...any) {
|
||||
if ctx == nil {
|
||||
facades.Log().Warningf(format, args...)
|
||||
return
|
||||
}
|
||||
|
||||
trace := traceid.FromHTTPContext(ctx)
|
||||
facades.Log().Warningf(prependTrace(trace, format), args...)
|
||||
}
|
||||
|
||||
// ErrorfHTTP logs an error and automatically attaches trace_id from the http context.
|
||||
func ErrorfHTTP(ctx http.Context, format string, args ...any) {
|
||||
if ctx == nil {
|
||||
facades.Log().Errorf(format, args...)
|
||||
return
|
||||
}
|
||||
|
||||
trace := traceid.FromHTTPContext(ctx)
|
||||
facades.Log().Errorf(prependTrace(trace, format), args...)
|
||||
}
|
||||
|
||||
// ErrorfContext logs an error with a standard context's trace id (if available).
|
||||
func ErrorfContext(ctx context.Context, format string, args ...any) {
|
||||
if ctx == nil {
|
||||
facades.Log().Errorf(format, args...)
|
||||
return
|
||||
}
|
||||
|
||||
trace := traceid.FromContext(ctx)
|
||||
facades.Log().Errorf(prependTrace(trace, format), args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error without any context (fallback).
|
||||
func Errorf(format string, args ...any) {
|
||||
facades.Log().Errorf(format, args...)
|
||||
}
|
||||
|
||||
func prependTrace(traceID, format string) string {
|
||||
if traceID == "" {
|
||||
return format
|
||||
}
|
||||
return fmt.Sprintf("[trace_id=%s] %s", traceID, format)
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/http"
|
||||
"github.com/goravel/framework/facades"
|
||||
"github.com/goravel/framework/support/str"
|
||||
|
||||
"goravel/app/models"
|
||||
)
|
||||
|
||||
// GetOperationTitleFromContext 从 context 中获取操作标题
|
||||
// 优先使用权限标识(permission_slug),如果没有权限标识,则根据路径和方法生成默认标题
|
||||
func GetOperationTitleFromContext(ctx http.Context) string {
|
||||
if ctx == nil {
|
||||
return "operation.unknown"
|
||||
}
|
||||
|
||||
// 优先从 context 中获取权限标识(由权限中间件设置)
|
||||
permissionSlugValue := ctx.Value("permission_slug")
|
||||
if permissionSlugValue != nil {
|
||||
if permissionSlug, ok := permissionSlugValue.(string); ok && permissionSlug != "" {
|
||||
// 直接返回权限标识,前端多语言文件中已有对应翻译
|
||||
return permissionSlug
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有权限标识,尝试从权限表中查询匹配的权限
|
||||
method := ctx.Request().Method()
|
||||
path := ctx.Request().Path()
|
||||
|
||||
// 先从权限表中查询匹配的权限
|
||||
permissionSlug := findPermissionSlugFromDB(method, path)
|
||||
if permissionSlug != "" {
|
||||
return permissionSlug
|
||||
}
|
||||
|
||||
// 如果权限表中没有找到,根据路径和方法生成默认标题
|
||||
defaultTitle := generateDefaultTitle(method, path)
|
||||
if defaultTitle != "" {
|
||||
return defaultTitle
|
||||
}
|
||||
|
||||
// 无法生成标题时,返回未知操作
|
||||
return "operation.unknown"
|
||||
}
|
||||
|
||||
// generateDefaultTitle 根据方法和路径生成默认操作标题
|
||||
func generateDefaultTitle(method, path string) string {
|
||||
pathStr := str.Of(path)
|
||||
|
||||
// 分片上传相关(与权限配置中的 slug 保持一致)
|
||||
if pathStr.Contains("/attachments/chunk") {
|
||||
if method == "POST" || method == "GET" {
|
||||
// 权限配置中的 slug 是 attachment.chunk
|
||||
return "attachment.chunk"
|
||||
}
|
||||
}
|
||||
|
||||
// 附件上传
|
||||
if pathStr.Contains("/attachments/upload") && method == "POST" {
|
||||
return "attachment.upload"
|
||||
}
|
||||
|
||||
// 附件删除
|
||||
if pathStr.Contains("/attachments/") && pathStr.EndsWith("/batch-delete") && method == "POST" {
|
||||
return "attachment.batch_delete"
|
||||
}
|
||||
if pathStr.Contains("/attachments/") && method == "DELETE" {
|
||||
return "attachment.destroy"
|
||||
}
|
||||
|
||||
// 附件更新显示名称
|
||||
if pathStr.Contains("/attachments/") && pathStr.EndsWith("/display-name") && method == "PUT" {
|
||||
return "attachment.update_display_name"
|
||||
}
|
||||
|
||||
// 导出下载
|
||||
if pathStr.Contains("/exports/") && pathStr.EndsWith("/download") && method == "GET" {
|
||||
return "export.download"
|
||||
}
|
||||
|
||||
// 订单导入
|
||||
if pathStr.Contains("/orders/import") && method == "POST" {
|
||||
return "order.import"
|
||||
}
|
||||
|
||||
// 订单导出
|
||||
if pathStr.Contains("/orders/export") && method == "POST" {
|
||||
return "order.export"
|
||||
}
|
||||
|
||||
// 管理员解绑谷歌验证码
|
||||
if pathStr.Contains("/admins/") && pathStr.EndsWith("/unbind-google-auth") && method == "POST" {
|
||||
return "admin.unbind_google_auth"
|
||||
}
|
||||
|
||||
// 更新个人资料
|
||||
if pathStr.EndsWith("/profile") && (method == "PUT" || method == "PATCH") {
|
||||
return "profile.update"
|
||||
}
|
||||
|
||||
// 修改密码
|
||||
if pathStr.EndsWith("/password") && (method == "PUT" || method == "PATCH") {
|
||||
return "password.update"
|
||||
}
|
||||
|
||||
// 批量删除(通用模式)
|
||||
if pathStr.EndsWith("/batch-delete") && method == "POST" {
|
||||
parts := pathStr.ChopStart("/api/admin/").Split("/")
|
||||
if len(parts) > 0 {
|
||||
module := str.Of(parts[0]).Replace("-", "_").String()
|
||||
return str.Of(module).Append(".batch_delete").String()
|
||||
}
|
||||
}
|
||||
|
||||
// 清理操作(通用模式)
|
||||
if pathStr.EndsWith("/clean") && method == "POST" {
|
||||
parts := pathStr.ChopStart("/api/admin/").Split("/")
|
||||
if len(parts) > 0 {
|
||||
module := str.Of(parts[0]).Replace("-", "_").String()
|
||||
return str.Of(module).Append(".clean").String()
|
||||
}
|
||||
}
|
||||
|
||||
// 标准 CRUD 操作(通用模式)
|
||||
parts := pathStr.ChopStart("/api/admin/").Split("/")
|
||||
if len(parts) >= 1 {
|
||||
module := str.Of(parts[0]).Replace("-", "_").String()
|
||||
switch method {
|
||||
case "POST":
|
||||
// 创建操作
|
||||
if len(parts) == 1 || (len(parts) == 2 && parts[1] != "batch-delete" && parts[1] != "clean") {
|
||||
return str.Of(module).Append(".store").String()
|
||||
}
|
||||
case "PUT", "PATCH":
|
||||
// 更新操作
|
||||
if len(parts) >= 2 {
|
||||
return str.Of(module).Append(".update").String()
|
||||
}
|
||||
case "DELETE":
|
||||
// 删除操作
|
||||
if len(parts) >= 2 {
|
||||
return str.Of(module).Append(".destroy").String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// findPermissionSlugFromDB 从权限表中查询匹配的权限标识
|
||||
// 优先匹配精确路径,然后匹配通配符路径
|
||||
func findPermissionSlugFromDB(method, path string) string {
|
||||
var permissions []models.Permission
|
||||
|
||||
// 查询所有启用的权限,方法匹配或方法为空
|
||||
query := facades.Orm().Query().Model(&models.Permission{}).
|
||||
Where("status", 1).
|
||||
Where("(method = ? OR method = '')", method)
|
||||
|
||||
if err := query.Find(&permissions); err != nil {
|
||||
// 查询失败时返回空,使用默认逻辑
|
||||
return ""
|
||||
}
|
||||
|
||||
// 优先匹配精确路径
|
||||
for _, perm := range permissions {
|
||||
if perm.Path == path {
|
||||
return perm.Slug
|
||||
}
|
||||
}
|
||||
|
||||
// 然后匹配通配符路径
|
||||
for _, perm := range permissions {
|
||||
if perm.Path != "" && perm.Path != path {
|
||||
if matchPermissionPath(perm.Path, path) {
|
||||
return perm.Slug
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// matchPermissionPath 路径匹配,支持通配符(与权限中间件中的 matchPath 逻辑一致)
|
||||
func matchPermissionPath(pattern, path string) bool {
|
||||
if pattern == path {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果模式不包含通配符,直接返回 false
|
||||
if !containsChar(pattern, '*') {
|
||||
return false
|
||||
}
|
||||
|
||||
// 将模式按 * 分割成多个部分
|
||||
parts := splitPatternString(pattern)
|
||||
if len(parts) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果模式以 * 开头,需要特殊处理
|
||||
if pattern[0] == '*' {
|
||||
// 检查路径是否以模式的剩余部分结尾
|
||||
if len(parts) > 1 {
|
||||
suffix := parts[1]
|
||||
return len(path) >= len(suffix) && path[len(path)-len(suffix):] == suffix
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果模式以 * 结尾
|
||||
if pattern[len(pattern)-1] == '*' {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
if len(path) >= len(prefix) {
|
||||
pathPrefix := path[:len(prefix)]
|
||||
if pathPrefix == prefix {
|
||||
// 如果前缀以 / 结尾,路径必须比前缀长(即后面还有内容)
|
||||
if len(prefix) > 0 && prefix[len(prefix)-1] == '/' {
|
||||
return len(path) > len(prefix)
|
||||
}
|
||||
// 如果前缀不以 / 结尾,路径可以等于或长于前缀
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 处理中间有通配符的情况,如 /api/admin/attachments/*/display-name
|
||||
// 过滤掉 "*" 标记,只保留实际的部分
|
||||
var actualParts []string
|
||||
for _, part := range parts {
|
||||
if part != "*" {
|
||||
actualParts = append(actualParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
if len(actualParts) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查路径是否以第一部分开头
|
||||
firstPart := actualParts[0]
|
||||
if len(path) < len(firstPart) || path[:len(firstPart)] != firstPart {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查路径是否以最后一部分结尾
|
||||
lastPart := actualParts[len(actualParts)-1]
|
||||
if len(path) < len(lastPart) || path[len(path)-len(lastPart):] != lastPart {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查中间部分是否存在(通配符匹配任意内容)
|
||||
remainingPath := path[len(firstPart) : len(path)-len(lastPart)]
|
||||
// 确保中间部分不为空(至少有一个字符,通常是数字ID)
|
||||
return len(remainingPath) > 0
|
||||
}
|
||||
|
||||
// containsChar 检查字符串是否包含指定字符
|
||||
func containsChar(s string, c byte) bool {
|
||||
for i := range s {
|
||||
if s[i] == c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// splitPatternString 按 * 分割模式字符串
|
||||
func splitPatternString(pattern string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
|
||||
for i := range pattern {
|
||||
if pattern[i] == '*' {
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
parts = append(parts, "*")
|
||||
} else {
|
||||
current.WriteByte(pattern[i])
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
// redisClients 缓存不同连接名的 Redis 客户端
|
||||
redisClients sync.Map
|
||||
// redisMutex 用于保护客户端创建过程
|
||||
redisMutex sync.Mutex
|
||||
)
|
||||
|
||||
// GetRedisClient 获取 Redis 客户端(使用连接池,支持多连接)
|
||||
// connectionName: Redis 连接名称,默认为 "default"
|
||||
// 返回缓存的 Redis 客户端,如果不存在则创建并缓存
|
||||
func GetRedisClient(connectionName string) (*redis.Client, error) {
|
||||
if connectionName == "" {
|
||||
connectionName = "default"
|
||||
}
|
||||
|
||||
// 先从缓存中获取
|
||||
if client, ok := redisClients.Load(connectionName); ok {
|
||||
redisClient := client.(*redis.Client)
|
||||
// 测试连接是否有效
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
err := redisClient.Ping(ctx).Err()
|
||||
cancel()
|
||||
if err == nil {
|
||||
return redisClient, nil
|
||||
}
|
||||
// 连接失效,从缓存中移除
|
||||
redisClients.Delete(connectionName)
|
||||
}
|
||||
|
||||
// 使用互斥锁确保只创建一个客户端
|
||||
redisMutex.Lock()
|
||||
defer redisMutex.Unlock()
|
||||
|
||||
// 双重检查,防止并发创建
|
||||
if client, ok := redisClients.Load(connectionName); ok {
|
||||
return client.(*redis.Client), nil
|
||||
}
|
||||
|
||||
// 创建新的 Redis 客户端
|
||||
client, err := createRedisClient(connectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存客户端
|
||||
redisClients.Store(connectionName, client)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// createRedisClient 创建 Redis 客户端
|
||||
func createRedisClient(connectionName string) (*redis.Client, error) {
|
||||
// 获取 Redis 配置
|
||||
host := facades.Config().GetString(fmt.Sprintf("database.redis.%s.host", connectionName), "")
|
||||
if host == "" {
|
||||
// 尝试使用 default 连接
|
||||
host = facades.Config().GetString("database.redis.default.host", "127.0.0.1")
|
||||
}
|
||||
|
||||
port := facades.Config().GetInt(fmt.Sprintf("database.redis.%s.port", connectionName), 0)
|
||||
if port == 0 {
|
||||
port = facades.Config().GetInt("database.redis.default.port", 6379)
|
||||
}
|
||||
|
||||
password := facades.Config().GetString(fmt.Sprintf("database.redis.%s.password", connectionName), "")
|
||||
if password == "" {
|
||||
password = facades.Config().GetString("database.redis.default.password", "")
|
||||
}
|
||||
|
||||
db := facades.Config().GetInt(fmt.Sprintf("database.redis.%s.database", connectionName), -1)
|
||||
if db == -1 {
|
||||
db = facades.Config().GetInt("database.redis.default.database", 0)
|
||||
}
|
||||
|
||||
// 创建 Redis 客户端(使用连接池配置)
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port),
|
||||
Password: password,
|
||||
DB: db,
|
||||
PoolSize: 10, // 连接池大小
|
||||
MinIdleConns: 5, // 最小空闲连接数
|
||||
MaxRetries: 3, // 最大重试次数
|
||||
})
|
||||
|
||||
// 测试连接(设置超时)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
client.Close() // 连接失败,关闭客户端
|
||||
return nil, fmt.Errorf("Redis 连接失败 [%s]: %v", connectionName, err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// CloseRedisClient 关闭指定的 Redis 客户端并从缓存中移除
|
||||
func CloseRedisClient(connectionName string) error {
|
||||
if connectionName == "" {
|
||||
connectionName = "default"
|
||||
}
|
||||
|
||||
if client, ok := redisClients.LoadAndDelete(connectionName); ok {
|
||||
redisClient := client.(*redis.Client)
|
||||
return redisClient.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseAllRedisClients 关闭所有缓存的 Redis 客户端
|
||||
func CloseAllRedisClients() error {
|
||||
var errs []error
|
||||
redisClients.Range(func(key, value any) bool {
|
||||
client := value.(*redis.Client)
|
||||
if err := client.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
redisClients.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("关闭 Redis 客户端时发生错误: %v", errs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/goravel/framework/facades"
|
||||
"github.com/goravel/framework/support/str"
|
||||
)
|
||||
|
||||
// IsSensitiveField 检查字段名是否是敏感字段
|
||||
func IsSensitiveField(fieldName string) bool {
|
||||
keyLower := str.Of(fieldName).Lower().String()
|
||||
|
||||
// 获取配置的敏感字段列表
|
||||
sensitiveFieldsInterface := facades.Config().Get("operation_log.sensitive_fields", []string{})
|
||||
if sensitiveFields, ok := sensitiveFieldsInterface.([]any); ok {
|
||||
for _, fieldInterface := range sensitiveFields {
|
||||
if field, ok := fieldInterface.(string); ok {
|
||||
if keyLower == str.Of(field).Lower().String() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if sensitiveFields, ok := sensitiveFieldsInterface.([]string); ok {
|
||||
for _, field := range sensitiveFields {
|
||||
if keyLower == str.Of(field).Lower().String() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否包含敏感关键词
|
||||
sensitiveKeywordsInterface := facades.Config().Get("operation_log.sensitive_keywords", []string{})
|
||||
if sensitiveKeywords, ok := sensitiveKeywordsInterface.([]any); ok {
|
||||
for _, keywordInterface := range sensitiveKeywords {
|
||||
if keyword, ok := keywordInterface.(string); ok {
|
||||
if str.Of(keyLower).Contains(str.Of(keyword).Lower().String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if sensitiveKeywords, ok := sensitiveKeywordsInterface.([]string); ok {
|
||||
for _, keyword := range sensitiveKeywords {
|
||||
if str.Of(keyLower).Contains(str.Of(keyword).Lower().String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
|
||||
"goravel/app/constants"
|
||||
"goravel/app/utils/errorlog"
|
||||
)
|
||||
|
||||
// GetShardingTableName 根据时间获取分表名称
|
||||
// baseTableName: 基础表名,如 "orders"
|
||||
// orderTime: 订单时间
|
||||
// 返回: 分表名称,如 "orders_202501"
|
||||
func GetShardingTableName(baseTableName string, orderTime time.Time) string {
|
||||
return fmt.Sprintf("%s_%s", baseTableName, orderTime.Format("200601"))
|
||||
}
|
||||
|
||||
// GetUserBalanceLogsShardingTableName 根据 user_id 获取用户余额变动记录分表名称
|
||||
// userID: 用户ID
|
||||
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
|
||||
// 分表逻辑:user_id % UserBalanceLogsShards
|
||||
// 注意:此函数为特定表实现,如需为其他表实现哈希分表,请使用 GetHashShardingTableName
|
||||
func GetUserBalanceLogsShardingTableName(userID uint) string {
|
||||
return GetHashShardingTableName("user_balance_logs", userID, constants.UserBalanceLogsShards)
|
||||
}
|
||||
|
||||
// HashShardingConfig 哈希分表配置
|
||||
type HashShardingConfig struct {
|
||||
BaseTableName string // 基础表名,如 "user_balance_logs"
|
||||
NumberOfShards int // 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
|
||||
}
|
||||
|
||||
// 预定义的哈希分表配置
|
||||
var (
|
||||
// UserBalanceLogsShardingConfig 用户余额变动记录分表配置
|
||||
UserBalanceLogsShardingConfig = HashShardingConfig{
|
||||
BaseTableName: "user_balance_logs",
|
||||
NumberOfShards: constants.UserBalanceLogsShards,
|
||||
}
|
||||
)
|
||||
|
||||
// GetHashShardingTableName 通用的哈希分表名称生成函数
|
||||
// baseTableName: 基础表名,如 "user_balance_logs"
|
||||
// shardingKey: 分表键值(uint 类型),如 user_id
|
||||
// numberOfShards: 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
|
||||
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
|
||||
// 分表逻辑:shardingKey % numberOfShards
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// // 为 user_balance_logs 表分表(4个分表)
|
||||
// tableName := GetHashShardingTableName("user_balance_logs", userID, 4)
|
||||
//
|
||||
// // 为新的表 example_table 分表(8个分表)
|
||||
// tableName := GetHashShardingTableName("example_table", entityID, 8)
|
||||
func GetHashShardingTableName(baseTableName string, shardingKey uint, numberOfShards int) string {
|
||||
if numberOfShards <= 0 {
|
||||
// 如果分表数量无效,返回基础表名(不分表)
|
||||
return baseTableName
|
||||
}
|
||||
shardIndex := int(shardingKey) % numberOfShards
|
||||
return fmt.Sprintf("%s_%d", baseTableName, shardIndex)
|
||||
}
|
||||
|
||||
// GetHashShardingTableNameByConfig 通过配置获取哈希分表名称
|
||||
func GetHashShardingTableNameByConfig(config HashShardingConfig, shardingKey uint) string {
|
||||
return GetHashShardingTableName(config.BaseTableName, shardingKey, config.NumberOfShards)
|
||||
}
|
||||
|
||||
// GetAllHashShardingTableNames 获取所有哈希分表名称列表
|
||||
// 用于批量操作所有分表(如迁移、统计等)
|
||||
func GetAllHashShardingTableNames(config HashShardingConfig) []string {
|
||||
tableNames := make([]string, config.NumberOfShards)
|
||||
for i := 0; i < config.NumberOfShards; i++ {
|
||||
tableNames[i] = fmt.Sprintf("%s_%d", config.BaseTableName, i)
|
||||
}
|
||||
return tableNames
|
||||
}
|
||||
|
||||
// GetHashShardIndex 获取分表索引
|
||||
func GetHashShardIndex(shardingKey uint, numberOfShards int) int {
|
||||
if numberOfShards <= 0 {
|
||||
return 0
|
||||
}
|
||||
return int(shardingKey) % numberOfShards
|
||||
}
|
||||
|
||||
// GetShardingTableNames 获取时间范围内的所有分表名称
|
||||
// baseTableName: 基础表名
|
||||
// startTime: 开始时间
|
||||
// endTime: 结束时间
|
||||
// 返回: 分表名称列表
|
||||
func GetShardingTableNames(baseTableName string, startTime, endTime time.Time) []string {
|
||||
var tableNames []string
|
||||
|
||||
// 如果结束时间是零值,使用当前时间
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now().UTC()
|
||||
}
|
||||
|
||||
// 确保开始时间不晚于结束时间
|
||||
if startTime.After(endTime) {
|
||||
return tableNames
|
||||
}
|
||||
|
||||
// 从开始时间到结束时间,按月遍历
|
||||
current := time.Date(startTime.Year(), startTime.Month(), 1, 0, 0, 0, 0, startTime.Location())
|
||||
end := time.Date(endTime.Year(), endTime.Month(), 1, 0, 0, 0, 0, endTime.Location())
|
||||
|
||||
for !current.After(end) {
|
||||
tableNames = append(tableNames, GetShardingTableName(baseTableName, current))
|
||||
current = current.AddDate(0, 1, 0) // 加一个月
|
||||
}
|
||||
|
||||
return tableNames
|
||||
}
|
||||
|
||||
// DefaultMaxTimeRangeMonths 默认最大时间范围(月数)
|
||||
// 可以通过配置覆盖,用于限制查询时间范围,避免跨太多分表
|
||||
const DefaultMaxTimeRangeMonths = 3
|
||||
|
||||
// TimeRangeError 时间范围验证错误
|
||||
type TimeRangeError struct {
|
||||
Key string
|
||||
Params map[string]any
|
||||
}
|
||||
|
||||
func (e *TimeRangeError) Error() string {
|
||||
// 返回翻译键,由调用方进行翻译
|
||||
return e.Key
|
||||
}
|
||||
|
||||
// ValidateTimeRange 验证时间范围是否超过指定月数
|
||||
// startTime: 开始时间
|
||||
// endTime: 结束时间
|
||||
// maxMonths: 最大允许的月数,如果为0则使用默认值 DefaultMaxTimeRangeMonths
|
||||
// 返回: 是否有效,错误信息(错误信息包含翻译键和参数)
|
||||
func ValidateTimeRange(startTime, endTime time.Time, maxMonths ...int) (bool, error) {
|
||||
// 检查开始时间是否晚于结束时间(忽略零值时间)
|
||||
if !startTime.IsZero() && !endTime.IsZero() && startTime.After(endTime) {
|
||||
return false, &TimeRangeError{
|
||||
Key: "start_time_after_end_time",
|
||||
Params: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// 确定最大月数
|
||||
maxMonthsValue := DefaultMaxTimeRangeMonths
|
||||
if len(maxMonths) > 0 && maxMonths[0] > 0 {
|
||||
maxMonthsValue = maxMonths[0]
|
||||
}
|
||||
|
||||
// 检查时间范围是否超过指定月数(忽略零值结束时间)
|
||||
if !endTime.IsZero() {
|
||||
maxTimeLater := startTime.AddDate(0, maxMonthsValue, 0)
|
||||
if endTime.After(maxTimeLater) {
|
||||
return false, &TimeRangeError{
|
||||
Key: "time_range_exceeded",
|
||||
Params: map[string]any{
|
||||
"months": maxMonthsValue,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetAllExistingShardingTables 获取数据库中所有已存在的分表名称
|
||||
// baseTableName: 基础表名,如 "orders" 或 "order_details"
|
||||
// 返回: 已存在的分表名称列表
|
||||
func GetAllExistingShardingTables(baseTableName string) ([]string, error) {
|
||||
var tableNames []string
|
||||
|
||||
// 获取当前数据库名
|
||||
dbName := facades.Config().GetString("database.connections.mysql.database")
|
||||
if dbName == "" {
|
||||
dbName = facades.Config().GetString("database.connections.postgresql.database")
|
||||
}
|
||||
|
||||
// 构建表名匹配模式:orders_YYYYMM 或 order_details_YYYYMM
|
||||
pattern := fmt.Sprintf("%s_%%", baseTableName)
|
||||
|
||||
// 查询所有匹配的表名
|
||||
query := `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = ?
|
||||
AND table_name LIKE ?
|
||||
ORDER BY table_name
|
||||
`
|
||||
|
||||
// 执行查询,使用 Scan 获取结果
|
||||
var rows []map[string]any
|
||||
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
|
||||
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
|
||||
"pattern": pattern,
|
||||
"error": err.Error(),
|
||||
}, "查询分表失败: %v", err)
|
||||
return nil, fmt.Errorf("查询分表失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证表名格式(确保是有效的分表名称,格式为 baseTableName_YYYYMM)
|
||||
// 从 pattern 中提取基础表名(移除末尾的 %)
|
||||
baseTableNameFromPattern := strings.TrimSuffix(pattern, "_%")
|
||||
patternRegex := regexp.MustCompile(fmt.Sprintf("^%s_\\d{6}$", regexp.QuoteMeta(baseTableNameFromPattern)))
|
||||
|
||||
for _, row := range rows {
|
||||
// 尝试不同的字段名格式(MySQL 可能返回不同的大小写)
|
||||
var tableName string
|
||||
var ok bool
|
||||
|
||||
// 尝试 table_name (小写)
|
||||
if tableName, ok = row["table_name"].(string); !ok {
|
||||
// 尝试 TABLE_NAME (大写)
|
||||
if tableName, ok = row["TABLE_NAME"].(string); !ok {
|
||||
// 尝试遍历所有键
|
||||
for key, value := range row {
|
||||
if (key == "table_name" || key == "TABLE_NAME" || key == "Table_Name") && value != nil {
|
||||
if str, ok := value.(string); ok {
|
||||
tableName = str
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ok && tableName != "" {
|
||||
// 验证表名格式
|
||||
if patternRegex.MatchString(tableName) {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tableNames, nil
|
||||
}
|
||||
|
||||
// GetAllExistingShardingTablesByPattern 通过表名模式获取所有已存在的分表
|
||||
// 这是一个更通用的方法,可以通过自定义模式匹配
|
||||
// pattern: 表名匹配模式,如 "orders_%" 或 "order_details_%"
|
||||
func GetAllExistingShardingTablesByPattern(pattern string) ([]string, error) {
|
||||
var tableNames []string
|
||||
|
||||
// 获取当前数据库名
|
||||
dbName := facades.Config().GetString("database.connections.mysql.database")
|
||||
if dbName == "" {
|
||||
dbName = facades.Config().GetString("database.connections.postgresql.database")
|
||||
}
|
||||
|
||||
// 查询所有匹配的表名
|
||||
query := `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = ?
|
||||
AND table_name LIKE ?
|
||||
ORDER BY table_name
|
||||
`
|
||||
|
||||
// 执行查询,使用 Scan 获取结果
|
||||
var rows []map[string]any
|
||||
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
|
||||
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
|
||||
"pattern": pattern,
|
||||
"error": err.Error(),
|
||||
}, "查询分表失败: %v", err)
|
||||
return nil, fmt.Errorf("查询分表失败: %v", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
if tableName, ok := row["table_name"].(string); ok {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
|
||||
return tableNames, nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
// ShardingNoConfig 分表单号配置
|
||||
type ShardingNoConfig struct {
|
||||
Prefix string // 单号前缀,如 "ORD"、"PAY"
|
||||
DateFormat string // 日期格式,如 "200601"(年月)、"20060102"(年月日)
|
||||
}
|
||||
|
||||
// 预定义配置
|
||||
var (
|
||||
// OrderNoConfig 订单号配置(ORD + YYYYMM + ULID)
|
||||
OrderNoConfig = ShardingNoConfig{
|
||||
Prefix: "ORD",
|
||||
DateFormat: "200601", // YYYYMM
|
||||
}
|
||||
|
||||
// PaymentNoConfig 支付单号配置(PAY + YYYYMMDD + ULID)
|
||||
PaymentNoConfig = ShardingNoConfig{
|
||||
Prefix: "PAY",
|
||||
DateFormat: "20060102", // YYYYMMDD
|
||||
}
|
||||
)
|
||||
|
||||
// GenerateShardingNo 生成分表单号
|
||||
// 格式:前缀 + 日期 + ULID
|
||||
// 例如:ORD20250101ARZ3S0K5M2X9P4Q6R8T1V3W5Y7Z9
|
||||
func GenerateShardingNo(config ShardingNoConfig) string {
|
||||
now := time.Now().UTC()
|
||||
dateStr := now.Format(config.DateFormat)
|
||||
ulidStr := ulid.Make().String()
|
||||
return fmt.Sprintf("%s%s%s", config.Prefix, dateStr, ulidStr)
|
||||
}
|
||||
|
||||
// ParseShardingNoDate 从分表单号解析日期
|
||||
// 返回:解析的时间和是否成功
|
||||
func ParseShardingNoDate(no string, config ShardingNoConfig) (time.Time, bool) {
|
||||
prefixLen := len(config.Prefix)
|
||||
dateLen := len(config.DateFormat)
|
||||
minLen := prefixLen + dateLen + 1 // 前缀 + 日期 + 至少1位ULID
|
||||
|
||||
// 验证长度和前缀
|
||||
if len(no) < minLen || !strings.HasPrefix(no, config.Prefix) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// 提取日期部分
|
||||
dateStr := no[prefixLen : prefixLen+dateLen]
|
||||
|
||||
// 验证日期格式(简单验证:都是数字)
|
||||
for _, c := range dateStr {
|
||||
if c < '0' || c > '9' {
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// 解析日期
|
||||
parsedTime, err := time.Parse(config.DateFormat, dateStr)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
return parsedTime, true
|
||||
}
|
||||
|
||||
// ParseShardingNoYearMonth 从分表单号解析年月字符串(用于分表定位)
|
||||
// 返回:年月字符串(如 "202501")和是否成功
|
||||
func ParseShardingNoYearMonth(no string, config ShardingNoConfig) (string, bool) {
|
||||
parsedTime, ok := ParseShardingNoDate(no, config)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return parsedTime.Format("200601"), true
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package utils
|
||||
|
||||
import "time"
|
||||
|
||||
// 时间格式常量
|
||||
const (
|
||||
// DateTimeFormat 标准日期时间格式 (YYYY-MM-DD HH:mm:ss)
|
||||
DateTimeFormat = "2006-01-02 15:04:05"
|
||||
|
||||
// DateFormat 标准日期格式 (YYYY-MM-DD)
|
||||
DateFormat = "2006-01-02"
|
||||
|
||||
// DateTimeFormatT 带T的日期时间格式 (YYYY-MM-DDTHH:mm:ss)
|
||||
DateTimeFormatT = "2006-01-02T15:04:05"
|
||||
|
||||
// DateTimeFormatMs 带毫秒的日期时间格式
|
||||
DateTimeFormatMs = "2006-01-02 15:04:05.000"
|
||||
|
||||
// DateTimeFormatTZ 带时区的日期时间格式
|
||||
DateTimeFormatTZ = "2006-01-02T15:04:05.000Z07:00"
|
||||
|
||||
// YearMonthFormat 年月格式 (YYYYMM)
|
||||
YearMonthFormat = "200601"
|
||||
)
|
||||
|
||||
// ParseDateTime 解析标准格式的日期时间字符串 (2006-01-02 15:04:05)
|
||||
func ParseDateTime(s string) (time.Time, error) {
|
||||
return time.Parse(DateTimeFormat, s)
|
||||
}
|
||||
|
||||
// ParseDateTimeUTC 解析标准格式的日期时间字符串并转换为 UTC
|
||||
func ParseDateTimeUTC(s string) (time.Time, error) {
|
||||
t, err := time.Parse(DateTimeFormat, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// ParseDate 解析标准日期格式字符串 (2006-01-02)
|
||||
func ParseDate(s string) (time.Time, error) {
|
||||
return time.Parse(DateFormat, s)
|
||||
}
|
||||
|
||||
// ParseDateUTC 解析标准日期格式字符串并转换为 UTC
|
||||
func ParseDateUTC(s string) (time.Time, error) {
|
||||
t, err := time.Parse(DateFormat, s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// FormatDateTime 格式化时间为标准日期时间字符串 (2006-01-02 15:04:05)
|
||||
func FormatDateTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(DateTimeFormat)
|
||||
}
|
||||
|
||||
// FormatDate 格式化时间为标准日期字符串 (2006-01-02)
|
||||
func FormatDate(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(DateFormat)
|
||||
}
|
||||
|
||||
// FormatDateTimePtr 格式化时间指针为标准日期时间字符串
|
||||
func FormatDateTimePtr(t *time.Time) string {
|
||||
if t == nil || t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(DateTimeFormat)
|
||||
}
|
||||
|
||||
// ParseDateTimeInLocation 在指定时区解析日期时间字符串
|
||||
func ParseDateTimeInLocation(s string, loc *time.Location) (time.Time, error) {
|
||||
return time.ParseInLocation(DateTimeFormat, s, loc)
|
||||
}
|
||||
|
||||
// FormatYearMonth 格式化时间为年月字符串 (200601)
|
||||
func FormatYearMonth(t time.Time) string {
|
||||
return t.Format(YearMonthFormat)
|
||||
}
|
||||
|
||||
// ParseYearMonth 解析年月格式字符串 (200601)
|
||||
func ParseYearMonth(s string) (time.Time, error) {
|
||||
return time.Parse(YearMonthFormat, s)
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package traceid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/contracts/http"
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
ContextKey contextKey = "trace_id"
|
||||
headerKey string = "X-Trace-Id"
|
||||
requestIDHeader string = "X-Request-Id"
|
||||
)
|
||||
|
||||
// Generate returns a new trace id using ULID to ensure it sorts well and is URL safe.
|
||||
func Generate() string {
|
||||
return strings.ToLower(ulid.Make().String())
|
||||
}
|
||||
|
||||
// EnsureHTTPContext stores a trace id on the http.Context (and returns it).
|
||||
func EnsureHTTPContext(ctx http.Context, preferred string) string {
|
||||
if ctx == nil {
|
||||
return Generate()
|
||||
}
|
||||
|
||||
traceID := preferred
|
||||
if traceID == "" {
|
||||
traceID = ctx.Request().Header(headerKey, "")
|
||||
}
|
||||
if traceID == "" {
|
||||
traceID = ctx.Request().Header(requestIDHeader, "")
|
||||
}
|
||||
if traceID == "" {
|
||||
traceID = Generate()
|
||||
}
|
||||
|
||||
ctx.WithValue(string(ContextKey), traceID)
|
||||
return traceID
|
||||
}
|
||||
|
||||
// FromHTTPContext retrieves the stored trace id from http.Context.
|
||||
func FromHTTPContext(ctx http.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if value := ctx.Value(string(ContextKey)); value != nil {
|
||||
if traceID, ok := value.(string); ok {
|
||||
return traceID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// StoreHTTPContext stores an existing trace id into the http context.
|
||||
func StoreHTTPContext(ctx http.Context, traceID string) {
|
||||
if ctx == nil || traceID == "" {
|
||||
return
|
||||
}
|
||||
ctx.WithValue(string(ContextKey), traceID)
|
||||
}
|
||||
|
||||
// EnsureContext ensures a standard context carries a trace id and returns both.
|
||||
func EnsureContext(ctx context.Context) (context.Context, string) {
|
||||
traceID := FromContext(ctx)
|
||||
if traceID == "" {
|
||||
traceID = Generate()
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, ContextKey, traceID), traceID
|
||||
}
|
||||
|
||||
// WithTrace assigns a trace id into the provided context (or background if nil).
|
||||
func WithTrace(ctx context.Context, traceID string) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if traceID == "" {
|
||||
traceID = Generate()
|
||||
}
|
||||
return context.WithValue(ctx, ContextKey, traceID)
|
||||
}
|
||||
|
||||
// FromContext reads the trace id from a standard context.
|
||||
func FromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if traceID, ok := ctx.Value(ContextKey).(string); ok {
|
||||
return traceID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// DeriveContextFromHTTP builds a standard context containing the http trace id.
|
||||
func DeriveContextFromHTTP(ctx http.Context) context.Context {
|
||||
traceID := FromHTTPContext(ctx)
|
||||
if traceID == "" {
|
||||
traceID = Generate()
|
||||
StoreHTTPContext(ctx, traceID)
|
||||
}
|
||||
return context.WithValue(context.Background(), ContextKey, traceID)
|
||||
}
|
||||
|
||||
// HeaderName exposes the header used to propagate the trace id.
|
||||
func HeaderName() string {
|
||||
return headerKey
|
||||
}
|
||||
|
||||
// RequestHeaderFallback exposes secondary header name for incoming trace ids.
|
||||
func RequestHeaderFallback() string {
|
||||
return requestIDHeader
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
|
||||
"goravel/lang"
|
||||
)
|
||||
|
||||
// loadLangMessages 加载语言文件的 messages 对象(内部辅助函数)
|
||||
// 优先从文件系统读取,失败时从 embed FS 读取(支持生产环境)
|
||||
// 返回 messages map 和是否成功加载
|
||||
func loadLangMessages(langCode string) (map[string]any, bool) {
|
||||
var langData []byte
|
||||
var err error
|
||||
|
||||
// 获取语言文件路径(使用框架配置的路径)
|
||||
langPath := facades.Config().GetString("app.lang_path", "lang")
|
||||
langFile := filepath.Join(langPath, fmt.Sprintf("%s.json", langCode))
|
||||
|
||||
// 优先尝试从文件系统读取
|
||||
langData, err = os.ReadFile(langFile)
|
||||
if err != nil {
|
||||
// 文件系统读取失败,尝试从 embed FS 读取
|
||||
embedFile := fmt.Sprintf("%s.json", langCode)
|
||||
langData, err = lang.FS.ReadFile(embedFile)
|
||||
if err != nil {
|
||||
facades.Log().Debugf("读取语言文件失败: %s, embed: %s, error=%v", langFile, embedFile, err)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var langMap map[string]any
|
||||
if err := json.Unmarshal(langData, &langMap); err != nil {
|
||||
facades.Log().Debugf("解析语言文件失败: lang=%s, error=%v", langCode, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 获取 messages 对象(使用类型辅助函数)
|
||||
messages, ok := GetMap(langMap, "messages")
|
||||
if !ok {
|
||||
facades.Log().Debugf("语言文件中没有 messages 对象: lang=%s", langCode)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return messages, true
|
||||
}
|
||||
|
||||
// TranslateHeaders 翻译表头(直接读取语言文件)
|
||||
// 这是一个通用函数,可以被任何需要翻译表头的 job 使用
|
||||
//
|
||||
// 参数:
|
||||
// - headerKeys: 需要翻译的键列表(如 ["export_header_id", "export_header_order_no"])
|
||||
// - lang: 语言代码(如 "cn" 或 "en")
|
||||
//
|
||||
// 返回:
|
||||
// - []string: 翻译后的表头列表,如果翻译失败则返回原始键
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// headerKeys := []string{"export_header_id", "export_header_order_no"}
|
||||
// headers := utils.TranslateHeaders(headerKeys, "cn")
|
||||
func TranslateHeaders(headerKeys []string, lang string) []string {
|
||||
headers := make([]string, len(headerKeys))
|
||||
|
||||
// 加载语言文件的 messages 对象
|
||||
messages, ok := loadLangMessages(lang)
|
||||
if !ok {
|
||||
// 如果读取失败,使用原始键
|
||||
facades.Log().Warningf("加载语言文件失败: lang=%s, 使用原始键", lang)
|
||||
return append([]string(nil), headerKeys...)
|
||||
}
|
||||
|
||||
// 翻译每个键(使用泛型辅助函数)
|
||||
for i, key := range headerKeys {
|
||||
fullKey := "messages." + key
|
||||
if value, ok := GetString(messages, key); ok && value != "" {
|
||||
headers[i] = value
|
||||
} else {
|
||||
// 如果翻译失败,使用原始键
|
||||
headers[i] = key
|
||||
facades.Log().Debugf("翻译键未找到: %s (语言: %s)", fullKey, lang)
|
||||
}
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// TranslateKey 翻译单个键(从语言文件读取,支持文件系统和embed FS)
|
||||
// 这是一个通用函数,可以被任何需要翻译单个键的地方使用(包括Job等无http.Context的场景)
|
||||
//
|
||||
// 参数:
|
||||
// - key: 需要翻译的键(如 "export_order_status_pending")
|
||||
// - lang: 语言代码(如 "cn" 或 "en")
|
||||
// - defaultValue: 如果翻译失败时返回的默认值(通常为原始键)
|
||||
//
|
||||
// 返回:
|
||||
// - string: 翻译后的文本,如果翻译失败则返回 defaultValue
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// statusText := utils.TranslateKey("export_order_status_pending", "cn", "pending")
|
||||
func TranslateKey(key, lang, defaultValue string) string {
|
||||
// 加载语言文件的 messages 对象(支持文件系统和embed FS)
|
||||
messages, ok := loadLangMessages(lang)
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// 尝试获取翻译值
|
||||
if value, ok := GetString(messages, key); ok && value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
// 如果翻译失败,返回默认值
|
||||
return defaultValue
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// GetValue 从 map[string]any 中安全地获取指定类型的值
|
||||
// 支持多种类型转换,如果转换失败返回零值和 false
|
||||
func GetValue[T any](m map[string]any, key string) (T, bool) {
|
||||
var zero T
|
||||
val, ok := m[key]
|
||||
if !ok {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// 尝试直接类型断言
|
||||
if v, ok := val.(T); ok {
|
||||
return v, true
|
||||
}
|
||||
|
||||
// 对于数字类型,尝试从其他数字类型转换
|
||||
return convertNumeric[T](val)
|
||||
}
|
||||
|
||||
// GetUint 从 map[string]any 中获取 uint 值(支持多种数字类型转换)
|
||||
func GetUint(m map[string]any, key string) (uint, bool) {
|
||||
return GetValue[uint](m, key)
|
||||
}
|
||||
|
||||
// GetFloat64 从 map[string]any 中获取 float64 值(支持多种数字类型转换)
|
||||
func GetFloat64(m map[string]any, key string) (float64, bool) {
|
||||
return GetValue[float64](m, key)
|
||||
}
|
||||
|
||||
// GetString 从 map[string]any 中获取 string 值
|
||||
func GetString(m map[string]any, key string) (string, bool) {
|
||||
val, ok := m[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if v, ok := val.(string); ok {
|
||||
return v, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetMap 从 map[string]any 中获取 map[string]any 值
|
||||
func GetMap(m map[string]any, key string) (map[string]any, bool) {
|
||||
val, ok := m[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if v, ok := val.(map[string]any); ok {
|
||||
return v, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// convertNumeric 将值转换为数字类型(支持多种数字类型)
|
||||
func convertNumeric[T any](val any) (T, bool) {
|
||||
var zero T
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
return convertFromFloat64[T](v)
|
||||
case int:
|
||||
return convertFromInt[T](v)
|
||||
case uint:
|
||||
return convertFromUint[T](v)
|
||||
case int64:
|
||||
return convertFromInt64[T](v)
|
||||
case uint64:
|
||||
return convertFromUint64[T](v)
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// convertFromFloat64 从 float64 转换
|
||||
func convertFromFloat64[T any](v float64) (T, bool) {
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case uint:
|
||||
return any(uint(v)).(T), true
|
||||
case int:
|
||||
return any(int(v)).(T), true
|
||||
case float64:
|
||||
return any(v).(T), true
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// convertFromInt 从 int 转换
|
||||
func convertFromInt[T any](v int) (T, bool) {
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case uint:
|
||||
return any(uint(v)).(T), true
|
||||
case int:
|
||||
return any(v).(T), true
|
||||
case float64:
|
||||
return any(float64(v)).(T), true
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// convertFromUint 从 uint 转换
|
||||
func convertFromUint[T any](v uint) (T, bool) {
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case uint:
|
||||
return any(v).(T), true
|
||||
case int:
|
||||
return any(int(v)).(T), true
|
||||
case float64:
|
||||
return any(float64(v)).(T), true
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// convertFromInt64 从 int64 转换
|
||||
func convertFromInt64[T any](v int64) (T, bool) {
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case uint:
|
||||
return any(uint(v)).(T), true
|
||||
case int:
|
||||
return any(int(v)).(T), true
|
||||
case float64:
|
||||
return any(float64(v)).(T), true
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// convertFromUint64 从 uint64 转换
|
||||
func convertFromUint64[T any](v uint64) (T, bool) {
|
||||
var zero T
|
||||
switch any(zero).(type) {
|
||||
case uint:
|
||||
return any(uint(v)).(T), true
|
||||
case int:
|
||||
return any(int(v)).(T), true
|
||||
case float64:
|
||||
return any(float64(v)).(T), true
|
||||
default:
|
||||
return zero, false
|
||||
}
|
||||
}
|
||||
|
||||
// MustGetValue 从 map[string]any 中获取值,如果不存在或类型不匹配则 panic
|
||||
// 仅在确定值存在且类型正确时使用
|
||||
func MustGetValue[T any](m map[string]any, key string) T {
|
||||
val, ok := GetValue[T](m, key)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("key %s not found or type mismatch in map", key))
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// FillFiltersFromMap 从 map[string]any 填充 Filters 结构体
|
||||
// 支持 string, uint, float64 类型,使用字段名的 snake_case 作为 map 的 key
|
||||
// 示例:
|
||||
//
|
||||
// filters := services.OrderFilters{}
|
||||
// utils.FillFiltersFromMap(m, &filters)
|
||||
func FillFiltersFromMap(m map[string]any, filtersPtr any) {
|
||||
v := reflect.ValueOf(filtersPtr)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
t := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
structField := t.Field(i)
|
||||
|
||||
// 获取 json tag 或使用 snake_case 字段名
|
||||
key := structField.Tag.Get("json")
|
||||
if key == "" || key == "-" {
|
||||
key = toSnakeCase(structField.Name)
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.String:
|
||||
if val, ok := GetString(m, key); ok {
|
||||
field.SetString(val)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
if val, ok := GetUint(m, key); ok {
|
||||
field.SetUint(uint64(val))
|
||||
}
|
||||
case reflect.Float64, reflect.Float32:
|
||||
if val, ok := GetFloat64(m, key); ok {
|
||||
field.SetFloat(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// toSnakeCase 将 PascalCase/camelCase 转换为 snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result []byte
|
||||
for i, c := range s {
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
if i > 0 {
|
||||
result = append(result, '_')
|
||||
}
|
||||
result = append(result, byte(c+'a'-'A'))
|
||||
} else {
|
||||
result = append(result, byte(c))
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/goravel/framework/support/str"
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
// GenerateULID 生成 ULID(默认使用当前时间,与 Laravel 行为一致)
|
||||
// 返回: ULID 字符串
|
||||
func GenerateULID() string {
|
||||
entropy := ulid.DefaultEntropy()
|
||||
return ulid.MustNew(ulid.Timestamp(time.Now()), entropy).String()
|
||||
}
|
||||
|
||||
// GenerateULIDWithTime 使用指定时间生成 ULID
|
||||
// t: 指定的时间
|
||||
// 返回: ULID 字符串
|
||||
func GenerateULIDWithTime(t time.Time) string {
|
||||
entropy := ulid.DefaultEntropy()
|
||||
return ulid.MustNew(ulid.Timestamp(t), entropy).String()
|
||||
}
|
||||
|
||||
// ParseULID 解析 ULID 字符串
|
||||
// ulidStr: ULID 字符串
|
||||
// 返回: ULID 对象和错误
|
||||
func ParseULID(ulidStr string) (ulid.ULID, error) {
|
||||
return ulid.Parse(ulidStr)
|
||||
}
|
||||
|
||||
// ParseULIDTime 从 ULID 解析出时间
|
||||
// ulidStr: ULID 字符串
|
||||
// 返回: 时间对象和错误
|
||||
func ParseULIDTime(ulidStr string) (time.Time, error) {
|
||||
id, err := ulid.Parse(ulidStr)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return ulid.Time(id.Time()), nil
|
||||
}
|
||||
|
||||
// ParseULIDTimeString 从 ULID 解析出时间字符串
|
||||
// ulidStr: ULID 字符串
|
||||
// format: 时间格式,如 DateTimeFormat
|
||||
// 返回: 格式化的时间字符串和错误
|
||||
func ParseULIDTimeString(ulidStr string, format string) (string, error) {
|
||||
t, err := ParseULIDTime(ulidStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if format == "" {
|
||||
format = DateTimeFormat
|
||||
}
|
||||
return t.Format(format), nil
|
||||
}
|
||||
|
||||
// IsValidULID 验证 ULID 字符串是否有效
|
||||
// 使用 Goravel 框架的字符串库进行验证(与框架的 IsUlid 方法一致)
|
||||
// 参考: https://www.goravel.dev/zh_CN/digging-deeper/strings.html#isulid
|
||||
// ulidStr: ULID 字符串
|
||||
// 返回: 是否有效
|
||||
func IsValidULID(ulidStr string) bool {
|
||||
// 使用 Goravel 框架的字符串库验证(框架只提供验证功能,不提供生成功能)
|
||||
return str.Of(ulidStr).IsUlid()
|
||||
}
|
||||
|
||||
// GetULIDTimestamp 获取 ULID 的时间戳(毫秒)
|
||||
// ulidStr: ULID 字符串
|
||||
// 返回: Unix 时间戳(毫秒)和错误
|
||||
func GetULIDTimestamp(ulidStr string) (int64, error) {
|
||||
id, err := ulid.Parse(ulidStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(id.Time()), nil
|
||||
}
|
||||
Reference in New Issue
Block a user