mirror of
https://github.com/fankes/komari-agent.git
synced 2025-10-18 18:49:23 +08:00
Compare commits
17 Commits
b59d54fe66
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
7a8a951f9b | ||
|
0dcdb89bb5 | ||
|
8e31514e9c | ||
|
a78756f101 | ||
|
bc40bdc04d | ||
|
c85230da2e | ||
|
cb7a2c09d2 | ||
|
e9e935cb96 | ||
|
39450ef39a | ||
|
868c576d7a | ||
|
e4d3703d3b | ||
|
9da02f615f | ||
|
6bdf718dc0 | ||
|
518f782185 | ||
|
396fe5cfc2 | ||
|
ad3c02c22c | ||
|
5304a68d5d |
76
.github/workflows/release-from-commits.yml
vendored
Normal file
76
.github/workflows/release-from-commits.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: Release notes from commits
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
generate-notes:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # fetch all tags/commits for range and logs
|
||||
|
||||
- name: Compute previous tag
|
||||
id: prev
|
||||
env:
|
||||
CURRENT_TAG: ${{ github.event.release.tag_name }}
|
||||
shell: bash
|
||||
run: |
|
||||
set -e
|
||||
git fetch --tags --force
|
||||
if [ -z "$CURRENT_TAG" ]; then
|
||||
echo "No current tag found from release payload" 1>&2
|
||||
exit 1
|
||||
fi
|
||||
# Sort by version, pick the latest tag that is not the current tag
|
||||
PREV_TAG=$(git tag --sort=-v:refname | grep -v "^${CURRENT_TAG}$" | head -n 1 || true)
|
||||
echo "Previous tag: $PREV_TAG"
|
||||
echo "prev_tag=$PREV_TAG" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Generate release notes from commits
|
||||
id: gen
|
||||
env:
|
||||
CURRENT_TAG: ${{ github.event.release.tag_name }}
|
||||
PREV_TAG: ${{ steps.prev.outputs.prev_tag }}
|
||||
REPO: ${{ github.repository }}
|
||||
shell: bash
|
||||
run: |
|
||||
set -e
|
||||
if [ -n "$PREV_TAG" ]; then
|
||||
RANGE="$PREV_TAG..$CURRENT_TAG"
|
||||
echo "# Release $CURRENT_TAG" > RELEASE_NOTES.md
|
||||
echo >> RELEASE_NOTES.md
|
||||
echo "[Full Changelog](https://github.com/$REPO/compare/$PREV_TAG...$CURRENT_TAG)" >> RELEASE_NOTES.md
|
||||
echo >> RELEASE_NOTES.md
|
||||
else
|
||||
RANGE="$CURRENT_TAG"
|
||||
echo "# Release $CURRENT_TAG" > RELEASE_NOTES.md
|
||||
echo >> RELEASE_NOTES.md
|
||||
fi
|
||||
|
||||
echo "## Changes" >> RELEASE_NOTES.md
|
||||
echo >> RELEASE_NOTES.md
|
||||
# Use commit subjects; skip merge commits for clarity
|
||||
git log --no-merges --pretty=format:'- %s (%h) by %an' $RANGE >> RELEASE_NOTES.md || true
|
||||
|
||||
echo "Generated notes:" && echo "-----" && cat RELEASE_NOTES.md && echo "-----"
|
||||
|
||||
- name: Update GitHub Release body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const notes = fs.readFileSync('RELEASE_NOTES.md', 'utf8');
|
||||
const releaseId = context.payload.release.id;
|
||||
await github.rest.repos.updateRelease({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
release_id: releaseId,
|
||||
body: notes,
|
||||
});
|
72
build_all.ps1
Normal file
72
build_all.ps1
Normal file
@@ -0,0 +1,72 @@
|
||||
# Requires: PowerShell 5+, Go toolchain, git (optional for version)
|
||||
|
||||
# Colors
|
||||
$Red = 'Red'
|
||||
$Green = 'Green'
|
||||
$White = 'White'
|
||||
|
||||
# OS/ARCH matrix
|
||||
$osList = @('windows','linux','darwin','freebsd')
|
||||
$archList = @('amd64','arm64','386','arm')
|
||||
|
||||
# Ensure build directory
|
||||
$buildDir = Join-Path -Path (Get-Location) -ChildPath 'build'
|
||||
New-Item -ItemType Directory -Force -Path $buildDir | Out-Null
|
||||
|
||||
# Detect version from git tags or fallback to dev
|
||||
$version = (git describe --tags --abbrev=0 2>$null)
|
||||
if (-not $version) { $version = 'dev' }
|
||||
$version = $version.Trim()
|
||||
|
||||
# Check go exists
|
||||
if (-not (Get-Command go -ErrorAction SilentlyContinue)) {
|
||||
Write-Host 'Go toolchain not found in PATH. Please install Go and try again.' -ForegroundColor $Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
$failedBuilds = @()
|
||||
|
||||
foreach ($goos in $osList) {
|
||||
foreach ($goarch in $archList) {
|
||||
# Skip unsupported combos: windows/arm, darwin/386, darwin/arm
|
||||
if ((($goos -eq 'windows') -and ($goarch -eq 'arm')) -or
|
||||
(($goos -eq 'darwin') -and (($goarch -eq '386') -or ($goarch -eq 'arm')))) {
|
||||
continue
|
||||
}
|
||||
|
||||
Write-Host "Building for $goos/$goarch..." -ForegroundColor $White
|
||||
|
||||
$binaryName = "komari-agent-$goos-$goarch"
|
||||
if ($goos -eq 'windows') { $binaryName = "$binaryName.exe" }
|
||||
$outPath = Join-Path $buildDir $binaryName
|
||||
|
||||
# Set env per invocation
|
||||
$env:GOOS = $goos
|
||||
$env:GOARCH = $goarch
|
||||
$env:CGO_ENABLED = '0'
|
||||
|
||||
& go build -trimpath -ldflags "-s -w -X github.com/komari-monitor/komari-agent/update.CurrentVersion=$version" -o "$outPath"
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Failed to build for $goos/$goarch" -ForegroundColor $Red
|
||||
$failedBuilds += "$goos/$goarch"
|
||||
}
|
||||
else {
|
||||
Write-Host "Successfully built $binaryName" -ForegroundColor $Green
|
||||
}
|
||||
|
||||
# Clear env to avoid affecting subsequent shells (optional)
|
||||
Remove-Item Env:GOOS -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:GOARCH -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:CGO_ENABLED -ErrorAction SilentlyContinue
|
||||
}
|
||||
}
|
||||
|
||||
if ($failedBuilds.Count -gt 0) {
|
||||
Write-Host "`nThe following builds failed:" -ForegroundColor $Red
|
||||
foreach ($b in $failedBuilds) { Write-Host "- $b" -ForegroundColor $Red }
|
||||
}
|
||||
else {
|
||||
Write-Host "`nAll builds completed successfully." -ForegroundColor $Green
|
||||
}
|
||||
|
||||
Write-Host "`nBinaries are in the ./build directory." -ForegroundColor $White
|
@@ -20,4 +20,6 @@ var (
|
||||
CFAccessClientSecret string
|
||||
MemoryIncludeCache bool
|
||||
CustomDNS string
|
||||
EnableGPU bool // 启用详细GPU监控
|
||||
ShowWarning bool // Windows 上显示安全警告,作为子进程运行一次
|
||||
)
|
||||
|
22
cmd/root.go
22
cmd/root.go
@@ -19,18 +19,26 @@ var RootCmd = &cobra.Command{
|
||||
Short: "komari agent",
|
||||
Long: `komari agent`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
||||
if flags.ShowWarning {
|
||||
ShowToast()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if !flags.DisableWebSsh {
|
||||
go WarnKomariRunning()
|
||||
}
|
||||
|
||||
log.Println("Komari Agent", update.CurrentVersion)
|
||||
log.Println("Github Repo:", update.Repo)
|
||||
|
||||
// 设置自定义DNS解析器
|
||||
// 设置 DNS 解析行为
|
||||
if flags.CustomDNS != "" {
|
||||
dnsresolver.SetCustomDNSServer(flags.CustomDNS)
|
||||
log.Printf("Using custom DNS server: %s", flags.CustomDNS)
|
||||
} else {
|
||||
log.Printf("Using default DNS servers, primary: %s (failover servers available)", dnsresolver.DNSServers[0])
|
||||
if len(dnsresolver.DNSServers) > 1 {
|
||||
log.Printf("Available failover DNS servers: %v", dnsresolver.DNSServers[1:])
|
||||
}
|
||||
// 未设置则使用系统默认 DNS(不使用内置列表)
|
||||
log.Printf("Using system default DNS resolver")
|
||||
}
|
||||
|
||||
// Auto discovery
|
||||
@@ -113,6 +121,8 @@ func init() {
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientID, "cf-access-client-id", "", "Cloudflare Access Client ID")
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CFAccessClientSecret, "cf-access-client-secret", "", "Cloudflare Access Client Secret")
|
||||
RootCmd.PersistentFlags().BoolVar(&flags.MemoryIncludeCache, "memory-include-cache", false, "Include cache/buffer in memory usage")
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CustomDNS, "custom-dns", "", "Custom DNS server to use (e.g. 8.8.8.8, 114.114.114.114). By default, the program will use multiple built-in DNS servers with failover support.")
|
||||
RootCmd.PersistentFlags().StringVar(&flags.CustomDNS, "custom-dns", "", "Custom DNS server to use (e.g. 8.8.8.8, 114.114.114.114). By default, the program uses the system DNS resolver.")
|
||||
RootCmd.PersistentFlags().BoolVar(&flags.EnableGPU, "gpu", false, "Enable detailed GPU monitoring (usage, memory, multi-GPU support)")
|
||||
RootCmd.PersistentFlags().BoolVar(&flags.ShowWarning, "show-warning", false, "Show security warning on Windows, run once as a subprocess")
|
||||
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
|
||||
}
|
||||
|
13
cmd/warn.go
Normal file
13
cmd/warn.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
func WarnKomariRunning() {
|
||||
// No-op on non-Windows platforms
|
||||
return
|
||||
}
|
||||
|
||||
func ShowToast() {
|
||||
// No-op on non-Windows platforms
|
||||
return
|
||||
}
|
408
cmd/warn_windows.go
Normal file
408
cmd/warn_windows.go
Normal file
@@ -0,0 +1,408 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
toast "gopkg.in/toast.v1"
|
||||
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// WarnKomariRunning
|
||||
// 作为 SYSTEM(Session 0)运行时:
|
||||
// 1) 轮询已登录的交互会话(WTSActive)
|
||||
// 2) 对新检测到的会话,以该用户身份在其会话内启动当前进程(追加 --show-warning 参数)
|
||||
// 3) 用户态子进程会进入 ShowToast() 分支并发送 Toast
|
||||
func WarnKomariRunning() {
|
||||
|
||||
// 启用权限
|
||||
if err := enablePrivileges([]string{"SeAssignPrimaryTokenPrivilege", "SeIncreaseQuotaPrivilege"}); err != nil {
|
||||
log.Printf("[warn] enabling privileges failed: %v", err)
|
||||
}
|
||||
|
||||
seen := map[uint32]struct{}{}
|
||||
var mu sync.Mutex
|
||||
|
||||
sessions := []uint32{}
|
||||
for _, sid := range sessions {
|
||||
seen[sid] = struct{}{}
|
||||
}
|
||||
|
||||
// 轮询新登录
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
current, err := enumerateActiveSessions()
|
||||
if err != nil {
|
||||
log.Printf("[warn] enumerateActiveSessions error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 将 current 列表转换为集合,便于清理旧会话
|
||||
currentSet := make(map[uint32]struct{}, len(current))
|
||||
for _, sid := range current {
|
||||
currentSet[sid] = struct{}{}
|
||||
}
|
||||
|
||||
// 找到新出现的会话 -> 在该会话启动进程
|
||||
for _, sid := range current {
|
||||
mu.Lock()
|
||||
_, known := seen[sid]
|
||||
if !known {
|
||||
seen[sid] = struct{}{}
|
||||
mu.Unlock()
|
||||
if err := launchSelfInSession(sid, []string{"--show-warning"}); err != nil {
|
||||
log.Printf("[warn] launch in session %d failed: %v", sid, err)
|
||||
} else {
|
||||
log.Printf("[info] launched toast helper in session %d", sid)
|
||||
}
|
||||
} else {
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// 清理不再存在的会话,避免 map 膨胀
|
||||
mu.Lock()
|
||||
for sid := range seen {
|
||||
if _, ok := currentSet[sid]; !ok {
|
||||
delete(seen, sid)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ShowToast 在用户态中执行
|
||||
func ShowToast() {
|
||||
title := "Komari is Running"
|
||||
message := "The remote control software \"Komari\" is running, which allows others to control your computer. If this was not initiated by you, please terminate the program immediately."
|
||||
|
||||
const aumid = "Komari.Monitor.Agent"
|
||||
const linkName = "Komari Warning (Auto Delete Later)"
|
||||
|
||||
if err := ensureStartMenuShortcut(aumid, linkName); err != nil {
|
||||
log.Printf("[warn] ensureStartMenuShortcut failed: %v", err)
|
||||
}
|
||||
|
||||
n := toast.Notification{
|
||||
AppID: aumid,
|
||||
Title: title,
|
||||
Message: message,
|
||||
Actions: []toast.Action{
|
||||
{Type: "protocol", Label: "Help", Arguments: "https://komari-document.pages.dev/faq/uninstall.html"},
|
||||
},
|
||||
}
|
||||
if err := n.Push(); err != nil {
|
||||
log.Printf("[warn] toast push failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待 15 秒后删除快捷方式
|
||||
shortcutPath := getStartMenuShortcutPath(linkName)
|
||||
time.Sleep(15 * time.Second)
|
||||
if err := os.Remove(shortcutPath); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.Printf("[warn] remove shortcut failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureStartMenuShortcut 使用 WScript.Shell 创建 .lnk 并设置 AppUserModelID
|
||||
func ensureStartMenuShortcut(aumid, linkName string) error {
|
||||
programs := getStartMenuProgramsDir()
|
||||
if err := os.MkdirAll(programs, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
shortcutPath := filepath.Join(programs, sanitizeFileName(linkName)+".lnk")
|
||||
if _, err := os.Stat(shortcutPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hr := ole.CoInitializeEx(0, ole.COINIT_APARTMENTTHREADED); hr != nil {
|
||||
// S_OK (0) 或 S_FALSE (1) 都视为成功;go-ole 将非零 HRESULT 作为 error 返回
|
||||
// 当返回错误时,我们再进行 Uninitialize 保护即可
|
||||
// 这里直接继续执行,由后续操作决定是否可用
|
||||
}
|
||||
defer ole.CoUninitialize()
|
||||
|
||||
unknown, err := oleutil.CreateObject("WScript.Shell")
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateObject WScript.Shell: %w", err)
|
||||
}
|
||||
defer unknown.Release()
|
||||
|
||||
shell, err := unknown.QueryInterface(ole.IID_IDispatch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("QueryInterface IDispatch: %w", err)
|
||||
}
|
||||
defer shell.Release()
|
||||
|
||||
cs, err := oleutil.CallMethod(shell, "CreateShortcut", shortcutPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateShortcut: %w", err)
|
||||
}
|
||||
shortcut := cs.ToIDispatch()
|
||||
defer shortcut.Release()
|
||||
|
||||
exePath, _ := os.Executable()
|
||||
exeDir := filepath.Dir(exePath)
|
||||
|
||||
if _, err = oleutil.PutProperty(shortcut, "TargetPath", exePath); err != nil {
|
||||
return fmt.Errorf("set TargetPath: %w", err)
|
||||
}
|
||||
if _, err = oleutil.PutProperty(shortcut, "WorkingDirectory", exeDir); err != nil {
|
||||
return fmt.Errorf("set WorkingDirectory: %w", err)
|
||||
}
|
||||
_, _ = oleutil.PutProperty(shortcut, "Description", "Komari Agent")
|
||||
// 设置 AUMID
|
||||
if _, err = oleutil.PutProperty(shortcut, "AppUserModelID", aumid); err != nil {
|
||||
// 某些系统该属性不存在时,依然尝试保存;Toast 可能仍然显示
|
||||
log.Printf("[warn] set AppUserModelID failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err = oleutil.CallMethod(shortcut, "Save"); err != nil {
|
||||
return fmt.Errorf("shortcut.Save: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回当前用户开始菜单 Programs 目录
|
||||
func getStartMenuProgramsDir() string {
|
||||
return filepath.Join(os.Getenv("APPDATA"), "Microsoft", "Windows", "Start Menu", "Programs")
|
||||
}
|
||||
|
||||
// 获取快捷方式完整路径
|
||||
func getStartMenuShortcutPath(linkName string) string {
|
||||
return filepath.Join(getStartMenuProgramsDir(), sanitizeFileName(linkName)+".lnk")
|
||||
}
|
||||
|
||||
func sanitizeFileName(name string) string {
|
||||
replacer := strings.NewReplacer("\\", "_", "/", "_", ":", "_", "*", "_", "?", "_", "\"", "_", "<", "_", ">", "_", "|", "_")
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
|
||||
// enumerateActiveSessions 列出当前处于交互活动状态的会话 ID
|
||||
func enumerateActiveSessions() ([]uint32, error) {
|
||||
type wtsSessionInfo struct {
|
||||
SessionID uint32
|
||||
WinStation *uint16
|
||||
State uint32
|
||||
}
|
||||
|
||||
wtsapi := windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procEnum := wtsapi.NewProc("WTSEnumerateSessionsW")
|
||||
procFree := wtsapi.NewProc("WTSFreeMemory")
|
||||
|
||||
var (
|
||||
server windows.Handle // WTS_CURRENT_SERVER_HANDLE == 0
|
||||
pinfo *wtsSessionInfo
|
||||
count uint32
|
||||
version uint32 = 1
|
||||
)
|
||||
r1, _, err := procEnum.Call(
|
||||
uintptr(server),
|
||||
0,
|
||||
uintptr(version),
|
||||
uintptr(unsafe.Pointer(&pinfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r1 == 0 {
|
||||
return nil, fmt.Errorf("WTSEnumerateSessionsW: %w", err)
|
||||
}
|
||||
defer procFree.Call(uintptr(unsafe.Pointer(pinfo)))
|
||||
|
||||
// WTS_CONNECTSTATE_CLASS
|
||||
const WTSActive = 0
|
||||
|
||||
// 遍历结构数组
|
||||
res := make([]uint32, 0, count)
|
||||
infos := unsafe.Slice(pinfo, int(count))
|
||||
for i := 0; i < len(infos); i++ {
|
||||
info := &infos[i]
|
||||
if info.State == WTSActive {
|
||||
if hasUserName(info.SessionID) {
|
||||
res = append(res, info.SessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// hasUserName 检查会话是否有用户名(避免将空会话当作登录)
|
||||
func hasUserName(sessionID uint32) bool {
|
||||
const WTSUserName = 5
|
||||
wtsapi := windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procQuery := wtsapi.NewProc("WTSQuerySessionInformationW")
|
||||
procFree := wtsapi.NewProc("WTSFreeMemory")
|
||||
|
||||
var buf *uint16
|
||||
var blen uint32
|
||||
r1, _, _ := procQuery.Call(0, uintptr(sessionID), uintptr(WTSUserName), uintptr(unsafe.Pointer(&buf)), uintptr(unsafe.Pointer(&blen)))
|
||||
if r1 == 0 || buf == nil {
|
||||
return false
|
||||
}
|
||||
defer procFree.Call(uintptr(unsafe.Pointer(buf)))
|
||||
name := windows.UTF16PtrToString(buf)
|
||||
return strings.TrimSpace(name) != ""
|
||||
}
|
||||
|
||||
// launchSelfInSession 在指定会话中以该用户身份启动当前进程并追加 args
|
||||
func launchSelfInSession(sessionID uint32, extraArgs []string) error {
|
||||
// 获取用户令牌
|
||||
userToken, err := queryUserToken(sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("queryUserToken: %w", err)
|
||||
}
|
||||
defer userToken.Close()
|
||||
|
||||
primary, err := duplicateTokenPrimary(userToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("duplicateToken: %w", err)
|
||||
}
|
||||
defer primary.Close()
|
||||
|
||||
exePath, _ := os.Executable()
|
||||
// 仅保留进程名,去掉已有的 --show-warning,避免递归
|
||||
baseArgs := filterArgs(os.Args[1:], "--show-warning")
|
||||
fullArgs := append([]string{quoteIfNeeded(exePath)}, baseArgs...)
|
||||
fullArgs = append(fullArgs, extraArgs...)
|
||||
cmdlineStr := strings.Join(fullArgs, " ")
|
||||
cmdline, err := windows.UTF16PtrFromString(cmdlineStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16PtrFromString: %w", err)
|
||||
}
|
||||
|
||||
env, err := createEnvironmentBlock(primary)
|
||||
if err != nil {
|
||||
return fmt.Errorf("createEnvironmentBlock: %w", err)
|
||||
}
|
||||
defer destroyEnvironmentBlock(env)
|
||||
|
||||
var si windows.StartupInfo
|
||||
si.Cb = uint32(unsafe.Sizeof(si))
|
||||
si.Flags = 0
|
||||
si.ShowWindow = 0
|
||||
// 指定桌面,确保窗口可见
|
||||
desktop, _ := windows.UTF16PtrFromString("winsta0\\default")
|
||||
si.Desktop = desktop
|
||||
|
||||
var pi windows.ProcessInformation
|
||||
// CREATE_UNICODE_ENVIRONMENT | DETACHED_PROCESS
|
||||
const CREATE_UNICODE_ENVIRONMENT = 0x00000400
|
||||
const DETACHED_PROCESS = 0x00000008
|
||||
|
||||
err = windows.CreateProcessAsUser(primary, nil, cmdline, nil, nil, false, CREATE_UNICODE_ENVIRONMENT|DETACHED_PROCESS, env, nil, &si, &pi)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
windows.CloseHandle(pi.Thread)
|
||||
windows.CloseHandle(pi.Process)
|
||||
return nil
|
||||
}
|
||||
|
||||
// enablePrivileges 尝试启用一组权限
|
||||
func enablePrivileges(names []string) error {
|
||||
var errs []string
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||
return err
|
||||
}
|
||||
defer token.Close()
|
||||
for _, name := range names {
|
||||
if err := setPrivilege(token, name, true); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("%s: %v", name, err))
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setPrivilege(token windows.Token, privName string, enable bool) error {
|
||||
var luid windows.LUID
|
||||
nameUTF16, _ := windows.UTF16PtrFromString(privName)
|
||||
if err := windows.LookupPrivilegeValue(nil, nameUTF16, &luid); err != nil {
|
||||
return err
|
||||
}
|
||||
tp := windows.Tokenprivileges{
|
||||
PrivilegeCount: 1,
|
||||
Privileges: [1]windows.LUIDAndAttributes{
|
||||
{Luid: luid, Attributes: 0},
|
||||
},
|
||||
}
|
||||
if enable {
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
}
|
||||
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
// queryUserToken 调用 WTSQueryUserToken 获取指定会话的用户令牌(模拟令牌)
|
||||
func queryUserToken(sessionID uint32) (windows.Token, error) {
|
||||
wtsapi := windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
proc := wtsapi.NewProc("WTSQueryUserToken")
|
||||
var h windows.Handle
|
||||
r1, _, err := proc.Call(uintptr(sessionID), uintptr(unsafe.Pointer(&h)))
|
||||
if r1 == 0 {
|
||||
return 0, fmt.Errorf("WTSQueryUserToken: %w", err)
|
||||
}
|
||||
return windows.Token(h), nil
|
||||
}
|
||||
|
||||
// duplicateTokenPrimary 将模拟令牌复制为主令牌,以供 CreateProcessAsUser 使用
|
||||
func duplicateTokenPrimary(token windows.Token) (windows.Token, error) {
|
||||
var primary windows.Token
|
||||
err := windows.DuplicateTokenEx(token, windows.TOKEN_ALL_ACCESS, nil, windows.SecurityIdentification, windows.TokenPrimary, &primary)
|
||||
return primary, err
|
||||
}
|
||||
|
||||
// createEnvironmentBlock 为用户令牌创建环境块
|
||||
func createEnvironmentBlock(token windows.Token) (*uint16, error) {
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
proc := userenv.NewProc("CreateEnvironmentBlock")
|
||||
var env *uint16
|
||||
r1, _, err := proc.Call(uintptr(unsafe.Pointer(&env)), uintptr(token), 0)
|
||||
if r1 == 0 {
|
||||
return nil, fmt.Errorf("CreateEnvironmentBlock: %w", err)
|
||||
}
|
||||
return env, nil
|
||||
}
|
||||
|
||||
func destroyEnvironmentBlock(env *uint16) {
|
||||
if env == nil {
|
||||
return
|
||||
}
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
proc := userenv.NewProc("DestroyEnvironmentBlock")
|
||||
_, _, _ = proc.Call(uintptr(unsafe.Pointer(env)))
|
||||
}
|
||||
|
||||
func quoteIfNeeded(s string) string {
|
||||
if strings.ContainsAny(s, " \t\"") {
|
||||
return "\"" + strings.ReplaceAll(s, "\"", "\\\"") + "\""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func filterArgs(args []string, drop string) []string {
|
||||
out := make([]string, 0, len(args))
|
||||
for i := 0; i < len(args); i++ {
|
||||
if args[i] == drop {
|
||||
continue
|
||||
}
|
||||
out = append(out, args[i])
|
||||
}
|
||||
return out
|
||||
}
|
@@ -2,37 +2,67 @@ package dnsresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
)
|
||||
|
||||
var (
|
||||
// DNS服务器列表,按优先级排序
|
||||
DNSServers = []string{
|
||||
"[2606:4700:4700::1111]:53", // Cloudflare IPv6
|
||||
"[2606:4700:4700::1001]:53", // Cloudflare IPv6 备用
|
||||
"[2001:4860:4860::8888]:53", // Google IPv6
|
||||
"[2001:4860:4860::8844]:53", // Google IPv6 备用
|
||||
|
||||
"114.114.114.114:53", // 114DNS,中国大陆
|
||||
"8.8.8.8:53", // Google DNS,全球
|
||||
"8.8.4.4:53", // Google DNS备用,全球
|
||||
"1.1.1.1:53", // Cloudflare DNS,全球
|
||||
"1.1.1.1:53", // Cloudflare IPv4
|
||||
"8.8.8.8:53", // Google IPv4
|
||||
"8.8.4.4:53", // Google IPv4 备用
|
||||
"223.5.5.5:53", // 阿里DNS,中国大陆
|
||||
"119.29.29.29:53", // DNSPod,中国大陆
|
||||
}
|
||||
|
||||
// CustomDNSServer 自定义DNS服务器,可以通过命令行参数设置
|
||||
CustomDNSServer string
|
||||
|
||||
preferV4Once sync.Once
|
||||
hasIPv4 bool
|
||||
)
|
||||
|
||||
// SetCustomDNSServer 设置自定义DNS服务器
|
||||
func SetCustomDNSServer(dnsServer string) {
|
||||
if dnsServer != "" {
|
||||
// 检查是否已包含端口,如果没有则添加默认端口53
|
||||
if !strings.Contains(dnsServer, ":") {
|
||||
dnsServer = dnsServer + ":53"
|
||||
}
|
||||
CustomDNSServer = dnsServer
|
||||
if dnsServer == "" {
|
||||
return
|
||||
}
|
||||
CustomDNSServer = normalizeDNSServer(dnsServer)
|
||||
}
|
||||
|
||||
// normalizeDNSServer 将输入的 DNS 服务器字符串规范化为 host:port 形式:
|
||||
// - IPv6 地址自动加方括号并补全端口 :53(若未提供)
|
||||
// - IPv4/域名未提供端口时补全 :53
|
||||
func normalizeDNSServer(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
// 已是 [ipv6]:port 或 host:port 形式
|
||||
if (strings.HasPrefix(s, "[") && strings.Contains(s, "]:")) || (strings.Count(s, ":") == 1 && !strings.Contains(s, "]")) {
|
||||
return s
|
||||
}
|
||||
// 纯 IPv6(未加端口/括号)
|
||||
if strings.Count(s, ":") >= 2 && !strings.Contains(s, "]") {
|
||||
return "[" + s + "]:53"
|
||||
}
|
||||
// 其它情况:若未包含端口则补 53
|
||||
if !strings.Contains(s, ":") {
|
||||
return s + ":53"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// getCurrentDNSServer 获取当前要使用的DNS服务器
|
||||
@@ -40,80 +70,107 @@ func getCurrentDNSServer() string {
|
||||
if CustomDNSServer != "" {
|
||||
return CustomDNSServer
|
||||
}
|
||||
// 如果没有设置自定义DNS,返回默认的第一个
|
||||
return DNSServers[0]
|
||||
// 如果没有设置自定义DNS,返回空字符串,表示应使用系统默认解析器
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCustomResolver 返回一个使用指定DNS服务器的解析器
|
||||
// GetCustomResolver 返回一个解析器:
|
||||
// - 若设置了自定义 DNS:使用该服务器(并在失败时尝试内置列表作为兜底)。
|
||||
// - 若未设置自定义 DNS:返回系统默认解析器(不使用内置列表)。
|
||||
func GetCustomResolver() *net.Resolver {
|
||||
// 未设置自定义 DNS,直接使用系统默认解析器
|
||||
if getCurrentDNSServer() == "" {
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
// 设置了自定义 DNS,则构造使用自定义 DNS 的解析器
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
d := net.Dialer{Timeout: 10 * time.Second}
|
||||
|
||||
// 尝试自定义DNS或默认DNS
|
||||
// 优先使用自定义 DNS 服务器
|
||||
dnsServer := getCurrentDNSServer()
|
||||
conn, err := d.DialContext(ctx, "udp", dnsServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
if dnsServer != "" {
|
||||
if conn, err := d.DialContext(ctx, "udp", dnsServer); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果连接失败,尝试其他DNS服务器
|
||||
log.Printf("Custom DNS server %s is unreachable, trying fallback servers", dnsServer)
|
||||
// 如果自定义DNS不可用,则尝试内置列表作为兜底
|
||||
for _, server := range DNSServers {
|
||||
if server != dnsServer { // 避免重复尝试
|
||||
conn, err := d.DialContext(ctx, "udp", server)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if server == dnsServer {
|
||||
continue
|
||||
}
|
||||
if conn, err := d.DialContext(ctx, "udp", server); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 所有DNS服务器都失败,返回最后一次的错误
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("no available DNS server")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetHTTPClient 返回一个使用自定义DNS解析器的HTTP客户端
|
||||
func GetHTTPClient(timeout time.Duration) *http.Client {
|
||||
// buildTransport 构建带有自定义解析/拨号策略的 HTTP 传输层,可注入 TLS 配置
|
||||
func buildTransport(timeout time.Duration, tlsConfig *tls.Config) *http.Transport {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
customResolver := GetCustomResolver()
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips, err := customResolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 根据本机是否具备 IPv4 动态排序
|
||||
preferIPv4 := preferIPv4First()
|
||||
sort.SliceStable(ips, func(i, j int) bool {
|
||||
ip1 := net.ParseIP(ips[i])
|
||||
ip2 := net.ParseIP(ips[j])
|
||||
if ip1 == nil || ip2 == nil {
|
||||
return false
|
||||
}
|
||||
ips, err := customResolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if preferIPv4 {
|
||||
return ip1.To4() != nil && ip2.To4() == nil
|
||||
}
|
||||
for _, ip := range ips {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
// IPv6 优先
|
||||
return ip1.To4() == nil && ip2.To4() != nil
|
||||
})
|
||||
for _, ip := range ips {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
return nil, fmt.Errorf("failed to dial to any of the resolved IPs")
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to dial to any of the resolved IPs")
|
||||
},
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
}
|
||||
|
||||
func GetHTTPClient(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: buildTransport(timeout, &tls.Config{
|
||||
InsecureSkipVerify: flags.IgnoreUnsafeCert,
|
||||
}),
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
@@ -130,3 +187,91 @@ func GetNetDialer(timeout time.Duration) *net.Dialer {
|
||||
Resolver: GetCustomResolver(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetDialContext 返回一个自定义 DialContext:
|
||||
// - 使用自定义解析器解析主机名
|
||||
// - 优先尝试 IPv4,再尝试 IPv6
|
||||
// - 逐个 IP 进行连接尝试,直到成功或全部失败
|
||||
func GetDialContext(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 15 * time.Second
|
||||
}
|
||||
|
||||
resolver := GetCustomResolver()
|
||||
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 为解析设置一个带超时的子 context,避免整体拨号过快超时
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := resolver.LookupHost(lookupCtx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 根据本机是否具备 IPv4 动态排序
|
||||
preferIPv4 := preferIPv4First()
|
||||
sort.SliceStable(ips, func(i, j int) bool {
|
||||
ip1 := net.ParseIP(ips[i])
|
||||
ip2 := net.ParseIP(ips[j])
|
||||
if ip1 == nil || ip2 == nil {
|
||||
return false
|
||||
}
|
||||
if preferIPv4 {
|
||||
return ip1.To4() != nil && ip2.To4() == nil
|
||||
}
|
||||
// IPv6 优先
|
||||
return ip1.To4() == nil && ip2.To4() != nil
|
||||
})
|
||||
|
||||
// 逐个 IP 尝试连接
|
||||
for _, ip := range ips {
|
||||
d := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
c, err := d.DialContext(ctx, network, net.JoinHostPort(ip, port))
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to dial to any of the resolved IPs")
|
||||
}
|
||||
}
|
||||
|
||||
// preferIPv4First 检测本机是否存在可用的 IPv4 地址,若没有则在连接尝试中优先 IPv6
|
||||
func preferIPv4First() bool {
|
||||
preferV4Once.Do(func() {
|
||||
ifaces, _ := net.Interfaces()
|
||||
for _, iface := range ifaces {
|
||||
if (iface.Flags&net.FlagUp) == 0 || (iface.Flags&net.FlagLoopback) != 0 {
|
||||
continue
|
||||
}
|
||||
addrs, _ := iface.Addrs()
|
||||
for _, a := range addrs {
|
||||
var ip net.IP
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
if ip == nil || ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
hasIPv4 = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
hasIPv4 = false
|
||||
})
|
||||
return hasIPv4
|
||||
}
|
||||
|
4
go.mod
4
go.mod
@@ -6,6 +6,7 @@ require (
|
||||
github.com/UserExistsError/conpty v0.1.4
|
||||
github.com/blang/semver v3.5.1+incompatible
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/go-ole/go-ole v1.2.6
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/klauspost/cpuid/v2 v2.3.0
|
||||
github.com/prometheus-community/pro-bing v0.7.0
|
||||
@@ -13,17 +14,18 @@ require (
|
||||
github.com/shirou/gopsutil/v4 v4.25.6
|
||||
github.com/spf13/cobra v1.9.1
|
||||
golang.org/x/sys v0.33.0
|
||||
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/google/go-github/v30 v30.1.0 // indirect
|
||||
github.com/google/go-querystring v1.0.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/inconshreveable/go-update v0.0.0-20160112193335-8152e7eb6ccf // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/tcnksm/go-gitconfig v0.1.2 // indirect
|
||||
|
4
go.sum
4
go.sum
@@ -37,6 +37,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d h1:VhgPp6v9qf9Agr/56bj7Y/xa04UccTW04VP0Qed4vnQ=
|
||||
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d/go.mod h1:YUTz3bUH2ZwIWBy3CJBeOBEugqcmXREj14T+iG/4k4U=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I=
|
||||
github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
@@ -104,6 +106,8 @@ google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2 h1:MZF6J7CV6s/h0HBkfqebrYfKCVEo5iN+wzE4QhV3Evo=
|
||||
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2/go.mod h1:s1Sn2yZos05Qfs7NKt867Xe18emOmtsO3eAKbDaon0o=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
149
install.sh
149
install.sh
@@ -389,8 +389,94 @@ log_success "Komari-agent installed to ${GREEN}$komari_agent_path${NC}"
|
||||
# Detect init system and configure service
|
||||
log_step "Configuring system service..."
|
||||
|
||||
# Check if running on NixOS
|
||||
if [ -f /etc/NIXOS ]; then
|
||||
# Function to detect actual init system
|
||||
detect_init_system() {
|
||||
# Check if running on NixOS (special case)
|
||||
if [ -f /etc/NIXOS ]; then
|
||||
echo "nixos"
|
||||
return
|
||||
fi
|
||||
|
||||
# Alpine Linux MUST be checked first
|
||||
# Alpine always uses OpenRC, even in containers where PID 1 might be different
|
||||
if [ -f /etc/alpine-release ]; then
|
||||
if command -v rc-service >/dev/null 2>&1 || [ -f /sbin/openrc-run ]; then
|
||||
echo "openrc"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Get PID 1 process for other detection
|
||||
local pid1_process=$(ps -p 1 -o comm= 2>/dev/null | tr -d ' ')
|
||||
|
||||
# If PID 1 is systemd, use systemd
|
||||
if [ "$pid1_process" = "systemd" ] || [ -d /run/systemd/system ]; then
|
||||
if command -v systemctl >/dev/null 2>&1; then
|
||||
# Additional verification that systemd is actually functioning
|
||||
if systemctl list-units >/dev/null 2>&1; then
|
||||
echo "systemd"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check for Gentoo OpenRC (PID 1 is openrc-init)
|
||||
if [ "$pid1_process" = "openrc-init" ]; then
|
||||
if command -v rc-service >/dev/null 2>&1; then
|
||||
echo "openrc"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check for other OpenRC systems (not Alpine, already handled)
|
||||
# Some systems use traditional init with OpenRC
|
||||
if [ "$pid1_process" = "init" ] && [ ! -f /etc/alpine-release ]; then
|
||||
# Check if OpenRC is actually managing services
|
||||
if [ -d /run/openrc ] && command -v rc-service >/dev/null 2>&1; then
|
||||
echo "openrc"
|
||||
return
|
||||
fi
|
||||
# Check for OpenRC files
|
||||
if [ -f /sbin/openrc ] && command -v rc-service >/dev/null 2>&1; then
|
||||
echo "openrc"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check for OpenWrt's procd
|
||||
if command -v uci >/dev/null 2>&1 && [ -f /etc/rc.common ]; then
|
||||
echo "procd"
|
||||
return
|
||||
fi
|
||||
|
||||
# Check for macOS launchd
|
||||
if [ "$os_name" = "darwin" ] && command -v launchctl >/dev/null 2>&1; then
|
||||
echo "launchd"
|
||||
return
|
||||
fi
|
||||
|
||||
# Fallback: if systemctl exists and appears functional, assume systemd
|
||||
if command -v systemctl >/dev/null 2>&1; then
|
||||
if systemctl list-units >/dev/null 2>&1; then
|
||||
echo "systemd"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Last resort: check for OpenRC without other indicators
|
||||
if command -v rc-service >/dev/null 2>&1 && [ -d /etc/init.d ]; then
|
||||
echo "openrc"
|
||||
return
|
||||
fi
|
||||
|
||||
echo "unknown"
|
||||
}
|
||||
|
||||
init_system=$(detect_init_system)
|
||||
log_info "Detected init system: ${GREEN}$init_system${NC}"
|
||||
|
||||
# Handle each init system
|
||||
if [ "$init_system" = "nixos" ]; then
|
||||
log_warning "NixOS detected. System services must be configured declaratively."
|
||||
log_info "Please add the following to your NixOS configuration:"
|
||||
echo ""
|
||||
@@ -409,32 +495,7 @@ if [ -f /etc/NIXOS ]; then
|
||||
echo ""
|
||||
log_info "Then run: sudo nixos-rebuild switch"
|
||||
log_warning "Service not started automatically on NixOS. Please rebuild your configuration."
|
||||
elif command -v systemctl >/dev/null 2>&1; then
|
||||
# Systemd service configuration
|
||||
log_info "Using systemd for service management"
|
||||
service_file="/etc/systemd/system/${service_name}.service"
|
||||
cat > "$service_file" << EOF
|
||||
[Unit]
|
||||
Description=Komari Agent Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=${komari_agent_path} ${komari_args}
|
||||
WorkingDirectory=${target_dir}
|
||||
Restart=always
|
||||
User=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# Reload systemd and start service
|
||||
systemctl daemon-reload
|
||||
systemctl enable ${service_name}.service
|
||||
systemctl start ${service_name}.service
|
||||
log_success "Systemd service configured and started"
|
||||
elif command -v rc-service >/dev/null 2>&1; then
|
||||
elif [ "$init_system" = "openrc" ]; then
|
||||
# OpenRC service configuration
|
||||
log_info "Using OpenRC for service management"
|
||||
service_file="/etc/init.d/${service_name}"
|
||||
@@ -462,7 +523,32 @@ EOF
|
||||
rc-update add ${service_name} default
|
||||
rc-service ${service_name} start
|
||||
log_success "OpenRC service configured and started"
|
||||
elif command -v uci >/dev/null 2>&1; then
|
||||
elif [ "$init_system" = "systemd" ]; then
|
||||
# Systemd service configuration
|
||||
log_info "Using systemd for service management"
|
||||
service_file="/etc/systemd/system/${service_name}.service"
|
||||
cat > "$service_file" << EOF
|
||||
[Unit]
|
||||
Description=Komari Agent Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=${komari_agent_path} ${komari_args}
|
||||
WorkingDirectory=${target_dir}
|
||||
Restart=always
|
||||
User=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# Reload systemd and start service
|
||||
systemctl daemon-reload
|
||||
systemctl enable ${service_name}.service
|
||||
systemctl start ${service_name}.service
|
||||
log_success "Systemd service configured and started"
|
||||
elif [ "$init_system" = "procd" ]; then
|
||||
# procd service configuration (OpenWrt)
|
||||
log_info "Using procd for service management"
|
||||
service_file="/etc/init.d/${service_name}"
|
||||
@@ -502,7 +588,7 @@ EOF
|
||||
/etc/init.d/${service_name} enable
|
||||
/etc/init.d/${service_name} start
|
||||
log_success "procd service configured and started"
|
||||
elif [ "$os_name" = "darwin" ] && command -v launchctl >/dev/null 2>&1; then
|
||||
elif [ "$init_system" = "launchd" ]; then
|
||||
# macOS launchd service configuration
|
||||
log_info "Using launchd for service management"
|
||||
|
||||
@@ -581,7 +667,8 @@ EOF
|
||||
fi
|
||||
fi
|
||||
else
|
||||
log_error "Unsupported init system (systemd, openrc, procd, or launchd not found)"
|
||||
log_error "Unsupported or unknown init system detected: $init_system"
|
||||
log_error "Supported init systems: systemd, openrc, procd, launchd"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
monitoring "github.com/komari-monitor/komari-agent/monitoring/unit"
|
||||
)
|
||||
|
||||
@@ -74,6 +75,46 @@ func GenerateReport() []byte {
|
||||
processcount := monitoring.ProcessCount()
|
||||
data["process"] = processcount
|
||||
|
||||
// GPU监控 - 根据标志决定详细程度
|
||||
if flags.EnableGPU {
|
||||
// 详细GPU监控模式
|
||||
gpuInfo, err := monitoring.GetDetailedGPUInfo()
|
||||
if err != nil {
|
||||
message += fmt.Sprintf("failed to get detailed GPU info: %v\n", err)
|
||||
// 降级到基础GPU信息
|
||||
gpuNames, nameErr := monitoring.GetDetailedGPUHost()
|
||||
if nameErr == nil && len(gpuNames) > 0 {
|
||||
data["gpu"] = map[string]interface{}{
|
||||
"models": gpuNames,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 成功获取详细信息
|
||||
gpuData := make([]map[string]interface{}, len(gpuInfo))
|
||||
totalGPUUsage := 0.0
|
||||
|
||||
for i, info := range gpuInfo {
|
||||
gpuData[i] = map[string]interface{}{
|
||||
"name": info.Name,
|
||||
"memory_total": info.MemoryTotal,
|
||||
"memory_used": info.MemoryUsed,
|
||||
"utilization": info.Utilization,
|
||||
"temperature": info.Temperature,
|
||||
}
|
||||
totalGPUUsage += info.Utilization
|
||||
}
|
||||
|
||||
avgGPUUsage := totalGPUUsage / float64(len(gpuInfo))
|
||||
|
||||
data["gpu"] = map[string]interface{}{
|
||||
"count": len(gpuInfo),
|
||||
"average_usage": avgGPUUsage,
|
||||
"detailed_info": gpuData,
|
||||
}
|
||||
}
|
||||
}
|
||||
// 基础模式下,GPU信息已在basicInfo中处理
|
||||
|
||||
data["message"] = message
|
||||
|
||||
s, err := json.Marshal(data)
|
||||
|
@@ -77,6 +77,7 @@ func isPhysicalDisk(part disk.PartitionStat) bool {
|
||||
"/dev/mqueue",
|
||||
"/etc/resolv.conf",
|
||||
"/etc/host", // /etc/hosts,/etc/hostname
|
||||
"/dev/hugepages",
|
||||
}
|
||||
for _, mp := range mountpointsToExclude {
|
||||
if mountpoint == mp || strings.HasPrefix(mountpoint, mp) {
|
||||
@@ -100,6 +101,7 @@ func isPhysicalDisk(part disk.PartitionStat) bool {
|
||||
"sysfs",
|
||||
"cgroup",
|
||||
"mqueue",
|
||||
"hugetlbfs",
|
||||
}
|
||||
for _, fs := range fstypeToExclude {
|
||||
if fstype == fs || strings.HasPrefix(fstype, fs) {
|
||||
|
246
monitoring/unit/gpu_amd_rocm_smi.go
Normal file
246
monitoring/unit/gpu_amd_rocm_smi.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package monitoring
|
||||
|
||||
// Modified from https://github.com/influxdata/telegraf/blob/master/plugins/inputs/amd_rocm_smi/amd_rocm_smi.go
|
||||
// Original License: MIT
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ROCmSMI struct {
|
||||
BinPath string
|
||||
data []byte
|
||||
}
|
||||
|
||||
// AMDGPUInfo AMD GPU详细信息
|
||||
type AMDGPUInfo struct {
|
||||
Name string // GPU型号
|
||||
MemoryTotal uint64 // 总显存 (字节)
|
||||
MemoryUsed uint64 // 已用显存 (字节)
|
||||
Utilization float64 // GPU使用率 (0-100)
|
||||
Temperature uint64 // 温度 (摄氏度)
|
||||
}
|
||||
|
||||
// ROCmSMI JSON响应结构
|
||||
type ROCmResponse map[string]ROCmGPUInfo
|
||||
|
||||
type ROCmGPUInfo struct {
|
||||
CardSeries string `json:"Card series"`
|
||||
GPUUsage string `json:"GPU use (%)"`
|
||||
VRAMTotalMemory string `json:"VRAM Total Memory (B)"`
|
||||
VRAMTotalUsedMemory string `json:"VRAM Total Used Memory (B)"`
|
||||
TemperatureJunction string `json:"Temperature (Sensor junction) (C)"`
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) GatherModel() ([]string, error) {
|
||||
return rsmi.gatherModel()
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) GatherUsage() ([]float64, error) {
|
||||
return rsmi.gatherUsage()
|
||||
}
|
||||
|
||||
// GatherDetailedInfo 获取详细GPU信息
|
||||
func (rsmi *ROCmSMI) GatherDetailedInfo() ([]AMDGPUInfo, error) {
|
||||
return rsmi.gatherDetailedInfo()
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) Start() error {
|
||||
if _, err := os.Stat(rsmi.BinPath); os.IsNotExist(err) {
|
||||
binPath, err := exec.LookPath("rocm-smi")
|
||||
if err != nil {
|
||||
return errors.New("rocm-smi tool not found")
|
||||
}
|
||||
rsmi.BinPath = binPath
|
||||
}
|
||||
|
||||
rsmi.data = rsmi.pollROCmSMI()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) pollROCmSMI() []byte {
|
||||
cmd := exec.Command(rsmi.BinPath, "--showallinfo", "--json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) gatherModel() ([]string, error) {
|
||||
var data map[string]interface{}
|
||||
var models []string
|
||||
|
||||
if err := json.Unmarshal(rsmi.data, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析JSON结构获取GPU型号
|
||||
for key, value := range data {
|
||||
if strings.HasPrefix(key, "card") {
|
||||
if cardData, ok := value.(map[string]interface{}); ok {
|
||||
if name, exists := cardData["Card series"]; exists {
|
||||
if nameStr, ok := name.(string); ok && nameStr != "" {
|
||||
models = append(models, nameStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) gatherUsage() ([]float64, error) {
|
||||
var data map[string]interface{}
|
||||
var usageList []float64
|
||||
|
||||
if err := json.Unmarshal(rsmi.data, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析JSON结构获取GPU使用率
|
||||
for key, value := range data {
|
||||
if strings.HasPrefix(key, "card") {
|
||||
if cardData, ok := value.(map[string]interface{}); ok {
|
||||
usage := 0.0
|
||||
if utilizationData, exists := cardData["GPU use (%)"]; exists {
|
||||
if utilizationStr, ok := utilizationData.(string); ok {
|
||||
if parsed, err := parseAMDPercentage(utilizationStr); err == nil {
|
||||
usage = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
usageList = append(usageList, usage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return usageList, nil
|
||||
}
|
||||
|
||||
func (rsmi *ROCmSMI) gatherDetailedInfo() ([]AMDGPUInfo, error) {
|
||||
if rsmi.data == nil {
|
||||
return nil, errors.New("no data available")
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
var gpuInfos []AMDGPUInfo
|
||||
|
||||
if err := json.Unmarshal(rsmi.data, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析每个GPU卡的详细信息
|
||||
for key, value := range data {
|
||||
if strings.HasPrefix(key, "card") {
|
||||
if cardData, ok := value.(map[string]interface{}); ok {
|
||||
gpuInfo := AMDGPUInfo{}
|
||||
|
||||
// 获取GPU名称
|
||||
if name, exists := cardData["Card series"]; exists {
|
||||
if nameStr, ok := name.(string); ok {
|
||||
gpuInfo.Name = nameStr
|
||||
}
|
||||
}
|
||||
|
||||
// 获取使用率
|
||||
if utilizationData, exists := cardData["GPU use (%)"]; exists {
|
||||
if utilizationStr, ok := utilizationData.(string); ok {
|
||||
if usage, err := parseAMDPercentage(utilizationStr); err == nil {
|
||||
gpuInfo.Utilization = usage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取显存信息
|
||||
if memUsedData, exists := cardData["VRAM Total Used Memory (B)"]; exists {
|
||||
if memUsedStr, ok := memUsedData.(string); ok {
|
||||
if memUsed, err := parseAMDMemoryBytes(memUsedStr); err == nil {
|
||||
gpuInfo.MemoryUsed = memUsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if memTotalData, exists := cardData["VRAM Total Memory (B)"]; exists {
|
||||
if memTotalStr, ok := memTotalData.(string); ok {
|
||||
if memTotal, err := parseAMDMemoryBytes(memTotalStr); err == nil {
|
||||
gpuInfo.MemoryTotal = memTotal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取温度信息
|
||||
if tempData, exists := cardData["Temperature (Sensor junction) (C)"]; exists {
|
||||
if tempStr, ok := tempData.(string); ok {
|
||||
if temp, err := parseAMDTemperature(tempStr); err == nil {
|
||||
gpuInfo.Temperature = temp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gpuInfos = append(gpuInfos, gpuInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return gpuInfos, nil
|
||||
}
|
||||
|
||||
// 解析AMD百分比值 (例如 "25" -> 25.0)
|
||||
func parseAMDPercentage(value string) (float64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
cleaned = strings.TrimSuffix(cleaned, "%")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0.0, nil
|
||||
}
|
||||
|
||||
result, err := strconv.ParseFloat(cleaned, 64)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 解析AMD显存字节 (例如 "1073741824" -> 1073741824字节)
|
||||
func parseAMDMemoryBytes(value string) (uint64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
bytes, err := strconv.ParseUint(cleaned, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 直接返回字节数
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
// 解析AMD温度值 (例如 "65" -> 65)
|
||||
func parseAMDTemperature(value string) (uint64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
cleaned = strings.TrimSuffix(cleaned, "C")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result, err := strconv.ParseUint(cleaned, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
31
monitoring/unit/gpu_detailed_fallback.go
Normal file
31
monitoring/unit/gpu_detailed_fallback.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build !linux
|
||||
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// DetailedGPUInfo 详细GPU信息结构体
|
||||
type DetailedGPUInfo struct {
|
||||
Name string `json:"name"` // GPU型号
|
||||
MemoryTotal uint64 `json:"memory_total"` // 总显存 (字节)
|
||||
MemoryUsed uint64 `json:"memory_used"` // 已用显存 (字节)
|
||||
Utilization float64 `json:"utilization"` // GPU使用率 (0-100)
|
||||
Temperature uint64 `json:"temperature"` // 温度 (摄氏度)
|
||||
}
|
||||
|
||||
// GetDetailedGPUHost 获取GPU型号信息 - 回退实现
|
||||
func GetDetailedGPUHost() ([]string, error) {
|
||||
return nil, errors.New("detailed GPU monitoring not supported on this platform")
|
||||
}
|
||||
|
||||
// GetDetailedGPUState 获取GPU使用率 - 回退实现
|
||||
func GetDetailedGPUState() ([]float64, error) {
|
||||
return nil, errors.New("detailed GPU monitoring not supported on this platform")
|
||||
}
|
||||
|
||||
// GetDetailedGPUInfo 获取详细GPU信息 - 回退实现
|
||||
func GetDetailedGPUInfo() ([]DetailedGPUInfo, error) {
|
||||
return nil, errors.New("detailed GPU monitoring not supported on this platform")
|
||||
}
|
213
monitoring/unit/gpu_detailed_linux.go
Normal file
213
monitoring/unit/gpu_detailed_linux.go
Normal file
@@ -0,0 +1,213 @@
|
||||
//go:build linux
|
||||
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
vendorAMD = iota + 1
|
||||
vendorNVIDIA
|
||||
)
|
||||
|
||||
var vendorType = getDetailedVendor()
|
||||
|
||||
// DetailedGPUInfo 详细GPU信息结构体
|
||||
type DetailedGPUInfo struct {
|
||||
Name string `json:"name"` // GPU型号
|
||||
MemoryTotal uint64 `json:"memory_total"` // 总显存 (字节)
|
||||
MemoryUsed uint64 `json:"memory_used"` // 已用显存 (字节)
|
||||
Utilization float64 `json:"utilization"` // GPU使用率 (0-100)
|
||||
Temperature uint64 `json:"temperature"` // 温度 (摄氏度)
|
||||
}
|
||||
|
||||
func getDetailedVendor() uint8 {
|
||||
_, err := getNvidiaDetailedStat()
|
||||
if err != nil {
|
||||
return vendorAMD
|
||||
} else {
|
||||
return vendorNVIDIA
|
||||
}
|
||||
}
|
||||
|
||||
func getNvidiaDetailedStat() ([]float64, error) {
|
||||
smi := &NvidiaSMI{
|
||||
BinPath: "/usr/bin/nvidia-smi",
|
||||
}
|
||||
err1 := smi.Start()
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
data, err2 := smi.GatherUsage()
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func getAMDDetailedStat() ([]float64, error) {
|
||||
rsmi := &ROCmSMI{
|
||||
BinPath: "/opt/rocm/bin/rocm-smi",
|
||||
}
|
||||
err := rsmi.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := rsmi.GatherUsage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func getNvidiaDetailedHost() ([]string, error) {
|
||||
smi := &NvidiaSMI{
|
||||
BinPath: "/usr/bin/nvidia-smi",
|
||||
}
|
||||
err := smi.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := smi.GatherModel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func getAMDDetailedHost() ([]string, error) {
|
||||
rsmi := &ROCmSMI{
|
||||
BinPath: "/opt/rocm/bin/rocm-smi",
|
||||
}
|
||||
err := rsmi.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := rsmi.GatherModel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// GetDetailedGPUHost 获取GPU型号信息
|
||||
func GetDetailedGPUHost() ([]string, error) {
|
||||
var gi []string
|
||||
var err error
|
||||
|
||||
switch vendorType {
|
||||
case vendorAMD:
|
||||
gi, err = getAMDDetailedHost()
|
||||
case vendorNVIDIA:
|
||||
gi, err = getNvidiaDetailedHost()
|
||||
default:
|
||||
return nil, errors.New("invalid vendor")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gi, nil
|
||||
}
|
||||
|
||||
// GetDetailedGPUState 获取GPU使用率
|
||||
func GetDetailedGPUState() ([]float64, error) {
|
||||
var gs []float64
|
||||
var err error
|
||||
|
||||
switch vendorType {
|
||||
case vendorAMD:
|
||||
gs, err = getAMDDetailedStat()
|
||||
case vendorNVIDIA:
|
||||
gs, err = getNvidiaDetailedStat()
|
||||
default:
|
||||
return nil, errors.New("invalid vendor")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gs, nil
|
||||
}
|
||||
|
||||
// GetDetailedGPUInfo 获取详细GPU信息
|
||||
func GetDetailedGPUInfo() ([]DetailedGPUInfo, error) {
|
||||
var gpuInfos []DetailedGPUInfo
|
||||
var err error
|
||||
|
||||
switch vendorType {
|
||||
case vendorAMD:
|
||||
gpuInfos, err = getAMDDetailedInfo()
|
||||
case vendorNVIDIA:
|
||||
gpuInfos, err = getNvidiaDetailedInfo()
|
||||
default:
|
||||
return nil, errors.New("invalid vendor")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gpuInfos, nil
|
||||
}
|
||||
|
||||
func getNvidiaDetailedInfo() ([]DetailedGPUInfo, error) {
|
||||
smi := &NvidiaSMI{
|
||||
BinPath: "/usr/bin/nvidia-smi",
|
||||
}
|
||||
err := smi.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := smi.GatherDetailedInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gpuInfos []DetailedGPUInfo
|
||||
for _, nvidiaInfo := range data {
|
||||
gpuInfo := DetailedGPUInfo{
|
||||
Name: nvidiaInfo.Name,
|
||||
MemoryTotal: nvidiaInfo.MemoryTotal,
|
||||
MemoryUsed: nvidiaInfo.MemoryUsed,
|
||||
Utilization: nvidiaInfo.Utilization,
|
||||
Temperature: nvidiaInfo.Temperature,
|
||||
}
|
||||
gpuInfos = append(gpuInfos, gpuInfo)
|
||||
}
|
||||
|
||||
return gpuInfos, nil
|
||||
}
|
||||
|
||||
func getAMDDetailedInfo() ([]DetailedGPUInfo, error) {
|
||||
rsmi := &ROCmSMI{
|
||||
BinPath: "/opt/rocm/bin/rocm-smi",
|
||||
}
|
||||
err := rsmi.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := rsmi.GatherDetailedInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var gpuInfos []DetailedGPUInfo
|
||||
for _, amdInfo := range data {
|
||||
gpuInfo := DetailedGPUInfo{
|
||||
Name: amdInfo.Name,
|
||||
MemoryTotal: amdInfo.MemoryTotal,
|
||||
MemoryUsed: amdInfo.MemoryUsed,
|
||||
Utilization: amdInfo.Utilization,
|
||||
Temperature: amdInfo.Temperature,
|
||||
}
|
||||
gpuInfos = append(gpuInfos, gpuInfo)
|
||||
}
|
||||
|
||||
return gpuInfos, nil
|
||||
}
|
67
monitoring/unit/gpu_detailed_test.go
Normal file
67
monitoring/unit/gpu_detailed_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDetailedGPUDetection(t *testing.T) {
|
||||
models, err := GetDetailedGPUHost()
|
||||
if err != nil {
|
||||
t.Logf("Detailed GPU detection failed (may be normal on non-Linux or non-GPU systems): %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Detected GPUs: %v", models)
|
||||
|
||||
if len(models) > 0 {
|
||||
usage, err := GetDetailedGPUState()
|
||||
if err != nil {
|
||||
t.Logf("GPU state collection failed: %v", err)
|
||||
} else {
|
||||
t.Logf("GPU usage: %v", usage)
|
||||
}
|
||||
|
||||
// 测试详细信息获取
|
||||
detailedInfo, err := GetDetailedGPUInfo()
|
||||
if err != nil {
|
||||
t.Logf("GPU detailed info collection failed: %v", err)
|
||||
} else {
|
||||
for i, info := range detailedInfo {
|
||||
t.Logf("GPU %d: %s - Memory: %dMB/%dMB, Usage: %.1f%%, Temp: %d°C",
|
||||
i, info.Name, info.MemoryUsed, info.MemoryTotal, info.Utilization, info.Temperature)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetailedGPUInfo(t *testing.T) {
|
||||
detailedInfo, err := GetDetailedGPUInfo()
|
||||
if err != nil {
|
||||
t.Logf("GPU detailed info test failed (may be normal): %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(detailedInfo) == 0 {
|
||||
t.Log("No detailed GPU info available")
|
||||
return
|
||||
}
|
||||
|
||||
for i, info := range detailedInfo {
|
||||
t.Logf("GPU %d Details:", i)
|
||||
t.Logf(" Name: %s", info.Name)
|
||||
t.Logf(" Memory Total: %d MB", info.MemoryTotal)
|
||||
t.Logf(" Memory Used: %d MB", info.MemoryUsed)
|
||||
//t.Logf(" Memory Free: %d MB", info.MemoryFree)
|
||||
t.Logf(" Utilization: %.1f%%", info.Utilization)
|
||||
t.Logf(" Temperature: %d°C", info.Temperature)
|
||||
|
||||
// 验证数据的合理性
|
||||
//if info.MemoryTotal > 0 && info.MemoryUsed+info.MemoryFree != info.MemoryTotal {
|
||||
// t.Logf("Warning: Memory usage calculation may be inconsistent for %s", info.Name)
|
||||
//}
|
||||
|
||||
if info.Utilization < 0 || info.Utilization > 100 {
|
||||
t.Errorf("Invalid utilization value for %s: %.1f%%", info.Name, info.Utilization)
|
||||
}
|
||||
}
|
||||
}
|
@@ -9,13 +9,26 @@ import (
|
||||
)
|
||||
|
||||
func GpuName() string {
|
||||
accept := []string{"vga", "nvidia", "amd", "radeon", "render"}
|
||||
// 调整优先级:专用显卡厂商优先,避免只识别集成显卡
|
||||
accept := []string{"nvidia", "amd", "radeon", "vga", "3d"}
|
||||
out, err := exec.Command("lspci").Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(out), "\n")
|
||||
|
||||
// 首先尝试找专用显卡
|
||||
for _, line := range lines {
|
||||
lower := strings.ToLower(line)
|
||||
|
||||
// 跳过集成显卡和管理控制器
|
||||
if strings.Contains(lower, "aspeed") ||
|
||||
strings.Contains(lower, "matrox") ||
|
||||
strings.Contains(lower, "management") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 优先匹配专用显卡厂商
|
||||
for _, a := range accept {
|
||||
if strings.Contains(strings.ToLower(line), a) {
|
||||
if strings.Contains(lower, a) {
|
||||
parts := strings.SplitN(line, ":", 4)
|
||||
if len(parts) >= 4 {
|
||||
return strings.TrimSpace(parts[3])
|
||||
@@ -27,6 +40,16 @@ func GpuName() string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到专用显卡,返回第一个VGA设备作为兜底
|
||||
for _, line := range lines {
|
||||
if strings.Contains(strings.ToLower(line), "vga") {
|
||||
parts := strings.SplitN(line, ":", 4)
|
||||
if len(parts) >= 3 {
|
||||
return strings.TrimSpace(parts[2])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "None"
|
||||
}
|
||||
|
200
monitoring/unit/gpu_nvidia_smi.go
Normal file
200
monitoring/unit/gpu_nvidia_smi.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package monitoring
|
||||
|
||||
// Modified from https://github.com/influxdata/telegraf/blob/master/plugins/inputs/nvidia_smi/nvidia_smi.go
|
||||
// Original License: MIT
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type NvidiaSMI struct {
|
||||
BinPath string
|
||||
data []byte
|
||||
}
|
||||
|
||||
// NVIDIAGPUInfo 包含详细的NVIDIA GPU信息
|
||||
type NVIDIAGPUInfo struct {
|
||||
Name string // GPU型号
|
||||
MemoryTotal uint64 // 总显存 (字节)
|
||||
MemoryUsed uint64 // 已用显存 (字节)
|
||||
Utilization float64 // GPU使用率 (0-100)
|
||||
Temperature uint64 // 温度 (摄氏度)
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) GatherModel() ([]string, error) {
|
||||
return smi.gatherModel()
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) GatherUsage() ([]float64, error) {
|
||||
return smi.gatherUsage()
|
||||
}
|
||||
|
||||
// GatherDetailedInfo 获取详细GPU信息
|
||||
func (smi *NvidiaSMI) GatherDetailedInfo() ([]NVIDIAGPUInfo, error) {
|
||||
return smi.gatherDetailedInfo()
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) Start() error {
|
||||
if _, err := os.Stat(smi.BinPath); os.IsNotExist(err) {
|
||||
binPath, err := exec.LookPath("nvidia-smi")
|
||||
if err != nil {
|
||||
return errors.New("nvidia-smi tool not found")
|
||||
}
|
||||
smi.BinPath = binPath
|
||||
}
|
||||
smi.data = smi.pollNvidiaSMI()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) pollNvidiaSMI() []byte {
|
||||
cmd := exec.Command(smi.BinPath, "-q", "-x")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) gatherModel() ([]string, error) {
|
||||
var stats nvidiaSMIXMLResult
|
||||
var models []string
|
||||
|
||||
if err := xml.Unmarshal(smi.data, &stats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, gpu := range stats.GPUs {
|
||||
if gpu.ProductName != "" {
|
||||
models = append(models, gpu.ProductName)
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) gatherUsage() ([]float64, error) {
|
||||
var stats nvidiaSMIXMLResult
|
||||
var usageList []float64
|
||||
|
||||
if err := xml.Unmarshal(smi.data, &stats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, gpu := range stats.GPUs {
|
||||
usage, err := parsePercentageValue(gpu.Utilization.GPUUtil)
|
||||
if err != nil {
|
||||
usage = 0.0 // 默认为0,不中断处理
|
||||
}
|
||||
usageList = append(usageList, usage)
|
||||
}
|
||||
|
||||
return usageList, nil
|
||||
}
|
||||
|
||||
func (smi *NvidiaSMI) gatherDetailedInfo() ([]NVIDIAGPUInfo, error) {
|
||||
var stats nvidiaSMIXMLResult
|
||||
var gpuInfos []NVIDIAGPUInfo
|
||||
|
||||
if err := xml.Unmarshal(smi.data, &stats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, gpu := range stats.GPUs {
|
||||
utilization, _ := parsePercentageValue(gpu.Utilization.GPUUtil)
|
||||
memTotal, _ := parseMemoryValue(gpu.FrameBufferMemoryUsage.Total)
|
||||
memUsed, _ := parseMemoryValue(gpu.FrameBufferMemoryUsage.Used)
|
||||
temp, _ := parseTemperatureValue(gpu.Temperature.GPUTemp)
|
||||
|
||||
gpuInfo := NVIDIAGPUInfo{
|
||||
Name: gpu.ProductName,
|
||||
MemoryTotal: memTotal,
|
||||
MemoryUsed: memUsed,
|
||||
Utilization: utilization,
|
||||
Temperature: temp,
|
||||
}
|
||||
|
||||
gpuInfos = append(gpuInfos, gpuInfo)
|
||||
}
|
||||
|
||||
return gpuInfos, nil
|
||||
}
|
||||
|
||||
// 解析百分比值 (例如 "25 %" -> 25.0)
|
||||
func parsePercentageValue(value string) (float64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
cleaned = strings.TrimSuffix(cleaned, "%")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0.0, nil
|
||||
}
|
||||
|
||||
result, err := strconv.ParseFloat(cleaned, 64)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 解析内存值 (例如 "1024 MiB" -> 1073741824字节)
|
||||
func parseMemoryValue(value string) (uint64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
cleaned = strings.TrimSuffix(cleaned, "MiB")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result, err := strconv.ParseUint(cleaned, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 转换MiB为字节 (1 MiB = 1024*1024 bytes)
|
||||
return result * 1024 * 1024, nil
|
||||
}
|
||||
|
||||
// 解析温度值 (例如 "65 C" -> 65)
|
||||
func parseTemperatureValue(value string) (uint64, error) {
|
||||
cleaned := strings.TrimSpace(value)
|
||||
cleaned = strings.TrimSuffix(cleaned, "C")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
if cleaned == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result, err := strconv.ParseUint(cleaned, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// NVIDIA-SMI XML结构定义
|
||||
type nvidiaSMIXMLResult struct {
|
||||
GPUs []nvidiaSMIGPU `xml:"gpu"`
|
||||
}
|
||||
|
||||
type nvidiaSMIGPU struct {
|
||||
ProductName string `xml:"product_name"`
|
||||
Utilization struct {
|
||||
GPUUtil string `xml:"gpu_util"`
|
||||
} `xml:"utilization"`
|
||||
FrameBufferMemoryUsage struct {
|
||||
Total string `xml:"total"`
|
||||
Used string `xml:"used"`
|
||||
Free string `xml:"free"`
|
||||
} `xml:"fb_memory_usage"`
|
||||
Temperature struct {
|
||||
GPUTemp string `xml:"gpu_temp"`
|
||||
} `xml:"temperature"`
|
||||
}
|
@@ -108,7 +108,7 @@ func GetIPv6Address() (string, error) {
|
||||
}
|
||||
|
||||
// 使用正则表达式从响应体中提取IPv6地址
|
||||
re := regexp.MustCompile(`(([0-9A-Fa-f]{1,4}:){7})([0-9A-Fa-f]{1,4})|(([0-9A-Fa-f]{1,4}:){1,6}:)(([0-9A-Fa-f]{1,4}:){0,4})([0-9A-Fa-f]{1,4})`)
|
||||
re := regexp.MustCompile(`(([0-9A-Fa-f]{1,4}:){7})([0-9A-Fa-f]{1,4})|(([0-9A-Fa-f]{1,4}:){1,6}:)(([0-9A-Fa-f]{1,4}:){0,4})([0-9A-Fa-f]{0,4})`)
|
||||
ipv6 := re.FindString(string(body))
|
||||
if ipv6 != "" {
|
||||
log.Printf("Get IPV6 Success: %s", ipv6)
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/komari-monitor/komari-agent/cmd/flags"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
)
|
||||
@@ -24,7 +26,11 @@ func Ram() RamInfo {
|
||||
return raminfo
|
||||
}
|
||||
raminfo.Total = v.Total
|
||||
raminfo.Used = v.Total - v.Available
|
||||
if runtime.GOOS == "windows" {
|
||||
raminfo.Used = v.Total - v.Available
|
||||
} else {
|
||||
raminfo.Used = v.Total - v.Free - v.Buffers - v.Cached
|
||||
}
|
||||
|
||||
return raminfo
|
||||
}
|
||||
|
@@ -87,7 +87,6 @@ func tryUploadData(data map[string]interface{}) error {
|
||||
req.Header.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
|
||||
}
|
||||
|
||||
// 使用dnsresolver获取自定义HTTP客户端
|
||||
client := dnsresolver.GetHTTPClient(30 * time.Second)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -91,20 +92,9 @@ func EstablishWebSocketConnection() {
|
||||
}
|
||||
|
||||
func connectWebSocket(websocketEndpoint string) (*ws.SafeConn, error) {
|
||||
// 使用dnsresolver获取自定义网络拨号器
|
||||
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
|
||||
dialer := newWSDialer()
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
NetDialContext: netDialer.DialContext,
|
||||
}
|
||||
|
||||
// 创建请求头并添加Cloudflare Access头部
|
||||
headers := http.Header{}
|
||||
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
|
||||
headers.Set("CF-Access-Client-Id", flags.CFAccessClientID)
|
||||
headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
|
||||
}
|
||||
headers := newWSHeaders()
|
||||
|
||||
conn, resp, err := dialer.Dial(websocketEndpoint, headers)
|
||||
if err != nil {
|
||||
@@ -165,20 +155,10 @@ func establishTerminalConnection(token, id, endpoint string) {
|
||||
endpoint = strings.TrimSuffix(endpoint, "/") + "/api/clients/terminal?token=" + token + "&id=" + id
|
||||
endpoint = "ws" + strings.TrimPrefix(endpoint, "http")
|
||||
|
||||
// 使用dnsresolver获取自定义网络拨号器
|
||||
netDialer := dnsresolver.GetNetDialer(5 * time.Second)
|
||||
// 使用与主 WS 相同的拨号策略
|
||||
dialer := newWSDialer()
|
||||
|
||||
dialer := &websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
NetDialContext: netDialer.DialContext,
|
||||
}
|
||||
|
||||
// 创建请求头并添加Cloudflare Access头部
|
||||
headers := http.Header{}
|
||||
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
|
||||
headers.Set("CF-Access-Client-Id", flags.CFAccessClientID)
|
||||
headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
|
||||
}
|
||||
headers := newWSHeaders()
|
||||
|
||||
conn, _, err := dialer.Dial(endpoint, headers)
|
||||
if err != nil {
|
||||
@@ -192,3 +172,25 @@ func establishTerminalConnection(token, id, endpoint string) {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// newWSDialer 构造统一的 WebSocket 拨号器(自定义解析、IPv4/IPv6 动态排序、可选 TLS 忽略)
|
||||
func newWSDialer() *websocket.Dialer {
|
||||
d := &websocket.Dialer{
|
||||
HandshakeTimeout: 15 * time.Second,
|
||||
NetDialContext: dnsresolver.GetDialContext(15 * time.Second),
|
||||
}
|
||||
if flags.IgnoreUnsafeCert {
|
||||
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// newWSHeaders 统一构造 WS 请求头(含 Cloudflare Access 头)
|
||||
func newWSHeaders() http.Header {
|
||||
headers := http.Header{}
|
||||
if flags.CFAccessClientID != "" && flags.CFAccessClientSecret != "" {
|
||||
headers.Set("CF-Access-Client-Id", flags.CFAccessClientID)
|
||||
headers.Set("CF-Access-Client-Secret", flags.CFAccessClientSecret)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
@@ -47,9 +47,7 @@ func CheckAndUpdate() error {
|
||||
return fmt.Errorf("failed to parse current version: %v", err)
|
||||
}
|
||||
|
||||
// 使用dnsresolver创建自定义HTTP客户端并设置为全局默认客户端
|
||||
// 这会影响所有HTTP请求,包括selfupdate库中的请求
|
||||
http.DefaultClient = dnsresolver.GetHTTPClient(60 * time.Second) // Create selfupdate configuration
|
||||
http.DefaultClient = dnsresolver.GetHTTPClient(60 * time.Second)
|
||||
config := selfupdate.Config{}
|
||||
updater, err := selfupdate.NewUpdater(config)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user