mirror of
https://github.com/fankes/komari-agent.git
synced 2025-10-18 10:39:24 +08:00
feat: 添加任务执行功能
fix: 版本解析逻辑
This commit is contained in:
81
server/task.go
Normal file
81
server/task.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
)
|
||||
|
||||
func NewTask(task_id, command string) {
|
||||
if task_id == "" {
|
||||
return
|
||||
}
|
||||
if command == "" {
|
||||
uploadTaskResult(task_id, "No command provided", 0, time.Now())
|
||||
return
|
||||
}
|
||||
if flags.DisableWebSsh {
|
||||
uploadTaskResult(task_id, "Web SSH (REC) is disabled.", -1, time.Now())
|
||||
return
|
||||
}
|
||||
log.Printf("Executing task %s with command: %s", task_id, command)
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("powershell", "-NoProfile", "-ExecutionPolicy", "Bypass", "-Command", "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "+command)
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", command)
|
||||
}
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
finishedAt := time.Now()
|
||||
|
||||
result := stdout.String()
|
||||
if stderr.Len() > 0 {
|
||||
result += "\n" + stderr.String()
|
||||
}
|
||||
result = strings.ReplaceAll(result, "\r\n", "\n")
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitError.ExitCode()
|
||||
}
|
||||
}
|
||||
|
||||
uploadTaskResult(task_id, result, exitCode, finishedAt)
|
||||
}
|
||||
|
||||
func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) {
|
||||
payload := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"result": result,
|
||||
"exit_code": exitCode,
|
||||
"finished_at": finishedAt,
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(payload)
|
||||
endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token
|
||||
|
||||
resp, _ := http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData))
|
||||
maxRetry := flags.MaxRetries
|
||||
for i := 0; i < maxRetry && resp.StatusCode != http.StatusOK; i++ {
|
||||
log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry)
|
||||
time.Sleep(2 * time.Second) // Wait before retrying
|
||||
resp, _ = http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData))
|
||||
}
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("Failed to upload task result: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
}
|
@@ -99,7 +99,11 @@ func handleWebSocketMessages(conn *websocket.Conn, done chan<- struct{}) {
|
||||
}
|
||||
var message struct {
|
||||
Message string `json:"message"`
|
||||
ID string `json:"request_id"`
|
||||
// Terminal
|
||||
TerminalId string `json:"request_id,omitempty"`
|
||||
// Remote Exec
|
||||
ExecCommand string `json:"command,omitempty"`
|
||||
ExecTaskID string `json:"task_id,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(message_raw, &message)
|
||||
if err != nil {
|
||||
@@ -107,8 +111,12 @@ func handleWebSocketMessages(conn *websocket.Conn, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
|
||||
if message.Message == "terminal" || message.ID != "" {
|
||||
go establishTerminalConnection(flags.Token, message.ID, flags.Endpoint)
|
||||
if message.Message == "terminal" || message.TerminalId != "" {
|
||||
go establishTerminalConnection(flags.Token, message.TerminalId, flags.Endpoint)
|
||||
continue
|
||||
}
|
||||
if message.Message == "exec" {
|
||||
go NewTask(message.ExecTaskID, message.ExecCommand)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
@@ -15,6 +16,19 @@ var (
|
||||
Repo string = "komari-monitor/komari-agent"
|
||||
)
|
||||
|
||||
// parseVersion 解析可能带有 v/V 前缀,以及预发布或构建元数据的版本字符串
|
||||
func parseVersion(ver string) (semver.Version, error) {
|
||||
ver = strings.TrimPrefix(ver, "v")
|
||||
ver = strings.TrimPrefix(ver, "V")
|
||||
return semver.ParseTolerant(ver)
|
||||
}
|
||||
|
||||
// needUpdate 判断是否需要更新
|
||||
func needUpdate(current, latest semver.Version) bool {
|
||||
// 返回最新版本大于当前版本时需要更新
|
||||
return latest.Compare(current) > 0
|
||||
}
|
||||
|
||||
func DoUpdateWorks() {
|
||||
ticker_ := time.NewTicker(time.Duration(6) * time.Hour)
|
||||
for range ticker_.C {
|
||||
@@ -26,7 +40,7 @@ func DoUpdateWorks() {
|
||||
func CheckAndUpdate() error {
|
||||
log.Println("Checking update...")
|
||||
// Parse current version
|
||||
currentSemVer, err := semver.Parse(CurrentVersion)
|
||||
currentSemVer, err := parseVersion(CurrentVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse current version: %v", err)
|
||||
}
|
||||
|
72
update/update_test.go
Normal file
72
update/update_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package update
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestParseVersion 验证 parseVersion 能够解析各种版本号格式,包括带 v/V 前缀、预发布和构建元数据
|
||||
func TestParseVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"v1.2.3", "1.2.3"},
|
||||
{"V1.2.3", "1.2.3"},
|
||||
{"1.2.3-beta.1", "1.2.3-beta.1"},
|
||||
{"v1.2.3+meta", "1.2.3+meta"},
|
||||
{"1.2.3-pre.1+build.123", "1.2.3-pre.1+build.123"},
|
||||
{" v2.0.0 ", "2.0.0"},
|
||||
{"invalid", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := parseVersion(strings.TrimSpace(tt.input))
|
||||
if tt.want == "" {
|
||||
if err == nil {
|
||||
t.Errorf("parseVersion(%q) expected error, got %v", tt.input, got)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("parseVersion(%q) unexpected error: %v", tt.input, err)
|
||||
continue
|
||||
}
|
||||
if got.String() != tt.want {
|
||||
t.Errorf("parseVersion(%q) = %q, want %q", tt.input, got.String(), tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestNeedUpdate 验证 needUpdate 在不同版本组合下的判断
|
||||
func TestNeedUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
current string
|
||||
latest string
|
||||
want bool
|
||||
}{
|
||||
{"1.0.0", "1.0.1", true},
|
||||
{"v1.0.0", "1.1.0", true},
|
||||
{"1.2.3", "1.2.3", false},
|
||||
{"1.2.4", "1.2.3", false},
|
||||
{"1.2.3-beta", "1.2.3", true},
|
||||
{"1.2.3", "1.2.3-beta", false},
|
||||
{"0.0.5", "0.0.6+build.1", true},
|
||||
{"0.0.6", "v0.0.6+build.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
cur, err := parseVersion(strings.TrimSpace(tt.current))
|
||||
if err != nil {
|
||||
t.Fatalf("parseVersion(%q) error: %v", tt.current, err)
|
||||
}
|
||||
lat, err := parseVersion(strings.TrimSpace(tt.latest))
|
||||
if err != nil {
|
||||
t.Fatalf("parseVersion(%q) error: %v", tt.latest, err)
|
||||
}
|
||||
got := needUpdate(cur, lat)
|
||||
if got != tt.want {
|
||||
t.Errorf("needUpdate(%q, %q) = %v, want %v", tt.current, tt.latest, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user