From b247a6ebee41db851ff56cffe3f7bbf23b13c65c Mon Sep 17 00:00:00 2001 From: Akizon77 Date: Wed, 4 Jun 2025 17:20:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E5=8A=9F=E8=83=BD=20fix:=20=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E8=A7=A3=E6=9E=90=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/task.go | 81 +++++++++++++++++++++++++++++++++++++++++++ server/websocket.go | 14 ++++++-- update/update.go | 16 ++++++++- update/update_test.go | 72 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 179 insertions(+), 4 deletions(-) create mode 100644 server/task.go create mode 100644 update/update_test.go diff --git a/server/task.go b/server/task.go new file mode 100644 index 0000000..e7e8bb6 --- /dev/null +++ b/server/task.go @@ -0,0 +1,81 @@ +package server + +import ( + "bytes" + "encoding/json" + "log" + "net/http" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/komari-monitor/komari-agent/cmd/flags" +) + +func NewTask(task_id, command string) { + if task_id == "" { + return + } + if command == "" { + uploadTaskResult(task_id, "No command provided", 0, time.Now()) + return + } + if flags.DisableWebSsh { + uploadTaskResult(task_id, "Web SSH (REC) is disabled.", -1, time.Now()) + return + } + log.Printf("Executing task %s with command: %s", task_id, command) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "+command) + } else { + cmd = exec.Command("sh", "-c", command) + } + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + finishedAt := time.Now() + + result := stdout.String() + if stderr.Len() > 0 { + result += "\n" + stderr.String() + } + result = strings.ReplaceAll(result, "\r\n", "\n") + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + exitCode = exitError.ExitCode() + } + } + + uploadTaskResult(task_id, result, exitCode, finishedAt) +} + +func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) { + payload := map[string]interface{}{ + "task_id": taskID, + "result": result, + "exit_code": exitCode, + "finished_at": finishedAt, + } + + jsonData, _ := json.Marshal(payload) + endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token + + resp, _ := http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) + maxRetry := flags.MaxRetries + for i := 0; i < maxRetry && resp.StatusCode != http.StatusOK; i++ { + log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry) + time.Sleep(2 * time.Second) // Wait before retrying + resp, _ = http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) + } + if resp != nil { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + log.Printf("Failed to upload task result: %s", resp.Status) + } + } +} diff --git a/server/websocket.go b/server/websocket.go index 0344093..8394e59 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -99,7 +99,11 @@ func handleWebSocketMessages(conn *websocket.Conn, done chan<- struct{}) { } var message struct { Message string `json:"message"` - ID string `json:"request_id"` + // Terminal + TerminalId string `json:"request_id,omitempty"` + // Remote Exec + ExecCommand string `json:"command,omitempty"` + ExecTaskID string `json:"task_id,omitempty"` } err = json.Unmarshal(message_raw, &message) if err != nil { @@ -107,8 +111,12 @@ func handleWebSocketMessages(conn *websocket.Conn, done chan<- struct{}) { continue } - if message.Message == "terminal" || message.ID != "" { - go establishTerminalConnection(flags.Token, message.ID, flags.Endpoint) + if message.Message == "terminal" || message.TerminalId != "" { + go establishTerminalConnection(flags.Token, message.TerminalId, flags.Endpoint) + continue + } + if message.Message == "exec" { + go NewTask(message.ExecTaskID, message.ExecCommand) continue } diff --git a/update/update.go b/update/update.go index ffbfb79..21a2208 100644 --- a/update/update.go +++ b/update/update.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "strings" "time" "github.com/blang/semver" @@ -15,6 +16,19 @@ var ( Repo string = "komari-monitor/komari-agent" ) +// parseVersion 解析可能带有 v/V 前缀,以及预发布或构建元数据的版本字符串 +func parseVersion(ver string) (semver.Version, error) { + ver = strings.TrimPrefix(ver, "v") + ver = strings.TrimPrefix(ver, "V") + return semver.ParseTolerant(ver) +} + +// needUpdate 判断是否需要更新 +func needUpdate(current, latest semver.Version) bool { + // 返回最新版本大于当前版本时需要更新 + return latest.Compare(current) > 0 +} + func DoUpdateWorks() { ticker_ := time.NewTicker(time.Duration(6) * time.Hour) for range ticker_.C { @@ -26,7 +40,7 @@ func DoUpdateWorks() { func CheckAndUpdate() error { log.Println("Checking update...") // Parse current version - currentSemVer, err := semver.Parse(CurrentVersion) + currentSemVer, err := parseVersion(CurrentVersion) if err != nil { return fmt.Errorf("failed to parse current version: %v", err) } diff --git a/update/update_test.go b/update/update_test.go new file mode 100644 index 0000000..f6d8f70 --- /dev/null +++ b/update/update_test.go @@ -0,0 +1,72 @@ +package update + +import ( + "strings" + "testing" +) + +// TestParseVersion 验证 parseVersion 能够解析各种版本号格式,包括带 v/V 前缀、预发布和构建元数据 +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"v1.2.3", "1.2.3"}, + {"V1.2.3", "1.2.3"}, + {"1.2.3-beta.1", "1.2.3-beta.1"}, + {"v1.2.3+meta", "1.2.3+meta"}, + {"1.2.3-pre.1+build.123", "1.2.3-pre.1+build.123"}, + {" v2.0.0 ", "2.0.0"}, + {"invalid", ""}, + } + + for _, tt := range tests { + got, err := parseVersion(strings.TrimSpace(tt.input)) + if tt.want == "" { + if err == nil { + t.Errorf("parseVersion(%q) expected error, got %v", tt.input, got) + } + } else { + if err != nil { + t.Errorf("parseVersion(%q) unexpected error: %v", tt.input, err) + continue + } + if got.String() != tt.want { + t.Errorf("parseVersion(%q) = %q, want %q", tt.input, got.String(), tt.want) + } + } + } +} + +// TestNeedUpdate 验证 needUpdate 在不同版本组合下的判断 +func TestNeedUpdate(t *testing.T) { + tests := []struct { + current string + latest string + want bool + }{ + {"1.0.0", "1.0.1", true}, + {"v1.0.0", "1.1.0", true}, + {"1.2.3", "1.2.3", false}, + {"1.2.4", "1.2.3", false}, + {"1.2.3-beta", "1.2.3", true}, + {"1.2.3", "1.2.3-beta", false}, + {"0.0.5", "0.0.6+build.1", true}, + {"0.0.6", "v0.0.6+build.1", false}, + } + + for _, tt := range tests { + cur, err := parseVersion(strings.TrimSpace(tt.current)) + if err != nil { + t.Fatalf("parseVersion(%q) error: %v", tt.current, err) + } + lat, err := parseVersion(strings.TrimSpace(tt.latest)) + if err != nil { + t.Fatalf("parseVersion(%q) error: %v", tt.latest, err) + } + got := needUpdate(cur, lat) + if got != tt.want { + t.Errorf("needUpdate(%q, %q) = %v, want %v", tt.current, tt.latest, got, tt.want) + } + } +}