package server import ( "bytes" "context" "encoding/json" "errors" "log" "net" "net/http" "os/exec" "runtime" "strings" "time" "github.com/komari-monitor/komari-agent/cmd/flags" "github.com/komari-monitor/komari-agent/ws" ping "github.com/prometheus-community/pro-bing" ) func NewTask(task_id, command string) { if task_id == "" { return } if command == "" { uploadTaskResult(task_id, "No command provided", 0, time.Now()) return } if flags.DisableWebSsh { uploadTaskResult(task_id, "Remote control is disabled.", -1, time.Now()) return } log.Printf("Executing task %s with command: %s", task_id, command) var cmd *exec.Cmd if runtime.GOOS == "windows" { cmd = exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "+command) } else { cmd = exec.Command("sh", "-c", command) } var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr err := cmd.Run() finishedAt := time.Now() result := stdout.String() if stderr.Len() > 0 { result += "\n" + stderr.String() } result = strings.ReplaceAll(result, "\r\n", "\n") exitCode := 0 if err != nil { if exitError, ok := err.(*exec.ExitError); ok { exitCode = exitError.ExitCode() } } uploadTaskResult(task_id, result, exitCode, finishedAt) } func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) { payload := map[string]interface{}{ "task_id": taskID, "result": result, "exit_code": exitCode, "finished_at": finishedAt, } jsonData, _ := json.Marshal(payload) endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token // 创建HTTP请求以支持自定义头部 req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) if err != nil { log.Printf("Failed to create task result request: %v", err) return } req.Header.Set("Content-Type", "application/json") // 添加Cloudflare Access头部(如果配置了) if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID) req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) } client := &http.Client{} resp, err := client.Do(req) maxRetry := flags.MaxRetries for i := 0; i < maxRetry && (err != nil || resp.StatusCode != http.StatusOK); i++ { log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry) time.Sleep(2 * time.Second) // Wait before retrying resp, err = client.Do(req) } if resp != nil { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("Failed to upload task result: %s", resp.Status) } } } // resolveIP 解析域名到 IP 地址,排除 DNS 查询时间 func resolveIP(target string) (string, error) { // 如果已经是 IP 地址,直接返回 if ip := net.ParseIP(target); ip != nil { return target, nil } // 解析域名到 IP addrs, err := net.LookupHost(target) if err != nil || len(addrs) == 0 { return "", errors.New("failed to resolve target") } return addrs[0], nil // 返回第一个解析的 IP } func icmpPing(target string, timeout time.Duration) (int64, error) { host, _, err := net.SplitHostPort(target) if err != nil { host = target } // For ICMP, we only need the host/IP, port is irrelevant. // If the host is an IPv6 literal, it might be wrapped in brackets. host = strings.Trim(host, "[]") // 先解析 IP 地址 ip, err := resolveIP(host) if err != nil { return -1, err } pinger, err := ping.NewPinger(ip) if err != nil { return -1, err } pinger.Count = 1 pinger.Timeout = timeout pinger.SetPrivileged(true) err = pinger.Run() if err != nil { return -1, err } stats := pinger.Statistics() if stats.PacketsRecv == 0 { return -1, errors.New("no packets received") } return stats.AvgRtt.Milliseconds(), nil } func tcpPing(target string, timeout time.Duration) (int64, error) { host, port, err := net.SplitHostPort(target) if err != nil { // No port, assume port 80 host = target port = "80" } ip, err := resolveIP(host) if err != nil { return -1, err } targetAddr := net.JoinHostPort(ip, port) start := time.Now() conn, err := net.DialTimeout("tcp", targetAddr, timeout) if err != nil { return -1, err } defer conn.Close() return time.Since(start).Milliseconds(), nil } func httpPing(target string, timeout time.Duration) (int64, error) { // Handle raw IPv6 address for URL if strings.Contains(target, ":") && !strings.Contains(target, "[") { // check if it's a valid IP to avoid wrapping hostnames if ip := net.ParseIP(target); ip != nil && ip.To4() == nil { target = "[" + target + "]" } } if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") { target = "http://" + target } client := &http.Client{ Timeout: timeout, Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { // 在 Dial 之前解析 IP,排除 DNS 时间 host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err } ip, err := resolveIP(host) if err != nil { return nil, err } return net.DialTimeout(network, net.JoinHostPort(ip, port), timeout) }, }, } start := time.Now() resp, err := client.Get(target) latency := time.Since(start).Milliseconds() if err != nil { return -1, err } defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 400 { return latency, nil } return latency, errors.New("http status not ok") } func NewPingTask(conn *ws.SafeConn, taskID uint, pingType, pingTarget string) { if taskID == 0 { log.Printf("Invalid task ID: %d", taskID) return } var err error = nil var latency int64 pingResult := -1 timeout := 3 * time.Second // 默认超时时间 const highLatencyThreshold = 1000 // ms 阈值 measure := func() (int64, error) { switch pingType { case "icmp": return icmpPing(pingTarget, timeout) case "tcp": return tcpPing(pingTarget, timeout) case "http": return httpPing(pingTarget, timeout) default: return -1, errors.New("unsupported ping type") } } PingHighLatencyRetries := 3 // 首次测量 if latency, err = measure(); err == nil { if latency > int64(highLatencyThreshold) && PingHighLatencyRetries > 0 { attempts := PingHighLatencyRetries for i := 0; i < attempts; i++ { if second, err2 := measure(); err2 == nil { if second <= int64(highLatencyThreshold) { latency = second break } if i == attempts-1 { // 最后一次仍高 err = errors.New("latency remains high after retries") } } else { err = err2 break } } } } if err != nil { log.Printf("Ping task %d failed: %v", taskID, err) pingResult = -1 // 如果有错误,设置结果为 -1 } else { pingResult = int(latency) } payload := map[string]interface{}{ "type": "ping_result", "task_id": taskID, "ping_type": pingType, "value": pingResult, "finished_at": time.Now(), } // https://github.com/komari-monitor/komari/commit/eb87a4fc330b7d1c407fa4ff70177615a4f50a1f // -1 代表丢包,服务端计算 //if pingResult == -1 { // return //} if err := conn.WriteJSON(payload); err != nil { log.Printf("Failed to write JSON to WebSocket: %v", err) } }