diff --git a/cmd/root.go b/cmd/root.go index d9b81d5..81e62b4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -34,6 +34,7 @@ var RootCmd = &cobra.Command{ } go server.DoUploadBasicInfoWorks() for { + server.UpdateBasicInfo() server.EstablishWebSocketConnection() } }, diff --git a/server/basicInfo.go b/server/basicInfo.go index 1f267e7..e66a176 100644 --- a/server/basicInfo.go +++ b/server/basicInfo.go @@ -15,10 +15,6 @@ import ( ) func DoUploadBasicInfoWorks() { - err := uploadBasicInfo() - if err != nil { - log.Println("Error uploading basic info:", err) - } ticker := time.NewTicker(time.Duration(15) * time.Minute) for range ticker.C { err := uploadBasicInfo() @@ -27,7 +23,14 @@ func DoUploadBasicInfoWorks() { } } } - +func UpdateBasicInfo() { + err := uploadBasicInfo() + if err != nil { + log.Println("Error uploading basic info:", err) + } else { + log.Println("Basic info uploaded successfully") + } +} func uploadBasicInfo() error { cpu := monitoring.Cpu() diff --git a/server/task.go b/server/task.go index a9aa681..eebceff 100644 --- a/server/task.go +++ b/server/task.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "encoding/json" "errors" "log" @@ -83,36 +84,88 @@ func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) } } } -func icmpPing(target string, timeout time.Duration) error { - pinger, err := getPinger(target) + +// resolveIP 解析域名到 IP 地址,排除 DNS 查询时间 +func resolveIP(target string) (string, error) { + // 如果已经是 IP 地址,直接返回 + if net.ParseIP(target) != nil { + return target, nil + } + // 解析域名到 IP + addrs, err := net.LookupHost(target) + if err != nil || len(addrs) == 0 { + return "", errors.New("failed to resolve target") + } + return addrs[0], nil // 返回第一个解析的 IP +} + +func icmpPing(target string, timeout time.Duration) (int64, error) { + // 先解析 IP 地址 + ip, err := resolveIP(target) if err != nil { - return err + return 0, err + } + + pinger, err := ping.NewPinger(ip) + if err != nil { + return 0, err } pinger.Count = 1 pinger.Timeout = timeout pinger.SetPrivileged(true) - return pinger.Run() -} - -func getPinger(target string) (*ping.Pinger, error) { - return ping.NewPinger(target) -} - -func tcpPing(target string, timeout time.Duration) error { - if !strings.Contains(target, ":") { - target += ":80" - } - conn, err := net.DialTimeout("tcp", target, timeout) + err = pinger.Run() if err != nil { - return err + return 0, err + } + stats := pinger.Statistics() + if stats.PacketsRecv == 0 { + return 0, errors.New("no packets received") + } + return stats.AvgRtt.Milliseconds(), nil +} + +func tcpPing(target string, timeout time.Duration) (int64, error) { + addr := strings.Split(target, ":") + ip, err := resolveIP(addr[0]) + if err != nil { + return 0, err + } + + port := "80" + if len(addr) > 1 { + port = addr[1] + } + targetAddr := ip + ":" + port + start := time.Now() + conn, err := net.DialTimeout("tcp", targetAddr, timeout) + if err != nil { + return 0, err } defer conn.Close() - return nil + return time.Since(start).Milliseconds(), nil } func httpPing(target string, timeout time.Duration) (int64, error) { - client := http.Client{ + if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") { + target = "http://" + target + } + + client := &http.Client{ Timeout: timeout, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // 在 Dial 之前解析 IP,排除 DNS 时间 + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ip, err := resolveIP(host) + if err != nil { + return nil, err + } + return net.DialTimeout(network, ip+":"+port, timeout) + }, + }, } start := time.Now() resp, err := client.Get(target) @@ -129,33 +182,39 @@ func httpPing(target string, timeout time.Duration) (int64, error) { func NewPingTask(conn *ws.SafeConn, taskID uint, pingType, pingTarget string) { if taskID == 0 { + log.Printf("Invalid task ID: %d", taskID) return } + pingResult := 0 - timeout := 3 * time.Second + timeout := 3 * time.Second // 默认超时时间 switch pingType { case "icmp": - start := time.Now() - if err := icmpPing(pingTarget, timeout); err == nil { - pingResult = int(time.Since(start).Milliseconds()) + if latency, err := icmpPing(pingTarget, timeout); err == nil { + pingResult = int(latency) } case "tcp": - start := time.Now() - if err := tcpPing(pingTarget, timeout); err == nil { - pingResult = int(time.Since(start).Milliseconds()) + if latency, err := tcpPing(pingTarget, timeout); err == nil { + pingResult = int(latency) } case "http": if latency, err := httpPing(pingTarget, timeout); err == nil { pingResult = int(latency) } default: + log.Printf("Unsupported ping type: %s", pingType) return } + payload := map[string]interface{}{ "type": "ping_result", "task_id": taskID, + "ping_type": pingType, "value": pingResult, "finished_at": time.Now(), } - _ = conn.WriteJSON(payload) + if err := conn.WriteJSON(payload); err != nil { + log.Printf("Failed to write JSON to WebSocket: %v", err) + } + } diff --git a/ws/safaConn.go b/ws/safaConn.go index 54e1e7c..7ec729c 100644 --- a/ws/safaConn.go +++ b/ws/safaConn.go @@ -37,18 +37,18 @@ func (sc *SafeConn) Close() error { return sc.conn.Close() } func (sc *SafeConn) ReadMessage() (int, []byte, error) { - sc.mu.Lock() - defer sc.mu.Unlock() + // sc.mu.Lock() + // defer sc.mu.Unlock() return sc.conn.ReadMessage() } func (sc *SafeConn) ReadJSON(v interface{}) error { - sc.mu.Lock() - defer sc.mu.Unlock() + // sc.mu.Lock() + // defer sc.mu.Unlock() return sc.conn.ReadJSON(v) } func (sc *SafeConn) SetReadDeadline(t time.Time) error { - sc.mu.Lock() - defer sc.mu.Unlock() + // sc.mu.Lock() + // defer sc.mu.Unlock() return sc.conn.SetReadDeadline(t) } func (sc *SafeConn) GetConn() *websocket.Conn {