Files
server/app/http/middleware/domain.go
T
2026-01-16 15:49:34 +08:00

212 lines
5.6 KiB
Go

package middleware
import (
"net"
"strings"
"github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/facades"
)
// Domain 域名验证中间件
func Domain(configValueOrDomains ...any) http.Middleware {
return func(ctx http.Context) {
var domains []string
// 如果没有参数,不验证(允许所有域名)
if len(configValueOrDomains) == 0 {
ctx.Request().Next()
return
}
// 解析配置值的辅助函数
parseConfigValue := func(value any) []string {
if value == nil {
return nil
}
switch v := value.(type) {
case []string:
return v
case string:
if v == "" {
return nil
}
// 分割逗号分隔的域名
domainsList := strings.Split(v, ",")
result := make([]string, 0, len(domainsList))
for _, d := range domainsList {
d = strings.TrimSpace(d)
if d != "" {
result = append(result, d)
}
}
return result
case []any:
result := make([]string, 0, len(v))
for _, item := range v {
if str, ok := item.(string); ok && str != "" {
result = append(result, str)
}
}
return result
default:
return nil
}
}
// 处理参数
for _, param := range configValueOrDomains {
if param == nil {
continue
}
switch v := param.(type) {
case []string:
// 如果是字符串数组,直接使用
if len(v) > 0 {
domains = append(domains, v...)
}
case string:
// 如果是字符串,可能是单个域名或配置键
if v != "" {
// 先尝试作为配置键读取
configValue := facades.Config().Get(v, nil)
if configValue != nil {
// 如果配置存在,解析配置值
parsedDomains := parseConfigValue(configValue)
if len(parsedDomains) > 0 {
domains = append(domains, parsedDomains...)
continue
}
}
// 否则当作域名
domains = append(domains, v)
}
case []any:
// 如果是 any 数组,递归处理
for _, item := range v {
if str, ok := item.(string); ok && str != "" {
domains = append(domains, str)
}
}
default:
// 其他类型,尝试解析为配置值
parsedDomains := parseConfigValue(v)
if len(parsedDomains) > 0 {
domains = append(domains, parsedDomains...)
}
}
}
// 如果没有指定允许的域名,允许所有域名访问(直接放行)
if len(domains) == 0 {
ctx.Request().Next()
return
}
// 获取请求的 Host
// 优先从 X-Forwarded-Host 获取(适用于反向代理场景)
host := ctx.Request().Header("X-Forwarded-Host", "")
if host == "" {
// 使用框架提供的 Host() 方法获取(推荐方式)
host = ctx.Request().Host()
}
// 如果 X-Forwarded-Host 包含多个值(逗号分隔),取第一个
if host != "" && strings.Contains(host, ",") {
host = strings.TrimSpace(strings.Split(host, ",")[0])
}
// 调试日志:记录获取到的 Host 值
if facades.Config().GetBool("app.debug", false) {
facades.Log().Debugf("Domain middleware: Host detection - X-Forwarded-Host: %s, Host(): %s, Final host: %s",
ctx.Request().Header("X-Forwarded-Host", ""),
ctx.Request().Host(),
host)
}
// 规范化 Host(移除端口号,转换为小写)
normalizedHost := normalizeHost(host)
// 调试日志:记录规范化后的 Host 和配置的域名
if facades.Config().GetBool("app.debug", false) {
facades.Log().Debugf("Domain middleware: Normalized host: %s, Configured domains: %v", normalizedHost, domains)
}
// 检查是否在允许的域名列表中
allowed := false
var matchedDomain string
for _, allowedDomain := range domains {
normalizedAllowed := normalizeHost(allowedDomain)
// 支持精确匹配和通配符匹配
if normalizedHost == normalizedAllowed || matchDomain(normalizedHost, normalizedAllowed) {
allowed = true
matchedDomain = allowedDomain
break
}
}
if !allowed {
// 域名不在允许列表中,拒绝访问
// facades.Log().Warningf("Domain middleware: Access denied. Request host: %s (normalized: %s), Allowed domains: %v", host, normalizedHost, domains)
_ = ctx.Response().Json(http.StatusForbidden, http.Json{
"code": http.StatusForbidden,
"message": "Access denied: domain not allowed",
}).Abort()
return
}
// 记录匹配的域名(仅在调试模式下)
if facades.Config().GetBool("app.debug", false) {
facades.Log().Debugf("Domain middleware: Access allowed. Request host: %s, Matched domain: %s", normalizedHost, matchedDomain)
}
// 域名验证通过,继续处理请求
ctx.Request().Next()
}
}
// normalizeHost 规范化域名(移除端口号,转换为小写,去除前后空格)
func normalizeHost(host string) string {
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
return ""
}
// 移除端口号
if hostname, _, err := net.SplitHostPort(host); err == nil {
host = hostname
}
return host
}
// matchDomain 域名匹配,支持通配符
// 例如:*.example.com 可以匹配 a.example.com, b.example.com 等
func matchDomain(host, pattern string) bool {
if pattern == "" {
return false
}
// 如果模式以 * 开头,进行通配符匹配
if after, ok := strings.CutPrefix(pattern, "*."); ok {
// 移除 *.
suffix := after
// 检查 host 是否以 .suffix 结尾
if strings.HasSuffix(host, "."+suffix) || host == suffix {
return true
}
}
// 如果模式以 * 结尾,进行前缀匹配
if before, ok := strings.CutSuffix(pattern, ".*"); ok {
prefix := before
if strings.HasPrefix(host, prefix+".") || host == prefix {
return true
}
}
return false
}