init
This commit is contained in:
@@ -0,0 +1,351 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/goravel/framework/facades"
|
||||
"github.com/samber/lo"
|
||||
|
||||
apperrors "goravel/app/errors"
|
||||
"goravel/app/utils"
|
||||
"goravel/app/utils/errorlog"
|
||||
)
|
||||
|
||||
// ShardingQueryConfig 分表查询配置
|
||||
type ShardingQueryConfig struct {
|
||||
// BaseTableName 基础表名,如 "orders"
|
||||
BaseTableName string
|
||||
// GetColumns 获取表的所有列名(用于 UNION ALL 查询)
|
||||
// 返回格式:如 "id, order_no, user_id, created_at, updated_at, deleted_at"
|
||||
GetColumns func() string
|
||||
// BuildWhereClause 构建 WHERE 条件
|
||||
// 返回:WHERE 子句(不包含 WHERE 关键字)和参数列表
|
||||
// 例如:返回 ("user_id = ? AND status = ?", []any{1, "paid"})
|
||||
BuildWhereClause func(filters any) (string, []any)
|
||||
// GetAllowedOrderFields 获取允许排序的字段列表
|
||||
// 返回:字段名到 bool 的映射,如 map[string]bool{"id": true, "created_at": true}
|
||||
GetAllowedOrderFields func() map[string]bool
|
||||
// DefaultOrderBy 默认排序,格式:字段:方向,如 "created_at:desc"
|
||||
DefaultOrderBy string
|
||||
// ModuleName 模块名称,用于日志记录,如 "order"
|
||||
ModuleName string
|
||||
// CountThreshold count 查询优化阈值,超过此值使用执行计划估算(默认 10000)
|
||||
// 不同模块可以设置不同的阈值
|
||||
CountThreshold int64
|
||||
}
|
||||
|
||||
// ShardingQueryService 分表查询服务接口
|
||||
type ShardingQueryService interface {
|
||||
// QueryMultipleTables 查询多个分表(带分页)
|
||||
// tableNames: 分表名称列表
|
||||
// filters: 筛选条件(类型由具体实现决定)
|
||||
// page: 页码(从1开始)
|
||||
// pageSize: 每页数量
|
||||
// result: 查询结果(必须是指针类型)
|
||||
// 返回:结果列表、总数、错误
|
||||
QueryMultipleTables(tableNames []string, filters any, page, pageSize int, result any) (int64, error)
|
||||
|
||||
// QueryMultipleTablesForExport 查询多个分表(不分页,用于导出)
|
||||
// tableNames: 分表名称列表
|
||||
// filters: 筛选条件(类型由具体实现决定)
|
||||
// result: 查询结果(必须是指针类型)
|
||||
// 返回:结果列表、错误
|
||||
QueryMultipleTablesForExport(tableNames []string, filters any, result any) error
|
||||
}
|
||||
|
||||
type ShardingQueryServiceImpl struct {
|
||||
config ShardingQueryConfig
|
||||
}
|
||||
|
||||
// NewShardingQueryService 创建分表查询服务
|
||||
func NewShardingQueryService(config ShardingQueryConfig) ShardingQueryService {
|
||||
// 设置默认值
|
||||
if config.DefaultOrderBy == "" {
|
||||
config.DefaultOrderBy = "created_at:desc"
|
||||
}
|
||||
if config.ModuleName == "" {
|
||||
config.ModuleName = "sharding"
|
||||
}
|
||||
|
||||
return &ShardingQueryServiceImpl{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryMultipleTables 查询多个分表(带分页)
|
||||
func (s *ShardingQueryServiceImpl) QueryMultipleTables(tableNames []string, filters any, page, pageSize int, result any) (int64, error) {
|
||||
// 构建 WHERE 条件
|
||||
whereClause, whereConditions := s.config.BuildWhereClause(filters)
|
||||
if whereClause == "" {
|
||||
whereClause = "1=1"
|
||||
} else {
|
||||
whereClause = "1=1 AND " + whereClause
|
||||
}
|
||||
|
||||
// 构建排序
|
||||
orderField, orderDir := s.parseOrderBy(filters)
|
||||
|
||||
// 获取列名
|
||||
columnsStr := s.config.GetColumns()
|
||||
|
||||
// 优化:每个分表先排序和限制,然后再合并(避免合并大量数据后再排序)
|
||||
// 计算每个分表需要查询的数量
|
||||
// 为了确保合并后有足够的数据进行分页,每个分表查询更多数据
|
||||
// 公式:limitPerTable = (page * pageSize) + pageSize,确保有足够数据
|
||||
offset := (page - 1) * pageSize
|
||||
limitPerTable := offset + pageSize + pageSize // 额外查询一页数据,确保有足够数据
|
||||
|
||||
// 如果 limitPerTable 太大(超过10000),限制为10000,避免单个查询太慢
|
||||
if limitPerTable > 10000 {
|
||||
limitPerTable = 10000
|
||||
}
|
||||
|
||||
// 构建 UNION ALL 查询
|
||||
// 过滤掉不存在的分表,避免查询错误
|
||||
existingTableNames := lo.Filter(tableNames, func(tableName string, _ int) bool {
|
||||
return facades.Schema().HasTable(tableName)
|
||||
})
|
||||
|
||||
if len(existingTableNames) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 为每个存在的表构建查询
|
||||
unionQueries := lo.Map(existingTableNames, func(tableName string, _ int) string {
|
||||
return fmt.Sprintf(
|
||||
"(SELECT %s FROM `%s` WHERE %s ORDER BY `%s` %s LIMIT %d)",
|
||||
columnsStr, tableName, whereClause, orderField, orderDir, limitPerTable,
|
||||
)
|
||||
})
|
||||
|
||||
// 每个查询都需要相同的参数
|
||||
allArgs := lo.Flatten(lo.Map(lo.Range(len(existingTableNames)), func(_ int, _ int) []any {
|
||||
return whereConditions
|
||||
}))
|
||||
|
||||
// 合并所有查询
|
||||
unionSQL := strings.Join(unionQueries, " UNION ALL ")
|
||||
|
||||
// 优化:分别对每个分表执行 COUNT,然后相加(性能更好,可以利用索引)
|
||||
// 而不是对 UNION ALL 结果进行 COUNT(需要先合并所有数据)
|
||||
var total int64
|
||||
threshold := s.config.CountThreshold
|
||||
|
||||
// 如果配置了阈值,使用执行计划优化;否则直接使用 count
|
||||
if threshold > 0 {
|
||||
// 使用优化的 count 查询(先估算,超过阈值用估算值,否则用实际 count)
|
||||
countOptimizer := utils.NewCountOptimizer(threshold, s.config.ModuleName)
|
||||
for _, tableName := range existingTableNames {
|
||||
// 使用对应的参数(每个分表使用相同的参数)
|
||||
args := whereConditions
|
||||
tableTotal, _, err := countOptimizer.OptimizedCountWithTable(tableName, whereClause, args...)
|
||||
if err != nil {
|
||||
errorlog.Record(context.Background(), s.config.ModuleName, "查询分表总数失败", map[string]any{
|
||||
"table_name": tableName,
|
||||
"error": err.Error(),
|
||||
}, "查询分表 %s 总数失败: %v", tableName, err)
|
||||
// 如果某个分表查询失败,继续查询其他分表,但记录错误
|
||||
continue
|
||||
}
|
||||
total += tableTotal
|
||||
}
|
||||
} else {
|
||||
// 没有配置阈值,直接使用传统的 count 统计
|
||||
// 根据数据库类型决定表名引号(MySQL 用反引号,PostgreSQL 不用)
|
||||
dbConnection := facades.Config().GetString("database.default", "sqlite")
|
||||
tableQuote := ""
|
||||
if dbConnection == "mysql" {
|
||||
tableQuote = "`"
|
||||
}
|
||||
|
||||
for _, tableName := range existingTableNames {
|
||||
// 每个分表使用相同的 WHERE 条件
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(*) as total FROM %s%s%s WHERE %s", tableQuote, tableName, tableQuote, whereClause)
|
||||
var countResult struct {
|
||||
Total int64
|
||||
}
|
||||
// 使用对应的参数(每个分表使用相同的参数)
|
||||
args := whereConditions
|
||||
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&countResult); err != nil {
|
||||
errorlog.Record(context.Background(), s.config.ModuleName, "查询分表总数失败", map[string]any{
|
||||
"table_name": tableName,
|
||||
"error": err.Error(),
|
||||
}, "查询分表 %s 总数失败: %v", tableName, err)
|
||||
// 如果某个分表查询失败,继续查询其他分表,但记录错误
|
||||
continue
|
||||
}
|
||||
total += countResult.Total
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有数据,直接返回
|
||||
if total == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 分页查询(在外层再次排序和分页)
|
||||
// 注意:虽然每个分表已经排序,但合并后需要重新排序以确保全局顺序正确
|
||||
// 但由于每个分表已经限制了数量,合并后的数据量大大减少,排序会快很多
|
||||
paginatedSQL := fmt.Sprintf(
|
||||
"SELECT %s FROM (%s) as combined ORDER BY `%s` %s LIMIT ? OFFSET ?",
|
||||
columnsStr,
|
||||
unionSQL,
|
||||
orderField,
|
||||
orderDir,
|
||||
)
|
||||
|
||||
// 添加 LIMIT 和 OFFSET 参数
|
||||
paginatedArgs := append(allArgs, pageSize, offset)
|
||||
|
||||
// 执行查询
|
||||
if err := facades.Orm().Query().Raw(paginatedSQL, paginatedArgs...).Scan(result); err != nil {
|
||||
errorlog.Record(context.Background(), s.config.ModuleName, "查询列表失败", map[string]any{
|
||||
"table_count": len(tableNames),
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"error": err.Error(),
|
||||
}, "查询列表失败: %v", err)
|
||||
return 0, apperrors.ErrQueryFailed.WithError(err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// QueryMultipleTablesForExport 查询多个分表(不分页,用于导出)
|
||||
func (s *ShardingQueryServiceImpl) QueryMultipleTablesForExport(tableNames []string, filters any, result any) error {
|
||||
// 构建 WHERE 条件
|
||||
whereClause, whereConditions := s.config.BuildWhereClause(filters)
|
||||
if whereClause == "" {
|
||||
whereClause = "1=1"
|
||||
} else {
|
||||
whereClause = "1=1 AND " + whereClause
|
||||
}
|
||||
|
||||
// 构建排序
|
||||
orderField, orderDir := s.parseOrderBy(filters)
|
||||
|
||||
// 获取列名
|
||||
columnsStr := s.config.GetColumns()
|
||||
|
||||
// 优化:每个分表先排序,然后再合并(对于导出,虽然需要所有数据,但先排序可以减少合并后的排序成本)
|
||||
// 构建 UNION ALL 查询
|
||||
// 过滤掉不存在的分表,避免查询错误
|
||||
existingTableNames := lo.Filter(tableNames, func(tableName string, _ int) bool {
|
||||
return facades.Schema().HasTable(tableName)
|
||||
})
|
||||
|
||||
if len(existingTableNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 为每个存在的表构建查询
|
||||
unionQueries := lo.Map(existingTableNames, func(tableName string, _ int) string {
|
||||
// 优化:每个分表先排序,然后再合并
|
||||
// 使用子查询包装,确保每个分表先排序
|
||||
// 注意:导出需要所有数据,所以不限制数量,但先排序可以优化合并后的排序性能
|
||||
return fmt.Sprintf(
|
||||
"(SELECT %s FROM `%s` WHERE %s ORDER BY `%s` %s)",
|
||||
columnsStr, tableName, whereClause, orderField, orderDir,
|
||||
)
|
||||
})
|
||||
|
||||
// 每个查询都需要相同的参数
|
||||
allArgs := lo.Flatten(lo.Map(lo.Range(len(existingTableNames)), func(_ int, _ int) []any {
|
||||
return whereConditions
|
||||
}))
|
||||
|
||||
if len(unionQueries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 合并所有查询
|
||||
unionSQL := strings.Join(unionQueries, " UNION ALL ")
|
||||
|
||||
// 导出查询(不分页,但需要排序)
|
||||
// 注意:虽然每个分表已经排序,但合并后需要重新排序以确保全局顺序正确
|
||||
// 但由于每个分表已经排序,合并后的排序会更快(归并排序)
|
||||
exportSQL := fmt.Sprintf(
|
||||
"SELECT %s FROM (%s) as combined ORDER BY `%s` %s",
|
||||
columnsStr,
|
||||
unionSQL,
|
||||
orderField,
|
||||
orderDir,
|
||||
)
|
||||
|
||||
// 执行查询
|
||||
if err := facades.Orm().Query().Raw(exportSQL, allArgs...).Scan(result); err != nil {
|
||||
errorlog.Record(context.Background(), s.config.ModuleName, "导出查询失败", map[string]any{
|
||||
"table_count": len(tableNames),
|
||||
"error": err.Error(),
|
||||
}, "导出查询失败: %v", err)
|
||||
return apperrors.ErrQueryFailed.WithError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOrderBy 解析排序字段
|
||||
// 从 filters 中提取 OrderBy 字段(如果 filters 有 OrderBy 字段)
|
||||
// 返回:排序字段名和方向
|
||||
func (s *ShardingQueryServiceImpl) parseOrderBy(filters any) (string, string) {
|
||||
// 尝试从 filters 中提取 OrderBy 字段
|
||||
// 使用类型断言或反射来获取 OrderBy 字段
|
||||
// 这里使用一个辅助函数来处理
|
||||
orderBy := s.extractOrderBy(filters)
|
||||
if orderBy == "" {
|
||||
orderBy = s.config.DefaultOrderBy
|
||||
}
|
||||
|
||||
orderParts := strings.Split(orderBy, ":")
|
||||
orderField := "created_at"
|
||||
orderDir := "desc"
|
||||
if len(orderParts) == 2 {
|
||||
field := orderParts[0]
|
||||
direction := strings.ToLower(orderParts[1])
|
||||
|
||||
// 验证排序字段是否允许
|
||||
allowedFields := s.config.GetAllowedOrderFields()
|
||||
if allowedFields != nil && allowedFields[field] {
|
||||
orderField = field
|
||||
if direction == "asc" {
|
||||
orderDir = "asc"
|
||||
} else {
|
||||
orderDir = "desc"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return orderField, orderDir
|
||||
}
|
||||
|
||||
// extractOrderBy 从 filters 中提取 OrderBy 字段
|
||||
// 支持结构体和 map[string]any
|
||||
func (s *ShardingQueryServiceImpl) extractOrderBy(filters any) string {
|
||||
if filters == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 尝试类型断言为 map[string]any
|
||||
if m, ok := filters.(map[string]any); ok {
|
||||
if orderBy, ok := m["OrderBy"].(string); ok {
|
||||
return orderBy
|
||||
}
|
||||
}
|
||||
|
||||
// 使用反射从结构体中提取 OrderBy 字段
|
||||
v := reflect.ValueOf(filters)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Struct {
|
||||
field := v.FieldByName("OrderBy")
|
||||
if field.IsValid() && field.Kind() == reflect.String {
|
||||
return field.String()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user