refactor(agent): refactor option parsing logic for agent command

This commit is contained in:
henrygd
2025-02-19 19:39:24 -05:00
parent d170e7a00d
commit 7485f79071
4 changed files with 78 additions and 62 deletions

View File

@@ -12,15 +12,16 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type cmdConfig struct { // cli options
type cmdOptions struct {
key string // key is the public key(s) for SSH authentication. key string // key is the public key(s) for SSH authentication.
addr string // addr is the address or port to listen on. addr string // addr is the address or port to listen on.
} }
// parseFlags parses the command line flags and populates the config struct. // parseFlags parses the command line flags and populates the config struct.
func parseFlags(cfg *cmdConfig) { func (opts *cmdOptions) parseFlags() {
flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication") flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication")
flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on") flag.StringVar(&opts.addr, "addr", "", "Address or port to listen on")
flag.Usage = func() { flag.Usage = func() {
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0]) fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0])
@@ -54,10 +55,10 @@ func handleSubcommand() bool {
} }
// loadPublicKeys loads the public keys from the command line flag, environment variable, or key file. // loadPublicKeys loads the public keys from the command line flag, environment variable, or key file.
func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) { func (opts *cmdOptions) loadPublicKeys() ([]ssh.PublicKey, error) {
// Try command line flag first // Try command line flag first
if cfg.key != "" { if opts.key != "" {
return agent.ParseKeys(cfg.key) return agent.ParseKeys(opts.key)
} }
// Try environment variable // Try environment variable
@@ -68,7 +69,7 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
// Try key file // Try key file
keyFile, ok := agent.GetEnv("KEY_FILE") keyFile, ok := agent.GetEnv("KEY_FILE")
if !ok { if !ok {
return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ") return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. Use 'beszel-agent help' for usage")
} }
pubKey, err := os.ReadFile(keyFile) pubKey, err := os.ReadFile(keyFile)
@@ -79,10 +80,10 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
} }
// getAddress gets the address to listen on from the command line flag, environment variable, or default value. // getAddress gets the address to listen on from the command line flag, environment variable, or default value.
func getAddress(addr string) string { func (opts *cmdOptions) getAddress() string {
// Try command line flag first // Try command line flag first
if addr != "" { if opts.addr != "" {
return addr return opts.addr
} }
// Try environment variables // Try environment variables
if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" { if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" {
@@ -96,19 +97,19 @@ func getAddress(addr string) string {
} }
// getNetwork returns the network type to use for the server. // getNetwork returns the network type to use for the server.
func getNetwork(addr string) string { func (opts *cmdOptions) getNetwork() string {
if network, _ := agent.GetEnv("NETWORK"); network != "" { if network, _ := agent.GetEnv("NETWORK"); network != "" {
return network return network
} }
if strings.HasPrefix(addr, "/") { if strings.HasPrefix(opts.addr, "/") {
return "unix" return "unix"
} }
return "tcp" return "tcp"
} }
func main() { func main() {
var cfg cmdConfig var opts cmdOptions
parseFlags(&cfg) opts.parseFlags()
if handleSubcommand() { if handleSubcommand() {
return return
@@ -116,15 +117,15 @@ func main() {
flag.Parse() flag.Parse()
var serverConfig agent.ServerConfig var serverConfig agent.ServerOptions
var err error var err error
serverConfig.Keys, err = loadPublicKeys(cfg) serverConfig.Keys, err = opts.loadPublicKeys()
if err != nil { if err != nil {
log.Fatal("Failed to load public keys:", err) log.Fatal("Failed to load public keys:", err)
} }
serverConfig.Addr = getAddress(cfg.addr) serverConfig.Addr = opts.getAddress()
serverConfig.Network = getNetwork(cfg.addr) serverConfig.Network = opts.getNetwork()
agent := agent.NewAgent() agent := agent.NewAgent()
if err := agent.StartServer(serverConfig); err != nil { if err := agent.StartServer(serverConfig); err != nil {

View File

@@ -15,32 +15,32 @@ import (
func TestGetAddress(t *testing.T) { func TestGetAddress(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cfg cmdConfig opts cmdOptions
envVars map[string]string envVars map[string]string
expected string expected string
}{ }{
{ {
name: "default port when no config", name: "default port when no config",
cfg: cmdConfig{}, opts: cmdOptions{},
expected: ":45876", expected: ":45876",
}, },
{ {
name: "use address from flag", name: "use address from flag",
cfg: cmdConfig{ opts: cmdOptions{
addr: "8080", addr: "8080",
}, },
expected: "8080", expected: "8080",
}, },
{ {
name: "use unix socket from flag", name: "use unix socket from flag",
cfg: cmdConfig{ opts: cmdOptions{
addr: "/tmp/beszel.sock", addr: "/tmp/beszel.sock",
}, },
expected: "/tmp/beszel.sock", expected: "/tmp/beszel.sock",
}, },
{ {
name: "use ADDR env var", name: "use ADDR env var",
cfg: cmdConfig{}, opts: cmdOptions{},
envVars: map[string]string{ envVars: map[string]string{
"ADDR": "1.2.3.4:9090", "ADDR": "1.2.3.4:9090",
}, },
@@ -48,7 +48,7 @@ func TestGetAddress(t *testing.T) {
}, },
{ {
name: "use legacy PORT env var", name: "use legacy PORT env var",
cfg: cmdConfig{}, opts: cmdOptions{},
envVars: map[string]string{ envVars: map[string]string{
"PORT": "7070", "PORT": "7070",
}, },
@@ -56,7 +56,7 @@ func TestGetAddress(t *testing.T) {
}, },
{ {
name: "flag takes precedence over env vars", name: "flag takes precedence over env vars",
cfg: cmdConfig{ opts: cmdOptions{
addr: ":8080", addr: ":8080",
}, },
envVars: map[string]string{ envVars: map[string]string{
@@ -74,7 +74,7 @@ func TestGetAddress(t *testing.T) {
t.Setenv(k, v) t.Setenv(k, v)
} }
addr := getAddress(tt.cfg.addr) addr := tt.opts.getAddress()
assert.Equal(t, tt.expected, addr) assert.Equal(t, tt.expected, addr)
}) })
} }
@@ -90,7 +90,7 @@ func TestLoadPublicKeys(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cfg cmdConfig opts cmdOptions
envVars map[string]string envVars map[string]string
setupFiles map[string][]byte setupFiles map[string][]byte
wantErr bool wantErr bool
@@ -98,7 +98,7 @@ func TestLoadPublicKeys(t *testing.T) {
}{ }{
{ {
name: "load key from flag", name: "load key from flag",
cfg: cmdConfig{ opts: cmdOptions{
key: string(pubKey), key: string(pubKey),
}, },
}, },
@@ -132,7 +132,7 @@ func TestLoadPublicKeys(t *testing.T) {
}, },
{ {
name: "error on invalid key data", name: "error on invalid key data",
cfg: cmdConfig{ opts: cmdOptions{
key: "invalid-key-data", key: "invalid-key-data",
}, },
wantErr: true, wantErr: true,
@@ -159,7 +159,7 @@ func TestLoadPublicKeys(t *testing.T) {
t.Setenv(k, v) t.Setenv(k, v)
} }
keys, err := loadPublicKeys(tt.cfg) keys, err := tt.opts.loadPublicKeys()
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
if tt.errContains != "" { if tt.errContains != "" {
@@ -178,33 +178,40 @@ func TestLoadPublicKeys(t *testing.T) {
func TestGetNetwork(t *testing.T) { func TestGetNetwork(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
addr string opts cmdOptions
envVars map[string]string envVars map[string]string
expected string expected string
}{ }{
{
name: "NETWORK env var",
envVars: map[string]string{
"NETWORK": "tcp4",
},
expected: "tcp4",
},
{ {
name: "only port", name: "only port",
addr: "8080", opts: cmdOptions{addr: "8080"},
expected: "tcp", expected: "tcp",
}, },
{ {
name: "ipv4 address", name: "ipv4 address",
addr: "1.2.3.4:8080", opts: cmdOptions{addr: "1.2.3.4:8080"},
expected: "tcp", expected: "tcp",
}, },
{ {
name: "ipv6 address", name: "ipv6 address",
addr: "[2001:db8::1]:8080", opts: cmdOptions{addr: "[2001:db8::1]:8080"},
expected: "tcp", expected: "tcp",
}, },
{ {
name: "unix network", name: "unix network",
addr: "/tmp/beszel.sock", opts: cmdOptions{addr: "/tmp/beszel.sock"},
expected: "unix", expected: "unix",
}, },
{ {
name: "env var network", name: "env var network",
addr: ":8080", opts: cmdOptions{addr: ":8080"},
envVars: map[string]string{"NETWORK": "tcp4"}, envVars: map[string]string{"NETWORK": "tcp4"},
expected: "tcp4", expected: "tcp4",
}, },
@@ -216,7 +223,7 @@ func TestGetNetwork(t *testing.T) {
for k, v := range tt.envVars { for k, v := range tt.envVars {
t.Setenv(k, v) t.Setenv(k, v)
} }
network := getNetwork(tt.addr) network := tt.opts.getNetwork()
assert.Equal(t, tt.expected, network) assert.Equal(t, tt.expected, network)
}) })
} }
@@ -233,12 +240,12 @@ func TestParseFlags(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args []string args []string
expected cmdConfig expected cmdOptions
}{ }{
{ {
name: "no flags", name: "no flags",
args: []string{"cmd"}, args: []string{"cmd"},
expected: cmdConfig{ expected: cmdOptions{
key: "", key: "",
addr: "", addr: "",
}, },
@@ -246,7 +253,7 @@ func TestParseFlags(t *testing.T) {
{ {
name: "key flag only", name: "key flag only",
args: []string{"cmd", "-key", "testkey"}, args: []string{"cmd", "-key", "testkey"},
expected: cmdConfig{ expected: cmdOptions{
key: "testkey", key: "testkey",
addr: "", addr: "",
}, },
@@ -254,7 +261,7 @@ func TestParseFlags(t *testing.T) {
{ {
name: "addr flag only", name: "addr flag only",
args: []string{"cmd", "-addr", ":8080"}, args: []string{"cmd", "-addr", ":8080"},
expected: cmdConfig{ expected: cmdOptions{
key: "", key: "",
addr: ":8080", addr: ":8080",
}, },
@@ -262,7 +269,7 @@ func TestParseFlags(t *testing.T) {
{ {
name: "both flags", name: "both flags",
args: []string{"cmd", "-key", "testkey", "-addr", ":8080"}, args: []string{"cmd", "-key", "testkey", "-addr", ":8080"},
expected: cmdConfig{ expected: cmdOptions{
key: "testkey", key: "testkey",
addr: ":8080", addr: ":8080",
}, },
@@ -275,11 +282,11 @@ func TestParseFlags(t *testing.T) {
flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError)
os.Args = tt.args os.Args = tt.args
var cfg cmdConfig var opts cmdOptions
parseFlags(&cfg) opts.parseFlags()
flag.Parse() flag.Parse()
assert.Equal(t, tt.expected, cfg) assert.Equal(t, tt.expected, opts)
}) })
} }
} }

View File

@@ -12,41 +12,41 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type ServerConfig struct { type ServerOptions struct {
Addr string Addr string
Network string Network string
Keys []ssh.PublicKey Keys []ssh.PublicKey
} }
func (a *Agent) StartServer(cfg ServerConfig) error { func (a *Agent) StartServer(opts ServerOptions) error {
sshServer.Handle(a.handleSession) sshServer.Handle(a.handleSession)
slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network) slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
switch cfg.Network { switch opts.Network {
case "unix": case "unix":
// remove existing socket file if it exists // remove existing socket file if it exists
if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) { if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) {
return err return err
} }
default: default:
// prefix with : if only port was provided // prefix with : if only port was provided
if !strings.Contains(cfg.Addr, ":") { if !strings.Contains(opts.Addr, ":") {
cfg.Addr = ":" + cfg.Addr opts.Addr = ":" + opts.Addr
} }
} }
// Listen on the address // Listen on the address
ln, err := net.Listen(cfg.Network, cfg.Addr) ln, err := net.Listen(opts.Network, opts.Addr)
if err != nil { if err != nil {
return err return err
} }
defer ln.Close() defer ln.Close()
// Start server on the listener // Start SSH server on the listener
err = sshServer.Serve(ln, nil, sshServer.NoPty(), err = sshServer.Serve(ln, nil, sshServer.NoPty(),
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
for _, pubKey := range cfg.Keys { for _, pubKey := range opts.Keys {
if sshServer.KeysEqual(key, pubKey) { if sshServer.KeysEqual(key, pubKey) {
return true return true
} }

View File

@@ -35,7 +35,7 @@ func TestStartServer(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
config ServerConfig config ServerOptions
wantErr bool wantErr bool
errContains string errContains string
setup func() error setup func() error
@@ -43,7 +43,7 @@ func TestStartServer(t *testing.T) {
}{ }{
{ {
name: "tcp port only", name: "tcp port only",
config: ServerConfig{ config: ServerOptions{
Network: "tcp", Network: "tcp",
Addr: "45987", Addr: "45987",
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
@@ -51,7 +51,7 @@ func TestStartServer(t *testing.T) {
}, },
{ {
name: "tcp with ipv4", name: "tcp with ipv4",
config: ServerConfig{ config: ServerOptions{
Network: "tcp4", Network: "tcp4",
Addr: "127.0.0.1:45988", Addr: "127.0.0.1:45988",
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
@@ -59,7 +59,7 @@ func TestStartServer(t *testing.T) {
}, },
{ {
name: "tcp with ipv6", name: "tcp with ipv6",
config: ServerConfig{ config: ServerOptions{
Network: "tcp6", Network: "tcp6",
Addr: "[::1]:45989", Addr: "[::1]:45989",
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
@@ -67,7 +67,7 @@ func TestStartServer(t *testing.T) {
}, },
{ {
name: "unix socket", name: "unix socket",
config: ServerConfig{ config: ServerOptions{
Network: "unix", Network: "unix",
Addr: socketFile, Addr: socketFile,
Keys: []ssh.PublicKey{sshPubKey}, Keys: []ssh.PublicKey{sshPubKey},
@@ -86,7 +86,7 @@ func TestStartServer(t *testing.T) {
}, },
{ {
name: "bad key should fail", name: "bad key should fail",
config: ServerConfig{ config: ServerOptions{
Network: "tcp", Network: "tcp",
Addr: "45987", Addr: "45987",
Keys: []ssh.PublicKey{sshBadPubKey}, Keys: []ssh.PublicKey{sshBadPubKey},
@@ -94,6 +94,14 @@ func TestStartServer(t *testing.T) {
wantErr: true, wantErr: true,
errContains: "ssh: handshake failed", errContains: "ssh: handshake failed",
}, },
{
name: "good key still good",
config: ServerOptions{
Network: "tcp",
Addr: "45987",
Keys: []ssh.PublicKey{sshPubKey},
},
},
} }
for _, tt := range tests { for _, tt := range tests {