feat: 添加Cloudflare Access支持,允许通过请求头传递Client ID和Client Secret

This commit is contained in:
xdf
2025-08-20 16:32:53 +09:00
parent aa461f2189
commit fb9828378b
6 changed files with 53 additions and 5 deletions

View File

@@ -120,6 +120,12 @@ func registerWithAutoDiscovery() error {
// 设置请求头 // 设置请求头
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", flags.AutoDiscoveryKey)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", flags.AutoDiscoveryKey))
// 添加Cloudflare Access头部
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID)
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
// 发送请求 // 发送请求
client := &http.Client{} client := &http.Client{}

View File

@@ -16,4 +16,6 @@ var (
ExcludeNics string ExcludeNics string
IncludeMountpoints string IncludeMountpoints string
MonthRotate int MonthRotate int
CFAccessClientID string
CFAccessClientSecret string
) )

View File

@@ -92,5 +92,7 @@ func init() {
RootCmd.PersistentFlags().StringVar(&flags.ExcludeNics, "exclude-nics", "", "Comma-separated list of network interfaces to exclude") RootCmd.PersistentFlags().StringVar(&flags.ExcludeNics, "exclude-nics", "", "Comma-separated list of network interfaces to exclude")
RootCmd.PersistentFlags().StringVar(&flags.IncludeMountpoints, "include-mountpoint", "", "Semicolon-separated list of mount points to include for disk statistics") RootCmd.PersistentFlags().StringVar(&flags.IncludeMountpoints, "include-mountpoint", "", "Semicolon-separated list of mount points to include for disk statistics")
RootCmd.PersistentFlags().IntVar(&flags.MonthRotate, "month-rotate", 0, "Month reset for network statistics (0 to disable)") RootCmd.PersistentFlags().IntVar(&flags.MonthRotate, "month-rotate", 0, "Month reset for network statistics (0 to disable)")
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID")
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret")
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
} }

View File

@@ -79,6 +79,12 @@ func tryUploadData(data map[string]interface{}) error {
return err return err
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
// 添加Cloudflare Access头部
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID)
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)

View File

@@ -70,12 +70,27 @@ func uploadTaskResult(taskID, result string, exitCode int, finishedAt time.Time)
jsonData, _ := json.Marshal(payload) jsonData, _ := json.Marshal(payload)
endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token endpoint := flags.Endpoint + "/api/clients/task/result?token=" + flags.Token
resp, _ := http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) // 创建HTTP请求以支持自定义头部
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
log.Printf("Failed to create task result request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
// 添加Cloudflare Access头部如果配置了
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
req.Header.Set("CF-Access-Client-Id", flags.CFAccessClientID)
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
client := &http.Client{}
resp, err := client.Do(req)
maxRetry := flags.MaxRetries maxRetry := flags.MaxRetries
for i := 0; i < maxRetry && resp.StatusCode != http.StatusOK; i++ { for i := 0; i < maxRetry && (err != nil || resp.StatusCode != http.StatusOK); i++ {
log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry) log.Printf("Failed to upload task result, retrying %d/%d", i+1, maxRetry)
time.Sleep(2 * time.Second) // Wait before retrying time.Sleep(2 * time.Second) // Wait before retrying
resp, _ = http.Post(endpoint, "application/json", bytes.NewBuffer(jsonData)) resp, err = client.Do(req)
} }
if resp != nil { if resp != nil {
defer resp.Body.Close() defer resp.Body.Close()

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"net/http"
"strings" "strings"
"time" "time"
@@ -92,7 +93,15 @@ func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) {
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
conn, resp, err := dialer.Dial(websocketEndpoint, nil)
// 创建请求头并添加Cloudflare Access头部
headers := http.Header{}
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
headers.Set("CF-Access-Client-Id", flags.CFAccessClientID)
headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
conn, resp, err := dialer.Dial(websocketEndpoint, headers)
if err != nil { if err != nil {
if resp != nil && resp.StatusCode != 101 { if resp != nil && resp.StatusCode != 101 {
return nil, fmt.Errorf("%s", resp.Status) return nil, fmt.Errorf("%s", resp.Status)
@@ -153,7 +162,15 @@ func establishTerminalConnection(token, id, endpoint string) {
dialer := &websocket.Dialer{ dialer := &websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
conn, _, err := dialer.Dial(endpoint, nil)
// 创建请求头并添加Cloudflare Access头部
headers := http.Header{}
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
headers.Set("CF-Access-Client-Id", flags.CFAccessClientID)
headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
}
conn, _, err := dialer.Dial(endpoint, headers)
if err != nil { if err != nil {
log.Println("Failed to establish terminal connection:", err) log.Println("Failed to establish terminal connection:", err)
return return