From edefc6f53e73842282b0d1a1c779cfc76b3eab3f Mon Sep 17 00:00:00 2001 From: henrygd Date: Fri, 14 Mar 2025 03:33:25 -0400 Subject: [PATCH] add health check for agent - Updated command-line flag parsing. - Moved GetAddress and GetNetwork to server.go --- beszel/cmd/agent/agent.go | 90 ++++++++----------- beszel/cmd/agent/agent_test.go | 9 +- beszel/internal/agent/health.go | 19 ++++ beszel/internal/agent/health_test.go | 130 +++++++++++++++++++++++++++ beszel/internal/agent/network.go | 2 +- beszel/internal/agent/server.go | 46 +++++++--- beszel/internal/agent/server_test.go | 6 +- 7 files changed, 228 insertions(+), 74 deletions(-) create mode 100644 beszel/internal/agent/health.go create mode 100644 beszel/internal/agent/health_test.go diff --git a/beszel/cmd/agent/agent.go b/beszel/cmd/agent/agent.go index 0507a0b..f519335 100644 --- a/beszel/cmd/agent/agent.go +++ b/beszel/cmd/agent/agent.go @@ -6,8 +6,8 @@ import ( "flag" "fmt" "log" + "log/slog" "os" - "strings" "golang.org/x/crypto/ssh" ) @@ -18,39 +18,50 @@ type cmdOptions struct { listen string // listen is the address or port to listen on. } -// parseFlags parses the command line flags and populates the config struct. -func (opts *cmdOptions) parseFlags() { +// parse parses the command line flags and populates the config struct. +// It returns true if a subcommand was handled and the program should exit. +func (opts *cmdOptions) parse() bool { flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication") flag.StringVar(&opts.listen, "listen", "", "Address or port to listen on") flag.Usage = func() { - fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0]) - fmt.Println("\nOptions:") + fmt.Printf("Usage: %s [command] [flags]\n", os.Args[0]) + fmt.Println("\nCommands:") + fmt.Println(" health Check if the agent is running") + fmt.Println(" help Display this help message") + fmt.Println(" update Update to the latest version") + fmt.Println(" version Display the version") + fmt.Println("\nFlags:") flag.PrintDefaults() - fmt.Println("\nSubcommands:") - fmt.Println(" version Display the version") - fmt.Println(" help Display this help message") - fmt.Println(" update Update the agent to the latest version") } -} -// handleSubcommand handles subcommands such as version, help, and update. -// It returns true if a subcommand was handled, false otherwise. -func handleSubcommand() bool { - if len(os.Args) <= 1 { - return false + subcommand := "" + if len(os.Args) > 1 { + subcommand = os.Args[1] } - switch os.Args[1] { - case "version", "-v": + + switch subcommand { + case "-v", "version": fmt.Println(beszel.AppName+"-agent", beszel.Version) - os.Exit(0) + return true case "help": flag.Usage() - os.Exit(0) + return true case "update": agent.Update() - os.Exit(0) + return true + case "health": + // for health, we need to parse flags first to get the listen address + args := append(os.Args[2:], subcommand) + flag.CommandLine.Parse(args) + addr := opts.getAddress() + network := agent.GetNetwork(addr) + exitCode, err := agent.Health(addr, network) + slog.Info("Health", "code", exitCode, "err", err) + os.Exit(exitCode) } + + flag.Parse() return false } @@ -79,46 +90,18 @@ func (opts *cmdOptions) loadPublicKeys() ([]ssh.PublicKey, error) { return agent.ParseKeys(string(pubKey)) } -// getAddress gets the address to listen on from the command line flag, environment variable, or default value. func (opts *cmdOptions) getAddress() string { - // Try command line flag first - if opts.listen != "" { - return opts.listen - } - // Try environment variables - if addr, ok := agent.GetEnv("LISTEN"); ok && addr != "" { - return addr - } - // Legacy PORT environment variable support - if port, ok := agent.GetEnv("PORT"); ok && port != "" { - return port - } - return ":45876" -} - -// getNetwork returns the network type to use for the server. -func (opts *cmdOptions) getNetwork() string { - if network, _ := agent.GetEnv("NETWORK"); network != "" { - return network - } - if strings.HasPrefix(opts.listen, "/") { - return "unix" - } - return "tcp" + return agent.GetAddress(opts.listen) } func main() { var opts cmdOptions - opts.parseFlags() + subcommandHandled := opts.parse() - if handleSubcommand() { + if subcommandHandled { return } - flag.Parse() - - opts.listen = opts.getAddress() - var serverConfig agent.ServerOptions var err error serverConfig.Keys, err = opts.loadPublicKeys() @@ -126,8 +109,9 @@ func main() { log.Fatal("Failed to load public keys:", err) } - serverConfig.Addr = opts.listen - serverConfig.Network = opts.getNetwork() + addr := opts.getAddress() + serverConfig.Addr = addr + serverConfig.Network = agent.GetNetwork(addr) agent := agent.NewAgent() if err := agent.StartServer(serverConfig); err != nil { diff --git a/beszel/cmd/agent/agent_test.go b/beszel/cmd/agent/agent_test.go index 52ebb2b..1983ab4 100644 --- a/beszel/cmd/agent/agent_test.go +++ b/beszel/cmd/agent/agent_test.go @@ -1,6 +1,7 @@ package main import ( + "beszel/internal/agent" "crypto/ed25519" "flag" "os" @@ -29,7 +30,7 @@ func TestGetAddress(t *testing.T) { opts: cmdOptions{ listen: "8080", }, - expected: "8080", + expected: ":8080", }, { name: "use unix socket from flag", @@ -52,7 +53,7 @@ func TestGetAddress(t *testing.T) { envVars: map[string]string{ "PORT": "7070", }, - expected: "7070", + expected: ":7070", }, { name: "use unix socket from env var", @@ -233,7 +234,7 @@ func TestGetNetwork(t *testing.T) { for k, v := range tt.envVars { t.Setenv(k, v) } - network := tt.opts.getNetwork() + network := agent.GetNetwork(tt.opts.listen) assert.Equal(t, tt.expected, network) }) } @@ -293,7 +294,7 @@ func TestParseFlags(t *testing.T) { os.Args = tt.args var opts cmdOptions - opts.parseFlags() + opts.parse() flag.Parse() assert.Equal(t, tt.expected, opts) diff --git a/beszel/internal/agent/health.go b/beszel/internal/agent/health.go new file mode 100644 index 0000000..ff1b7e0 --- /dev/null +++ b/beszel/internal/agent/health.go @@ -0,0 +1,19 @@ +package agent + +import ( + "net" + "time" +) + +// Health checks if the agent's server is running by attempting to connect to it. +// It returns 0 if the server is running, 1 otherwise (as in exit codes). +// +// If an error occurs when attempting to connect to the server, it returns the error. +func Health(addr string, network string) (int, error) { + conn, err := net.DialTimeout(network, addr, 4*time.Second) + if err != nil { + return 1, err + } + conn.Close() + return 0, nil +} diff --git a/beszel/internal/agent/health_test.go b/beszel/internal/agent/health_test.go new file mode 100644 index 0000000..87e9892 --- /dev/null +++ b/beszel/internal/agent/health_test.go @@ -0,0 +1,130 @@ +//go:build testing +// +build testing + +package agent_test + +import ( + "fmt" + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "beszel/internal/agent" +) + +// setupTestServer creates a temporary server for testing +func setupTestServer(t *testing.T) (string, func()) { + // Create a temporary socket file for Unix socket testing + tempSockFile := os.TempDir() + "/beszel_health_test.sock" + + // Clean up any existing socket file + os.Remove(tempSockFile) + + // Create a listener + listener, err := net.Listen("unix", tempSockFile) + require.NoError(t, err, "Failed to create test listener") + + // Start a simple server in a goroutine + go func() { + conn, err := listener.Accept() + if err != nil { + return // Listener closed + } + defer conn.Close() + // Just accept the connection and do nothing + }() + + // Return the socket file path and a cleanup function + return tempSockFile, func() { + listener.Close() + os.Remove(tempSockFile) + } +} + +// setupTCPTestServer creates a temporary TCP server for testing +func setupTCPTestServer(t *testing.T) (string, func()) { + // Listen on a random available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "Failed to create test listener") + + // Get the port that was assigned + addr := listener.Addr().(*net.TCPAddr) + port := addr.Port + + // Start a simple server in a goroutine + go func() { + conn, err := listener.Accept() + if err != nil { + return // Listener closed + } + defer conn.Close() + // Just accept the connection and do nothing + }() + + // Return the address and a cleanup function + return fmt.Sprintf("127.0.0.1:%d", port), func() { + listener.Close() + } +} + +func TestHealth(t *testing.T) { + t.Run("server is running (unix socket)", func(t *testing.T) { + // Setup a test server + sockFile, cleanup := setupTestServer(t) + defer cleanup() + + // Run the health check with explicit parameters + result, err := agent.Health(sockFile, "unix") + require.NoError(t, err, "Failed to check health") + + // Verify the result + assert.Equal(t, 0, result, "Health check should return 0 when server is running") + }) + + t.Run("server is running (tcp address)", func(t *testing.T) { + // Setup a test server + addr, cleanup := setupTCPTestServer(t) + defer cleanup() + + // Run the health check with explicit parameters + result, err := agent.Health(addr, "tcp") + require.NoError(t, err, "Failed to check health") + + // Verify the result + assert.Equal(t, 0, result, "Health check should return 0 when server is running") + }) + + t.Run("server is not running", func(t *testing.T) { + // Use an address that's likely not in use + addr := "127.0.0.1:65535" + + // Run the health check with explicit parameters + result, err := agent.Health(addr, "tcp") + require.Error(t, err, "Health check should return an error when server is not running") + + // Verify the result + assert.Equal(t, 1, result, "Health check should return 1 when server is not running") + }) + + t.Run("invalid network", func(t *testing.T) { + // Use an invalid network type + result, err := agent.Health("127.0.0.1:8080", "invalid_network") + require.Error(t, err, "Health check should return an error with invalid network") + assert.Equal(t, 1, result, "Health check should return 1 when network is invalid") + }) + + t.Run("unix socket not found", func(t *testing.T) { + // Use a non-existent unix socket + nonExistentSocket := os.TempDir() + "/non_existent_socket.sock" + + // Make sure it really doesn't exist + os.Remove(nonExistentSocket) + + result, err := agent.Health(nonExistentSocket, "unix") + require.Error(t, err, "Health check should return an error when socket doesn't exist") + assert.Equal(t, 1, result, "Health check should return 1 when socket doesn't exist") + }) +} diff --git a/beszel/internal/agent/network.go b/beszel/internal/agent/network.go index dbba6bb..6b5b9c2 100644 --- a/beszel/internal/agent/network.go +++ b/beszel/internal/agent/network.go @@ -17,7 +17,7 @@ func (a *Agent) initializeNetIoStats() { nics, nicsEnvExists := GetEnv("NICS") if nicsEnvExists { nicsMap = make(map[string]struct{}, 0) - for _, nic := range strings.Split(nics, ",") { + for nic := range strings.SplitSeq(nics, ",") { nicsMap[nic] = struct{}{} } } diff --git a/beszel/internal/agent/server.go b/beszel/internal/agent/server.go index 7d6df23..b32e219 100644 --- a/beszel/internal/agent/server.go +++ b/beszel/internal/agent/server.go @@ -23,20 +23,14 @@ func (a *Agent) StartServer(opts ServerOptions) error { slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network) - switch opts.Network { - case "unix": + if opts.Network == "unix" { // remove existing socket file if it exists if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) { return err } - default: - // prefix with : if only port was provided - if !strings.Contains(opts.Addr, ":") { - opts.Addr = ":" + opts.Addr - } } - // Listen on the address + // start listening on the address ln, err := net.Listen(opts.Network, opts.Addr) if err != nil { return err @@ -44,7 +38,7 @@ func (a *Agent) StartServer(opts ServerOptions) error { defer ln.Close() // Start SSH server on the listener - err = sshServer.Serve(ln, nil, sshServer.NoPty(), + return sshServer.Serve(ln, nil, sshServer.NoPty(), sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { for _, pubKey := range opts.Keys { if sshServer.KeysEqual(key, pubKey) { @@ -54,10 +48,6 @@ func (a *Agent) StartServer(opts ServerOptions) error { return false }), ) - if err != nil { - return err - } - return nil } func (a *Agent) handleSession(s sshServer.Session) { @@ -89,3 +79,33 @@ func ParseKeys(input string) ([]ssh.PublicKey, error) { } return parsedKeys, nil } + +// GetAddress gets the address to listen on or connect to from environment variables or default value. +func GetAddress(addr string) string { + if addr == "" { + addr, _ = GetEnv("LISTEN") + } + if addr == "" { + // Legacy PORT environment variable support + addr, _ = GetEnv("PORT") + } + if addr == "" { + return ":45876" + } + // prefix with : if only port was provided + if GetNetwork(addr) != "unix" && !strings.Contains(addr, ":") { + addr = ":" + addr + } + return addr +} + +// GetNetwork returns the network type to use based on the address +func GetNetwork(addr string) string { + if network, ok := GetEnv("NETWORK"); ok && network != "" { + return network + } + if strings.HasPrefix(addr, "/") { + return "unix" + } + return "tcp" +} diff --git a/beszel/internal/agent/server_test.go b/beszel/internal/agent/server_test.go index 6bbc90e..c9a34f3 100644 --- a/beszel/internal/agent/server_test.go +++ b/beszel/internal/agent/server_test.go @@ -45,7 +45,7 @@ func TestStartServer(t *testing.T) { name: "tcp port only", config: ServerOptions{ Network: "tcp", - Addr: "45987", + Addr: ":45987", Keys: []ssh.PublicKey{sshPubKey}, }, }, @@ -88,7 +88,7 @@ func TestStartServer(t *testing.T) { name: "bad key should fail", config: ServerOptions{ Network: "tcp", - Addr: "45987", + Addr: ":45987", Keys: []ssh.PublicKey{sshBadPubKey}, }, wantErr: true, @@ -98,7 +98,7 @@ func TestStartServer(t *testing.T) { name: "good key still good", config: ServerOptions{ Network: "tcp", - Addr: "45987", + Addr: ":45987", Keys: []ssh.PublicKey{sshPubKey}, }, },