diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0ac8076..edbd46b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -36,7 +36,7 @@ jobs: if [ "${{ matrix.goos }}" = "windows" ]; then BINARY_NAME=${BINARY_NAME}.exe fi - go build -trimpath -ldflags="-s -w -X main.CurrentVersion=${VERSION}" -o $BINARY_NAME + go build -trimpath -ldflags="-s -w -X update.CurrentVersion=${VERSION}" -o $BINARY_NAME - name: Upload binary to release env: diff --git a/cmd/flags/flag.go b/cmd/flags/flag.go new file mode 100644 index 0000000..1d6f004 --- /dev/null +++ b/cmd/flags/flag.go @@ -0,0 +1,12 @@ +package flags + +var ( + DisableAutoUpdate bool + DisableWebSsh bool + Token string + Endpoint string + Interval float64 + IgnoreUnsafeCert bool + MaxRetries int + ReconnectInterval int +) diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..911b4ff --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,68 @@ +package cmd + +import ( + "crypto/tls" + "log" + "net/http" + "os" + + "github.com/komari-monitor/komari-agent/cmd/flags" + "github.com/komari-monitor/komari-agent/server" + "github.com/komari-monitor/komari-agent/update" + "github.com/spf13/cobra" +) + +var RootCmd = &cobra.Command{ + Use: "komari-agent", + Short: "komari agent", + Long: `komari agent`, + Run: func(cmd *cobra.Command, args []string) { + log.Println("Komari Agent", update.CurrentVersion) + log.Println("Github Repo:", update.Repo) + + // 忽略不安全的证书 + if flags.IgnoreUnsafeCert { + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + // 自动更新 + if !flags.DisableAutoUpdate { + err := update.CheckAndUpdate() + if err != nil { + log.Println("[ERROR]", err) + } + go update.DoUpdateWorks() + } + go server.DoUploadBasicInfoWorks() + server.EstablishWebSocketConnection() + os.Exit(0) + }, +} + +func Execute() { + for i, arg := range os.Args { + if arg == "-autoUpdate" || arg == "--autoUpdate" { + log.Println("WARNING: The -autoUpdate flag is deprecated in version 0.0.9 and later. Use --disable-auto-update to configure auto-update behavior.") + // 从参数列表中移除该参数,防止cobra解析错误 + os.Args = append(os.Args[:i], os.Args[i+1:]...) + break + } + } + + if err := RootCmd.Execute(); err != nil { + log.Println(err) + } +} + +func init() { + RootCmd.PersistentFlags().StringVarP(&flags.Token, "token", "t", "", "API token") + RootCmd.MarkPersistentFlagRequired("token") + RootCmd.PersistentFlags().StringVarP(&flags.Endpoint, "endpoint", "e", "", "API endpoint") + RootCmd.MarkPersistentFlagRequired("endpoint") + RootCmd.PersistentFlags().BoolVarP(&flags.DisableAutoUpdate, "disable-auto-update", "d", false, "Disable automatic updates") + RootCmd.PersistentFlags().BoolVarP(&flags.DisableWebSsh, "disable-web-ssh", "w", false, "Disable web SSH") + RootCmd.PersistentFlags().Float64VarP(&flags.Interval, "interval", "i", 1.0, "Interval in seconds") + RootCmd.PersistentFlags().BoolVarP(&flags.IgnoreUnsafeCert, "ignore-unsafe-cert", "u", false, "Ignore unsafe certificate errors") + RootCmd.PersistentFlags().IntVarP(&flags.MaxRetries, "max-retries", "r", 3, "Maximum number of retries") + RootCmd.PersistentFlags().IntVarP(&flags.ReconnectInterval, "reconnect-interval", "c", 5, "Reconnect interval in seconds") + RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true +} diff --git a/config/local.go b/config/local.go deleted file mode 100644 index 95533ac..0000000 --- a/config/local.go +++ /dev/null @@ -1,76 +0,0 @@ -package config - -import ( - "encoding/json" - "flag" - "os" -) - -type LocalConfig struct { - Endpoint string `json:"endpoint"` - Token string `json:"token"` - Terminal bool `json:"terminal"` - MaxRetries int `json:"maxRetries"` - ReconnectInterval int `json:"reconnectInterval"` - IgnoreUnsafeCert bool `json:"ignoreUnsafeCert"` - Interval float64 `json:"interval"` - AutoUpdate bool `json:"autoUpdate"` -} - -func LoadConfig() (LocalConfig, error) { - - var ( - endpoint string - token string - terminal bool - path string - maxRetries int - reconnectInterval int - ignoreUnsafeCert bool - interval float64 - autoUpdate bool - ) - - flag.StringVar(&endpoint, "e", "", "The endpoint URL") - flag.StringVar(&token, "token", "", "The authentication token") - flag.BoolVar(&terminal, "terminal", false, "Enable or disable terminal (default: false)") - flag.StringVar(&path, "c", "agent.json", "Path to the configuration file") - flag.IntVar(&maxRetries, "maxRetries", 10, "Maximum number of retries for WebSocket connection") - flag.IntVar(&reconnectInterval, "reconnectInterval", 5, "Reconnect interval in seconds") - flag.Float64Var(&interval, "interval", 1.1, "Interval in seconds for sending data to the server") - flag.BoolVar(&ignoreUnsafeCert, "ignoreUnsafeCert", false, "Ignore unsafe certificate errors") - flag.BoolVar(&autoUpdate, "autoUpdate", false, "Enable or disable auto update (default: false)") - flag.Parse() - - // Ensure -c cannot coexist with other flags - if path != "agent.json" && (endpoint != "" || token != "" || !terminal) { - return LocalConfig{}, flag.ErrHelp - } - - // 必填项 Endpoint、Token 没有读取配置文件 - if endpoint == "" || token == "" { - file, err := os.Open(path) - if err != nil { - return LocalConfig{}, err - } - defer file.Close() - - var localConfig LocalConfig - if err := json.NewDecoder(file).Decode(&localConfig); err != nil { - return LocalConfig{}, err - } - - return localConfig, nil - } - - return LocalConfig{ - Endpoint: endpoint, - Token: token, - Terminal: terminal, - MaxRetries: maxRetries, - ReconnectInterval: reconnectInterval, - IgnoreUnsafeCert: ignoreUnsafeCert, - Interval: interval, - AutoUpdate: autoUpdate, - }, nil -} diff --git a/go.mod b/go.mod index 79790e3..7fd0835 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,13 @@ module github.com/komari-monitor/komari-agent go 1.23.2 require ( + github.com/UserExistsError/conpty v0.1.4 github.com/blang/semver v3.5.1+incompatible + github.com/creack/pty v1.1.24 github.com/gorilla/websocket v1.5.3 github.com/rhysd/go-github-selfupdate v1.2.3 github.com/shirou/gopsutil v3.21.11+incompatible + github.com/spf13/cobra v1.9.1 golang.org/x/sys v0.32.0 ) @@ -16,6 +19,8 @@ require ( github.com/google/go-github/v30 v30.1.0 // indirect github.com/google/go-querystring v1.0.0 // indirect github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/testify v1.10.0 // indirect github.com/tcnksm/go-gitconfig v0.1.2 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect diff --git a/go.sum b/go.sum index 2fc9e10..1ab038e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,10 @@ +github.com/UserExistsError/conpty v0.1.4 h1:+3FhJhiqhyEJa+K5qaK3/w6w+sN3Nh9O9VbJyBS02to= +github.com/UserExistsError/conpty v0.1.4/go.mod h1:PDglKIkX3O/2xVk0MV9a6bCWxRmPVfxqZoTG/5sSd9I= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= @@ -17,6 +22,8 @@ github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf h1:WfD7VjIE6z8dIvMsI4/s+1qr5EL+zoIGev1BQj1eoJ8= github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf/go.mod h1:hyb9oH7vZsitZCiBt0ZvifOrB+qc8PS5IiilCIb87rg= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -27,8 +34,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rhysd/go-github-selfupdate v1.2.3 h1:iaa+J202f+Nc+A8zi75uccC8Wg3omaM7HDeimXA22Ag= github.com/rhysd/go-github-selfupdate v1.2.3/go.mod h1:mp/N8zj6jFfBQy/XMYoWsmfzxazpPAODuqarmPDe2Rg= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tcnksm/go-gitconfig v0.1.2 h1:iiDhRitByXAEyjgBqsKi9QU4o2TNtv9kPP3RgPgXBPw= @@ -57,6 +69,7 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= diff --git a/main.go b/main.go index 6720235..58cb3ba 100644 --- a/main.go +++ b/main.go @@ -1,297 +1,12 @@ package main import ( - "crypto/tls" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "strings" - "time" + "os" - "github.com/gorilla/websocket" - "github.com/komari-monitor/komari-agent/config" - "github.com/komari-monitor/komari-agent/monitoring" - "github.com/komari-monitor/komari-agent/update" -) - -var ( - CurrentVersion string = "0.0.1" - repo = "komari-monitor/komari-agent" + "github.com/komari-monitor/komari-agent/cmd" ) func main() { - log.Printf("Komari Agent %s\n", CurrentVersion) - localConfig, err := config.LoadConfig() - if err != nil { - log.Fatalln("Failed to load local config:", err) - } - if localConfig.IgnoreUnsafeCert { - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - go func() { - err = uploadBasicInfo(localConfig.Endpoint, localConfig.Token) - ticker := time.NewTicker(time.Duration(time.Minute * 15)) - for range ticker.C { - err = uploadBasicInfo(localConfig.Endpoint, localConfig.Token) - if err != nil { - log.Fatalln("Failed to upload basic info:", err) - } - } - }() - - websocketEndpoint := strings.TrimSuffix(localConfig.Endpoint, "/") + "/api/clients/report?token=" + localConfig.Token - websocketEndpoint = "ws" + strings.TrimPrefix(websocketEndpoint, "http") - - var conn *websocket.Conn - defer func() { - if conn != nil { - conn.Close() - } - }() - - ticker := time.NewTicker(time.Duration(localConfig.Interval * float64(time.Second))) - defer ticker.Stop() - - if localConfig.AutoUpdate { - go func() { - ticker_ := time.NewTicker(time.Duration(6) * time.Hour) - update_komari() - for range ticker_.C { - update_komari() - } - }() - } - - for range ticker.C { - // If no connection, attempt to connect - if conn == nil { - log.Println("Attempting to connect to WebSocket...") - retry := 0 - for retry < localConfig.MaxRetries { - conn, err = connectWebSocket(websocketEndpoint) - if err == nil { - log.Println("WebSocket connected") - go handleWebSocketMessages(localConfig, conn, make(chan struct{})) - break - } - retry++ - time.Sleep(time.Duration(localConfig.ReconnectInterval) * time.Second) - } - - if retry >= localConfig.MaxRetries { - log.Println("Max retries reached, falling back to POST") - // Send report via POST and continue - data := report(localConfig) - if err := reportWithPOST(localConfig.Endpoint, data); err != nil { - log.Println("Failed to send POST report:", err) - } - continue - } - } - - // Send report via WebSocket - data := report(localConfig) - err = conn.WriteMessage(websocket.TextMessage, data) - if err != nil { - log.Println("Failed to send WebSocket message:", err) - conn.Close() - conn = nil // Mark connection as dead - continue - } - - } -} - -func update_komari() { - // 初始化 Updater - updater := update.NewUpdater(CurrentVersion, repo) - - // 检查并更新 - err := updater.CheckAndUpdate() - if err != nil { - log.Printf("Update Failed: %v", err) - } -} - -// connectWebSocket attempts to establish a WebSocket connection and upload basic info -func connectWebSocket(websocketEndpoint string) (*websocket.Conn, error) { - dialer := &websocket.Dialer{ - HandshakeTimeout: 5 * time.Second, - } - conn, _, err := dialer.Dial(websocketEndpoint, nil) - if err != nil { - return nil, err - } - - return conn, nil -} - -func handleWebSocketMessages(localConfig config.LocalConfig, conn *websocket.Conn, done chan<- struct{}) { - defer close(done) - for { - _, message_raw, err := conn.ReadMessage() - if err != nil { - log.Println("WebSocket read error:", err) - return - } - // TODO: Remote config update - // TODO: Handle incoming messages - log.Println("Received message:", string(message_raw)) - message := make(map[string]interface{}) - err = json.Unmarshal(message_raw, &message) - if err != nil { - log.Println("Bad ws message:", err) - continue - } - - } -} - -func reportWithPOST(endpoint string, data []byte) error { - url := strings.TrimSuffix(endpoint, "/") + "/api/clients/report" - req, err := http.NewRequest("POST", url, strings.NewReader(string(data))) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return err - } - return nil -} - -func uploadBasicInfo(endpoint string, token string) error { - log.Println("Uploading basic info...") - defer log.Println("Upload complete") - cpu := monitoring.Cpu() - - osname := monitoring.OSName() - ipv4, ipv6, _ := monitoring.GetIPAddress() - - data := map[string]interface{}{ - "cpu_name": cpu.CPUName, - "cpu_cores": cpu.CPUCores, - "arch": cpu.CPUArchitecture, - "os": osname, - "ipv4": ipv4, - "ipv6": ipv6, - "mem_total": monitoring.Ram().Total, - "swap_total": monitoring.Swap().Total, - "disk_total": monitoring.Disk().Total, - "gpu_name": "Unknown", - "version": CurrentVersion, - } - - endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/uploadBasicInfo?token=" + token - payload, err := json.Marshal(data) - if err != nil { - return err - } - - req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(payload))) - if err != nil { - return err - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - message := string(body) - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("status code: %d,%s", resp.StatusCode, message) - } - - return nil -} - -func report(localConfig config.LocalConfig) []byte { - message := "" - data := map[string]interface{}{} - - cpu := monitoring.Cpu() - data["cpu"] = map[string]interface{}{ - "usage": cpu.CPUUsage, - } - - ram := monitoring.Ram() - data["ram"] = map[string]interface{}{ - "total": ram.Total, - "used": ram.Used, - } - - swap := monitoring.Swap() - data["swap"] = map[string]interface{}{ - "total": swap.Total, - "used": swap.Used, - } - load := monitoring.Load() - data["load"] = map[string]interface{}{ - "load1": load.Load1, - "load5": load.Load5, - "load15": load.Load15, - } - - disk := monitoring.Disk() - data["disk"] = map[string]interface{}{ - "total": disk.Total, - "used": disk.Used, - } - - totalUp, totalDown, networkUp, networkDown, err := monitoring.NetworkSpeed(int(localConfig.Interval)) - if err != nil { - message += fmt.Sprintf("failed to get network speed: %v\n", err) - } - data["network"] = map[string]interface{}{ - "up": networkUp, - "down": networkDown, - "totalUp": totalUp, - "totalDown": totalDown, - } - - tcpCount, udpCount, err := monitoring.ConnectionsCount() - if err != nil { - message += fmt.Sprintf("failed to get connections: %v\n", err) - } - data["connections"] = map[string]interface{}{ - "tcp": tcpCount, - "udp": udpCount, - } - - uptime, err := monitoring.Uptime() - if err != nil { - message += fmt.Sprintf("failed to get uptime: %v\n", err) - } - data["uptime"] = uptime - - processcount := monitoring.ProcessCount() - data["process"] = processcount - - data["message"] = message - - s, err := json.Marshal(data) - if err != nil { - log.Println("Failed to marshal data:", err) - } - return s + cmd.Execute() + os.Exit(0) } diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go new file mode 100644 index 0000000..5926c3b --- /dev/null +++ b/monitoring/monitoring.go @@ -0,0 +1,84 @@ +package monitoring + +import ( + "encoding/json" + "fmt" + "log" + + monitoring "github.com/komari-monitor/komari-agent/monitoring/unit" +) + +func GenerateReport() []byte { + message := "" + data := map[string]interface{}{} + + cpu := monitoring.Cpu() + cpuUsage := cpu.CPUUsage + if cpuUsage <= 0.001 { + cpuUsage = 0.001 + } + data["cpu"] = map[string]interface{}{ + "usage": cpuUsage, + } + + ram := monitoring.Ram() + data["ram"] = map[string]interface{}{ + "total": ram.Total, + "used": ram.Used, + } + + swap := monitoring.Swap() + data["swap"] = map[string]interface{}{ + "total": swap.Total, + "used": swap.Used, + } + load := monitoring.Load() + data["load"] = map[string]interface{}{ + "load1": load.Load1, + "load5": load.Load5, + "load15": load.Load15, + } + + disk := monitoring.Disk() + data["disk"] = map[string]interface{}{ + "total": disk.Total, + "used": disk.Used, + } + + totalUp, totalDown, networkUp, networkDown, err := monitoring.NetworkSpeed() + if err != nil { + message += fmt.Sprintf("failed to get network speed: %v\n", err) + } + data["network"] = map[string]interface{}{ + "up": networkUp, + "down": networkDown, + "totalUp": totalUp, + "totalDown": totalDown, + } + + tcpCount, udpCount, err := monitoring.ConnectionsCount() + if err != nil { + message += fmt.Sprintf("failed to get connections: %v\n", err) + } + data["connections"] = map[string]interface{}{ + "tcp": tcpCount, + "udp": udpCount, + } + + uptime, err := monitoring.Uptime() + if err != nil { + message += fmt.Sprintf("failed to get uptime: %v\n", err) + } + data["uptime"] = uptime + + processcount := monitoring.ProcessCount() + data["process"] = processcount + + data["message"] = message + + s, err := json.Marshal(data) + if err != nil { + log.Println("Failed to marshal data:", err) + } + return s +} diff --git a/monitoring/net.go b/monitoring/net.go deleted file mode 100644 index 735d508..0000000 --- a/monitoring/net.go +++ /dev/null @@ -1,60 +0,0 @@ -package monitoring - -import ( - "fmt" - - "github.com/shirou/gopsutil/net" -) - -func ConnectionsCount() (tcpCount, udpCount int, err error) { - tcps, err := net.Connections("tcp") - if err != nil { - return 0, 0, fmt.Errorf("failed to get TCP connections: %w", err) - } - udps, err := net.Connections("udp") - if err != nil { - return 0, 0, fmt.Errorf("failed to get UDP connections: %w", err) - } - - return len(tcps), len(udps), nil -} - -var ( - lastUp uint64 - lastDown uint64 -) - -func NetworkSpeed(interval int) (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { - // Get the network IO counters - ioCounters, err := net.IOCounters(false) - if err != nil { - return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) - } - - if len(ioCounters) == 0 { - return 0, 0, 0, 0, fmt.Errorf("no network interfaces found") - } - - for _, interfaceStats := range ioCounters { - loopbackNames := []string{"lo", "lo0", "localhost", "brd0", "docker0", "docker1", "veth0", "veth1", "veth2", "veth3", "veth4", "veth5", "veth6", "veth7"} - isLoopback := false - for _, name := range loopbackNames { - if interfaceStats.Name == name { - isLoopback = true - break - } - } - if isLoopback { - continue // Skip loopback interface - } - totalUp += interfaceStats.BytesSent - totalDown += interfaceStats.BytesRecv - - } - upSpeed = (totalUp - lastUp) / uint64(interval) - downSpeed = (totalDown - lastDown) / uint64(interval) - - lastUp = totalUp - lastDown = totalDown - return totalUp, totalDown, upSpeed, downSpeed, nil -} diff --git a/monitoring/cpu.go b/monitoring/unit/cpu.go similarity index 100% rename from monitoring/cpu.go rename to monitoring/unit/cpu.go diff --git a/monitoring/disk.go b/monitoring/unit/disk.go similarity index 100% rename from monitoring/disk.go rename to monitoring/unit/disk.go diff --git a/monitoring/gpu.go b/monitoring/unit/gpu.go similarity index 100% rename from monitoring/gpu.go rename to monitoring/unit/gpu.go diff --git a/monitoring/ip.go b/monitoring/unit/ip.go similarity index 100% rename from monitoring/ip.go rename to monitoring/unit/ip.go diff --git a/monitoring/load.go b/monitoring/unit/load.go similarity index 100% rename from monitoring/load.go rename to monitoring/unit/load.go diff --git a/monitoring/mem.go b/monitoring/unit/mem.go similarity index 100% rename from monitoring/mem.go rename to monitoring/unit/mem.go diff --git a/monitoring/unit/net.go b/monitoring/unit/net.go new file mode 100644 index 0000000..c3193cb --- /dev/null +++ b/monitoring/unit/net.go @@ -0,0 +1,83 @@ +package monitoring + +import ( + "fmt" + "time" + + "github.com/shirou/gopsutil/net" +) + +func ConnectionsCount() (tcpCount, udpCount int, err error) { + tcps, err := net.Connections("tcp") + if err != nil { + return 0, 0, fmt.Errorf("failed to get TCP connections: %w", err) + } + udps, err := net.Connections("udp") + if err != nil { + return 0, 0, fmt.Errorf("failed to get UDP connections: %w", err) + } + + return len(tcps), len(udps), nil +} + +var ( + // 预定义常见的回环和虚拟接口名称 + loopbackNames = map[string]struct{}{ + "lo": {}, "lo0": {}, "localhost": {}, + "brd0": {}, "docker0": {}, "docker1": {}, + "veth0": {}, "veth1": {}, "veth2": {}, "veth3": {}, + "veth4": {}, "veth5": {}, "veth6": {}, "veth7": {}, + } +) + +func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { + // 获取第一次网络IO计数器 + ioCounters1, err := net.IOCounters(false) + if err != nil { + return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) + } + + if len(ioCounters1) == 0 { + return 0, 0, 0, 0, fmt.Errorf("no network interfaces found") + } + + // 统计第一次所有非回环接口的流量 + var totalUp1, totalDown1 uint64 + for _, interfaceStats := range ioCounters1 { + // 使用映射表进行O(1)查找 + if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { + continue // 跳过回环接口 + } + totalUp1 += interfaceStats.BytesSent + totalDown1 += interfaceStats.BytesRecv + } + + // 等待1秒 + time.Sleep(time.Second) + + // 获取第二次网络IO计数器 + ioCounters2, err := net.IOCounters(false) + if err != nil { + return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) + } + + if len(ioCounters2) == 0 { + return 0, 0, 0, 0, fmt.Errorf("no network interfaces found") + } + + // 统计第二次所有非回环接口的流量 + var totalUp2, totalDown2 uint64 + for _, interfaceStats := range ioCounters2 { + if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { + continue // 跳过回环接口 + } + totalUp2 += interfaceStats.BytesSent + totalDown2 += interfaceStats.BytesRecv + } + + // 计算速度 (每秒的速率) + upSpeed = totalUp2 - totalUp1 + downSpeed = totalDown2 - totalDown1 + + return totalUp2, totalDown2, upSpeed, downSpeed, nil +} diff --git a/monitoring/os_linux.go b/monitoring/unit/os_linux.go similarity index 92% rename from monitoring/os_linux.go rename to monitoring/unit/os_linux.go index 0a1e2c8..9457d3c 100644 --- a/monitoring/os_linux.go +++ b/monitoring/unit/os_linux.go @@ -1,5 +1,5 @@ -//go:build linux -// +build linux +//go:build !windows +// +build !windows package monitoring diff --git a/monitoring/os_windows.go b/monitoring/unit/os_windows.go similarity index 100% rename from monitoring/os_windows.go rename to monitoring/unit/os_windows.go diff --git a/monitoring/process_linux.go b/monitoring/unit/process_linux.go similarity index 93% rename from monitoring/process_linux.go rename to monitoring/unit/process_linux.go index 6760e7b..97e64da 100644 --- a/monitoring/process_linux.go +++ b/monitoring/unit/process_linux.go @@ -1,5 +1,5 @@ -//go:build linux -// +build linux +//go:build !windows +// +build !windows package monitoring diff --git a/monitoring/process_windows.go b/monitoring/unit/process_windows.go similarity index 100% rename from monitoring/process_windows.go rename to monitoring/unit/process_windows.go diff --git a/monitoring/uptime.go b/monitoring/unit/uptime.go similarity index 100% rename from monitoring/uptime.go rename to monitoring/unit/uptime.go diff --git a/server/basicInfo.go b/server/basicInfo.go new file mode 100644 index 0000000..08da2e9 --- /dev/null +++ b/server/basicInfo.go @@ -0,0 +1,81 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/komari-monitor/komari-agent/cmd/flags" + monitoring "github.com/komari-monitor/komari-agent/monitoring/unit" + "github.com/komari-monitor/komari-agent/update" +) + +func DoUploadBasicInfoWorks() { + err := uploadBasicInfo() + if err != nil { + log.Println("Error uploading basic info:", err) + } + ticker := time.NewTicker(time.Duration(15) * time.Minute) + for range ticker.C { + err := uploadBasicInfo() + if err != nil { + log.Println("Error uploading basic info:", err) + } + } +} + +func uploadBasicInfo() error { + cpu := monitoring.Cpu() + + osname := monitoring.OSName() + ipv4, ipv6, _ := monitoring.GetIPAddress() + + data := map[string]interface{}{ + "cpu_name": cpu.CPUName, + "cpu_cores": cpu.CPUCores, + "arch": cpu.CPUArchitecture, + "os": osname, + "ipv4": ipv4, + "ipv6": ipv6, + "mem_total": monitoring.Ram().Total, + "swap_total": monitoring.Swap().Total, + "disk_total": monitoring.Disk().Total, + "gpu_name": "Unknown", + "version": update.CurrentVersion, + } + + endpoint := strings.TrimSuffix(flags.Endpoint, "/") + "/api/clients/uploadBasicInfo?token=" + flags.Token + payload, err := json.Marshal(data) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(payload))) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + message := string(body) + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status code: %d,%s", resp.StatusCode, message) + } + + return nil +} diff --git a/server/websocket.go b/server/websocket.go new file mode 100644 index 0000000..0344093 --- /dev/null +++ b/server/websocket.go @@ -0,0 +1,138 @@ +package server + +import ( + "encoding/json" + "fmt" + "log" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/komari-monitor/komari-agent/cmd/flags" + "github.com/komari-monitor/komari-agent/monitoring" + "github.com/komari-monitor/komari-agent/terminal" +) + +func EstablishWebSocketConnection() { + + websocketEndpoint := strings.TrimSuffix(flags.Endpoint, "/") + "/api/clients/report?token=" + flags.Token + websocketEndpoint = "ws" + strings.TrimPrefix(websocketEndpoint, "http") + + var conn *websocket.Conn + defer func() { + if conn != nil { + conn.Close() + } + }() + var err error + var interval float64 + if flags.Interval <= 1 { + interval = 1 + } else { + interval = flags.Interval - 1 + } + + ticker := time.NewTicker(time.Duration(interval * float64(time.Second))) + defer ticker.Stop() + + for range ticker.C { + // If no connection, attempt to connect + if conn == nil { + log.Println("Attempting to connect to WebSocket...") + retry := 0 + for retry <= flags.MaxRetries { + if retry > 0 { + log.Println("Retrying websocket connection, attempt:", retry) + } + conn, err = connectWebSocket(websocketEndpoint) + if err == nil { + log.Println("WebSocket connected") + go handleWebSocketMessages(conn, make(chan struct{})) + break + } else { + log.Println("Failed to connect to WebSocket:", err) + } + retry++ + time.Sleep(time.Duration(flags.ReconnectInterval) * time.Second) + } + + if retry > flags.MaxRetries { + log.Println("Max retries reached.") + return + } + } + + data := monitoring.GenerateReport() + err = conn.WriteMessage(websocket.TextMessage, data) + if err != nil { + log.Println("Failed to send WebSocket message:", err) + conn.Close() + conn = nil // Mark connection as dead + continue + } + } +} + +func connectWebSocket(websocketEndpoint string) (*websocket.Conn, error) { + dialer := &websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + conn, resp, err := dialer.Dial(websocketEndpoint, nil) + if err != nil { + if resp != nil && resp.StatusCode != 101 { + return nil, fmt.Errorf("%s", resp.Status) + } + return nil, err + } + + return conn, nil +} + +func handleWebSocketMessages(conn *websocket.Conn, done chan<- struct{}) { + + defer close(done) + for { + _, message_raw, err := conn.ReadMessage() + if err != nil { + log.Println("WebSocket read error:", err) + return + } + var message struct { + Message string `json:"message"` + ID string `json:"request_id"` + } + err = json.Unmarshal(message_raw, &message) + if err != nil { + log.Println("Bad ws message:", err) + continue + } + + if message.Message == "terminal" || message.ID != "" { + go establishTerminalConnection(flags.Token, message.ID, flags.Endpoint) + continue + } + + } +} + +// connectWebSocket attempts to establish a WebSocket connection and upload basic info + +// establishTerminalConnection 建立终端连接并使用terminal包处理终端操作 +func establishTerminalConnection(token, id, endpoint string) { + endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id + endpoint = "ws" + strings.TrimPrefix(endpoint, "http") + dialer := &websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + conn, _, err := dialer.Dial(endpoint, nil) + if err != nil { + log.Println("Failed to establish terminal connection:", err) + return + } + + // 启动终端 + terminal.StartTerminal(conn) + if conn != nil { + conn.Close() + } +} diff --git a/terminal/terminal_unix.go b/terminal/terminal_unix.go new file mode 100644 index 0000000..32a3d12 --- /dev/null +++ b/terminal/terminal_unix.go @@ -0,0 +1,119 @@ +//go:build !windows + +package terminal + +import ( + "encoding/json" + "fmt" + "os/exec" + "syscall" + + "github.com/creack/pty" + "github.com/gorilla/websocket" +) + +// StartTerminal 在Unix/Linux系统上启动终端 +func StartTerminal(conn *websocket.Conn) { + // 获取shell + defalut_shell := []string{"zsh", "bash", "sh"} + shell := "" + for _, s := range defalut_shell { + if _, err := exec.LookPath(s); err == nil { + shell = s + break + } + } + if shell == "" { + conn.WriteMessage(websocket.TextMessage, []byte("No supported shell found.")) + return + } + // 创建进程 + cmd := exec.Command(shell) + cmd.Env = append(cmd.Env, "TERM=xterm-256color") + tty, err := pty.Start(cmd) + if err != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + return + } + defer tty.Close() + // 设置终端大小 + pty.Setsize(tty, &pty.Winsize{ + Rows: 24, + Cols: 80, + X: 0, + Y: 0, + }) + terminateConn := func() { + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err != nil { + cmd.Process.Kill() + } + syscall.Kill(-pgid, syscall.SIGKILL) + if conn != nil { + conn.Close() + } + } + err_chan := make(chan error, 1) + // 从WebSocket读取数据并写入pty + go func() { + for { + t, p, err := conn.ReadMessage() + if err != nil { + err_chan <- err + return + } + if t == websocket.TextMessage { + var cmd struct { + Type string `json:"type"` + Cols int `json:"cols,omitempty"` + Rows int `json:"rows,omitempty"` + Input string `json:"input,omitempty"` + } + + if err := json.Unmarshal(p, &cmd); err == nil { + switch cmd.Type { + case "resize": + if cmd.Cols > 0 && cmd.Rows > 0 { + pty.Setsize(tty, &pty.Winsize{ + Rows: uint16(cmd.Rows), + Cols: uint16(cmd.Cols), + }) + } + case "input": + if cmd.Input != "" { + tty.Write([]byte(cmd.Input)) + } + } + } else { + tty.Write(p) + } + } + if t == websocket.BinaryMessage { + tty.Write(p) + } + } + }() + + go func() { + buf := make([]byte, 4096) + for { + n, err := tty.Read(buf) + if err != nil { + err_chan <- err + return + } + + err = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) + if err != nil { + err_chan <- err + return + } + } + }() + + err = <-err_chan + if err != nil && conn != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + } + terminateConn() +} diff --git a/terminal/terminal_windows.go b/terminal/terminal_windows.go new file mode 100644 index 0000000..526d0e9 --- /dev/null +++ b/terminal/terminal_windows.go @@ -0,0 +1,108 @@ +//go:build windows + +package terminal + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/UserExistsError/conpty" + "github.com/gorilla/websocket" +) + +// StartTerminal 在Windows系统上启动终端 +func StartTerminal(conn *websocket.Conn) { + // 创建进程 + shell, err := exec.LookPath("powershell.exe") + if err != nil || shell == "" { + shell = "cmd.exe" + } + current_dir := "." + executable, err := os.Executable() + if err == nil { + current_dir = filepath.Dir(executable) + } + if shell == "" || current_dir == "" { + conn.WriteMessage(websocket.TextMessage, []byte("No supported shell found.")) + return + } + + tty, err := conpty.Start(shell, conpty.ConPtyWorkDir(current_dir)) + if err != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + return + } + defer tty.Close() + err_chan := make(chan error, 1) + // 设置终端大小 + tty.Resize(80, 24) + + go func() { + for { + t, p, err := conn.ReadMessage() + if err != nil { + err_chan <- err + return + } + if t == websocket.TextMessage { + var cmd struct { + Type string `json:"type"` + Cols int `json:"cols,omitempty"` + Rows int `json:"rows,omitempty"` + Input string `json:"input,omitempty"` + } + + if err := json.Unmarshal(p, &cmd); err == nil { + switch cmd.Type { + case "resize": + if cmd.Cols > 0 && cmd.Rows > 0 { + tty.Resize(cmd.Cols, cmd.Rows) + } + case "input": + if cmd.Input != "" { + tty.Write([]byte(cmd.Input)) + } + } + } else { + tty.Write(p) + } + } + if t == websocket.BinaryMessage { + tty.Write(p) + } + } + }() + + go func() { + buf := make([]byte, 4096) + for { + n, err := tty.Read(buf) + if err != nil { + err_chan <- err + return + } + + err = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) + if err != nil { + err_chan <- err + return + } + } + }() + + go func() { + err := <-err_chan + if err != nil && tty != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + } + conn.Close() + tty.Close() + }() + tty.Wait(context.Background()) + tty.Close() + +} diff --git a/update/update.go b/update/update.go index 0ce1ef9..c1f47ac 100644 --- a/update/update.go +++ b/update/update.go @@ -4,63 +4,64 @@ import ( "fmt" "log" "os" + "time" "github.com/blang/semver" "github.com/rhysd/go-github-selfupdate/selfupdate" ) -type Updater struct { - CurrentVersion string // 当前版本 - Repo string // GitHub 仓库,例如 "komari-monitor/komari-agent" -} +var ( + CurrentVersion string = "0.0.1" + Repo string = "komari-monitor/komari-agent" +) -func NewUpdater(currentVersion, repo string) *Updater { - return &Updater{ - CurrentVersion: currentVersion, - Repo: repo, +func DoUpdateWorks() { + ticker_ := time.NewTicker(time.Duration(6) * time.Hour) + for range ticker_.C { + CheckAndUpdate() } } // 检查更新并执行自动更新 -func (u *Updater) CheckAndUpdate() error { +func CheckAndUpdate() error { log.Println("Checking update...") - // 解析当前版本 - currentSemVer, err := semver.Parse(u.CurrentVersion) + // Parse current version + currentSemVer, err := semver.Parse(CurrentVersion) if err != nil { - return fmt.Errorf("解析当前版本失败: %v", err) + return fmt.Errorf("failed to parse current version: %v", err) } - // 创建 selfupdate 配置 + // Create selfupdate configuration config := selfupdate.Config{} updater, err := selfupdate.NewUpdater(config) if err != nil { - return fmt.Errorf("创建 updater 失败: %v", err) + return fmt.Errorf("failed to create updater: %v", err) } - // 检查最新版本 - latest, err := updater.UpdateSelf(currentSemVer, u.Repo) + // Check for latest version + latest, err := updater.UpdateSelf(currentSemVer, Repo) if err != nil { - return fmt.Errorf("检查更新失败: %v", err) + return fmt.Errorf("failed to check for updates: %v", err) } - // 判断是否需要更新 + // Determine if update is needed if latest.Version.Equals(currentSemVer) { - fmt.Println("当前版本已是最新版本:", u.CurrentVersion) + fmt.Println("Current version is the latest:", CurrentVersion) return nil } - execPath, err := os.Executable() - if err != nil { - return fmt.Errorf("获取当前执行路径失败: %v", err) - } + // Default is installed as a service, so don't automatically restart + //execPath, err := os.Executable() + //if err != nil { + // return fmt.Errorf("failed to get current executable path: %v", err) + //} - _, err = os.StartProcess(execPath, os.Args, &os.ProcAttr{ - Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, - }) - if err != nil { - return fmt.Errorf("重新启动程序失败: %v", err) - } - fmt.Printf("成功更新到版本 %s\n", latest.Version) - fmt.Printf("发布说明:\n%s\n", latest.ReleaseNotes) + // _, err = os.StartProcess(execPath, os.Args, &os.ProcAttr{ + // Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, + // }) + // if err != nil { + // return fmt.Errorf("failed to restart program: %v", err) + // } + fmt.Printf("Successfully updated to version %s\n", latest.Version) os.Exit(0) return nil }