diff --git a/cmd/flags/flag.go b/cmd/flags/flag.go index fc00d88..bf2a34d 100644 --- a/cmd/flags/flag.go +++ b/cmd/flags/flag.go @@ -11,4 +11,6 @@ var ( MaxRetries int ReconnectInterval int InfoReportInterval int + IncludeNics string + ExcludeNics string ) diff --git a/cmd/root.go b/cmd/root.go index 8ede6df..a1e904e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -68,5 +68,7 @@ func init() { RootCmd.PersistentFlags().IntVarP(&flags.MaxRetries, "max-retries", "r", 3, "Maximum number of retries") RootCmd.PersistentFlags().IntVarP(&flags.ReconnectInterval, "reconnect-interval", "c", 5, "Reconnect interval in seconds") RootCmd.PersistentFlags().IntVar(&flags.InfoReportInterval, "info-report-interval", 5, "Interval in minutes for reporting basic info") + RootCmd.PersistentFlags().StringVar(&flags.IncludeNics, "include-nics", "", "Comma-separated list of network interfaces to include") + RootCmd.PersistentFlags().StringVar(&flags.ExcludeNics, "exclude-nics", "", "Comma-separated list of network interfaces to exclude") RootCmd.PersistentFlags().ParseErrorsWhitelist.UnknownFlags = true } diff --git a/monitoring/unit/net.go b/monitoring/unit/net.go index c3193cb..d605b00 100644 --- a/monitoring/unit/net.go +++ b/monitoring/unit/net.go @@ -2,8 +2,10 @@ package monitoring import ( "fmt" + "strings" "time" + "github.com/komari-monitor/komari-agent/cmd/flags" "github.com/shirou/gopsutil/net" ) @@ -31,8 +33,11 @@ var ( ) func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { + includeNics := parseNics(flags.IncludeNics) + excludeNics := parseNics(flags.ExcludeNics) + // 获取第一次网络IO计数器 - ioCounters1, err := net.IOCounters(false) + ioCounters1, err := net.IOCounters(true) if err != nil { return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) } @@ -44,19 +49,17 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { // 统计第一次所有非回环接口的流量 var totalUp1, totalDown1 uint64 for _, interfaceStats := range ioCounters1 { - // 使用映射表进行O(1)查找 - if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { - continue // 跳过回环接口 + if shouldInclude(interfaceStats.Name, includeNics, excludeNics) { + totalUp1 += interfaceStats.BytesSent + totalDown1 += interfaceStats.BytesRecv } - totalUp1 += interfaceStats.BytesSent - totalDown1 += interfaceStats.BytesRecv } // 等待1秒 time.Sleep(time.Second) // 获取第二次网络IO计数器 - ioCounters2, err := net.IOCounters(false) + ioCounters2, err := net.IOCounters(true) if err != nil { return 0, 0, 0, 0, fmt.Errorf("failed to get network IO counters: %w", err) } @@ -68,11 +71,10 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { // 统计第二次所有非回环接口的流量 var totalUp2, totalDown2 uint64 for _, interfaceStats := range ioCounters2 { - if _, isLoopback := loopbackNames[interfaceStats.Name]; isLoopback { - continue // 跳过回环接口 + if shouldInclude(interfaceStats.Name, includeNics, excludeNics) { + totalUp2 += interfaceStats.BytesSent + totalDown2 += interfaceStats.BytesRecv } - totalUp2 += interfaceStats.BytesSent - totalDown2 += interfaceStats.BytesRecv } // 计算速度 (每秒的速率) @@ -81,3 +83,36 @@ func NetworkSpeed() (totalUp, totalDown, upSpeed, downSpeed uint64, err error) { return totalUp2, totalDown2, upSpeed, downSpeed, nil } + +func parseNics(nics string) map[string]struct{} { + if nics == "" { + return nil + } + nicSet := make(map[string]struct{}) + for _, nic := range strings.Split(nics, ",") { + nicSet[strings.TrimSpace(nic)] = struct{}{} + } + return nicSet +} + +func shouldInclude(nicName string, includeNics, excludeNics map[string]struct{}) bool { + // 默认排除回环接口 + if _, isLoopback := loopbackNames[nicName]; isLoopback { + return false + } + + // 如果定义了白名单,则只包括白名单中的接口 + if len(includeNics) > 0 { + _, ok := includeNics[nicName] + return ok + } + + // 如果定义了黑名单,则排除黑名单中的接口 + if len(excludeNics) > 0 { + if _, ok := excludeNics[nicName]; ok { + return false + } + } + + return true +}