Files
server/app/utils/count_optimizer.go
2026-01-16 15:49:34 +08:00

253 lines
8.2 KiB
Go

package utils
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/goravel/framework/contracts/database/orm"
"github.com/goravel/framework/facades"
)
// CountOptimizer 分页统计优化器
// 当数据量超过阈值时,使用执行计划估算的行数,否则使用实际的 count(*)
type CountOptimizer struct {
// Threshold 阈值,超过此值使用估算值(默认 10000)
Threshold int64
// ModuleName 模块名称,用于日志记录
ModuleName string
}
// NewCountOptimizer 创建新的 CountOptimizer
// threshold: 阈值,超过此值使用估算值(默认 10000)
// moduleName: 模块名称,用于日志记录
func NewCountOptimizer(threshold int64, moduleName string) *CountOptimizer {
if threshold <= 0 {
threshold = 10000 // 默认阈值
}
return &CountOptimizer{
Threshold: threshold,
ModuleName: moduleName,
}
}
// OptimizedCount 优化的 count 查询(使用 ORM Query 对象)
// query: ORM 查询对象(已应用筛选条件,但未应用排序和分页)
// 注意:此方法会先尝试估算,如果失败则使用实际 count
// 返回:总数、是否使用估算值、错误
func (co *CountOptimizer) OptimizedCount(query orm.Query) (int64, bool, error) {
// 先尝试执行实际 count(如果数据量小,直接 count 也很快)
// 如果数据量大,我们可以通过执行时间来判断,但更简单的方法是先估算
// 这里我们提供一个简化版本:直接使用实际 count,如果慢的话可以后续优化
// 由于无法直接从 ORM Query 提取 SQL,这里提供一个变通方案:
// 先执行一次快速查询获取估算值(需要表名)
// 但更推荐使用 OptimizedCountWithTable 方法
actualCount, err := query.Count()
return actualCount, false, err
}
// extractRowsFromPostgreSQLExplain 从 PostgreSQL EXPLAIN JSON 结果中提取行数
func (co *CountOptimizer) extractRowsFromPostgreSQLExplain(jsonStr string) int64 {
// 简化实现:使用正则表达式提取 "Plan" -> "Plan Rows" 的值
// 实际应该使用 JSON 解析,但为了简化,使用正则
re := regexp.MustCompile(`"Plan Rows":\s*(\d+)`)
matches := re.FindStringSubmatch(jsonStr)
if len(matches) > 1 {
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
return rows
}
}
return 0
}
// extractRowsFromPostgreSQLExplainText 从 PostgreSQL EXPLAIN 文本结果中提取行数
func (co *CountOptimizer) extractRowsFromPostgreSQLExplainText(result []map[string]any) int64 {
// PostgreSQL 文本格式的 EXPLAIN 结果中,行数通常在 "rows=" 后面
for _, row := range result {
if queryPlan, ok := row["QUERY PLAN"]; ok {
if planStr, ok := queryPlan.(string); ok {
// 使用正则表达式提取 rows= 后面的数字
re := regexp.MustCompile(`rows=(\d+)`)
matches := re.FindStringSubmatch(planStr)
if len(matches) > 1 {
if rows, err := strconv.ParseInt(matches[1], 10, 64); err == nil {
return rows
}
}
}
}
}
return 0
}
// OptimizedCountWithTable 使用表名和 WHERE 条件进行优化的 count 查询
// tableName: 表名
// whereClause: WHERE 子句(不包含 WHERE 关键字),例如:"status = ? AND user_id = ?"
// args: WHERE 条件的参数
// 返回:总数、是否使用估算值、错误
func (co *CountOptimizer) OptimizedCountWithTable(tableName, whereClause string, args ...any) (int64, bool, error) {
dbConnection := facades.Config().GetString("database.default", "sqlite")
// 构建 COUNT SQL
var countSQL string
if whereClause != "" {
switch dbConnection {
case "mysql":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s` WHERE %s", tableName, whereClause)
case "postgres":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s WHERE %s", tableName, whereClause)
default:
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
} else {
switch dbConnection {
case "mysql":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM `%s`", tableName)
case "postgres":
countSQL = fmt.Sprintf("SELECT COUNT(*) as cnt FROM %s", tableName)
default:
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
}
// 构建 EXPLAIN SQL
var explainSQL string
if dbConnection == "mysql" {
if whereClause != "" {
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s` WHERE %s", tableName, whereClause)
} else {
explainSQL = fmt.Sprintf("EXPLAIN SELECT COUNT(*) FROM `%s`", tableName)
}
} else if dbConnection == "postgres" {
if whereClause != "" {
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s WHERE %s", tableName, whereClause)
} else {
explainSQL = fmt.Sprintf("EXPLAIN (FORMAT JSON) SELECT COUNT(*) FROM %s", tableName)
}
} else {
// 其他数据库不支持估算,直接使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// 先执行 EXPLAIN 查询获取估算值(快速,不需要特别精准)
estimatedCount, err := co.executeExplain(explainSQL, args...)
if err != nil {
// 如果估算失败,使用实际 count
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// 如果估算值超过阈值,直接返回估算值(不需要执行实际 count,更快)
if estimatedCount >= co.Threshold {
return estimatedCount, true, nil
}
// 估算值小于阈值,执行实际的 count 获取精确值
var result struct {
Cnt int64
}
if err := facades.Orm().Query().Raw(countSQL, args...).Scan(&result); err != nil {
return 0, false, err
}
return result.Cnt, false, nil
}
// executeExplain 执行 EXPLAIN 查询并提取估算行数
func (co *CountOptimizer) executeExplain(explainSQL string, args ...any) (int64, error) {
dbConnection := facades.Config().GetString("database.default", "sqlite")
switch dbConnection {
case "mysql":
// MySQL EXPLAIN 返回表格格式
var explainResult []map[string]any
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
return 0, err
}
if len(explainResult) == 0 {
return 0, fmt.Errorf("explain result is empty")
}
// MySQL EXPLAIN 结果中,rows 字段包含估算行数
if rows, ok := explainResult[0]["rows"]; ok {
switch v := rows.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case float64:
return int64(v), nil
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed, nil
}
}
}
return 0, fmt.Errorf("cannot extract rows from explain result")
case "postgres":
// PostgreSQL EXPLAIN (FORMAT JSON) 返回 JSON 格式
var explainResult []map[string]any
if err := facades.Orm().Query().Raw(explainSQL, args...).Get(&explainResult); err != nil {
return 0, err
}
if len(explainResult) == 0 {
return 0, fmt.Errorf("explain result is empty")
}
// PostgreSQL EXPLAIN (FORMAT JSON) 结果是一个包含 JSON 字符串的数组
if queryPlan, ok := explainResult[0]["QUERY PLAN"]; ok {
if planStr, ok := queryPlan.(string); ok {
// 从 JSON 字符串中提取 "Plan" -> "Plan Rows" 的值
rows := co.extractRowsFromPostgreSQLExplain(planStr)
if rows > 0 {
return rows, nil
}
}
}
// 如果 JSON 格式解析失败,尝试文本格式
explainTextSQL := strings.Replace(explainSQL, "EXPLAIN (FORMAT JSON)", "EXPLAIN", 1)
var explainTextResult []map[string]any
if err := facades.Orm().Query().Raw(explainTextSQL, args...).Get(&explainTextResult); err == nil {
rows := co.extractRowsFromPostgreSQLExplainText(explainTextResult)
if rows > 0 {
return rows, nil
}
}
return 0, fmt.Errorf("cannot extract rows from explain result")
}
return 0, fmt.Errorf("unsupported database: %s", dbConnection)
}