mirror of
https://github.com/fankes/beszel.git
synced 2025-10-19 09:49:28 +08:00
refactor(agent): refactor option parsing logic for agent command
This commit is contained in:
@@ -12,15 +12,16 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type cmdConfig struct {
|
||||
// cli options
|
||||
type cmdOptions struct {
|
||||
key string // key is the public key(s) for SSH authentication.
|
||||
addr string // addr is the address or port to listen on.
|
||||
}
|
||||
|
||||
// parseFlags parses the command line flags and populates the config struct.
|
||||
func parseFlags(cfg *cmdConfig) {
|
||||
flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication")
|
||||
flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on")
|
||||
func (opts *cmdOptions) parseFlags() {
|
||||
flag.StringVar(&opts.key, "key", "", "Public key(s) for SSH authentication")
|
||||
flag.StringVar(&opts.addr, "addr", "", "Address or port to listen on")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0])
|
||||
@@ -54,10 +55,10 @@ func handleSubcommand() bool {
|
||||
}
|
||||
|
||||
// loadPublicKeys loads the public keys from the command line flag, environment variable, or key file.
|
||||
func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
|
||||
func (opts *cmdOptions) loadPublicKeys() ([]ssh.PublicKey, error) {
|
||||
// Try command line flag first
|
||||
if cfg.key != "" {
|
||||
return agent.ParseKeys(cfg.key)
|
||||
if opts.key != "" {
|
||||
return agent.ParseKeys(opts.key)
|
||||
}
|
||||
|
||||
// Try environment variable
|
||||
@@ -68,7 +69,7 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
|
||||
// Try key file
|
||||
keyFile, ok := agent.GetEnv("KEY_FILE")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ")
|
||||
return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. Use 'beszel-agent help' for usage")
|
||||
}
|
||||
|
||||
pubKey, err := os.ReadFile(keyFile)
|
||||
@@ -79,10 +80,10 @@ func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) {
|
||||
}
|
||||
|
||||
// getAddress gets the address to listen on from the command line flag, environment variable, or default value.
|
||||
func getAddress(addr string) string {
|
||||
func (opts *cmdOptions) getAddress() string {
|
||||
// Try command line flag first
|
||||
if addr != "" {
|
||||
return addr
|
||||
if opts.addr != "" {
|
||||
return opts.addr
|
||||
}
|
||||
// Try environment variables
|
||||
if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" {
|
||||
@@ -96,19 +97,19 @@ func getAddress(addr string) string {
|
||||
}
|
||||
|
||||
// getNetwork returns the network type to use for the server.
|
||||
func getNetwork(addr string) string {
|
||||
func (opts *cmdOptions) getNetwork() string {
|
||||
if network, _ := agent.GetEnv("NETWORK"); network != "" {
|
||||
return network
|
||||
}
|
||||
if strings.HasPrefix(addr, "/") {
|
||||
if strings.HasPrefix(opts.addr, "/") {
|
||||
return "unix"
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func main() {
|
||||
var cfg cmdConfig
|
||||
parseFlags(&cfg)
|
||||
var opts cmdOptions
|
||||
opts.parseFlags()
|
||||
|
||||
if handleSubcommand() {
|
||||
return
|
||||
@@ -116,15 +117,15 @@ func main() {
|
||||
|
||||
flag.Parse()
|
||||
|
||||
var serverConfig agent.ServerConfig
|
||||
var serverConfig agent.ServerOptions
|
||||
var err error
|
||||
serverConfig.Keys, err = loadPublicKeys(cfg)
|
||||
serverConfig.Keys, err = opts.loadPublicKeys()
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load public keys:", err)
|
||||
}
|
||||
|
||||
serverConfig.Addr = getAddress(cfg.addr)
|
||||
serverConfig.Network = getNetwork(cfg.addr)
|
||||
serverConfig.Addr = opts.getAddress()
|
||||
serverConfig.Network = opts.getNetwork()
|
||||
|
||||
agent := agent.NewAgent()
|
||||
if err := agent.StartServer(serverConfig); err != nil {
|
||||
|
@@ -15,32 +15,32 @@ import (
|
||||
func TestGetAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg cmdConfig
|
||||
opts cmdOptions
|
||||
envVars map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "default port when no config",
|
||||
cfg: cmdConfig{},
|
||||
opts: cmdOptions{},
|
||||
expected: ":45876",
|
||||
},
|
||||
{
|
||||
name: "use address from flag",
|
||||
cfg: cmdConfig{
|
||||
opts: cmdOptions{
|
||||
addr: "8080",
|
||||
},
|
||||
expected: "8080",
|
||||
},
|
||||
{
|
||||
name: "use unix socket from flag",
|
||||
cfg: cmdConfig{
|
||||
opts: cmdOptions{
|
||||
addr: "/tmp/beszel.sock",
|
||||
},
|
||||
expected: "/tmp/beszel.sock",
|
||||
},
|
||||
{
|
||||
name: "use ADDR env var",
|
||||
cfg: cmdConfig{},
|
||||
opts: cmdOptions{},
|
||||
envVars: map[string]string{
|
||||
"ADDR": "1.2.3.4:9090",
|
||||
},
|
||||
@@ -48,7 +48,7 @@ func TestGetAddress(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "use legacy PORT env var",
|
||||
cfg: cmdConfig{},
|
||||
opts: cmdOptions{},
|
||||
envVars: map[string]string{
|
||||
"PORT": "7070",
|
||||
},
|
||||
@@ -56,7 +56,7 @@ func TestGetAddress(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "flag takes precedence over env vars",
|
||||
cfg: cmdConfig{
|
||||
opts: cmdOptions{
|
||||
addr: ":8080",
|
||||
},
|
||||
envVars: map[string]string{
|
||||
@@ -74,7 +74,7 @@ func TestGetAddress(t *testing.T) {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
addr := getAddress(tt.cfg.addr)
|
||||
addr := tt.opts.getAddress()
|
||||
assert.Equal(t, tt.expected, addr)
|
||||
})
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func TestLoadPublicKeys(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg cmdConfig
|
||||
opts cmdOptions
|
||||
envVars map[string]string
|
||||
setupFiles map[string][]byte
|
||||
wantErr bool
|
||||
@@ -98,7 +98,7 @@ func TestLoadPublicKeys(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "load key from flag",
|
||||
cfg: cmdConfig{
|
||||
opts: cmdOptions{
|
||||
key: string(pubKey),
|
||||
},
|
||||
},
|
||||
@@ -132,7 +132,7 @@ func TestLoadPublicKeys(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error on invalid key data",
|
||||
cfg: cmdConfig{
|
||||
opts: cmdOptions{
|
||||
key: "invalid-key-data",
|
||||
},
|
||||
wantErr: true,
|
||||
@@ -159,7 +159,7 @@ func TestLoadPublicKeys(t *testing.T) {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
keys, err := loadPublicKeys(tt.cfg)
|
||||
keys, err := tt.opts.loadPublicKeys()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
@@ -178,33 +178,40 @@ func TestLoadPublicKeys(t *testing.T) {
|
||||
func TestGetNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
opts cmdOptions
|
||||
envVars map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "NETWORK env var",
|
||||
envVars: map[string]string{
|
||||
"NETWORK": "tcp4",
|
||||
},
|
||||
expected: "tcp4",
|
||||
},
|
||||
{
|
||||
name: "only port",
|
||||
addr: "8080",
|
||||
opts: cmdOptions{addr: "8080"},
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "ipv4 address",
|
||||
addr: "1.2.3.4:8080",
|
||||
opts: cmdOptions{addr: "1.2.3.4:8080"},
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "ipv6 address",
|
||||
addr: "[2001:db8::1]:8080",
|
||||
opts: cmdOptions{addr: "[2001:db8::1]:8080"},
|
||||
expected: "tcp",
|
||||
},
|
||||
{
|
||||
name: "unix network",
|
||||
addr: "/tmp/beszel.sock",
|
||||
opts: cmdOptions{addr: "/tmp/beszel.sock"},
|
||||
expected: "unix",
|
||||
},
|
||||
{
|
||||
name: "env var network",
|
||||
addr: ":8080",
|
||||
opts: cmdOptions{addr: ":8080"},
|
||||
envVars: map[string]string{"NETWORK": "tcp4"},
|
||||
expected: "tcp4",
|
||||
},
|
||||
@@ -216,7 +223,7 @@ func TestGetNetwork(t *testing.T) {
|
||||
for k, v := range tt.envVars {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
network := getNetwork(tt.addr)
|
||||
network := tt.opts.getNetwork()
|
||||
assert.Equal(t, tt.expected, network)
|
||||
})
|
||||
}
|
||||
@@ -233,12 +240,12 @@ func TestParseFlags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected cmdConfig
|
||||
expected cmdOptions
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
args: []string{"cmd"},
|
||||
expected: cmdConfig{
|
||||
expected: cmdOptions{
|
||||
key: "",
|
||||
addr: "",
|
||||
},
|
||||
@@ -246,7 +253,7 @@ func TestParseFlags(t *testing.T) {
|
||||
{
|
||||
name: "key flag only",
|
||||
args: []string{"cmd", "-key", "testkey"},
|
||||
expected: cmdConfig{
|
||||
expected: cmdOptions{
|
||||
key: "testkey",
|
||||
addr: "",
|
||||
},
|
||||
@@ -254,7 +261,7 @@ func TestParseFlags(t *testing.T) {
|
||||
{
|
||||
name: "addr flag only",
|
||||
args: []string{"cmd", "-addr", ":8080"},
|
||||
expected: cmdConfig{
|
||||
expected: cmdOptions{
|
||||
key: "",
|
||||
addr: ":8080",
|
||||
},
|
||||
@@ -262,7 +269,7 @@ func TestParseFlags(t *testing.T) {
|
||||
{
|
||||
name: "both flags",
|
||||
args: []string{"cmd", "-key", "testkey", "-addr", ":8080"},
|
||||
expected: cmdConfig{
|
||||
expected: cmdOptions{
|
||||
key: "testkey",
|
||||
addr: ":8080",
|
||||
},
|
||||
@@ -275,11 +282,11 @@ func TestParseFlags(t *testing.T) {
|
||||
flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError)
|
||||
os.Args = tt.args
|
||||
|
||||
var cfg cmdConfig
|
||||
parseFlags(&cfg)
|
||||
var opts cmdOptions
|
||||
opts.parseFlags()
|
||||
flag.Parse()
|
||||
|
||||
assert.Equal(t, tt.expected, cfg)
|
||||
assert.Equal(t, tt.expected, opts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -12,41 +12,41 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
type ServerOptions struct {
|
||||
Addr string
|
||||
Network string
|
||||
Keys []ssh.PublicKey
|
||||
}
|
||||
|
||||
func (a *Agent) StartServer(cfg ServerConfig) error {
|
||||
func (a *Agent) StartServer(opts ServerOptions) error {
|
||||
sshServer.Handle(a.handleSession)
|
||||
|
||||
slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network)
|
||||
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
|
||||
|
||||
switch cfg.Network {
|
||||
switch opts.Network {
|
||||
case "unix":
|
||||
// remove existing socket file if it exists
|
||||
if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(opts.Addr); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// prefix with : if only port was provided
|
||||
if !strings.Contains(cfg.Addr, ":") {
|
||||
cfg.Addr = ":" + cfg.Addr
|
||||
if !strings.Contains(opts.Addr, ":") {
|
||||
opts.Addr = ":" + opts.Addr
|
||||
}
|
||||
}
|
||||
|
||||
// Listen on the address
|
||||
ln, err := net.Listen(cfg.Network, cfg.Addr)
|
||||
ln, err := net.Listen(opts.Network, opts.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Start server on the listener
|
||||
// Start SSH server on the listener
|
||||
err = sshServer.Serve(ln, nil, sshServer.NoPty(),
|
||||
sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool {
|
||||
for _, pubKey := range cfg.Keys {
|
||||
for _, pubKey := range opts.Keys {
|
||||
if sshServer.KeysEqual(key, pubKey) {
|
||||
return true
|
||||
}
|
||||
|
@@ -35,7 +35,7 @@ func TestStartServer(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ServerConfig
|
||||
config ServerOptions
|
||||
wantErr bool
|
||||
errContains string
|
||||
setup func() error
|
||||
@@ -43,7 +43,7 @@ func TestStartServer(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "tcp port only",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -51,7 +51,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv4",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp4",
|
||||
Addr: "127.0.0.1:45988",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -59,7 +59,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tcp with ipv6",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp6",
|
||||
Addr: "[::1]:45989",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -67,7 +67,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "unix",
|
||||
Addr: socketFile,
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
@@ -86,7 +86,7 @@ func TestStartServer(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bad key should fail",
|
||||
config: ServerConfig{
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshBadPubKey},
|
||||
@@ -94,6 +94,14 @@ func TestStartServer(t *testing.T) {
|
||||
wantErr: true,
|
||||
errContains: "ssh: handshake failed",
|
||||
},
|
||||
{
|
||||
name: "good key still good",
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: "45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
Reference in New Issue
Block a user