feat: 添加自定义 DNS 解析器

This commit is contained in:
2025-09-13 02:05:00 +08:00
parent 7af8e540db
commit 9834a55e5d
7 changed files with 175 additions and 17 deletions

View File

@@ -19,4 +19,5 @@ var (
CFAccessClientID string CFAccessClientID string
CFAccessClientSecret string CFAccessClientSecret string
MemoryIncludeCache bool MemoryIncludeCache bool
CustomDNS string
) )

View File

@@ -7,6 +7,7 @@ import (
"os" "os"
"github.com/komari-monitor/komari-agent/cmd/flags" "github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit" monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
"github.com/komari-monitor/komari-agent/server" "github.com/komari-monitor/komari-agent/server"
"github.com/komari-monitor/komari-agent/update" "github.com/komari-monitor/komari-agent/update"
@@ -20,6 +21,15 @@ var RootCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
log.Println("Komari Agent", update.CurrentVersion) log.Println("Komari Agent", update.CurrentVersion)
log.Println("Github Repo:", update.Repo) log.Println("Github Repo:", update.Repo)
// 设置自定义DNS解析器
if flags.CustomDNS != "" {
dnsresolver.SetCustomDNSServer(flags.CustomDNS)
log.Printf("Using custom DNS server: %s", flags.CustomDNS)
} else {
log.Printf("Using default DNS servers, primary: %s", dnsresolver.DNSServers[0])
}
// Auto discovery // Auto discovery
if flags.AutoDiscoveryKey != "" { if flags.AutoDiscoveryKey != "" {
err := handleAutoDiscovery() err := handleAutoDiscovery()
@@ -100,5 +110,6 @@ func init() {
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID") RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID")
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret") RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret")
RootCmd.PersistentFlags().BoolVar(&flags.MemoryIncludeCache, "memory-include-cache", false, "Include cache/buffer in memory usage") RootCmd.PersistentFlags().BoolVar(&flags.MemoryIncludeCache, "memory-include-cache", false, "Include cache/buffer in memory usage")
RootCmd.PersistentFlags().StringVar(&flags.CustomDNS, "custom-dns", "", "Custom DNS server to use (e.g. 8.8.8.8)")
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
} }

132
dnsresolver/resolver.go Normal file
View File

@@ -0,0 +1,132 @@
package dnsresolver
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
)
var (
// DNS服务器列表按优先级排序
DNSServers = []string{
"114.114.114.114:53", // 114DNS中国大陆
"8.8.8.8:53", // Google DNS全球
"8.8.4.4:53", // Google DNS备用全球
"1.1.1.1:53", // Cloudflare DNS全球
"223.5.5.5:53", // 阿里DNS中国大陆
"119.29.29.29:53", // DNSPod中国大陆
}
// CustomDNSServer 自定义DNS服务器可以通过命令行参数设置
CustomDNSServer string
)
// SetCustomDNSServer 设置自定义DNS服务器
func SetCustomDNSServer(dnsServer string) {
if dnsServer != "" {
// 检查是否已包含端口如果没有则添加默认端口53
if !strings.Contains(dnsServer, ":") {
dnsServer = dnsServer + ":53"
}
CustomDNSServer = dnsServer
}
}
// getCurrentDNSServer 获取当前要使用的DNS服务器
func getCurrentDNSServer() string {
if CustomDNSServer != "" {
return CustomDNSServer
}
// 如果没有设置自定义DNS返回默认的第一个
return DNSServers[0]
}
// GetCustomResolver 返回一个使用指定DNS服务器的解析器
func GetCustomResolver() *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: 10 * time.Second,
}
// 尝试自定义DNS或默认DNS
dnsServer := getCurrentDNSServer()
conn, err := d.DialContext(ctx, "udp", dnsServer)
if err == nil {
return conn, nil
}
// 如果连接失败尝试其他DNS服务器
for _, server := range DNSServers {
if server != dnsServer { // 避免重复尝试
conn, err := d.DialContext(ctx, "udp", server)
if err == nil {
return conn, nil
}
}
}
// 所有DNS服务器都失败返回最后一次的错误
return nil, err
},
}
}
// GetHTTPClient 返回一个使用自定义DNS解析器的HTTP客户端
func GetHTTPClient(timeout time.Duration) *http.Client {
if timeout <= 0 {
timeout = 30 * time.Second
}
customResolver := GetCustomResolver()
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ips, err := customResolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
if err == nil {
return conn, nil
}
}
return nil, fmt.Errorf("failed to dial to any of the resolved IPs")
},
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
Timeout: timeout,
}
}
// GetNetDialer 返回一个使用自定义DNS解析器的网络拨号器
func GetNetDialer(timeout time.Duration) *net.Dialer {
if timeout <= 0 {
timeout = 5 * time.Second
}
return &net.Dialer{
Timeout: timeout,
KeepAlive: 30 * time.Second,
Resolver: GetCustomResolver(),
}
}

View File

@@ -8,17 +8,17 @@ import (
"net/http" "net/http"
"regexp" "regexp"
"time" "time"
"github.com/komari-monitor/komari-agent/dnsresolver"
) )
var ( var (
// 创建适用于IPv4和IPv6的HTTP客户端
ipv4HTTPClient = &http.Client{ ipv4HTTPClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{ dialer := dnsresolver.GetNetDialer(15 * time.Second)
Timeout: 15 * time.Second, return dialer.DialContext(ctx, "tcp4", addr) // 锁v4防止出现问题
KeepAlive: 30 * time.Second,
}
return d.DialContext(ctx, "tcp4", addr) // 锁v4防止出现问题
}, },
MaxIdleConns: 10, MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second, IdleConnTimeout: 30 * time.Second,
@@ -30,11 +30,8 @@ var (
ipv6HTTPClient = &http.Client{ ipv6HTTPClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
d := net.Dialer{ dialer := dnsresolver.GetNetDialer(15 * time.Second)
Timeout: 15 * time.Second, return dialer.DialContext(ctx, "tcp6", addr) // 锁v6防止出现问题
KeepAlive: 30 * time.Second,
}
return d.DialContext(ctx, "tcp6", addr) // 锁v6防止出现问题
}, },
MaxIdleConns: 10, MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second, IdleConnTimeout: 30 * time.Second,

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/komari-monitor/komari-agent/cmd/flags" "github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit" monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
"github.com/komari-monitor/komari-agent/update" "github.com/komari-monitor/komari-agent/update"
) )
@@ -86,7 +87,9 @@ func tryUploadData(data map[string]interface{}) error {
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
} }
client := &http.Client{} // 使用dnsresolver获取自定义HTTP客户端
client := dnsresolver.GetHTTPClient(30 * time.Second)
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return err return err

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/komari-monitor/komari-agent/cmd/flags" "github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/dnsresolver"
"github.com/komari-monitor/komari-agent/monitoring" "github.com/komari-monitor/komari-agent/monitoring"
"github.com/komari-monitor/komari-agent/terminal" "github.com/komari-monitor/komari-agent/terminal"
"github.com/komari-monitor/komari-agent/ws" "github.com/komari-monitor/komari-agent/ws"
@@ -90,8 +91,12 @@ func EstablishWebSocketConnection() {
} }
func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) { func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) {
// 使用dnsresolver获取自定义网络拨号器
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
NetDialContext: netDialer.DialContext,
} }
// 创建请求头并添加Cloudflare Access头部 // 创建请求头并添加Cloudflare Access头部
@@ -159,8 +164,13 @@ func handleWebSocketMessages(conn *ws.SafeConn, done chan<- struct{}) {
func establishTerminalConnection(token, id, endpoint string) { func establishTerminalConnection(token, id, endpoint string) {
endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id
endpoint = "ws" + strings.TrimPrefix(endpoint, "http") endpoint = "ws" + strings.TrimPrefix(endpoint, "http")
// 使用dnsresolver获取自定义网络拨号器
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
NetDialContext: netDialer.DialContext,
} }
// 创建请求头并添加Cloudflare Access头部 // 创建请求头并添加Cloudflare Access头部

View File

@@ -3,11 +3,13 @@ package update
import ( import (
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/blang/semver" "github.com/blang/semver"
"github.com/komari-monitor/komari-agent/dnsresolver"
"github.com/rhysd/go-github-selfupdate/selfupdate" "github.com/rhysd/go-github-selfupdate/selfupdate"
) )
@@ -45,7 +47,9 @@ func CheckAndUpdate() error {
return fmt.Errorf("failed to parse current version: %v", err) return fmt.Errorf("failed to parse current version: %v", err)
} }
// Create selfupdate configuration // 使用dnsresolver创建自定义HTTP客户端并设置为全局默认客户端
// 这会影响所有HTTP请求包括selfupdate库中的请求
http.DefaultClient = dnsresolver.GetHTTPClient(60 * time.Second) // Create selfupdate configuration
config := selfupdate.Config{} config := selfupdate.Config{}
updater, err := selfupdate.NewUpdater(config) updater, err := selfupdate.NewUpdater(config)
if err != nil { if err != nil {