mirror of
https://github.com/fankes/komari-agent.git
synced 2025-10-18 18:49:23 +08:00
feat: 添加自定义 DNS 解析器
This commit is contained in:
@@ -19,4 +19,5 @@ var (
|
||||
CFAccessClientID string
|
||||
CFAccessClientSecret string
|
||||
MemoryIncludeCache bool
|
||||
CustomDNS string
|
||||
)
|
||||
|
11
cmd/root.go
11
cmd/root.go
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
"github.com/komari-monitor/komari-agent/dnsresolver"
|
||||
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
|
||||
"github.com/komari-monitor/komari-agent/server"
|
||||
"github.com/komari-monitor/komari-agent/update"
|
||||
@@ -20,6 +21,15 @@ var RootCmd = &cobra.Command{
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
log.Println("Komari Agent", update.CurrentVersion)
|
||||
log.Println("Github Repo:", update.Repo)
|
||||
|
||||
// 设置自定义DNS解析器
|
||||
if flags.CustomDNS != "" {
|
||||
dnsresolver.SetCustomDNSServer(flags.CustomDNS)
|
||||
log.Printf("Using custom DNS server: %s", flags.CustomDNS)
|
||||
} else {
|
||||
log.Printf("Using default DNS servers, primary: %s", dnsresolver.DNSServers[0])
|
||||
}
|
||||
|
||||
// Auto discovery
|
||||
if flags.AutoDiscoveryKey != "" {
|
||||
err := handleAutoDiscovery()
|
||||
@@ -100,5 +110,6 @@ func init() {
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID")
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret")
|
||||
RootCmd.PersistentFlags().BoolVar(&flags.MemoryIncludeCache, "memory-include-cache", false, "Include cache/buffer in memory usage")
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CustomDNS, "custom-dns", "", "Custom DNS server to use (e.g. 8.8.8.8)")
|
||||
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
|
||||
}
|
||||
|
132
dnsresolver/resolver.go
Normal file
132
dnsresolver/resolver.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package dnsresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// DNS服务器列表,按优先级排序
|
||||
DNSServers = []string{
|
||||
"114.114.114.114:53", // 114DNS,中国大陆
|
||||
"8.8.8.8:53", // Google DNS,全球
|
||||
"8.8.4.4:53", // Google DNS备用,全球
|
||||
"1.1.1.1:53", // Cloudflare DNS,全球
|
||||
"223.5.5.5:53", // 阿里DNS,中国大陆
|
||||
"119.29.29.29:53", // DNSPod,中国大陆
|
||||
}
|
||||
|
||||
// CustomDNSServer 自定义DNS服务器,可以通过命令行参数设置
|
||||
CustomDNSServer string
|
||||
)
|
||||
|
||||
// SetCustomDNSServer 设置自定义DNS服务器
|
||||
func SetCustomDNSServer(dnsServer string) {
|
||||
if dnsServer != "" {
|
||||
// 检查是否已包含端口,如果没有则添加默认端口53
|
||||
if !strings.Contains(dnsServer, ":") {
|
||||
dnsServer = dnsServer + ":53"
|
||||
}
|
||||
CustomDNSServer = dnsServer
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentDNSServer 获取当前要使用的DNS服务器
|
||||
func getCurrentDNSServer() string {
|
||||
if CustomDNSServer != "" {
|
||||
return CustomDNSServer
|
||||
}
|
||||
// 如果没有设置自定义DNS,返回默认的第一个
|
||||
return DNSServers[0]
|
||||
}
|
||||
|
||||
// GetCustomResolver 返回一个使用指定DNS服务器的解析器
|
||||
func GetCustomResolver() *net.Resolver {
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// 尝试自定义DNS或默认DNS
|
||||
dnsServer := getCurrentDNSServer()
|
||||
conn, err := d.DialContext(ctx, "udp", dnsServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// 如果连接失败,尝试其他DNS服务器
|
||||
for _, server := range DNSServers {
|
||||
if server != dnsServer { // 避免重复尝试
|
||||
conn, err := d.DialContext(ctx, "udp", server)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 所有DNS服务器都失败,返回最后一次的错误
|
||||
return nil, err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetHTTPClient 返回一个使用自定义DNS解析器的HTTP客户端
|
||||
func GetHTTPClient(timeout time.Duration) *http.Client {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
customResolver := GetCustomResolver()
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips, err := customResolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to dial to any of the resolved IPs")
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetDialer 返回一个使用自定义DNS解析器的网络拨号器
|
||||
func GetNetDialer(timeout time.Duration) *net.Dialer {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
return &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: GetCustomResolver(),
|
||||
}
|
||||
}
|
@@ -8,17 +8,17 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/dnsresolver"
|
||||
)
|
||||
|
||||
var (
|
||||
// 创建适用于IPv4和IPv6的HTTP客户端
|
||||
ipv4HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return d.DialContext(ctx, "tcp4", addr) // 锁v4防止出现问题
|
||||
dialer := dnsresolver.GetNetDialer(15 * time.Second)
|
||||
return dialer.DialContext(ctx, "tcp4", addr) // 锁v4防止出现问题
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
@@ -30,11 +30,8 @@ var (
|
||||
ipv6HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return d.DialContext(ctx, "tcp6", addr) // 锁v6防止出现问题
|
||||
dialer := dnsresolver.GetNetDialer(15 * time.Second)
|
||||
return dialer.DialContext(ctx, "tcp6", addr) // 锁v6防止出现问题
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
"github.com/komari-monitor/komari-agent/dnsresolver"
|
||||
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
|
||||
"github.com/komari-monitor/komari-agent/update"
|
||||
)
|
||||
@@ -86,7 +87,9 @@ func tryUploadData(data map[string]interface{}) error {
|
||||
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
// 使用dnsresolver获取自定义HTTP客户端
|
||||
client := dnsresolver.GetHTTPClient(30 * time.Second)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
"github.com/komari-monitor/komari-agent/dnsresolver"
|
||||
"github.com/komari-monitor/komari-agent/monitoring"
|
||||
"github.com/komari-monitor/komari-agent/terminal"
|
||||
"github.com/komari-monitor/komari-agent/ws"
|
||||
@@ -90,8 +91,12 @@ func EstablishWebSocketConnection() {
|
||||
}
|
||||
|
||||
func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) {
|
||||
// 使用dnsresolver获取自定义网络拨号器
|
||||
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
NetDialContext: netDialer.DialContext,
|
||||
}
|
||||
|
||||
// 创建请求头并添加Cloudflare Access头部
|
||||
@@ -159,8 +164,13 @@ func handleWebSocketMessages(conn *ws.SafeConn, done chan<- struct{}) {
|
||||
func establishTerminalConnection(token, id, endpoint string) {
|
||||
endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id
|
||||
endpoint = "ws" + strings.TrimPrefix(endpoint, "http")
|
||||
|
||||
// 使用dnsresolver获取自定义网络拨号器
|
||||
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
NetDialContext: netDialer.DialContext,
|
||||
}
|
||||
|
||||
// 创建请求头并添加Cloudflare Access头部
|
||||
|
@@ -3,11 +3,13 @@ package update
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/komari-monitor/komari-agent/dnsresolver"
|
||||
"github.com/rhysd/go-github-selfupdate/selfupdate"
|
||||
)
|
||||
|
||||
@@ -45,7 +47,9 @@ func CheckAndUpdate() error {
|
||||
return fmt.Errorf("failed to parse current version: %v", err)
|
||||
}
|
||||
|
||||
// Create selfupdate configuration
|
||||
// 使用dnsresolver创建自定义HTTP客户端并设置为全局默认客户端
|
||||
// 这会影响所有HTTP请求,包括selfupdate库中的请求
|
||||
http.DefaultClient = dnsresolver.GetHTTPClient(60 * time.Second) // Create selfupdate configuration
|
||||
config := selfupdate.Config{}
|
||||
updater, err := selfupdate.NewUpdater(config)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user