212 lines
5.6 KiB
Go
212 lines
5.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/goravel/framework/contracts/http"
|
|
"github.com/goravel/framework/facades"
|
|
)
|
|
|
|
// Domain 域名验证中间件
|
|
func Domain(configValueOrDomains ...any) http.Middleware {
|
|
return func(ctx http.Context) {
|
|
var domains []string
|
|
|
|
// 如果没有参数,不验证(允许所有域名)
|
|
if len(configValueOrDomains) == 0 {
|
|
ctx.Request().Next()
|
|
return
|
|
}
|
|
|
|
// 解析配置值的辅助函数
|
|
parseConfigValue := func(value any) []string {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
|
|
switch v := value.(type) {
|
|
case []string:
|
|
return v
|
|
case string:
|
|
if v == "" {
|
|
return nil
|
|
}
|
|
// 分割逗号分隔的域名
|
|
domainsList := strings.Split(v, ",")
|
|
result := make([]string, 0, len(domainsList))
|
|
for _, d := range domainsList {
|
|
d = strings.TrimSpace(d)
|
|
if d != "" {
|
|
result = append(result, d)
|
|
}
|
|
}
|
|
return result
|
|
case []any:
|
|
result := make([]string, 0, len(v))
|
|
for _, item := range v {
|
|
if str, ok := item.(string); ok && str != "" {
|
|
result = append(result, str)
|
|
}
|
|
}
|
|
return result
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// 处理参数
|
|
for _, param := range configValueOrDomains {
|
|
if param == nil {
|
|
continue
|
|
}
|
|
|
|
switch v := param.(type) {
|
|
case []string:
|
|
// 如果是字符串数组,直接使用
|
|
if len(v) > 0 {
|
|
domains = append(domains, v...)
|
|
}
|
|
case string:
|
|
// 如果是字符串,可能是单个域名或配置键
|
|
if v != "" {
|
|
// 先尝试作为配置键读取
|
|
configValue := facades.Config().Get(v, nil)
|
|
if configValue != nil {
|
|
// 如果配置存在,解析配置值
|
|
parsedDomains := parseConfigValue(configValue)
|
|
if len(parsedDomains) > 0 {
|
|
domains = append(domains, parsedDomains...)
|
|
continue
|
|
}
|
|
}
|
|
// 否则当作域名
|
|
domains = append(domains, v)
|
|
}
|
|
case []any:
|
|
// 如果是 any 数组,递归处理
|
|
for _, item := range v {
|
|
if str, ok := item.(string); ok && str != "" {
|
|
domains = append(domains, str)
|
|
}
|
|
}
|
|
default:
|
|
// 其他类型,尝试解析为配置值
|
|
parsedDomains := parseConfigValue(v)
|
|
if len(parsedDomains) > 0 {
|
|
domains = append(domains, parsedDomains...)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 如果没有指定允许的域名,允许所有域名访问(直接放行)
|
|
if len(domains) == 0 {
|
|
ctx.Request().Next()
|
|
return
|
|
}
|
|
|
|
// 获取请求的 Host
|
|
// 优先从 X-Forwarded-Host 获取(适用于反向代理场景)
|
|
host := ctx.Request().Header("X-Forwarded-Host", "")
|
|
if host == "" {
|
|
// 使用框架提供的 Host() 方法获取(推荐方式)
|
|
host = ctx.Request().Host()
|
|
}
|
|
|
|
// 如果 X-Forwarded-Host 包含多个值(逗号分隔),取第一个
|
|
if host != "" && strings.Contains(host, ",") {
|
|
host = strings.TrimSpace(strings.Split(host, ",")[0])
|
|
}
|
|
|
|
// 调试日志:记录获取到的 Host 值
|
|
if facades.Config().GetBool("app.debug", false) {
|
|
facades.Log().Debugf("Domain middleware: Host detection - X-Forwarded-Host: %s, Host(): %s, Final host: %s",
|
|
ctx.Request().Header("X-Forwarded-Host", ""),
|
|
ctx.Request().Host(),
|
|
host)
|
|
}
|
|
|
|
// 规范化 Host(移除端口号,转换为小写)
|
|
normalizedHost := normalizeHost(host)
|
|
|
|
// 调试日志:记录规范化后的 Host 和配置的域名
|
|
if facades.Config().GetBool("app.debug", false) {
|
|
facades.Log().Debugf("Domain middleware: Normalized host: %s, Configured domains: %v", normalizedHost, domains)
|
|
}
|
|
|
|
// 检查是否在允许的域名列表中
|
|
allowed := false
|
|
var matchedDomain string
|
|
for _, allowedDomain := range domains {
|
|
normalizedAllowed := normalizeHost(allowedDomain)
|
|
// 支持精确匹配和通配符匹配
|
|
if normalizedHost == normalizedAllowed || matchDomain(normalizedHost, normalizedAllowed) {
|
|
allowed = true
|
|
matchedDomain = allowedDomain
|
|
break
|
|
}
|
|
}
|
|
|
|
if !allowed {
|
|
// 域名不在允许列表中,拒绝访问
|
|
// facades.Log().Warningf("Domain middleware: Access denied. Request host: %s (normalized: %s), Allowed domains: %v", host, normalizedHost, domains)
|
|
_ = ctx.Response().Json(http.StatusForbidden, http.Json{
|
|
"code": http.StatusForbidden,
|
|
"message": "Access denied: domain not allowed",
|
|
}).Abort()
|
|
return
|
|
}
|
|
|
|
// 记录匹配的域名(仅在调试模式下)
|
|
if facades.Config().GetBool("app.debug", false) {
|
|
facades.Log().Debugf("Domain middleware: Access allowed. Request host: %s, Matched domain: %s", normalizedHost, matchedDomain)
|
|
}
|
|
|
|
// 域名验证通过,继续处理请求
|
|
ctx.Request().Next()
|
|
}
|
|
}
|
|
|
|
// normalizeHost 规范化域名(移除端口号,转换为小写,去除前后空格)
|
|
func normalizeHost(host string) string {
|
|
host = strings.ToLower(strings.TrimSpace(host))
|
|
if host == "" {
|
|
return ""
|
|
}
|
|
|
|
// 移除端口号
|
|
if hostname, _, err := net.SplitHostPort(host); err == nil {
|
|
host = hostname
|
|
}
|
|
|
|
return host
|
|
}
|
|
|
|
// matchDomain 域名匹配,支持通配符
|
|
// 例如:*.example.com 可以匹配 a.example.com, b.example.com 等
|
|
func matchDomain(host, pattern string) bool {
|
|
if pattern == "" {
|
|
return false
|
|
}
|
|
|
|
// 如果模式以 * 开头,进行通配符匹配
|
|
if after, ok := strings.CutPrefix(pattern, "*."); ok {
|
|
// 移除 *.
|
|
suffix := after
|
|
// 检查 host 是否以 .suffix 结尾
|
|
if strings.HasSuffix(host, "."+suffix) || host == suffix {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// 如果模式以 * 结尾,进行前缀匹配
|
|
if before, ok := strings.CutSuffix(pattern, ".*"); ok {
|
|
prefix := before
|
|
if strings.HasPrefix(host, prefix+".") || host == prefix {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|