add health check for agent

- Updated command-line flag parsing.
- Moved GetAddress and GetNetwork to server.go
This commit is contained in:
henrygd
2025-03-14 03:33:25 -04:00
parent 400ea89587
commit edefc6f53e
7 changed files with 228 additions and 74 deletions

View File

@@ -6,8 +6,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"log/slog"
"os" "os"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -18,39 +18,50 @@ type cmdOptions struct {
listen string // listen is the address or port to listen on. listen string // listen is the address or port to listen on.
} }
// parseFlags parses the command line flags and populates the config struct. // parse parses the command line flags and populates the config struct.
func (opts *cmdOptions) parseFlags() { // It returns true if a subcommand was handled and the program should exit.
func (opts *cmdOptions) parse() bool {
flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication") flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication")
flag.StringVar(&opts.listen, "listen", "", "Address or port to listen on") flag.StringVar(&opts.listen, "listen", "", "Address or port to listen on")
flag.Usage = func() { flag.Usage = func() {
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0]) fmt.Printf("Usage: %s [command] [flags]\n", os.Args[0])
fmt.Println("\nOptions:") fmt.Println("\nCommands:")
flag.PrintDefaults() fmt.Println(" health Check if the agent is running")
fmt.Println("\nSubcommands:")
fmt.Println(" version Display the version")
fmt.Println(" help Display this help message") fmt.Println(" help Display this help message")
fmt.Println(" update Update the agent to the latest version") fmt.Println(" update Update to the latest version")
} fmt.Println(" version Display the version")
fmt.Println("\nFlags:")
flag.PrintDefaults()
} }
// handleSubcommand handles subcommands such as version, help, and update. subcommand := ""
// It returns true if a subcommand was handled, false otherwise. if len(os.Args) > 1 {
func handleSubcommand() bool { subcommand = os.Args[1]
if len(os.Args) <= 1 {
return false
} }
switch os.Args[1] {
case "version", "-v": switch subcommand {
case "-v", "version":
fmt.Println(beszel.AppName+"-agent", beszel.Version) fmt.Println(beszel.AppName+"-agent", beszel.Version)
os.Exit(0) return true
case "help": case "help":
flag.Usage() flag.Usage()
os.Exit(0) return true
case "update": case "update":
agent.Update() agent.Update()
os.Exit(0) return true
case "health":
// for health, we need to parse flags first to get the listen address
args := append(os.Args[2:], subcommand)
flag.CommandLine.Parse(args)
addr := opts.getAddress()
network := agent.GetNetwork(addr)
exitCode, err := agent.Health(addr, network)
slog.Info("Health", "code", exitCode, "err", err)
os.Exit(exitCode)
} }
flag.Parse()
return false return false
} }
@@ -79,46 +90,18 @@ func (opts *cmdOptions) loadPublicKeys() ([]ssh.PublicKey, error) {
return agent.ParseKeys(string(pubKey)) return agent.ParseKeys(string(pubKey))
} }
// getAddress gets the address to listen on from the command line flag, environment variable, or default value.
func (opts *cmdOptions) getAddress() string { func (opts *cmdOptions) getAddress() string {
// Try command line flag first return agent.GetAddress(opts.listen)
if opts.listen != "" {
return opts.listen
}
// Try environment variables
if addr, ok := agent.GetEnv("LISTEN"); ok && addr != "" {
return addr
}
// Legacy PORT environment variable support
if port, ok := agent.GetEnv("PORT"); ok && port != "" {
return port
}
return ":45876"
}
// getNetwork returns the network type to use for the server.
func (opts *cmdOptions) getNetwork() string {
if network, _ := agent.GetEnv("NETWORK"); network != "" {
return network
}
if strings.HasPrefix(opts.listen, "/") {
return "unix"
}
return "tcp"
} }
func main() { func main() {
var opts cmdOptions var opts cmdOptions
opts.parseFlags() subcommandHandled := opts.parse()
if handleSubcommand() { if subcommandHandled {
return return
} }
flag.Parse()
opts.listen = opts.getAddress()
var serverConfig agent.ServerOptions var serverConfig agent.ServerOptions
var err error var err error
serverConfig.Keys, err = opts.loadPublicKeys() serverConfig.Keys, err = opts.loadPublicKeys()
@@ -126,8 +109,9 @@ func main() {
log.Fatal("Failed to load public keys:", err) log.Fatal("Failed to load public keys:", err)
} }
serverConfig.Addr = opts.listen addr := opts.getAddress()
serverConfig.Network = opts.getNetwork() serverConfig.Addr = addr
serverConfig.Network = agent.GetNetwork(addr)
agent := agent.NewAgent() agent := agent.NewAgent()
if err := agent.StartServer(serverConfig); err != nil { if err := agent.StartServer(serverConfig); err != nil {

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"beszel/internal/agent"
"crypto/ed25519" "crypto/ed25519"
"flag" "flag"
"os" "os"
@@ -29,7 +30,7 @@ func TestGetAddress(t *testing.T) {
opts: cmdOptions{ opts: cmdOptions{
listen: "8080", listen: "8080",
}, },
expected: "8080", expected: ":8080",
}, },
{ {
name: "use unix socket from flag", name: "use unix socket from flag",
@@ -52,7 +53,7 @@ func TestGetAddress(t *testing.T) {
envVars: map[string]string{ envVars: map[string]string{
"PORT": "7070", "PORT": "7070",
}, },
expected: "7070", expected: ":7070",
}, },
{ {
name: "use unix socket from env var", name: "use unix socket from env var",
@@ -233,7 +234,7 @@ func TestGetNetwork(t *testing.T) {
for k, v := range tt.envVars { for k, v := range tt.envVars {
t.Setenv(k, v) t.Setenv(k, v)
} }
network := tt.opts.getNetwork() network := agent.GetNetwork(tt.opts.listen)
assert.Equal(t, tt.expected, network) assert.Equal(t, tt.expected, network)
}) })
} }
@@ -293,7 +294,7 @@ func TestParseFlags(t *testing.T) {
os.Args = tt.args os.Args = tt.args
var opts cmdOptions var opts cmdOptions
opts.parseFlags() opts.parse()
flag.Parse() flag.Parse()
assert.Equal(t, tt.expected, opts) assert.Equal(t, tt.expected, opts)

View File

@@ -0,0 +1,19 @@
package agent
import (
"net"
"time"
)
// Health checks if the agent's server is running by attempting to connect to it.
// It returns 0 if the server is running, 1 otherwise (as in exit codes).
//
// If an error occurs when attempting to connect to the server, it returns the error.
func Health(addr string, network string) (int, error) {
conn, err := net.DialTimeout(network, addr, 4*time.Second)
if err != nil {
return 1, err
}
conn.Close()
return 0, nil
}

View File

@@ -0,0 +1,130 @@
//go:build testing
// +build testing
package agent_test
import (
"fmt"
"net"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"beszel/internal/agent"
)
// setupTestServer creates a temporary server for testing
func setupTestServer(t *testing.T) (string, func()) {
// Create a temporary socket file for Unix socket testing
tempSockFile := os.TempDir() + "/beszel_health_test.sock"
// Clean up any existing socket file
os.Remove(tempSockFile)
// Create a listener
listener, err := net.Listen("unix", tempSockFile)
require.NoError(t, err, "Failed to create test listener")
// Start a simple server in a goroutine
go func() {
conn, err := listener.Accept()
if err != nil {
return // Listener closed
}
defer conn.Close()
// Just accept the connection and do nothing
}()
// Return the socket file path and a cleanup function
return tempSockFile, func() {
listener.Close()
os.Remove(tempSockFile)
}
}
// setupTCPTestServer creates a temporary TCP server for testing
func setupTCPTestServer(t *testing.T) (string, func()) {
// Listen on a random available port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "Failed to create test listener")
// Get the port that was assigned
addr := listener.Addr().(*net.TCPAddr)
port := addr.Port
// Start a simple server in a goroutine
go func() {
conn, err := listener.Accept()
if err != nil {
return // Listener closed
}
defer conn.Close()
// Just accept the connection and do nothing
}()
// Return the address and a cleanup function
return fmt.Sprintf("127.0.0.1:%d", port), func() {
listener.Close()
}
}
func TestHealth(t *testing.T) {
t.Run("server is running (unix socket)", func(t *testing.T) {
// Setup a test server
sockFile, cleanup := setupTestServer(t)
defer cleanup()
// Run the health check with explicit parameters
result, err := agent.Health(sockFile, "unix")
require.NoError(t, err, "Failed to check health")
// Verify the result
assert.Equal(t, 0, result, "Health check should return 0 when server is running")
})
t.Run("server is running (tcp address)", func(t *testing.T) {
// Setup a test server
addr, cleanup := setupTCPTestServer(t)
defer cleanup()
// Run the health check with explicit parameters
result, err := agent.Health(addr, "tcp")
require.NoError(t, err, "Failed to check health")
// Verify the result
assert.Equal(t, 0, result, "Health check should return 0 when server is running")
})
t.Run("server is not running", func(t *testing.T) {
// Use an address that's likely not in use
addr := "127.0.0.1:65535"
// Run the health check with explicit parameters
result, err := agent.Health(addr, "tcp")
require.Error(t, err, "Health check should return an error when server is not running")
// Verify the result
assert.Equal(t, 1, result, "Health check should return 1 when server is not running")
})
t.Run("invalid network", func(t *testing.T) {
// Use an invalid network type
result, err := agent.Health("127.0.0.1:8080", "invalid_network")
require.Error(t, err, "Health check should return an error with invalid network")
assert.Equal(t, 1, result, "Health check should return 1 when network is invalid")
})
t.Run("unix socket not found", func(t *testing.T) {
// Use a non-existent unix socket
nonExistentSocket := os.TempDir() + "/non_existent_socket.sock"
// Make sure it really doesn't exist
os.Remove(nonExistentSocket)
result, err := agent.Health(nonExistentSocket, "unix")
require.Error(t, err, "Health check should return an error when socket doesn't exist")
assert.Equal(t, 1, result, "Health check should return 1 when socket doesn't exist")
})
}

View File

@@ -17,7 +17,7 @@ func (a *Agent) initializeNetIoStats() {
nics, nicsEnvExists := GetEnv("NICS") nics, nicsEnvExists := GetEnv("NICS")
if nicsEnvExists { if nicsEnvExists {
nicsMap = make(map[string]struct{}, 0) nicsMap = make(map[string]struct{}, 0)
for _, nic := range strings.Split(nics, ",") { for nic := range strings.SplitSeq(nics, ",") {
nicsMap[nic] = struct{}{} nicsMap[nic] = struct{}{}
} }
} }

View File

@@ -23,20 +23,14 @@ func (a *Agent) StartServer(opts ServerOptions) error {
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network) slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
switch opts.Network { if opts.Network == "unix" {
case "unix":
// remove existing socket file if it exists // remove existing socket file if it exists
if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) { if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) {
return err return err
} }
default:
// prefix with : if only port was provided
if !strings.Contains(opts.Addr, ":") {
opts.Addr = ":" + opts.Addr
}
} }
// Listen on the address // start listening on the address
ln, err := net.Listen(opts.Network, opts.Addr) ln, err := net.Listen(opts.Network, opts.Addr)
if err != nil { if err != nil {
return err return err
@@ -44,7 +38,7 @@ func (a *Agent) StartServer(opts ServerOptions) error {
defer ln.Close() defer ln.Close()
// Start SSH server on the listener // Start SSH server on the listener
err = sshServer.Serve(ln, nil, sshServer.NoPty(), return sshServer.Serve(ln, nil, sshServer.NoPty(),
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
for _, pubKey := range opts.Keys { for _, pubKey := range opts.Keys {
if sshServer.KeysEqual(key, pubKey) { if sshServer.KeysEqual(key, pubKey) {
@@ -54,10 +48,6 @@ func (a *Agent) StartServer(opts ServerOptions) error {
return false return false
}), }),
) )
if err != nil {
return err
}
return nil
} }
func (a *Agent) handleSession(s sshServer.Session) { func (a *Agent) handleSession(s sshServer.Session) {
@@ -89,3 +79,33 @@ func ParseKeys(input string) ([]ssh.PublicKey, error) {
} }
return parsedKeys, nil return parsedKeys, nil
} }
// GetAddress gets the address to listen on or connect to from environment variables or default value.
func GetAddress(addr string) string {
if addr == "" {
addr, _ = GetEnv("LISTEN")
}
if addr == "" {
// Legacy PORT environment variable support
addr, _ = GetEnv("PORT")
}
if addr == "" {
return ":45876"
}
// prefix with : if only port was provided
if GetNetwork(addr) != "unix" && !strings.Contains(addr, ":") {
addr = ":" + addr
}
return addr
}
// GetNetwork returns the network type to use based on the address
func GetNetwork(addr string) string {
if network, ok := GetEnv("NETWORK"); ok && network != "" {
return network
}
if strings.HasPrefix(addr, "/") {
return "unix"
}
return "tcp"
}

View File

@@ -45,7 +45,7 @@ func TestStartServer(t *testing.T) {
name: "tcp port only", name: "tcp port only",
config: ServerOptions{ config: ServerOptions{
Network: "tcp", Network: "tcp",
Addr: "45987", Addr: ":45987",
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
}, },
}, },
@@ -88,7 +88,7 @@ func TestStartServer(t *testing.T) {
name: "bad key should fail", name: "bad key should fail",
config: ServerOptions{ config: ServerOptions{
Network: "tcp", Network: "tcp",
Addr: "45987", Addr: ":45987",
Keys: []ssh.PublicKey{sshBadPubKey}, Keys: []ssh.PublicKey{sshBadPubKey},
}, },
wantErr: true, wantErr: true,
@@ -98,7 +98,7 @@ func TestStartServer(t *testing.T) {
name: "good key still good", name: "good key still good",
config: ServerOptions{ config: ServerOptions{
Network: "tcp", Network: "tcp",
Addr: "45987", Addr: ":45987",
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
}, },
}, },