mirror of
https://github.com/fankes/beszel.git
synced 2025-10-19 17:59:28 +08:00
feat(agent): NETWORK env var and support for multiple keys
- merges agent.Run with agent.NewAgent - separates StartServer method - bumps go version to 1.24 - add tests
This commit is contained in:
@@ -8,12 +8,19 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
type cmdConfig struct {
|
||||||
// Define flags for key and port
|
key string // key is the public key(s) for SSH authentication.
|
||||||
keyFlag := flag.String("key", "", "Public key")
|
addr string // addr is the address or port to listen on.
|
||||||
portFlag := flag.String("port", "45876", "Port number")
|
}
|
||||||
|
|
||||||
|
// parseFlags parses the command line flags and populates the config struct.
|
||||||
|
func parseFlags(cfg *cmdConfig) {
|
||||||
|
flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication")
|
||||||
|
flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on")
|
||||||
|
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0])
|
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0])
|
||||||
@@ -24,14 +31,16 @@ func main() {
|
|||||||
fmt.Println(" help Display this help message")
|
fmt.Println(" help Display this help message")
|
||||||
fmt.Println(" update Update the agent to the latest version")
|
fmt.Println(" update Update the agent to the latest version")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the flags
|
// handleSubcommand handles subcommands such as version, help, and update.
|
||||||
flag.Parse()
|
// It returns true if a subcommand was handled, false otherwise.
|
||||||
|
func handleSubcommand() bool {
|
||||||
// handle flags / subcommands
|
if len(os.Args) <= 1 {
|
||||||
if len(os.Args) > 1 {
|
return false
|
||||||
|
}
|
||||||
switch os.Args[1] {
|
switch os.Args[1] {
|
||||||
case "version":
|
case "version", "-v":
|
||||||
fmt.Println(beszel.AppName+"-agent", beszel.Version)
|
fmt.Println(beszel.AppName+"-agent", beszel.Version)
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
case "help":
|
case "help":
|
||||||
@@ -41,48 +50,84 @@ func main() {
|
|||||||
agent.Update()
|
agent.Update()
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var pubKey []byte
|
// loadPublicKeys loads the public keys from the command line flag, environment variable, or key file.
|
||||||
// Override the key if the -key flag is provided
|
func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
|
||||||
if *keyFlag != "" {
|
// Try command line flag first
|
||||||
pubKey = []byte(*keyFlag)
|
if cfg.key != "" {
|
||||||
} else {
|
return agent.ParseKeys(cfg.key)
|
||||||
// Try to get the key from the KEY environment variable.
|
|
||||||
key, _ := agent.GetEnv("KEY")
|
|
||||||
pubKey = []byte(key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If KEY is not set, try to read the key from the file specified by KEY_FILE.
|
// Try environment variable
|
||||||
if len(pubKey) == 0 {
|
if key, ok := agent.GetEnv("KEY"); ok && key != "" {
|
||||||
keyFile, exists := agent.GetEnv("KEY_FILE")
|
return agent.ParseKeys(key)
|
||||||
if !exists {
|
|
||||||
log.Fatal("Must set KEY or KEY_FILE environment variable or supply as input argument. Use 'beszel-agent help' for more information.")
|
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
pubKey, err = os.ReadFile(keyFile)
|
// Try key file
|
||||||
|
keyFile, ok := agent.GetEnv("KEY_FILE")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ")
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, err := os.ReadFile(keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
return nil, fmt.Errorf("failed to read key file: %w", err)
|
||||||
}
|
}
|
||||||
|
return agent.ParseKeys(string(pubKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init with default port
|
// getAddress gets the address to listen on from the command line flag, environment variable, or default value.
|
||||||
addr := ":" + *portFlag
|
func getAddress(addr string) string {
|
||||||
|
// Try command line flag first
|
||||||
//Use port from ENV if it exists
|
if addr != "" {
|
||||||
// TODO: change env var to ADDR
|
return addr
|
||||||
if portEnvVar, exists := agent.GetEnv("PORT"); exists {
|
|
||||||
// allow passing an address in the form of "127.0.0.1:45876"
|
|
||||||
if !strings.Contains(portEnvVar, ":") {
|
|
||||||
portEnvVar = ":" + portEnvVar
|
|
||||||
}
|
}
|
||||||
addr = portEnvVar
|
// Try environment variables
|
||||||
|
if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
// Legacy PORT environment variable support
|
||||||
|
if port, ok := agent.GetEnv("PORT"); ok && port != "" {
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
return ":45876"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override the default and ENV port if the -port flag is provided and is non default
|
// getNetwork returns the network type to use for the server.
|
||||||
if *portFlag != "45876" {
|
func getNetwork(addr string) string {
|
||||||
addr = ":" + *portFlag
|
if network, _ := agent.GetEnv("NETWORK"); network != "" {
|
||||||
|
return network
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(addr, "/") {
|
||||||
|
return "unix"
|
||||||
|
}
|
||||||
|
return "tcp"
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.NewAgent().Run(pubKey, addr)
|
func main() {
|
||||||
|
var cfg cmdConfig
|
||||||
|
parseFlags(&cfg)
|
||||||
|
|
||||||
|
if handleSubcommand() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
var serverConfig agent.ServerConfig
|
||||||
|
var err error
|
||||||
|
serverConfig.Keys, err = loadPublicKeys(cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to load public keys:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfig.Addr = getAddress(cfg.addr)
|
||||||
|
serverConfig.Network = getNetwork(cfg.addr)
|
||||||
|
|
||||||
|
agent := agent.NewAgent()
|
||||||
|
if err := agent.StartServer(serverConfig); err != nil {
|
||||||
|
log.Fatal("Failed to start server:", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
285
beszel/cmd/agent/agent_test.go
Normal file
285
beszel/cmd/agent/agent_test.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetAddress(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg cmdConfig
|
||||||
|
envVars map[string]string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default port when no config",
|
||||||
|
cfg: cmdConfig{},
|
||||||
|
expected: ":45876",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use address from flag",
|
||||||
|
cfg: cmdConfig{
|
||||||
|
addr: "8080",
|
||||||
|
},
|
||||||
|
expected: "8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use unix socket from flag",
|
||||||
|
cfg: cmdConfig{
|
||||||
|
addr: "/tmp/beszel.sock",
|
||||||
|
},
|
||||||
|
expected: "/tmp/beszel.sock",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use ADDR env var",
|
||||||
|
cfg: cmdConfig{},
|
||||||
|
envVars: map[string]string{
|
||||||
|
"ADDR": "1.2.3.4:9090",
|
||||||
|
},
|
||||||
|
expected: "1.2.3.4:9090",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use legacy PORT env var",
|
||||||
|
cfg: cmdConfig{},
|
||||||
|
envVars: map[string]string{
|
||||||
|
"PORT": "7070",
|
||||||
|
},
|
||||||
|
expected: "7070",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "flag takes precedence over env vars",
|
||||||
|
cfg: cmdConfig{
|
||||||
|
addr: ":8080",
|
||||||
|
},
|
||||||
|
envVars: map[string]string{
|
||||||
|
"ADDR": ":9090",
|
||||||
|
"PORT": "7070",
|
||||||
|
},
|
||||||
|
expected: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Setup environment
|
||||||
|
for k, v := range tt.envVars {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := getAddress(tt.cfg.addr)
|
||||||
|
assert.Equal(t, tt.expected, addr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadPublicKeys(t *testing.T) {
|
||||||
|
// Generate a test key
|
||||||
|
_, priv, err := ed25519.GenerateKey(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
signer, err := ssh.NewSignerFromKey(priv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg cmdConfig
|
||||||
|
envVars map[string]string
|
||||||
|
setupFiles map[string][]byte
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "load key from flag",
|
||||||
|
cfg: cmdConfig{
|
||||||
|
key: string(pubKey),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "load key from env var",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"KEY": string(pubKey),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "load key from file",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"KEY_FILE": "testkey.pub",
|
||||||
|
},
|
||||||
|
setupFiles: map[string][]byte{
|
||||||
|
"testkey.pub": pubKey,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error when no key provided",
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "no key provided",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error on invalid key file",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"KEY_FILE": "nonexistent.pub",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "failed to read key file",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error on invalid key data",
|
||||||
|
cfg: cmdConfig{
|
||||||
|
key: "invalid-key-data",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create a temporary directory for test files
|
||||||
|
if len(tt.setupFiles) > 0 {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
for name, content := range tt.setupFiles {
|
||||||
|
path := filepath.Join(tmpDir, name)
|
||||||
|
err := os.WriteFile(path, content, 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if tt.envVars != nil {
|
||||||
|
tt.envVars["KEY_FILE"] = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up environment
|
||||||
|
for k, v := range tt.envVars {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := loadPublicKeys(tt.cfg)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, keys, 1)
|
||||||
|
assert.Equal(t, signer.PublicKey().Type(), keys[0].Type())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
envVars map[string]string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "only port",
|
||||||
|
addr: "8080",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 address",
|
||||||
|
addr: "1.2.3.4:8080",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 address",
|
||||||
|
addr: "[2001:db8::1]:8080",
|
||||||
|
expected: "tcp",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix network",
|
||||||
|
addr: "/tmp/beszel.sock",
|
||||||
|
expected: "unix",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "env var network",
|
||||||
|
addr: ":8080",
|
||||||
|
envVars: map[string]string{"NETWORK": "tcp4"},
|
||||||
|
expected: "tcp4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Setup environment
|
||||||
|
for k, v := range tt.envVars {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
network := getNetwork(tt.addr)
|
||||||
|
assert.Equal(t, tt.expected, network)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFlags(t *testing.T) {
|
||||||
|
// Save original command line arguments and restore after test
|
||||||
|
oldArgs := os.Args
|
||||||
|
defer func() {
|
||||||
|
os.Args = oldArgs
|
||||||
|
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
expected cmdConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no flags",
|
||||||
|
args: []string{"cmd"},
|
||||||
|
expected: cmdConfig{
|
||||||
|
key: "",
|
||||||
|
addr: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "key flag only",
|
||||||
|
args: []string{"cmd", "-key", "testkey"},
|
||||||
|
expected: cmdConfig{
|
||||||
|
key: "testkey",
|
||||||
|
addr: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "addr flag only",
|
||||||
|
args: []string{"cmd", "-addr", ":8080"},
|
||||||
|
expected: cmdConfig{
|
||||||
|
key: "",
|
||||||
|
addr: ":8080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both flags",
|
||||||
|
args: []string{"cmd", "-key", "testkey", "-addr", ":8080"},
|
||||||
|
expected: cmdConfig{
|
||||||
|
key: "testkey",
|
||||||
|
addr: ":8080",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset flags for each test
|
||||||
|
flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError)
|
||||||
|
os.Args = tt.args
|
||||||
|
|
||||||
|
var cfg cmdConfig
|
||||||
|
parseFlags(&cfg)
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expected, cfg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -1,8 +1,6 @@
|
|||||||
module beszel
|
module beszel
|
||||||
|
|
||||||
go 1.23
|
go 1.24
|
||||||
|
|
||||||
toolchain go1.23.2
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/blang/semver v3.5.1+incompatible
|
github.com/blang/semver v3.5.1+incompatible
|
||||||
@@ -15,6 +13,7 @@ require (
|
|||||||
github.com/shirou/gopsutil/v4 v4.25.1
|
github.com/shirou/gopsutil/v4 v4.25.1
|
||||||
github.com/spf13/cast v1.7.1
|
github.com/spf13/cast v1.7.1
|
||||||
github.com/spf13/cobra v1.8.1
|
github.com/spf13/cobra v1.8.1
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
golang.org/x/crypto v0.32.0
|
golang.org/x/crypto v0.32.0
|
||||||
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c
|
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
@@ -28,29 +28,16 @@ type Agent struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent() *Agent {
|
func NewAgent() *Agent {
|
||||||
newAgent := &Agent{
|
agent := &Agent{
|
||||||
sensorsContext: context.Background(),
|
|
||||||
fsStats: make(map[string]*system.FsStats),
|
fsStats: make(map[string]*system.FsStats),
|
||||||
}
|
}
|
||||||
newAgent.memCalc, _ = GetEnv("MEM_CALC")
|
agent.memCalc, _ = GetEnv("MEM_CALC")
|
||||||
return newAgent
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key.
|
|
||||||
func GetEnv(key string) (value string, exists bool) {
|
|
||||||
if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists {
|
|
||||||
return value, exists
|
|
||||||
}
|
|
||||||
// Fallback to the old unprefixed key
|
|
||||||
return os.LookupEnv(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Agent) Run(pubKey []byte, addr string) {
|
|
||||||
// Set up slog with a log level determined by the LOG_LEVEL env var
|
// Set up slog with a log level determined by the LOG_LEVEL env var
|
||||||
if logLevelStr, exists := GetEnv("LOG_LEVEL"); exists {
|
if logLevelStr, exists := GetEnv("LOG_LEVEL"); exists {
|
||||||
switch strings.ToLower(logLevelStr) {
|
switch strings.ToLower(logLevelStr) {
|
||||||
case "debug":
|
case "debug":
|
||||||
a.debug = true
|
agent.debug = true
|
||||||
slog.SetLogLoggerLevel(slog.LevelDebug)
|
slog.SetLogLoggerLevel(slog.LevelDebug)
|
||||||
case "warn":
|
case "warn":
|
||||||
slog.SetLogLoggerLevel(slog.LevelWarn)
|
slog.SetLogLoggerLevel(slog.LevelWarn)
|
||||||
@@ -64,40 +51,51 @@ func (a *Agent) Run(pubKey []byte, addr string) {
|
|||||||
// Set sensors context (allows overriding sys location for sensors)
|
// Set sensors context (allows overriding sys location for sensors)
|
||||||
if sysSensors, exists := GetEnv("SYS_SENSORS"); exists {
|
if sysSensors, exists := GetEnv("SYS_SENSORS"); exists {
|
||||||
slog.Info("SYS_SENSORS", "path", sysSensors)
|
slog.Info("SYS_SENSORS", "path", sysSensors)
|
||||||
a.sensorsContext = context.WithValue(a.sensorsContext,
|
agent.sensorsContext = context.WithValue(agent.sensorsContext,
|
||||||
common.EnvKey, common.EnvMap{common.HostSysEnvKey: sysSensors},
|
common.EnvKey, common.EnvMap{common.HostSysEnvKey: sysSensors},
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
agent.sensorsContext = context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set sensors whitelist
|
// Set sensors whitelist
|
||||||
if sensors, exists := GetEnv("SENSORS"); exists {
|
if sensors, exists := GetEnv("SENSORS"); exists {
|
||||||
a.sensorsWhitelist = make(map[string]struct{})
|
agent.sensorsWhitelist = make(map[string]struct{})
|
||||||
for _, sensor := range strings.Split(sensors, ",") {
|
for _, sensor := range strings.Split(sensors, ",") {
|
||||||
if sensor != "" {
|
if sensor != "" {
|
||||||
a.sensorsWhitelist[sensor] = struct{}{}
|
agent.sensorsWhitelist[sensor] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize system info / docker manager
|
// initialize system info / docker manager
|
||||||
a.initializeSystemInfo()
|
agent.initializeSystemInfo()
|
||||||
a.initializeDiskInfo()
|
agent.initializeDiskInfo()
|
||||||
a.initializeNetIoStats()
|
agent.initializeNetIoStats()
|
||||||
a.dockerManager = newDockerManager(a)
|
agent.dockerManager = newDockerManager(agent)
|
||||||
|
|
||||||
// initialize GPU manager
|
// initialize GPU manager
|
||||||
if gm, err := NewGPUManager(); err != nil {
|
if gm, err := NewGPUManager(); err != nil {
|
||||||
slog.Debug("GPU", "err", err)
|
slog.Debug("GPU", "err", err)
|
||||||
} else {
|
} else {
|
||||||
a.gpuManager = gm
|
agent.gpuManager = gm
|
||||||
}
|
}
|
||||||
|
|
||||||
// if debugging, print stats
|
// if debugging, print stats
|
||||||
if a.debug {
|
if agent.debug {
|
||||||
slog.Debug("Stats", "data", a.gatherStats())
|
slog.Debug("Stats", "data", agent.gatherStats())
|
||||||
}
|
}
|
||||||
|
|
||||||
a.startServer(pubKey, addr)
|
return agent
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key.
|
||||||
|
func GetEnv(key string) (value string, exists bool) {
|
||||||
|
if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists {
|
||||||
|
return value, exists
|
||||||
|
}
|
||||||
|
// Fallback to the old unprefixed key
|
||||||
|
return os.LookupEnv(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) gatherStats() system.CombinedData {
|
func (a *Agent) gatherStats() system.CombinedData {
|
||||||
|
@@ -2,33 +2,96 @@ package agent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
sshServer "github.com/gliderlabs/ssh"
|
sshServer "github.com/gliderlabs/ssh"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *Agent) startServer(pubKey []byte, addr string) {
|
type ServerConfig struct {
|
||||||
|
Addr string
|
||||||
|
Network string
|
||||||
|
Keys []ssh.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) StartServer(cfg ServerConfig) error {
|
||||||
sshServer.Handle(a.handleSession)
|
sshServer.Handle(a.handleSession)
|
||||||
|
|
||||||
slog.Info("Starting SSH server", "address", addr)
|
slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network)
|
||||||
if err := sshServer.ListenAndServe(addr, nil, sshServer.NoPty(),
|
|
||||||
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
switch cfg.Network {
|
||||||
allowed, _, _, _, _ := sshServer.ParseAuthorizedKey(pubKey)
|
case "unix":
|
||||||
return sshServer.KeysEqual(key, allowed)
|
// remove existing socket file if it exists
|
||||||
}),
|
if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) {
|
||||||
); err != nil {
|
return err
|
||||||
slog.Error("Error starting SSH server", "err", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
// prefix with : if only port was provided
|
||||||
|
if !strings.Contains(cfg.Addr, ":") {
|
||||||
|
cfg.Addr = ":" + cfg.Addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen on the address
|
||||||
|
ln, err := net.Listen(cfg.Network, cfg.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
// Start server on the listener
|
||||||
|
err = sshServer.Serve(ln, nil, sshServer.NoPty(),
|
||||||
|
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
||||||
|
for _, pubKey := range cfg.Keys {
|
||||||
|
if sshServer.KeysEqual(key, pubKey) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) handleSession(s sshServer.Session) {
|
func (a *Agent) handleSession(s sshServer.Session) {
|
||||||
|
// slog.Debug("connection", "remoteaddr", s.RemoteAddr(), "user", s.User())
|
||||||
stats := a.gatherStats()
|
stats := a.gatherStats()
|
||||||
if err := json.NewEncoder(s).Encode(stats); err != nil {
|
if err := json.NewEncoder(s).Encode(stats); err != nil {
|
||||||
slog.Error("Error encoding stats", "err", err, "stats", stats)
|
slog.Error("Error encoding stats", "err", err, "stats", stats)
|
||||||
s.Exit(1)
|
s.Exit(1)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
s.Exit(0)
|
s.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseKeys parses a string containing SSH public keys in authorized_keys format.
|
||||||
|
// It returns a slice of ssh.PublicKey and an error if any key fails to parse.
|
||||||
|
func ParseKeys(input string) ([]ssh.PublicKey, error) {
|
||||||
|
var parsedKeys []ssh.PublicKey
|
||||||
|
|
||||||
|
for line := range strings.Lines(input) {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
// Skip empty lines or comments
|
||||||
|
if len(line) == 0 || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the key
|
||||||
|
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse key: %s, error: %w", line, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the parsed key to the list
|
||||||
|
parsedKeys = append(parsedKeys, parsedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedKeys, nil
|
||||||
|
}
|
||||||
|
281
beszel/internal/agent/server_test.go
Normal file
281
beszel/internal/agent/server_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStartServer(t *testing.T) {
|
||||||
|
// Generate a test key pair
|
||||||
|
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
signer, err := ssh.NewSignerFromKey(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate a different key pair for bad key test
|
||||||
|
badPubKey, badPrivKey, err := ed25519.GenerateKey(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
badSigner, err := ssh.NewSignerFromKey(badPrivKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
sshBadPubKey, err := ssh.NewPublicKey(badPubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
socketFile := filepath.Join(t.TempDir(), "beszel-test.sock")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config ServerConfig
|
||||||
|
wantErr bool
|
||||||
|
errContains string
|
||||||
|
setup func() error
|
||||||
|
cleanup func() error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp port only",
|
||||||
|
config: ServerConfig{
|
||||||
|
Network: "tcp",
|
||||||
|
Addr: "45987",
|
||||||
|
Keys: []ssh.PublicKey{sshPubKey},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp with ipv4",
|
||||||
|
config: ServerConfig{
|
||||||
|
Network: "tcp4",
|
||||||
|
Addr: "127.0.0.1:45988",
|
||||||
|
Keys: []ssh.PublicKey{sshPubKey},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp with ipv6",
|
||||||
|
config: ServerConfig{
|
||||||
|
Network: "tcp6",
|
||||||
|
Addr: "[::1]:45989",
|
||||||
|
Keys: []ssh.PublicKey{sshPubKey},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket",
|
||||||
|
config: ServerConfig{
|
||||||
|
Network: "unix",
|
||||||
|
Addr: socketFile,
|
||||||
|
Keys: []ssh.PublicKey{sshPubKey},
|
||||||
|
},
|
||||||
|
setup: func() error {
|
||||||
|
// Create a socket file that should be removed
|
||||||
|
f, err := os.Create(socketFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return f.Close()
|
||||||
|
},
|
||||||
|
cleanup: func() error {
|
||||||
|
return os.Remove(socketFile)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad key should fail",
|
||||||
|
config: ServerConfig{
|
||||||
|
Network: "tcp",
|
||||||
|
Addr: "45987",
|
||||||
|
Keys: []ssh.PublicKey{sshBadPubKey},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContains: "ssh: handshake failed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.setup != nil {
|
||||||
|
err := tt.setup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.cleanup != nil {
|
||||||
|
defer tt.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
agent := NewAgent()
|
||||||
|
|
||||||
|
// Start server in a goroutine since it blocks
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errChan <- agent.StartServer(tt.config)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add a short delay to allow the server to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Try to connect to verify server is running
|
||||||
|
var client *ssh.Client
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Choose the appropriate signer based on the test case
|
||||||
|
testSigner := signer
|
||||||
|
if tt.name == "bad key should fail" {
|
||||||
|
testSigner = badSigner
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConfig := &ssh.ClientConfig{
|
||||||
|
User: "a",
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.PublicKeys(testSigner),
|
||||||
|
},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 4 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt.config.Network {
|
||||||
|
case "unix":
|
||||||
|
client, err = ssh.Dial("unix", tt.config.Addr, sshClientConfig)
|
||||||
|
default:
|
||||||
|
if !strings.Contains(tt.config.Addr, ":") {
|
||||||
|
tt.config.Addr = ":" + tt.config.Addr
|
||||||
|
}
|
||||||
|
client, err = ssh.Dial("tcp", tt.config.Addr, sshClientConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.errContains != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, client)
|
||||||
|
client.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
//////////////////// ParseKeys Tests ////////////////////////////
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Helper function to generate a temporary file with content
|
||||||
|
func createTempFile(content string) (string, error) {
|
||||||
|
tmpFile, err := os.CreateTemp("", "ssh_keys_*.txt")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||||
|
}
|
||||||
|
defer tmpFile.Close()
|
||||||
|
|
||||||
|
if _, err := tmpFile.WriteString(content); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to write to temp file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpFile.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 1: String with a single SSH key
|
||||||
|
func TestParseSingleKeyFromString(t *testing.T) {
|
||||||
|
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
|
||||||
|
keys, err := ParseKeys(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if len(keys) != 1 {
|
||||||
|
t.Fatalf("Expected 1 key, got %d keys", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].Type() != "ssh-ed25519" {
|
||||||
|
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 2: String with multiple SSH keys
|
||||||
|
func TestParseMultipleKeysFromString(t *testing.T) {
|
||||||
|
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
|
||||||
|
keys, err := ParseKeys(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if len(keys) != 3 {
|
||||||
|
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
|
||||||
|
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 3: File with a single SSH key
|
||||||
|
func TestParseSingleKeyFromFile(t *testing.T) {
|
||||||
|
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
|
||||||
|
filePath, err := createTempFile(content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(filePath) // Clean up the file after the test
|
||||||
|
|
||||||
|
// Read the file content
|
||||||
|
fileContent, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the keys
|
||||||
|
keys, err := ParseKeys(string(fileContent))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if len(keys) != 1 {
|
||||||
|
t.Fatalf("Expected 1 key, got %d keys", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].Type() != "ssh-ed25519" {
|
||||||
|
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 4: File with multiple SSH keys
|
||||||
|
func TestParseMultipleKeysFromFile(t *testing.T) {
|
||||||
|
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
|
||||||
|
filePath, err := createTempFile(content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
// defer os.Remove(filePath) // Clean up the file after the test
|
||||||
|
|
||||||
|
// Read the file content
|
||||||
|
fileContent, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the keys
|
||||||
|
keys, err := ParseKeys(string(fileContent))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if len(keys) != 3 {
|
||||||
|
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
|
||||||
|
}
|
||||||
|
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
|
||||||
|
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 5: Invalid SSH key input
|
||||||
|
func TestParseInvalidKey(t *testing.T) {
|
||||||
|
input := "invalid-key-data"
|
||||||
|
_, err := ParseKeys(input)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Expected an error for invalid key, got nil")
|
||||||
|
}
|
||||||
|
expectedErrMsg := "failed to parse key"
|
||||||
|
if !strings.Contains(err.Error(), expectedErrMsg) {
|
||||||
|
t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user