Files
server/app/utils/sharding_helper.go
2026-01-16 15:49:34 +08:00

286 lines
9.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package utils
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/goravel/framework/facades"
"goravel/app/constants"
"goravel/app/utils/errorlog"
)
// GetShardingTableName 根据时间获取分表名称
// baseTableName: 基础表名,如 "orders"
// orderTime: 订单时间
// 返回: 分表名称,如 "orders_202501"
func GetShardingTableName(baseTableName string, orderTime time.Time) string {
return fmt.Sprintf("%s_%s", baseTableName, orderTime.Format("200601"))
}
// GetUserBalanceLogsShardingTableName 根据 user_id 获取用户余额变动记录分表名称
// userID: 用户ID
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
// 分表逻辑:user_id % UserBalanceLogsShards
// 注意:此函数为特定表实现,如需为其他表实现哈希分表,请使用 GetHashShardingTableName
func GetUserBalanceLogsShardingTableName(userID uint) string {
return GetHashShardingTableName("user_balance_logs", userID, constants.UserBalanceLogsShards)
}
// HashShardingConfig 哈希分表配置
type HashShardingConfig struct {
BaseTableName string // 基础表名,如 "user_balance_logs"
NumberOfShards int // 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
}
// 预定义的哈希分表配置
var (
// UserBalanceLogsShardingConfig 用户余额变动记录分表配置
UserBalanceLogsShardingConfig = HashShardingConfig{
BaseTableName: "user_balance_logs",
NumberOfShards: constants.UserBalanceLogsShards,
}
)
// GetHashShardingTableName 通用的哈希分表名称生成函数
// baseTableName: 基础表名,如 "user_balance_logs"
// shardingKey: 分表键值(uint 类型),如 user_id
// numberOfShards: 分表数量,建议为 2 的幂次(如 4, 8, 16, 32, 64 等)
// 返回: 分表名称,如 "user_balance_logs_0", "user_balance_logs_1" 等
// 分表逻辑:shardingKey % numberOfShards
//
// 使用示例:
//
// // 为 user_balance_logs 表分表(4个分表)
// tableName := GetHashShardingTableName("user_balance_logs", userID, 4)
//
// // 为新的表 example_table 分表(8个分表)
// tableName := GetHashShardingTableName("example_table", entityID, 8)
func GetHashShardingTableName(baseTableName string, shardingKey uint, numberOfShards int) string {
if numberOfShards <= 0 {
// 如果分表数量无效,返回基础表名(不分表)
return baseTableName
}
shardIndex := int(shardingKey) % numberOfShards
return fmt.Sprintf("%s_%d", baseTableName, shardIndex)
}
// GetHashShardingTableNameByConfig 通过配置获取哈希分表名称
func GetHashShardingTableNameByConfig(config HashShardingConfig, shardingKey uint) string {
return GetHashShardingTableName(config.BaseTableName, shardingKey, config.NumberOfShards)
}
// GetAllHashShardingTableNames 获取所有哈希分表名称列表
// 用于批量操作所有分表(如迁移、统计等)
func GetAllHashShardingTableNames(config HashShardingConfig) []string {
tableNames := make([]string, config.NumberOfShards)
for i := 0; i < config.NumberOfShards; i++ {
tableNames[i] = fmt.Sprintf("%s_%d", config.BaseTableName, i)
}
return tableNames
}
// GetHashShardIndex 获取分表索引
func GetHashShardIndex(shardingKey uint, numberOfShards int) int {
if numberOfShards <= 0 {
return 0
}
return int(shardingKey) % numberOfShards
}
// GetShardingTableNames 获取时间范围内的所有分表名称
// baseTableName: 基础表名
// startTime: 开始时间
// endTime: 结束时间
// 返回: 分表名称列表
func GetShardingTableNames(baseTableName string, startTime, endTime time.Time) []string {
var tableNames []string
// 如果结束时间是零值,使用当前时间
if endTime.IsZero() {
endTime = time.Now().UTC()
}
// 确保开始时间不晚于结束时间
if startTime.After(endTime) {
return tableNames
}
// 从开始时间到结束时间,按月遍历
current := time.Date(startTime.Year(), startTime.Month(), 1, 0, 0, 0, 0, startTime.Location())
end := time.Date(endTime.Year(), endTime.Month(), 1, 0, 0, 0, 0, endTime.Location())
for !current.After(end) {
tableNames = append(tableNames, GetShardingTableName(baseTableName, current))
current = current.AddDate(0, 1, 0) // 加一个月
}
return tableNames
}
// DefaultMaxTimeRangeMonths 默认最大时间范围(月数)
// 可以通过配置覆盖,用于限制查询时间范围,避免跨太多分表
const DefaultMaxTimeRangeMonths = 3
// TimeRangeError 时间范围验证错误
type TimeRangeError struct {
Key string
Params map[string]any
}
func (e *TimeRangeError) Error() string {
// 返回翻译键,由调用方进行翻译
return e.Key
}
// ValidateTimeRange 验证时间范围是否超过指定月数
// startTime: 开始时间
// endTime: 结束时间
// maxMonths: 最大允许的月数,如果为0则使用默认值 DefaultMaxTimeRangeMonths
// 返回: 是否有效,错误信息(错误信息包含翻译键和参数)
func ValidateTimeRange(startTime, endTime time.Time, maxMonths ...int) (bool, error) {
// 检查开始时间是否晚于结束时间(忽略零值时间)
if !startTime.IsZero() && !endTime.IsZero() && startTime.After(endTime) {
return false, &TimeRangeError{
Key: "start_time_after_end_time",
Params: nil,
}
}
// 确定最大月数
maxMonthsValue := DefaultMaxTimeRangeMonths
if len(maxMonths) > 0 && maxMonths[0] > 0 {
maxMonthsValue = maxMonths[0]
}
// 检查时间范围是否超过指定月数(忽略零值结束时间)
if !endTime.IsZero() {
maxTimeLater := startTime.AddDate(0, maxMonthsValue, 0)
if endTime.After(maxTimeLater) {
return false, &TimeRangeError{
Key: "time_range_exceeded",
Params: map[string]any{
"months": maxMonthsValue,
},
}
}
}
return true, nil
}
// GetAllExistingShardingTables 获取数据库中所有已存在的分表名称
// baseTableName: 基础表名,如 "orders" 或 "order_details"
// 返回: 已存在的分表名称列表
func GetAllExistingShardingTables(baseTableName string) ([]string, error) {
var tableNames []string
// 获取当前数据库名
dbName := facades.Config().GetString("database.connections.mysql.database")
if dbName == "" {
dbName = facades.Config().GetString("database.connections.postgresql.database")
}
// 构建表名匹配模式:orders_YYYYMM 或 order_details_YYYYMM
pattern := fmt.Sprintf("%s_%%", baseTableName)
// 查询所有匹配的表名
query := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = ?
AND table_name LIKE ?
ORDER BY table_name
`
// 执行查询,使用 Scan 获取结果
var rows []map[string]any
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
"pattern": pattern,
"error": err.Error(),
}, "查询分表失败: %v", err)
return nil, fmt.Errorf("查询分表失败: %v", err)
}
// 验证表名格式(确保是有效的分表名称,格式为 baseTableName_YYYYMM
// 从 pattern 中提取基础表名(移除末尾的 %)
baseTableNameFromPattern := strings.TrimSuffix(pattern, "_%")
patternRegex := regexp.MustCompile(fmt.Sprintf("^%s_\\d{6}$", regexp.QuoteMeta(baseTableNameFromPattern)))
for _, row := range rows {
// 尝试不同的字段名格式(MySQL 可能返回不同的大小写)
var tableName string
var ok bool
// 尝试 table_name (小写)
if tableName, ok = row["table_name"].(string); !ok {
// 尝试 TABLE_NAME (大写)
if tableName, ok = row["TABLE_NAME"].(string); !ok {
// 尝试遍历所有键
for key, value := range row {
if (key == "table_name" || key == "TABLE_NAME" || key == "Table_Name") && value != nil {
if str, ok := value.(string); ok {
tableName = str
ok = true
break
}
}
}
}
}
if ok && tableName != "" {
// 验证表名格式
if patternRegex.MatchString(tableName) {
tableNames = append(tableNames, tableName)
}
}
}
return tableNames, nil
}
// GetAllExistingShardingTablesByPattern 通过表名模式获取所有已存在的分表
// 这是一个更通用的方法,可以通过自定义模式匹配
// pattern: 表名匹配模式,如 "orders_%" 或 "order_details_%"
func GetAllExistingShardingTablesByPattern(pattern string) ([]string, error) {
var tableNames []string
// 获取当前数据库名
dbName := facades.Config().GetString("database.connections.mysql.database")
if dbName == "" {
dbName = facades.Config().GetString("database.connections.postgresql.database")
}
// 查询所有匹配的表名
query := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = ?
AND table_name LIKE ?
ORDER BY table_name
`
// 执行查询,使用 Scan 获取结果
var rows []map[string]any
if err := facades.Orm().Query().Raw(query, dbName, pattern).Scan(&rows); err != nil {
errorlog.Record(context.Background(), "sharding", "查询分表失败", map[string]any{
"pattern": pattern,
"error": err.Error(),
}, "查询分表失败: %v", err)
return nil, fmt.Errorf("查询分表失败: %v", err)
}
for _, row := range rows {
if tableName, ok := row["table_name"].(string); ok {
tableNames = append(tableNames, tableName)
}
}
return tableNames, nil
}