Compare commits

...

4 Commits

21 changed files with 366 additions and 84 deletions

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@ komari-agent
build/
auto-discovery.json
net_static.json
config.json

View File

@@ -11,7 +11,7 @@ import (
"os"
"path/filepath"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/utils"
)
// AutoDiscoveryConfig 自动发现配置结构体
@@ -109,6 +109,14 @@ func registerWithAutoDiscovery() error {
if len(endpoint) > 0 && endpoint[len(endpoint)-1] == '/' {
endpoint = endpoint[:len(endpoint)-1]
}
// 转换中文域名为 ASCII 兼容编码
endpoint, err = utils.ConvertIDNToASCII(endpoint)
if err != nil {
log.Printf("Warning: Failed to convert IDN to ASCII: %v", err)
// 继续使用原始 endpoint可能在某些情况下仍能工作
}
registerURL := fmt.Sprintf("%s/api/clients/register?name=%s", endpoint, url.QueryEscape(hostname))
// 创建HTTP请求
@@ -120,7 +128,7 @@ func registerWithAutoDiscovery() error {
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", flags.AutoDiscoveryKey))
// 添加Cloudflare Access头部
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID)

View File

@@ -1,25 +1,31 @@
package flags
package flags_pkg
var (
AutoDiscoveryKey string
DisableAutoUpdate bool
DisableWebSsh bool
MemoryModeAvailable bool
Token string
Endpoint string
Interval float64
IgnoreUnsafeCert bool
MaxRetries int
ReconnectInterval int
InfoReportInterval int
IncludeNics string
ExcludeNics string
IncludeMountpoints string
MonthRotate int
CFAccessClientID string
CFAccessClientSecret string
MemoryIncludeCache bool
CustomDNS string
EnableGPU bool // 启用详细GPU监控
ShowWarning bool // Windows 上显示安全警告,作为子进程运行一次
)
type Config struct {
AutoDiscoveryKey string `json:"auto_discovery_key" env:"AGENT_AUTO_DISCOVERY_KEY"` // 自动发现密钥
DisableAutoUpdate bool `json:"disable_auto_update" env:"AGENT_DISABLE_AUTO_UPDATE"` // 禁用自动更新
DisableWebSsh bool `json:"disable_web_ssh" env:"AGENT_DISABLE_WEB_SSH"` // 禁用远程控制web ssh 和 rce
MemoryModeAvailable bool `json:"memory_mode_available" env:"AGENT_MEMORY_MODE_AVAILABLE"` // [deprecated] 已弃用,请使用 MemoryIncludeCache
Token string `json:"token" env:"AGENT_TOKEN"` // Token
Endpoint string `json:"endpoint" env:"AGENT_ENDPOINT"` // 面板地址
Interval float64 `json:"interval" env:"AGENT_INTERVAL"` // 数据采集间隔,单位秒
IgnoreUnsafeCert bool `json:"ignore_unsafe_cert" env:"AGENT_IGNORE_UNSAFE_CERT"` // 忽略不安全的证书
MaxRetries int `json:"max_retries" env:"AGENT_MAX_RETRIES"` // 最大重试次数
ReconnectInterval int `json:"reconnect_interval" env:"AGENT_RECONNECT_INTERVAL"` // 重连间隔,单位秒
InfoReportInterval int `json:"info_report_interval" env:"AGENT_INFO_REPORT_INTERVAL"` // 基础信息上报间隔,单位分钟
IncludeNics string `json:"include_nics" env:"AGENT_INCLUDE_NICS"` // 仅统计网卡,逗号分隔的网卡名称列表,支持通配符
ExcludeNics string `json:"exclude_nics" env:"AGENT_EXCLUDE_NICS"` // 统计时排除的网卡,逗号分隔的网卡名称列表,支持通配符
IncludeMountpoints string `json:"include_mountpoints" env:"AGENT_INCLUDE_MOUNTPOINTS"` // 磁盘统计的包含挂载点列表,使用分号分隔
MonthRotate int `json:"month_rotate" env:"AGENT_MONTH_ROTATE"` // 流量统计的月份重置日期0表示禁用
CFAccessClientID string `json:"cf_access_client_id" env:"AGENT_CF_ACCESS_CLIENT_ID"` // Cloudflare Access Client ID
CFAccessClientSecret string `json:"cf_access_client_secret" env:"AGENT_CF_ACCESS_CLIENT_SECRET"` // Cloudflare Access Client Secret
MemoryIncludeCache bool `json:"memory_include_cache" env:"AGENT_MEMORY_INCLUDE_CACHE"` // 包括缓存/缓冲区的内存使用情况
CustomDNS string `json:"custom_dns" env:"AGENT_CUSTOM_DNS"` // 使用的自定义DNS服务器
EnableGPU bool `json:"enable_gpu" env:"AGENT_ENABLE_GPU"` // 启用详细GPU监控
ShowWarning bool `json:"show_warning" env:"AGENT_SHOW_WARNING"` // Windows 上显示安全警告,作为子进程运行一次
CustomIpv4 string `json:"custom_ipv4" env:"AGENT_CUSTOM_IPV4"` // 自定义 IPv4 地址
CustomIpv6 string `json:"custom_ipv6" env:"AGENT_CUSTOM_IPV6"` // 自定义 IPv6 地址
GetIpAddrFromNic bool `json:"get_ip_addr_from_nic" env:"AGENT_GET_IP_ADDR_FROM_NIC"` // 从网卡获取IP地址
ConfigFile string `json:"config_file" env:"AGENT_CONFIG_FILE"` // JSON配置文件路径
}
var GlobalConfig = &Config{}

View File

@@ -3,26 +3,44 @@ package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"log"
"net/http"
"os"
"os/signal"
"reflect"
"strconv"
"strings"
"syscall"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
"github.com/komari-monitor/komari-agent/monitoring/netstatic"
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
"github.com/komari-monitor/komari-agent/server"
"github.com/komari-monitor/komari-agent/update"
"github.com/spf13/cobra"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
)
var flags = pkg_flags.GlobalConfig
var RootCmd = &cobra.Command{
Use: "komari-agent",
Short: "komari agent",
Long: `komari agent`,
Run: func(cmd *cobra.Command, args []string) {
loadFromEnv() // 从环境变量加载配置,覆盖解析
if flags.ConfigFile != "" {
bytes, err := os.ReadFile(flags.ConfigFile)
if err != nil {
log.Fatalf("Failed to read config file: %v", err)
}
err = json.Unmarshal(bytes, flags)
if err != nil {
log.Fatalf("Failed to parse config file: %v", err)
}
}
// 捕获中止信号,优雅退出
stopCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
@@ -154,5 +172,49 @@ func init() {
RootCmd.PersistentFlags().StringVar(&flags.CustomDNS, "custom-dns", "", "Custom DNS server to use (e.g. 8.8.8.8, 114.114.114.114). By default, the program uses the system DNS resolver.")
RootCmd.PersistentFlags().BoolVar(&flags.EnableGPU, "gpu", false, "Enable detailed GPU monitoring (usage, memory, multi-GPU support)")
RootCmd.PersistentFlags().BoolVar(&flags.ShowWarning, "show-warning", false, "Show security warning on Windows, run once as a subprocess")
RootCmd.PersistentFlags().StringVar(&flags.CustomIpv4, "custom-ipv4", "", "Custom IPv4 address to use")
RootCmd.PersistentFlags().StringVar(&flags.CustomIpv6, "custom-ipv6", "", "Custom IPv6 address to use")
RootCmd.PersistentFlags().BoolVar(&flags.GetIpAddrFromNic, "get-ip-addr-from-nic", false, "Get IP address from network interface")
RootCmd.PersistentFlags().StringVar(&flags.ConfigFile, "config", "", "Path to the configuration file")
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
}
func loadFromEnv() {
val := reflect.ValueOf(flags).Elem()
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Get the env tag
envTag := fieldType.Tag.Get("env")
if envTag == "" {
continue
}
// Get the environment variable value
envValue := os.Getenv(envTag)
if envValue == "" {
continue
}
// Set the field based on its type
switch field.Kind() {
case reflect.String:
field.SetString(envValue)
case reflect.Bool:
if strings.ToLower(envValue) == "true" || envValue == "1" {
field.SetBool(true)
}
case reflect.Int:
if intVal, err := strconv.Atoi(envValue); err == nil {
field.SetInt(int64(intVal))
}
case reflect.Float64:
if floatVal, err := strconv.ParseFloat(envValue, 64); err == nil {
field.SetFloat(floatVal)
}
}
}
}

View File

@@ -12,9 +12,10 @@ import (
"sync"
"time"
"github.com/komari-monitor/komari-agent/cmd/flags"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
)
var flags = pkg_flags.GlobalConfig
var (
DNSServers = []string{
"[2606:4700:4700::1111]:53", // Cloudflare IPv6

5
go.mod
View File

@@ -13,6 +13,7 @@ require (
github.com/rhysd/go-github-selfupdate v1.2.3
github.com/shirou/gopsutil/v4 v4.25.6
github.com/spf13/cobra v1.9.1
golang.org/x/net v0.38.0
golang.org/x/sys v0.33.0
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2
)
@@ -34,7 +35,7 @@ require (
github.com/ulikunitz/xz v0.5.9 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/text v0.26.0 // indirect
)

4
go.sum
View File

@@ -84,8 +84,8 @@ golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAG
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@@ -5,10 +5,12 @@ import (
"fmt"
"log"
"github.com/komari-monitor/komari-agent/cmd/flags"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
)
var flags = pkg_flags.GlobalConfig
func GenerateReport() []byte {
message := ""
data := map[string]interface{}{}
@@ -92,7 +94,7 @@ func GenerateReport() []byte {
// 成功获取详细信息
gpuData := make([]map[string]interface{}, len(gpuInfo))
totalGPUUsage := 0.0
for i, info := range gpuInfo {
gpuData[i] = map[string]interface{}{
"name": info.Name,
@@ -103,13 +105,13 @@ func GenerateReport() []byte {
}
totalGPUUsage += info.Utilization
}
avgGPUUsage := totalGPUUsage / float64(len(gpuInfo))
data["gpu"] = map[string]interface{}{
"count": len(gpuInfo),
"average_usage": avgGPUUsage,
"detailed_info": gpuData,
"count": len(gpuInfo),
"average_usage": avgGPUUsage,
"detailed_info": gpuData,
}
}
}

View File

@@ -8,9 +8,12 @@ import (
"strings"
"time"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/shirou/gopsutil/v4/cpu"
)
var flags = pkg_flags.GlobalConfig
type CpuInfo struct {
CPUName string `json:"cpu_name"`
CPUArchitecture string `json:"cpu_architecture"`

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"strings"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/shirou/gopsutil/v4/disk"
)

View File

@@ -119,16 +119,101 @@ func GetIPv6Address() (string, error) {
}
func GetIPAddress() (ipv4, ipv6 string, err error) {
ipv4, err = GetIPv4Address()
if err != nil {
log.Printf("Get IPV4 Error: %v", err)
ipv4 = ""
if flags.GetIpAddrFromNic {
allowNics, err := InterfaceList()
if err != nil {
log.Printf("Get Interface List Error: %v", err)
} else {
ipv4, ipv6 = getIPFromInterfaces(allowNics)
if ipv4 != "" || ipv6 != "" {
log.Printf("Get IP from NIC - IPv4: %s, IPv6: %s", ipv4, ipv6)
return ipv4, ipv6, nil
}
}
}
ipv6, err = GetIPv6Address()
if err != nil {
log.Printf("Get IPV6 Error: %v", err)
ipv6 = ""
if flags.CustomIpv4 != "" {
ipv4 = flags.CustomIpv4
} else {
ipv4, err = GetIPv4Address()
if err != nil {
log.Printf("Get IPV4 Error: %v", err)
ipv4 = ""
}
}
if flags.CustomIpv6 != "" {
ipv6 = flags.CustomIpv6
} else {
ipv6, err = GetIPv6Address()
if err != nil {
log.Printf("Get IPV6 Error: %v", err)
ipv6 = ""
}
}
return ipv4, ipv6, nil
}
// getIPFromInterfaces 从指定的网卡接口获取 IPv4 和 IPv6 地址
func getIPFromInterfaces(nicNames []string) (ipv4, ipv6 string) {
interfaces, err := net.Interfaces()
if err != nil {
log.Printf("Failed to get network interfaces: %v", err)
return "", ""
}
for _, iface := range interfaces {
// 检查接口是否在允许列表中
if !func(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}(nicNames, iface.Name) {
continue
}
// 跳过未启动的接口
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
// 获取 IPv4 地址
if ipv4 == "" && ip.To4() != nil {
ipv4 = ip.String()
}
// 获取 IPv6 地址(排除链路本地地址)
if ipv6 == "" && ip.To4() == nil && !ip.IsLinkLocalUnicast() {
ipv6 = ip.String()
}
// 如果已经找到 IPv4 和 IPv6,提前返回
if ipv4 != "" && ipv6 != "" {
return ipv4, ipv6
}
}
}
return ipv4, ipv6
}

View File

@@ -3,7 +3,6 @@ package monitoring
import (
"runtime"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/shirou/gopsutil/v4/mem"
)

View File

@@ -5,7 +5,6 @@ import (
"strings"
"time"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/monitoring/netstatic"
"github.com/komari-monitor/komari-agent/utils"
"github.com/shirou/gopsutil/v4/net"

View File

@@ -2,8 +2,6 @@ package monitoring
import (
"testing"
"github.com/komari-monitor/komari-agent/cmd/flags"
)
func TestConnectionsCount(t *testing.T) {

View File

@@ -1 +1,5 @@
# komari-agent
# komari-agent
支持使用环境变量 / JSON配置文件来传入 agent 参数
详见 `cmd/flags/flags.go``cmd/root.go`

View File

@@ -9,12 +9,15 @@ import (
"strings"
"time"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
"github.com/komari-monitor/komari-agent/update"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
)
var flags = pkg_flags.GlobalConfig
func DoUploadBasicInfoWorks() {
ticker := time.NewTicker(time.Duration(flags.InfoReportInterval) * time.Minute)
for range ticker.C {

View File

@@ -13,7 +13,6 @@ import (
"strings"
"time"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/ws"
ping "github.com/prometheus-community/pro-bing"
)

View File

@@ -10,10 +10,10 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
"github.com/komari-monitor/komari-agent/monitoring"
"github.com/komari-monitor/komari-agent/terminal"
"github.com/komari-monitor/komari-agent/utils"
"github.com/komari-monitor/komari-agent/ws"
)
@@ -22,6 +22,13 @@ func EstablishWebSocketConnection() {
websocketEndpoint := strings.TrimSuffix(flags.Endpoint, "/") + "/api/clients/report?token=" + flags.Token
websocketEndpoint = "ws" + strings.TrimPrefix(websocketEndpoint, "http")
// 转换中文域名为 ASCII 兼容编码
if convertedEndpoint, err := utils.ConvertIDNToASCII(websocketEndpoint); err == nil {
websocketEndpoint = convertedEndpoint
} else {
log.Printf("Warning: Failed to convert WebSocket IDN to ASCII: %v", err)
}
var conn *ws.SafeConn
defer func() {
if conn != nil {
@@ -155,6 +162,13 @@ func establishTerminalConnection(token, id, endpoint string) {
endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id
endpoint = "ws" + strings.TrimPrefix(endpoint, "http")
// 转换中文域名为 ASCII 兼容编码
if convertedEndpoint, err := utils.ConvertIDNToASCII(endpoint); err == nil {
endpoint = convertedEndpoint
} else {
log.Printf("Warning: Failed to convert Terminal WebSocket IDN to ASCII: %v", err)
}
// 使用与主 WS 相同的拨号策略
dialer := newWSDialer()

View File

@@ -3,11 +3,15 @@ package terminal
import (
"encoding/json"
"fmt"
"time"
"github.com/gorilla/websocket"
"github.com/komari-monitor/komari-agent/cmd/flags"
pkg_flags "github.com/komari-monitor/komari-agent/cmd/flags"
)
var flags = pkg_flags.GlobalConfig
// Terminal 接口定义平台特定的终端操作
type Terminal interface {
Close() error
@@ -37,41 +41,69 @@ func StartTerminal(conn *websocket.Conn) {
return
}
errChan := make(chan error, 1)
defer impl.term.Close()
errChan := make(chan error, 3) // 增加容量以容纳多个错误源
done := make(chan struct{})
defer func() {
gracefulShutdown(impl.term)
impl.term.Close()
conn.Close()
close(done)
}()
// 从 WebSocket 读取消息并写入终端
go handleWebSocketInput(conn, impl.term, errChan)
go handleWebSocketInput(conn, impl.term, errChan, done)
// 从终端读取输出并写入 WebSocket
go handleTerminalOutput(conn, impl.term, errChan)
go handleTerminalOutput(conn, impl.term, errChan, done)
// 错误处理和清理
go func() {
err := <-errChan
if err != nil && conn != nil {
conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err)))
conn.Close()
}
impl.term.Close()
}()
// 等待终端进程结束
if err := impl.term.Wait(); err != nil {
select {
case errChan <- err:
// 错误已发送
default:
// 错误通道已满或已关闭
conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Terminal exited with error: %v\r\n", err)))
// 等待终端进程结束或出现错误
select {
case err := <-errChan:
// WebSocket 连接断开或出现错误
if err != nil {
conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("\r\nConnection error: %v\r\n", err)))
}
case <-done:
// 已经被其他地方关闭
}
}
// gracefulShutdown 尝试优雅地关闭终端
func gracefulShutdown(term Terminal) {
// Ctrl+C
for i := 0; i < 3; i++ {
if _, err := term.Write([]byte{3}); err != nil {
return
}
time.Sleep(50 * time.Millisecond)
}
time.Sleep(200 * time.Millisecond)
// Ctrl+D (EOF)
term.Write([]byte{4})
time.Sleep(100 * time.Millisecond)
term.Write([]byte("exit\n"))
time.Sleep(100 * time.Millisecond)
}
// handleWebSocketInput 处理 WebSocket 输入
func handleWebSocketInput(conn *websocket.Conn, term Terminal, errChan chan<- error) {
func handleWebSocketInput(conn *websocket.Conn, term Terminal, errChan chan<- error, done <-chan struct{}) {
for {
select {
case <-done:
return
default:
}
t, p, err := conn.ReadMessage()
if err != nil {
errChan <- err
select {
case errChan <- err:
default:
}
return
}
if t == websocket.TextMessage {
@@ -103,16 +135,28 @@ func handleWebSocketInput(conn *websocket.Conn, term Terminal, errChan chan<- er
}
// handleTerminalOutput 处理终端输出
func handleTerminalOutput(conn *websocket.Conn, term Terminal, errChan chan<- error) {
func handleTerminalOutput(conn *websocket.Conn, term Terminal, errChan chan<- error, done <-chan struct{}) {
buf := make([]byte, 4096)
for {
select {
case <-done:
return
default:
}
n, err := term.Read(buf)
if err != nil {
errChan <- err
select {
case errChan <- err:
default:
}
return
}
if err := conn.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil {
errChan <- err
select {
case errChan <- err:
default:
}
return
}
}

View File

@@ -29,7 +29,7 @@ func newTerminalImpl() (*terminalImpl, error) {
parts := strings.Split(line, ":")
if len(parts) >= 7 && parts[6] != "" {
shell = parts[6]
log.Printf("Found shell from /etc/passwd: %s for user home: %s\n", shell, userHomeDir)
//log.Printf("Found shell from /etc/passwd: %s for user home: %s\n", shell, userHomeDir)
break
}
}

54
utils/idna.go Normal file
View File

@@ -0,0 +1,54 @@
package utils
import (
"net/url"
"strings"
"golang.org/x/net/idna"
)
// ConvertIDNToASCII 将包含国际化域名(IDN)的 URL 转换为 ASCII 兼容编码(ACE)格式
// 例如: "https://中文域名.com" -> "https://xn--fiq228c.com"
func ConvertIDNToASCII(urlStr string) (string, error) {
// 解析 URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return urlStr, err
}
// 转换主机名为 Punycode
asciiHost, err := idna.ToASCII(parsedURL.Hostname())
if err != nil {
return urlStr, err
}
// 如果有端口,需要保留
if parsedURL.Port() != "" {
parsedURL.Host = asciiHost + ":" + parsedURL.Port()
} else {
parsedURL.Host = asciiHost
}
return parsedURL.String(), nil
}
// ConvertHostToASCII 将主机名(可能包含端口)转换为 ASCII 兼容编码格式
// 例如: "中文域名.com:8080" -> "xn--fiq228c.com:8080"
func ConvertHostToASCII(host string) (string, error) {
// 分离主机名和端口
var hostname, port string
if idx := strings.LastIndex(host, ":"); idx != -1 {
hostname = host[:idx]
port = host[idx:]
} else {
hostname = host
}
// 转换为 ASCII
asciiHost, err := idna.ToASCII(hostname)
if err != nil {
return host, err
}
return asciiHost + port, nil
}