feat: 添加网络接口包含和排除选项

This commit is contained in:
Akizon77
2025-07-03 14:01:57 +08:00
parent ba8fc8c6d7
commit b18f79bf4e
3 changed files with 50 additions and 11 deletions

View File

@@ -11,4 +11,6 @@ var (
MaxRetries int MaxRetries int
ReconnectInterval int ReconnectInterval int
InfoReportInterval int InfoReportInterval int
IncludeNics string
ExcludeNics string
) )

View File

@@ -68,5 +68,7 @@ func init() {
RootCmd.PersistentFlags().IntVarP(&flags.MaxRetries, "max-retries", "r", 3, "Maximum number of retries") RootCmd.PersistentFlags().IntVarP(&flags.MaxRetries, "max-retries", "r", 3, "Maximum number of retries")
RootCmd.PersistentFlags().IntVarP(&flags.ReconnectInterval, "reconnect-interval", "c", 5, "Reconnect interval in seconds") RootCmd.PersistentFlags().IntVarP(&flags.ReconnectInterval, "reconnect-interval", "c", 5, "Reconnect interval in seconds")
RootCmd.PersistentFlags().IntVar(&flags.InfoReportInterval, "info-report-interval", 5, "Interval in minutes for reporting basic info") RootCmd.PersistentFlags().IntVar(&flags.InfoReportInterval, "info-report-interval", 5, "Interval in minutes for reporting basic info")
RootCmd.PersistentFlags().StringVar(&flags.IncludeNics, "include-nics", "", "Comma-separated list of network interfaces to include")
RootCmd.PersistentFlags().StringVar(&flags.ExcludeNics, "exclude-nics", "", "Comma-separated list of network interfaces to exclude")
RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true
} }

View File

@@ -2,8 +2,10 @@ package monitoring
import ( import (
"fmt" "fmt"
"strings"
"time" "time"
"github.com/komari-monitor/komari-agent/cmd/flags"
"github.com/shirou/gopsutil/net" "github.com/shirou/gopsutil/net"
) )
@@ -31,8 +33,11 @@ var (
) )
func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) {
includeNics := parseNics(flags.IncludeNics)
excludeNics := parseNics(flags.ExcludeNics)
// 获取第一次网络IO计数器 // 获取第一次网络IO计数器
ioCounters1, err := net.IOCounters(false) ioCounters1, err := net.IOCounters(true)
if err != nil { if err != nil {
return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err)
} }
@@ -44,19 +49,17 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) {
// 统计第一次所有非回环接口的流量 // 统计第一次所有非回环接口的流量
var totalUp1, totalDown1 uint64 var totalUp1, totalDown1 uint64
for _, interfaceStats := range ioCounters1 { for _, interfaceStats := range ioCounters1 {
// 使用映射表进行O(1)查找 if shouldInclude(interfaceStats.Name, includeNics, excludeNics) {
if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { totalUp1 += interfaceStats.BytesSent
continue // 跳过回环接口 totalDown1 += interfaceStats.BytesRecv
} }
totalUp1 += interfaceStats.BytesSent
totalDown1 += interfaceStats.BytesRecv
} }
// 等待1秒 // 等待1秒
time.Sleep(time.Second) time.Sleep(time.Second)
// 获取第二次网络IO计数器 // 获取第二次网络IO计数器
ioCounters2, err := net.IOCounters(false) ioCounters2, err := net.IOCounters(true)
if err != nil { if err != nil {
return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err)
} }
@@ -68,11 +71,10 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) {
// 统计第二次所有非回环接口的流量 // 统计第二次所有非回环接口的流量
var totalUp2, totalDown2 uint64 var totalUp2, totalDown2 uint64
for _, interfaceStats := range ioCounters2 { for _, interfaceStats := range ioCounters2 {
if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { if shouldInclude(interfaceStats.Name, includeNics, excludeNics) {
continue // 跳过回环接口 totalUp2 += interfaceStats.BytesSent
totalDown2 += interfaceStats.BytesRecv
} }
totalUp2 += interfaceStats.BytesSent
totalDown2 += interfaceStats.BytesRecv
} }
// 计算速度 (每秒的速率) // 计算速度 (每秒的速率)
@@ -81,3 +83,36 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) {
return totalUp2, totalDown2, upSpeed, downSpeed, nil return totalUp2, totalDown2, upSpeed, downSpeed, nil
} }
func parseNics(nics string) map[string]struct{} {
if nics == "" {
return nil
}
nicSet := make(map[string]struct{})
for _, nic := range strings.Split(nics, ",") {
nicSet[strings.TrimSpace(nic)] = struct{}{}
}
return nicSet
}
func shouldInclude(nicName string, includeNics, excludeNics map[string]struct{}) bool {
// 默认排除回环接口
if _, isLoopback := loopbackNames[nicName]; isLoopback {
return false
}
// 如果定义了白名单,则只包括白名单中的接口
if len(includeNics) > 0 {
_, ok := includeNics[nicName]
return ok
}
// 如果定义了黑名单,则排除黑名单中的接口
if len(excludeNics) > 0 {
if _, ok := excludeNics[nicName]; ok {
return false
}
}
return true
}