Files
server/app/utils/ip_matcher.go
T
2026-01-16 15:49:34 +08:00

164 lines
4.4 KiB
Go

package utils
import (
"fmt"
"net"
"strings"
"github.com/goravel/framework/support/str"
apperrors "goravel/app/errors"
)
// IsIPInBlacklist 检查IP是否在黑名单中
// ip: 要检查的IP地址
// blacklistIPs: 黑名单IP字符串,支持:
// - 单个IP: 192.168.1.1
// - 多个IP(逗号分隔): 192.168.1.1,192.168.1.2
// - CIDR格式: 192.168.0.0/24
// - IP范围: 192.168.1.1-192.168.1.100
func IsIPInBlacklist(ip string, blacklistIPs string) bool {
if blacklistIPs == "" {
return false
}
// 解析IP地址
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// 分割多个IP(支持逗号分隔)
ipList := strings.SplitSeq(blacklistIPs, ",")
for blacklistIP := range ipList {
blacklistIP = str.Of(blacklistIP).Trim().String()
if str.Of(blacklistIP).IsEmpty() {
continue
}
// 检查单个IP
if blacklistIP == ip {
return true
}
// 检查CIDR格式 (192.168.0.0/24)
if str.Of(blacklistIP).Contains("/") {
_, ipNet, err := net.ParseCIDR(blacklistIP)
if err == nil && ipNet.Contains(parsedIP) {
return true
}
}
// 检查IP范围格式 (192.168.1.1-192.168.1.100)
if str.Of(blacklistIP).Contains("-") {
parts := str.Of(blacklistIP).Split("-")
if len(parts) == 2 {
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
if startIP != nil && endIP != nil {
if isIPInRange(parsedIP, startIP, endIP) {
return true
}
}
}
}
}
return false
}
// isIPInRange 检查IP是否在指定范围内
func isIPInRange(ip, startIP, endIP net.IP) bool {
ipBytes := ip.To4()
startBytes := startIP.To4()
endBytes := endIP.To4()
if ipBytes == nil || startBytes == nil || endBytes == nil {
return false
}
// 将IP地址转换为32位整数进行比较
ipInt := uint32(ipBytes[0])<<24 | uint32(ipBytes[1])<<16 | uint32(ipBytes[2])<<8 | uint32(ipBytes[3])
startInt := uint32(startBytes[0])<<24 | uint32(startBytes[1])<<16 | uint32(startBytes[2])<<8 | uint32(startBytes[3])
endInt := uint32(endBytes[0])<<24 | uint32(endBytes[1])<<16 | uint32(endBytes[2])<<8 | uint32(endBytes[3])
return ipInt >= startInt && ipInt <= endInt
}
// ValidateBlacklistIP 验证黑名单IP格式
// 返回业务错误类型,如果格式正确返回 nil
func ValidateBlacklistIP(ipStr string) error {
if ipStr == "" {
return apperrors.ErrIPAddressRequired
}
ipList := strings.SplitSeq(ipStr, ",")
for ip := range ipList {
ip = str.Of(ip).Trim().String()
if str.Of(ip).IsEmpty() {
continue
}
// 检查CIDR格式
if str.Of(ip).Contains("/") {
_, _, err := net.ParseCIDR(ip)
if err != nil {
return apperrors.ErrInvalidCIDRFormat.WithMessage(fmt.Sprintf("CIDR格式错误: %s", ip))
}
continue
}
// 检查IP范围格式
if str.Of(ip).Contains("-") {
parts := str.Of(ip).Split("-")
if len(parts) != 2 {
return apperrors.ErrInvalidIPRangeFormat.WithMessage(fmt.Sprintf("IP范围格式错误: %s (格式应为: 192.168.1.1-192.168.1.100)", ip))
}
startIP := net.ParseIP(str.Of(parts[0]).Trim().String())
endIP := net.ParseIP(str.Of(parts[1]).Trim().String())
if startIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("起始IP格式错误: %s", parts[0]))
}
if endIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("结束IP格式错误: %s", parts[1]))
}
// 验证范围是否有效(结束IP应该大于等于起始IP)
if !isIPInRange(endIP, startIP, endIP) && !endIP.Equal(startIP) {
// 简单比较:检查结束IP是否大于起始IP
startBytes := startIP.To4()
endBytes := endIP.To4()
if startBytes != nil && endBytes != nil {
for i := range 4 {
if endBytes[i] < startBytes[i] {
return apperrors.ErrInvalidIPRangeOrder.WithMessage(fmt.Sprintf("结束IP必须大于等于起始IP: %s", ip))
}
if endBytes[i] > startBytes[i] {
break
}
}
}
}
continue
}
// 检查单个IP格式
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return apperrors.ErrInvalidIPFormat.WithMessage(fmt.Sprintf("IP地址格式错误: %s", ip))
}
}
return nil
}
// FormatIPRange 格式化IP范围显示
func FormatIPRange(ipStr string) string {
if str.Of(ipStr).Contains("-") {
parts := str.Of(ipStr).Split("-")
if len(parts) == 2 {
return fmt.Sprintf("%s ~ %s", str.Of(parts[0]).Trim().String(), str.Of(parts[1]).Trim().String())
}
}
return ipStr
}