From d170e7a00d71a8a1bf05487b667940f5f02326b4 Mon Sep 17 00:00:00 2001 From: henrygd Date: Wed, 19 Feb 2025 00:32:27 -0500 Subject: [PATCH] feat(agent): NETWORK env var and support for multiple keys - merges agent.Run with agent.NewAgent - separates StartServer method - bumps go version to 1.24 - add tests --- beszel/cmd/agent/agent.go | 161 +++++++++------ beszel/cmd/agent/agent_test.go | 285 +++++++++++++++++++++++++++ beszel/go.mod | 5 +- beszel/internal/agent/agent.go | 54 +++-- beszel/internal/agent/server.go | 85 ++++++-- beszel/internal/agent/server_test.go | 281 ++++++++++++++++++++++++++ 6 files changed, 771 insertions(+), 100 deletions(-) create mode 100644 beszel/cmd/agent/agent_test.go create mode 100644 beszel/internal/agent/server_test.go diff --git a/beszel/cmd/agent/agent.go b/beszel/cmd/agent/agent.go index 9ad4834..f789b54 100644 --- a/beszel/cmd/agent/agent.go +++ b/beszel/cmd/agent/agent.go @@ -8,12 +8,19 @@ import ( "log" "os" "strings" + + "golang.org/x/crypto/ssh" ) -func main() { - // Define flags for key and port - keyFlag := flag.String("key", "", "Public key") - portFlag := flag.String("port", "45876", "Port number") +type cmdConfig struct { + key string // key is the public key(s) for SSH authentication. + addr string // addr is the address or port to listen on. +} + +// parseFlags parses the command line flags and populates the config struct. +func parseFlags(cfg *cmdConfig) { + flag.StringVar(&cfg.key, "key", "", "Public key(s) for SSH authentication") + flag.StringVar(&cfg.addr, "addr", "", "Address or port to listen on") flag.Usage = func() { fmt.Printf("Usage: %s [options] [subcommand]\n", os.Args[0]) @@ -24,65 +31,103 @@ func main() { fmt.Println(" help Display this help message") fmt.Println(" update Update the agent to the latest version") } +} + +// handleSubcommand handles subcommands such as version, help, and update. +// It returns true if a subcommand was handled, false otherwise. +func handleSubcommand() bool { + if len(os.Args) <= 1 { + return false + } + switch os.Args[1] { + case "version", "-v": + fmt.Println(beszel.AppName+"-agent", beszel.Version) + os.Exit(0) + case "help": + flag.Usage() + os.Exit(0) + case "update": + agent.Update() + os.Exit(0) + } + return false +} + +// loadPublicKeys loads the public keys from the command line flag, environment variable, or key file. +func loadPublicKeys(cfg cmdConfig) ([]ssh.PublicKey, error) { + // Try command line flag first + if cfg.key != "" { + return agent.ParseKeys(cfg.key) + } + + // Try environment variable + if key, ok := agent.GetEnv("KEY"); ok && key != "" { + return agent.ParseKeys(key) + } + + // Try key file + keyFile, ok := agent.GetEnv("KEY_FILE") + if !ok { + return nil, fmt.Errorf("no key provided: must set -key flag, KEY env var, or KEY_FILE env var. ") + } + + pubKey, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("failed to read key file: %w", err) + } + return agent.ParseKeys(string(pubKey)) +} + +// getAddress gets the address to listen on from the command line flag, environment variable, or default value. +func getAddress(addr string) string { + // Try command line flag first + if addr != "" { + return addr + } + // Try environment variables + if addr, ok := agent.GetEnv("ADDR"); ok && addr != "" { + return addr + } + // Legacy PORT environment variable support + if port, ok := agent.GetEnv("PORT"); ok && port != "" { + return port + } + return ":45876" +} + +// getNetwork returns the network type to use for the server. +func getNetwork(addr string) string { + if network, _ := agent.GetEnv("NETWORK"); network != "" { + return network + } + if strings.HasPrefix(addr, "/") { + return "unix" + } + return "tcp" +} + +func main() { + var cfg cmdConfig + parseFlags(&cfg) + + if handleSubcommand() { + return + } - // Parse the flags flag.Parse() - // handle flags / subcommands - if len(os.Args) > 1 { - switch os.Args[1] { - case "version": - fmt.Println(beszel.AppName+"-agent", beszel.Version) - os.Exit(0) - case "help": - flag.Usage() - os.Exit(0) - case "update": - agent.Update() - os.Exit(0) - } + var serverConfig agent.ServerConfig + var err error + serverConfig.Keys, err = loadPublicKeys(cfg) + if err != nil { + log.Fatal("Failed to load public keys:", err) } - var pubKey []byte - // Override the key if the -key flag is provided - if *keyFlag != "" { - pubKey = []byte(*keyFlag) - } else { - // Try to get the key from the KEY environment variable. - key, _ := agent.GetEnv("KEY") - pubKey = []byte(key) - } + serverConfig.Addr = getAddress(cfg.addr) + serverConfig.Network = getNetwork(cfg.addr) - // If KEY is not set, try to read the key from the file specified by KEY_FILE. - if len(pubKey) == 0 { - keyFile, exists := agent.GetEnv("KEY_FILE") - if !exists { - log.Fatal("Must set KEY or KEY_FILE environment variable or supply as input argument. Use 'beszel-agent help' for more information.") - } - var err error - pubKey, err = os.ReadFile(keyFile) - if err != nil { - log.Fatal(err) - } + agent := agent.NewAgent() + if err := agent.StartServer(serverConfig); err != nil { + log.Fatal("Failed to start server:", err) } - - // Init with default port - addr := ":" + *portFlag - - //Use port from ENV if it exists - // TODO: change env var to ADDR - if portEnvVar, exists := agent.GetEnv("PORT"); exists { - // allow passing an address in the form of "127.0.0.1:45876" - if !strings.Contains(portEnvVar, ":") { - portEnvVar = ":" + portEnvVar - } - addr = portEnvVar - } - - // Override the default and ENV port if the -port flag is provided and is non default - if *portFlag != "45876" { - addr = ":" + *portFlag - } - - agent.NewAgent().Run(pubKey, addr) } diff --git a/beszel/cmd/agent/agent_test.go b/beszel/cmd/agent/agent_test.go new file mode 100644 index 0000000..94582f8 --- /dev/null +++ b/beszel/cmd/agent/agent_test.go @@ -0,0 +1,285 @@ +package main + +import ( + "crypto/ed25519" + "flag" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestGetAddress(t *testing.T) { + tests := []struct { + name string + cfg cmdConfig + envVars map[string]string + expected string + }{ + { + name: "default port when no config", + cfg: cmdConfig{}, + expected: ":45876", + }, + { + name: "use address from flag", + cfg: cmdConfig{ + addr: "8080", + }, + expected: "8080", + }, + { + name: "use unix socket from flag", + cfg: cmdConfig{ + addr: "/tmp/beszel.sock", + }, + expected: "/tmp/beszel.sock", + }, + { + name: "use ADDR env var", + cfg: cmdConfig{}, + envVars: map[string]string{ + "ADDR": "1.2.3.4:9090", + }, + expected: "1.2.3.4:9090", + }, + { + name: "use legacy PORT env var", + cfg: cmdConfig{}, + envVars: map[string]string{ + "PORT": "7070", + }, + expected: "7070", + }, + { + name: "flag takes precedence over env vars", + cfg: cmdConfig{ + addr: ":8080", + }, + envVars: map[string]string{ + "ADDR": ":9090", + "PORT": "7070", + }, + expected: ":8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup environment + for k, v := range tt.envVars { + t.Setenv(k, v) + } + + addr := getAddress(tt.cfg.addr) + assert.Equal(t, tt.expected, addr) + }) + } +} + +func TestLoadPublicKeys(t *testing.T) { + // Generate a test key + _, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey()) + + tests := []struct { + name string + cfg cmdConfig + envVars map[string]string + setupFiles map[string][]byte + wantErr bool + errContains string + }{ + { + name: "load key from flag", + cfg: cmdConfig{ + key: string(pubKey), + }, + }, + { + name: "load key from env var", + envVars: map[string]string{ + "KEY": string(pubKey), + }, + }, + { + name: "load key from file", + envVars: map[string]string{ + "KEY_FILE": "testkey.pub", + }, + setupFiles: map[string][]byte{ + "testkey.pub": pubKey, + }, + }, + { + name: "error when no key provided", + wantErr: true, + errContains: "no key provided", + }, + { + name: "error on invalid key file", + envVars: map[string]string{ + "KEY_FILE": "nonexistent.pub", + }, + wantErr: true, + errContains: "failed to read key file", + }, + { + name: "error on invalid key data", + cfg: cmdConfig{ + key: "invalid-key-data", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary directory for test files + if len(tt.setupFiles) > 0 { + tmpDir := t.TempDir() + for name, content := range tt.setupFiles { + path := filepath.Join(tmpDir, name) + err := os.WriteFile(path, content, 0600) + require.NoError(t, err) + if tt.envVars != nil { + tt.envVars["KEY_FILE"] = path + } + } + } + + // Set up environment + for k, v := range tt.envVars { + t.Setenv(k, v) + } + + keys, err := loadPublicKeys(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Len(t, keys, 1) + assert.Equal(t, signer.PublicKey().Type(), keys[0].Type()) + }) + } +} + +func TestGetNetwork(t *testing.T) { + tests := []struct { + name string + addr string + envVars map[string]string + expected string + }{ + { + name: "only port", + addr: "8080", + expected: "tcp", + }, + { + name: "ipv4 address", + addr: "1.2.3.4:8080", + expected: "tcp", + }, + { + name: "ipv6 address", + addr: "[2001:db8::1]:8080", + expected: "tcp", + }, + { + name: "unix network", + addr: "/tmp/beszel.sock", + expected: "unix", + }, + { + name: "env var network", + addr: ":8080", + envVars: map[string]string{"NETWORK": "tcp4"}, + expected: "tcp4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup environment + for k, v := range tt.envVars { + t.Setenv(k, v) + } + network := getNetwork(tt.addr) + assert.Equal(t, tt.expected, network) + }) + } +} + +func TestParseFlags(t *testing.T) { + // Save original command line arguments and restore after test + oldArgs := os.Args + defer func() { + os.Args = oldArgs + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + }() + + tests := []struct { + name string + args []string + expected cmdConfig + }{ + { + name: "no flags", + args: []string{"cmd"}, + expected: cmdConfig{ + key: "", + addr: "", + }, + }, + { + name: "key flag only", + args: []string{"cmd", "-key", "testkey"}, + expected: cmdConfig{ + key: "testkey", + addr: "", + }, + }, + { + name: "addr flag only", + args: []string{"cmd", "-addr", ":8080"}, + expected: cmdConfig{ + key: "", + addr: ":8080", + }, + }, + { + name: "both flags", + args: []string{"cmd", "-key", "testkey", "-addr", ":8080"}, + expected: cmdConfig{ + key: "testkey", + addr: ":8080", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset flags for each test + flag.CommandLine = flag.NewFlagSet(tt.args[0], flag.ExitOnError) + os.Args = tt.args + + var cfg cmdConfig + parseFlags(&cfg) + flag.Parse() + + assert.Equal(t, tt.expected, cfg) + }) + } +} diff --git a/beszel/go.mod b/beszel/go.mod index dc8c56f..26595f7 100644 --- a/beszel/go.mod +++ b/beszel/go.mod @@ -1,8 +1,6 @@ module beszel -go 1.23 - -toolchain go1.23.2 +go 1.24 require ( github.com/blang/semver v3.5.1+incompatible @@ -15,6 +13,7 @@ require ( github.com/shirou/gopsutil/v4 v4.25.1 github.com/spf13/cast v1.7.1 github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c gopkg.in/yaml.v3 v3.0.1 diff --git a/beszel/internal/agent/agent.go b/beszel/internal/agent/agent.go index 153d411..2ed693c 100644 --- a/beszel/internal/agent/agent.go +++ b/beszel/internal/agent/agent.go @@ -28,29 +28,16 @@ type Agent struct { } func NewAgent() *Agent { - newAgent := &Agent{ - sensorsContext: context.Background(), - fsStats: make(map[string]*system.FsStats), + agent := &Agent{ + fsStats: make(map[string]*system.FsStats), } - newAgent.memCalc, _ = GetEnv("MEM_CALC") - return newAgent -} + agent.memCalc, _ = GetEnv("MEM_CALC") -// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key. -func GetEnv(key string) (value string, exists bool) { - if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists { - return value, exists - } - // Fallback to the old unprefixed key - return os.LookupEnv(key) -} - -func (a *Agent) Run(pubKey []byte, addr string) { // Set up slog with a log level determined by the LOG_LEVEL env var if logLevelStr, exists := GetEnv("LOG_LEVEL"); exists { switch strings.ToLower(logLevelStr) { case "debug": - a.debug = true + agent.debug = true slog.SetLogLoggerLevel(slog.LevelDebug) case "warn": slog.SetLogLoggerLevel(slog.LevelWarn) @@ -64,40 +51,51 @@ func (a *Agent) Run(pubKey []byte, addr string) { // Set sensors context (allows overriding sys location for sensors) if sysSensors, exists := GetEnv("SYS_SENSORS"); exists { slog.Info("SYS_SENSORS", "path", sysSensors) - a.sensorsContext = context.WithValue(a.sensorsContext, + agent.sensorsContext = context.WithValue(agent.sensorsContext, common.EnvKey, common.EnvMap{common.HostSysEnvKey: sysSensors}, ) + } else { + agent.sensorsContext = context.Background() } // Set sensors whitelist if sensors, exists := GetEnv("SENSORS"); exists { - a.sensorsWhitelist = make(map[string]struct{}) + agent.sensorsWhitelist = make(map[string]struct{}) for _, sensor := range strings.Split(sensors, ",") { if sensor != "" { - a.sensorsWhitelist[sensor] = struct{}{} + agent.sensorsWhitelist[sensor] = struct{}{} } } } // initialize system info / docker manager - a.initializeSystemInfo() - a.initializeDiskInfo() - a.initializeNetIoStats() - a.dockerManager = newDockerManager(a) + agent.initializeSystemInfo() + agent.initializeDiskInfo() + agent.initializeNetIoStats() + agent.dockerManager = newDockerManager(agent) // initialize GPU manager if gm, err := NewGPUManager(); err != nil { slog.Debug("GPU", "err", err) } else { - a.gpuManager = gm + agent.gpuManager = gm } // if debugging, print stats - if a.debug { - slog.Debug("Stats", "data", a.gatherStats()) + if agent.debug { + slog.Debug("Stats", "data", agent.gatherStats()) } - a.startServer(pubKey, addr) + return agent +} + +// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key. +func GetEnv(key string) (value string, exists bool) { + if value, exists = os.LookupEnv("BESZEL_AGENT_" + key); exists { + return value, exists + } + // Fallback to the old unprefixed key + return os.LookupEnv(key) } func (a *Agent) gatherStats() system.CombinedData { diff --git a/beszel/internal/agent/server.go b/beszel/internal/agent/server.go index 67c8a88..1830fc6 100644 --- a/beszel/internal/agent/server.go +++ b/beszel/internal/agent/server.go @@ -2,33 +2,96 @@ package agent import ( "encoding/json" + "fmt" "log/slog" + "net" "os" + "strings" sshServer "github.com/gliderlabs/ssh" + "golang.org/x/crypto/ssh" ) -func (a *Agent) startServer(pubKey []byte, addr string) { +type ServerConfig struct { + Addr string + Network string + Keys []ssh.PublicKey +} + +func (a *Agent) StartServer(cfg ServerConfig) error { sshServer.Handle(a.handleSession) - slog.Info("Starting SSH server", "address", addr) - if err := sshServer.ListenAndServe(addr, nil, sshServer.NoPty(), - sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { - allowed, _, _, _, _ := sshServer.ParseAuthorizedKey(pubKey) - return sshServer.KeysEqual(key, allowed) - }), - ); err != nil { - slog.Error("Error starting SSH server", "err", err) - os.Exit(1) + slog.Info("Starting SSH server", "addr", cfg.Addr, "network", cfg.Network) + + switch cfg.Network { + case "unix": + // remove existing socket file if it exists + if err := os.Remove(cfg.Addr); err != nil && !os.IsNotExist(err) { + return err + } + default: + // prefix with : if only port was provided + if !strings.Contains(cfg.Addr, ":") { + cfg.Addr = ":" + cfg.Addr + } } + + // Listen on the address + ln, err := net.Listen(cfg.Network, cfg.Addr) + if err != nil { + return err + } + defer ln.Close() + + // Start server on the listener + err = sshServer.Serve(ln, nil, sshServer.NoPty(), + sshServer.PublicKeyAuth(func(ctx sshServer.Context, key sshServer.PublicKey) bool { + for _, pubKey := range cfg.Keys { + if sshServer.KeysEqual(key, pubKey) { + return true + } + } + return false + }), + ) + if err != nil { + return err + } + return nil } func (a *Agent) handleSession(s sshServer.Session) { + // slog.Debug("connection", "remoteaddr", s.RemoteAddr(), "user", s.User()) stats := a.gatherStats() if err := json.NewEncoder(s).Encode(stats); err != nil { slog.Error("Error encoding stats", "err", err, "stats", stats) s.Exit(1) - return } s.Exit(0) } + +// ParseKeys parses a string containing SSH public keys in authorized_keys format. +// It returns a slice of ssh.PublicKey and an error if any key fails to parse. +func ParseKeys(input string) ([]ssh.PublicKey, error) { + var parsedKeys []ssh.PublicKey + + for line := range strings.Lines(input) { + line = strings.TrimSpace(line) + + // Skip empty lines or comments + if len(line) == 0 || strings.HasPrefix(line, "#") { + continue + } + + // Parse the key + parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line)) + if err != nil { + return nil, fmt.Errorf("failed to parse key: %s, error: %w", line, err) + } + + // Append the parsed key to the list + parsedKeys = append(parsedKeys, parsedKey) + } + + return parsedKeys, nil +} diff --git a/beszel/internal/agent/server_test.go b/beszel/internal/agent/server_test.go new file mode 100644 index 0000000..41b3399 --- /dev/null +++ b/beszel/internal/agent/server_test.go @@ -0,0 +1,281 @@ +package agent + +import ( + "crypto/ed25519" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestStartServer(t *testing.T) { + // Generate a test key pair + pubKey, privKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(privKey) + require.NoError(t, err) + sshPubKey, err := ssh.NewPublicKey(pubKey) + require.NoError(t, err) + + // Generate a different key pair for bad key test + badPubKey, badPrivKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + badSigner, err := ssh.NewSignerFromKey(badPrivKey) + require.NoError(t, err) + sshBadPubKey, err := ssh.NewPublicKey(badPubKey) + require.NoError(t, err) + + socketFile := filepath.Join(t.TempDir(), "beszel-test.sock") + + tests := []struct { + name string + config ServerConfig + wantErr bool + errContains string + setup func() error + cleanup func() error + }{ + { + name: "tcp port only", + config: ServerConfig{ + Network: "tcp", + Addr: "45987", + Keys: []ssh.PublicKey{sshPubKey}, + }, + }, + { + name: "tcp with ipv4", + config: ServerConfig{ + Network: "tcp4", + Addr: "127.0.0.1:45988", + Keys: []ssh.PublicKey{sshPubKey}, + }, + }, + { + name: "tcp with ipv6", + config: ServerConfig{ + Network: "tcp6", + Addr: "[::1]:45989", + Keys: []ssh.PublicKey{sshPubKey}, + }, + }, + { + name: "unix socket", + config: ServerConfig{ + Network: "unix", + Addr: socketFile, + Keys: []ssh.PublicKey{sshPubKey}, + }, + setup: func() error { + // Create a socket file that should be removed + f, err := os.Create(socketFile) + if err != nil { + return err + } + return f.Close() + }, + cleanup: func() error { + return os.Remove(socketFile) + }, + }, + { + name: "bad key should fail", + config: ServerConfig{ + Network: "tcp", + Addr: "45987", + Keys: []ssh.PublicKey{sshBadPubKey}, + }, + wantErr: true, + errContains: "ssh: handshake failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + err := tt.setup() + require.NoError(t, err) + } + + if tt.cleanup != nil { + defer tt.cleanup() + } + + agent := NewAgent() + + // Start server in a goroutine since it blocks + errChan := make(chan error, 1) + go func() { + errChan <- agent.StartServer(tt.config) + }() + + // Add a short delay to allow the server to start + time.Sleep(100 * time.Millisecond) + + // Try to connect to verify server is running + var client *ssh.Client + var err error + + // Choose the appropriate signer based on the test case + testSigner := signer + if tt.name == "bad key should fail" { + testSigner = badSigner + } + + sshClientConfig := &ssh.ClientConfig{ + User: "a", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(testSigner), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 4 * time.Second, + } + + switch tt.config.Network { + case "unix": + client, err = ssh.Dial("unix", tt.config.Addr, sshClientConfig) + default: + if !strings.Contains(tt.config.Addr, ":") { + tt.config.Addr = ":" + tt.config.Addr + } + client, err = ssh.Dial("tcp", tt.config.Addr, sshClientConfig) + } + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, client) + client.Close() + }) + } +} + +///////////////////////////////////////////////////////////////// +//////////////////// ParseKeys Tests //////////////////////////// +///////////////////////////////////////////////////////////////// + +// Helper function to generate a temporary file with content +func createTempFile(content string) (string, error) { + tmpFile, err := os.CreateTemp("", "ssh_keys_*.txt") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + defer tmpFile.Close() + + if _, err := tmpFile.WriteString(content); err != nil { + return "", fmt.Errorf("failed to write to temp file: %w", err) + } + + return tmpFile.Name(), nil +} + +// Test case 1: String with a single SSH key +func TestParseSingleKeyFromString(t *testing.T) { + input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo" + keys, err := ParseKeys(input) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if len(keys) != 1 { + t.Fatalf("Expected 1 key, got %d keys", len(keys)) + } + if keys[0].Type() != "ssh-ed25519" { + t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type()) + } +} + +// Test case 2: String with multiple SSH keys +func TestParseMultipleKeysFromString(t *testing.T) { + input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D" + keys, err := ParseKeys(input) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if len(keys) != 3 { + t.Fatalf("Expected 3 keys, got %d keys", len(keys)) + } + if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" { + t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type()) + } +} + +// Test case 3: File with a single SSH key +func TestParseSingleKeyFromFile(t *testing.T) { + content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo" + filePath, err := createTempFile(content) + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(filePath) // Clean up the file after the test + + // Read the file content + fileContent, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read temp file: %v", err) + } + + // Parse the keys + keys, err := ParseKeys(string(fileContent)) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if len(keys) != 1 { + t.Fatalf("Expected 1 key, got %d keys", len(keys)) + } + if keys[0].Type() != "ssh-ed25519" { + t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type()) + } +} + +// Test case 4: File with multiple SSH keys +func TestParseMultipleKeysFromFile(t *testing.T) { + content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D" + filePath, err := createTempFile(content) + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + // defer os.Remove(filePath) // Clean up the file after the test + + // Read the file content + fileContent, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read temp file: %v", err) + } + + // Parse the keys + keys, err := ParseKeys(string(fileContent)) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if len(keys) != 3 { + t.Fatalf("Expected 3 keys, got %d keys", len(keys)) + } + if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" { + t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type()) + } +} + +// Test case 5: Invalid SSH key input +func TestParseInvalidKey(t *testing.T) { + input := "invalid-key-data" + _, err := ParseKeys(input) + if err == nil { + t.Fatalf("Expected an error for invalid key, got nil") + } + expectedErrMsg := "failed to parse key" + if !strings.Contains(err.Error(), expectedErrMsg) { + t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err) + } +}