diff --git a/cmd/autodiscovery.go b/cmd/autodiscovery.go index 0403999..7a46a33 100644 --- a/cmd/autodiscovery.go +++ b/cmd/autodiscovery.go @@ -120,6 +120,12 @@ func registerWithAutoDiscovery() error { // 设置请求头 req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", flags.AutoDiscoveryKey)) + + // 添加Cloudflare Access头部 + if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { + req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID) + req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) + } // 发送请求 client := &http.Client{} diff --git a/cmd/flags/flag.go b/cmd/flags/flag.go index 14d2a43..0768eb1 100644 --- a/cmd/flags/flag.go +++ b/cmd/flags/flag.go @@ -16,4 +16,6 @@ var ( ExcludeNics string IncludeMountpoints string MonthRotate int + CFAccessClientID string + CFAccessClientSecret string ) diff --git a/cmd/root.go b/cmd/root.go index 8efcb77..34539c2 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -92,5 +92,7 @@ func init() { RootCmd.PersistentFlags().StringVar(&flags.ExcludeNics, "exclude-nics", "", "Comma-separated list of network interfaces to exclude") RootCmd.PersistentFlags().StringVar(&flags.IncludeMountpoints, "include-mountpoint", "", "Semicolon-separated list of mount points to include for disk statistics") RootCmd.PersistentFlags().IntVar(&flags.MonthRotate, "month-rotate", 0, "Month reset for network statistics (0 to disable)") + RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID") + RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret") RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true } diff --git a/server/basicInfo.go b/server/basicInfo.go index dec25a6..c9b027c 100644 --- a/server/basicInfo.go +++ b/server/basicInfo.go @@ -79,6 +79,12 @@ func tryUploadData(data map[string]interface{}) error { return err } req.Header.Set("Content-Type", "application/json") + + // 添加Cloudflare Access头部 + if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { + req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID) + req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) + } client := &http.Client{} resp, err := client.Do(req) diff --git a/server/task.go b/server/task.go index 415cb15..86a02d2 100644 --- a/server/task.go +++ b/server/task.go @@ -70,12 +70,27 @@ func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time) jsonData, _ := json.Marshal(payload) endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token - resp, _ := http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) + // 创建HTTP请求以支持自定义头部 + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) + if err != nil { + log.Printf("Failed to create task result request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + + // 添加Cloudflare Access头部(如果配置了) + if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { + req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID) + req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) + } + + client := &http.Client{} + resp, err := client.Do(req) maxRetry := flags.MaxRetries - for i := 0; i < maxRetry && resp.StatusCode != http.StatusOK; i++ { + for i := 0; i < maxRetry && (err != nil || resp.StatusCode != http.StatusOK); i++ { log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry) time.Sleep(2 * time.Second) // Wait before retrying - resp, _ = http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) + resp, err = client.Do(req) } if resp != nil { defer resp.Body.Close() diff --git a/server/websocket.go b/server/websocket.go index f79411d..8db0479 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "log" + "net/http" "strings" "time" @@ -92,7 +93,15 @@ func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) { dialer := &websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, resp, err := dialer.Dial(websocketEndpoint, nil) + + // 创建请求头并添加Cloudflare Access头部 + headers := http.Header{} + if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { + headers.Set("CF-Access-Client-Id", flags.CFAccessClientID) + headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) + } + + conn, resp, err := dialer.Dial(websocketEndpoint, headers) if err != nil { if resp != nil && resp.StatusCode != 101 { return nil, fmt.Errorf("%s", resp.Status) @@ -153,7 +162,15 @@ func establishTerminalConnection(token, id, endpoint string) { dialer := &websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, _, err := dialer.Dial(endpoint, nil) + + // 创建请求头并添加Cloudflare Access头部 + headers := http.Header{} + if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" { + headers.Set("CF-Access-Client-Id", flags.CFAccessClientID) + headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret) + } + + conn, _, err := dialer.Dial(endpoint, headers) if err != nil { log.Println("Failed to establish terminal connection:", err) return