diff --git a/server/task.go b/server/task.go index 6348502..510d366 100644 --- a/server/task.go +++ b/server/task.go @@ -88,7 +88,7 @@ func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) // resolveIP 解析域名到 IP 地址,排除 DNS 查询时间 func resolveIP(target string) (string, error) { // 如果已经是 IP 地址,直接返回 - if net.ParseIP(target) != nil { + if ip := net.ParseIP(target); ip != nil { return target, nil } // 解析域名到 IP @@ -100,8 +100,16 @@ func resolveIP(target string) (string, error) { } func icmpPing(target string, timeout time.Duration) (int64, error) { + host, _, err := net.SplitHostPort(target) + if err != nil { + host = target + } + // For ICMP, we only need the host/IP, port is irrelevant. + // If the host is an IPv6 literal, it might be wrapped in brackets. + host = strings.Trim(host, "[]") + // 先解析 IP 地址 - ip, err := resolveIP(target) + ip, err := resolveIP(host) if err != nil { return -1, err } @@ -125,17 +133,19 @@ func icmpPing(target string, timeout time.Duration) (int64, error) { } func tcpPing(target string, timeout time.Duration) (int64, error) { - addr := strings.Split(target, ":") - ip, err := resolveIP(addr[0]) + host, port, err := net.SplitHostPort(target) + if err != nil { + // No port, assume port 80 + host = target + port = "80" + } + + ip, err := resolveIP(host) if err != nil { return -1, err } - port := "80" - if len(addr) > 1 { - port = addr[1] - } - targetAddr := ip + ":" + port + targetAddr := net.JoinHostPort(ip, port) start := time.Now() conn, err := net.DialTimeout("tcp", targetAddr, timeout) if err != nil { @@ -146,6 +156,14 @@ func tcpPing(target string, timeout time.Duration) (int64, error) { } func httpPing(target string, timeout time.Duration) (int64, error) { + // Handle raw IPv6 address for URL + if strings.Contains(target, ":") && !strings.Contains(target, "[") { + // check if it's a valid IP to avoid wrapping hostnames + if ip := net.ParseIP(target); ip != nil && ip.To4() == nil { + target = "[" + target + "]" + } + } + if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") { target = "http://" + target } @@ -163,7 +181,7 @@ func httpPing(target string, timeout time.Duration) (int64, error) { if err != nil { return nil, err } - return net.DialTimeout(network, ip+":"+port, timeout) + return net.DialTimeout(network, net.JoinHostPort(ip, port), timeout) }, }, } diff --git a/server/task_test.go b/server/task_test.go new file mode 100644 index 0000000..4032899 --- /dev/null +++ b/server/task_test.go @@ -0,0 +1,62 @@ +package server + +import ( + "testing" + "time" +) + +var testTargets = []struct { + target string +}{ + {"v6-sh-cm.oojj.de"}, + {"2409:8c1e:8f80:2:6a::"}, + {"[2409:8c1e:8f80:2:6a::]:80"}, + {"v4-sh-cm.oojj.de"}, + {"117.185.125.154"}, + {"117.185.125.154:80"}, +} + +func TestICMPPing(t *testing.T) { + timeout := 3 * time.Second + for _, tt := range testTargets { + t.Run(tt.target, func(t *testing.T) { + latency, err := icmpPing(tt.target, timeout) + if latency < -1 { + t.Errorf("ICMP ping %s: invalid latency %d", tt.target, latency) + } + if err != nil { + t.Errorf("ICMP ping %s error: %v", tt.target, err) + } + }) + } +} + +func TestTCPPing(t *testing.T) { + timeout := 3 * time.Second + for _, tt := range testTargets { + t.Run(tt.target, func(t *testing.T) { + latency, err := tcpPing(tt.target, timeout) + if latency < -1 { + t.Errorf("TCP ping %s: invalid latency %d", tt.target, latency) + } + if err != nil { + t.Errorf("TCP ping %s error: %v", tt.target, err) + } + }) + } +} + +func TestHTTPPing(t *testing.T) { + timeout := 3 * time.Second + for _, tt := range testTargets { + t.Run(tt.target, func(t *testing.T) { + latency, err := httpPing(tt.target, timeout) + if latency < -1 { + t.Errorf("HTTP ping %s: invalid latency %d", tt.target, latency) + } + if err != nil { + t.Errorf("HTTP ping %s error: %v", tt.target, err) + } + }) + } +}