mirror of
https://github.com/fankes/komari-agent.git
synced 2025-10-19 02:59:23 +08:00
248 lines
6.3 KiB
Go
248 lines
6.3 KiB
Go
package server
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"os/exec"
|
||
"runtime"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||
"github.com/komari-monitor/komari-agent/ws"
|
||
ping "github.com/prometheus-community/pro-bing"
|
||
)
|
||
|
||
func NewTask(task_id, command string) {
|
||
if task_id == "" {
|
||
return
|
||
}
|
||
if command == "" {
|
||
uploadTaskResult(task_id, "No command provided", 0, time.Now())
|
||
return
|
||
}
|
||
if flags.DisableWebSsh {
|
||
uploadTaskResult(task_id, "Remote control is disabled.", -1, time.Now())
|
||
return
|
||
}
|
||
log.Printf("Executing task %s with command: %s", task_id, command)
|
||
var cmd *exec.Cmd
|
||
if runtime.GOOS == "windows" {
|
||
cmd = exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "+command)
|
||
} else {
|
||
cmd = exec.Command("sh", "-c", command)
|
||
}
|
||
var stdout, stderr bytes.Buffer
|
||
cmd.Stdout = &stdout
|
||
cmd.Stderr = &stderr
|
||
|
||
err := cmd.Run()
|
||
finishedAt := time.Now()
|
||
|
||
result := stdout.String()
|
||
if stderr.Len() > 0 {
|
||
result += "\n" + stderr.String()
|
||
}
|
||
result = strings.ReplaceAll(result, "\r\n", "\n")
|
||
exitCode := 0
|
||
if err != nil {
|
||
if exitError, ok := err.(*exec.ExitError); ok {
|
||
exitCode = exitError.ExitCode()
|
||
}
|
||
}
|
||
|
||
uploadTaskResult(task_id, result, exitCode, finishedAt)
|
||
}
|
||
|
||
func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) {
|
||
payload := map[string]interface{}{
|
||
"task_id": taskID,
|
||
"result": result,
|
||
"exit_code": exitCode,
|
||
"finished_at": finishedAt,
|
||
}
|
||
|
||
jsonData, _ := json.Marshal(payload)
|
||
endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token
|
||
|
||
resp, _ := http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData))
|
||
maxRetry := flags.MaxRetries
|
||
for i := 0; i < maxRetry && resp.StatusCode != http.StatusOK; i++ {
|
||
log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry)
|
||
time.Sleep(2 * time.Second) // Wait before retrying
|
||
resp, _ = http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData))
|
||
}
|
||
if resp != nil {
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
log.Printf("Failed to upload task result: %s", resp.Status)
|
||
}
|
||
}
|
||
}
|
||
|
||
// resolveIP 解析域名到 IP 地址,排除 DNS 查询时间
|
||
func resolveIP(target string) (string, error) {
|
||
// 如果已经是 IP 地址,直接返回
|
||
if ip := net.ParseIP(target); ip != nil {
|
||
return target, nil
|
||
}
|
||
// 解析域名到 IP
|
||
addrs, err := net.LookupHost(target)
|
||
if err != nil || len(addrs) == 0 {
|
||
return "", errors.New("failed to resolve target")
|
||
}
|
||
return addrs[0], nil // 返回第一个解析的 IP
|
||
}
|
||
|
||
func icmpPing(target string, timeout time.Duration) (int64, error) {
|
||
host, _, err := net.SplitHostPort(target)
|
||
if err != nil {
|
||
host = target
|
||
}
|
||
// For ICMP, we only need the host/IP, port is irrelevant.
|
||
// If the host is an IPv6 literal, it might be wrapped in brackets.
|
||
host = strings.Trim(host, "[]")
|
||
|
||
// 先解析 IP 地址
|
||
ip, err := resolveIP(host)
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
|
||
pinger, err := ping.NewPinger(ip)
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
pinger.Count = 1
|
||
pinger.Timeout = timeout
|
||
pinger.SetPrivileged(true)
|
||
err = pinger.Run()
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
stats := pinger.Statistics()
|
||
if stats.PacketsRecv == 0 {
|
||
return -1, errors.New("no packets received")
|
||
}
|
||
return stats.AvgRtt.Milliseconds(), nil
|
||
}
|
||
|
||
func tcpPing(target string, timeout time.Duration) (int64, error) {
|
||
host, port, err := net.SplitHostPort(target)
|
||
if err != nil {
|
||
// No port, assume port 80
|
||
host = target
|
||
port = "80"
|
||
}
|
||
|
||
ip, err := resolveIP(host)
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
|
||
targetAddr := net.JoinHostPort(ip, port)
|
||
start := time.Now()
|
||
conn, err := net.DialTimeout("tcp", targetAddr, timeout)
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
defer conn.Close()
|
||
return time.Since(start).Milliseconds(), nil
|
||
}
|
||
|
||
func httpPing(target string, timeout time.Duration) (int64, error) {
|
||
// Handle raw IPv6 address for URL
|
||
if strings.Contains(target, ":") && !strings.Contains(target, "[") {
|
||
// check if it's a valid IP to avoid wrapping hostnames
|
||
if ip := net.ParseIP(target); ip != nil && ip.To4() == nil {
|
||
target = "[" + target + "]"
|
||
}
|
||
}
|
||
|
||
if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") {
|
||
target = "http://" + target
|
||
}
|
||
|
||
client := &http.Client{
|
||
Timeout: timeout,
|
||
Transport: &http.Transport{
|
||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 在 Dial 之前解析 IP,排除 DNS 时间
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
ip, err := resolveIP(host)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return net.DialTimeout(network, net.JoinHostPort(ip, port), timeout)
|
||
},
|
||
},
|
||
}
|
||
start := time.Now()
|
||
resp, err := client.Get(target)
|
||
latency := time.Since(start).Milliseconds()
|
||
if err != nil {
|
||
return -1, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
|
||
return latency, nil
|
||
}
|
||
return latency, errors.New("http status not ok")
|
||
}
|
||
|
||
func NewPingTask(conn *ws.SafeConn, taskID uint, pingType, pingTarget string) {
|
||
if taskID == 0 {
|
||
log.Printf("Invalid task ID: %d", taskID)
|
||
return
|
||
}
|
||
var err error = nil
|
||
var latency int64
|
||
pingResult := -1
|
||
timeout := 3 * time.Second // 默认超时时间
|
||
switch pingType {
|
||
case "icmp":
|
||
if latency, err = icmpPing(pingTarget, timeout); err == nil {
|
||
pingResult = int(latency)
|
||
}
|
||
case "tcp":
|
||
if latency, err = tcpPing(pingTarget, timeout); err == nil {
|
||
pingResult = int(latency)
|
||
}
|
||
case "http":
|
||
if latency, err = httpPing(pingTarget, timeout); err == nil {
|
||
pingResult = int(latency)
|
||
}
|
||
default:
|
||
log.Printf("Unsupported ping type: %s", pingType)
|
||
return
|
||
}
|
||
if err != nil {
|
||
log.Printf("Ping task %d failed: %v", taskID, err)
|
||
pingResult = -1 // 如果有错误,设置结果为 -1
|
||
}
|
||
payload := map[string]interface{}{
|
||
"type": "ping_result",
|
||
"task_id": taskID,
|
||
"ping_type": pingType,
|
||
"value": pingResult,
|
||
"finished_at": time.Now(),
|
||
}
|
||
// https://github.com/komari-monitor/komari/commit/eb87a4fc330b7d1c407fa4ff70177615a4f50a1f
|
||
// -1 代表丢包,服务端计算
|
||
//if pingResult == -1 {
|
||
// return
|
||
//}
|
||
if err := conn.WriteJSON(payload); err != nil {
|
||
log.Printf("Failed to write JSON to WebSocket: %v", err)
|
||
}
|
||
|
||
}
|