Files
komari-agent/server/task.go
2025-08-23 16:54:31 +08:00

284 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"net"
"net/http"
"os/exec"
"runtime"
"strings"
"time"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/komari-monitor/komari-agent/ws"
ping "github.com/prometheus-community/pro-bing"
)
func NewTask(task_id, command string) {
if task_id == "" {
return
}
if command == "" {
uploadTaskResult(task_id, "No command provided", 0, time.Now())
return
}
if flags.DisableWebSsh {
uploadTaskResult(task_id, "Remote control is disabled.", -1, time.Now())
return
}
log.Printf("Executing task %s with command: %s", task_id, command)
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "+command)
} else {
cmd = exec.Command("sh", "-c", command)
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
finishedAt := time.Now()
result := stdout.String()
if stderr.Len() > 0 {
result += "\n" + stderr.String()
}
result = strings.ReplaceAll(result, "\r\n", "\n")
exitCode := 0
if err != nil {
if exitError, ok := err.(*exec.ExitError); ok {
exitCode = exitError.ExitCode()
}
}
uploadTaskResult(task_id, result, exitCode, finishedAt)
}
func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) {
payload := map[string]interface{}{
"task_id": taskID,
"result": result,
"exit_code": exitCode,
"finished_at": finishedAt,
}
jsonData, _ := json.Marshal(payload)
endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token
// 创建HTTP请求以支持自定义头部
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
log.Printf("Failed to create task result request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
// 添加Cloudflare Access头部如果配置了
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID)
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
client := &http.Client{}
resp, err := client.Do(req)
maxRetry := flags.MaxRetries
for i := 0; i < maxRetry && (err != nil || resp.StatusCode != http.StatusOK); i++ {
log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry)
time.Sleep(2 * time.Second) // Wait before retrying
resp, err = client.Do(req)
}
if resp != nil {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to upload task result: %s", resp.Status)
}
}
}
// resolveIP 解析域名到 IP 地址,排除 DNS 查询时间
func resolveIP(target string) (string, error) {
// 如果已经是 IP 地址,直接返回
if ip := net.ParseIP(target); ip != nil {
return target, nil
}
// 解析域名到 IP
addrs, err := net.LookupHost(target)
if err != nil || len(addrs) == 0 {
return "", errors.New("failed to resolve target")
}
return addrs[0], nil // 返回第一个解析的 IP
}
func icmpPing(target string, timeout time.Duration) (int64, error) {
host, _, err := net.SplitHostPort(target)
if err != nil {
host = target
}
// For ICMP, we only need the host/IP, port is irrelevant.
// If the host is an IPv6 literal, it might be wrapped in brackets.
host = strings.Trim(host, "[]")
// 先解析 IP 地址
ip, err := resolveIP(host)
if err != nil {
return -1, err
}
pinger, err := ping.NewPinger(ip)
if err != nil {
return -1, err
}
pinger.Count = 1
pinger.Timeout = timeout
pinger.SetPrivileged(true)
err = pinger.Run()
if err != nil {
return -1, err
}
stats := pinger.Statistics()
if stats.PacketsRecv == 0 {
return -1, errors.New("no packets received")
}
return stats.AvgRtt.Milliseconds(), nil
}
func tcpPing(target string, timeout time.Duration) (int64, error) {
host, port, err := net.SplitHostPort(target)
if err != nil {
// No port, assume port 80
host = target
port = "80"
}
ip, err := resolveIP(host)
if err != nil {
return -1, err
}
targetAddr := net.JoinHostPort(ip, port)
start := time.Now()
conn, err := net.DialTimeout("tcp", targetAddr, timeout)
if err != nil {
return -1, err
}
defer conn.Close()
return time.Since(start).Milliseconds(), nil
}
func httpPing(target string, timeout time.Duration) (int64, error) {
// Handle raw IPv6 address for URL
if strings.Contains(target, ":") && !strings.Contains(target, "[") {
// check if it's a valid IP to avoid wrapping hostnames
if ip := net.ParseIP(target); ip != nil && ip.To4() == nil {
target = "[" + target + "]"
}
}
if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") {
target = "http://" + target
}
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 在 Dial 之前解析 IP排除 DNS 时间
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip, err := resolveIP(host)
if err != nil {
return nil, err
}
return net.DialTimeout(network, net.JoinHostPort(ip, port), timeout)
},
},
}
start := time.Now()
resp, err := client.Get(target)
latency := time.Since(start).Milliseconds()
if err != nil {
return -1, err
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
return latency, nil
}
return latency, errors.New("http status not ok")
}
func NewPingTask(conn *ws.SafeConn, taskID uint, pingType, pingTarget string) {
if taskID == 0 {
log.Printf("Invalid task ID: %d", taskID)
return
}
var err error = nil
var latency int64
pingResult := -1
timeout := 3 * time.Second // 默认超时时间
const highLatencyThreshold = 1000 // ms 阈值
measure := func() (int64, error) {
switch pingType {
case "icmp":
return icmpPing(pingTarget, timeout)
case "tcp":
return tcpPing(pingTarget, timeout)
case "http":
return httpPing(pingTarget, timeout)
default:
return -1, errors.New("unsupported ping type")
}
}
PingHighLatencyRetries := 3
// 首次测量
if latency, err = measure(); err == nil {
if latency > int64(highLatencyThreshold) && PingHighLatencyRetries > 0 {
attempts := PingHighLatencyRetries
for i := 0; i < attempts; i++ {
if second, err2 := measure(); err2 == nil {
if second <= int64(highLatencyThreshold) {
latency = second
break
}
if i == attempts-1 { // 最后一次仍高
err = errors.New("latency remains high after retries")
}
} else {
err = err2
break
}
}
}
}
if err != nil {
log.Printf("Ping task %d failed: %v", taskID, err)
pingResult = -1 // 如果有错误,设置结果为 -1
} else {
pingResult = int(latency)
}
payload := map[string]interface{}{
"type": "ping_result",
"task_id": taskID,
"ping_type": pingType,
"value": pingResult,
"finished_at": time.Now(),
}
// https://github.com/komari-monitor/komari/commit/eb87a4fc330b7d1c407fa4ff70177615a4f50a1f
// -1 代表丢包,服务端计算
//if pingResult == -1 {
// return
//}
if err := conn.WriteJSON(payload); err != nil {
log.Printf("Failed to write JSON to WebSocket: %v", err)
}
}