Files
beszel/agent/server_test.go
2025-09-09 13:35:34 -04:00

607 lines
17 KiB
Go

package agent
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/henrygd/beszel/internal/entities/container"
"github.com/henrygd/beszel/internal/entities/system"
"github.com/blang/semver"
"github.com/fxamacker/cbor/v2"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"
)
func TestStartServer(t *testing.T) {
// Generate a test key pair
pubKey, privKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := gossh.NewSignerFromKey(privKey)
require.NoError(t, err)
sshPubKey, err := gossh.NewPublicKey(pubKey)
require.NoError(t, err)
// Generate a different key pair for bad key test
badPubKey, badPrivKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
badSigner, err := gossh.NewSignerFromKey(badPrivKey)
require.NoError(t, err)
sshBadPubKey, err := gossh.NewPublicKey(badPubKey)
require.NoError(t, err)
socketFile := filepath.Join(t.TempDir(), "beszel-test.sock")
tests := []struct {
name string
config ServerOptions
wantErr bool
errContains string
setup func() error
cleanup func() error
}{
{
name: "tcp port only",
config: ServerOptions{
Network: "tcp",
Addr: ":45987",
Keys: []gossh.PublicKey{sshPubKey},
},
},
{
name: "tcp with ipv4",
config: ServerOptions{
Network: "tcp4",
Addr: "127.0.0.1:45988",
Keys: []gossh.PublicKey{sshPubKey},
},
},
{
name: "tcp with ipv6",
config: ServerOptions{
Network: "tcp6",
Addr: "[::1]:45989",
Keys: []gossh.PublicKey{sshPubKey},
},
},
{
name: "unix socket",
config: ServerOptions{
Network: "unix",
Addr: socketFile,
Keys: []gossh.PublicKey{sshPubKey},
},
setup: func() error {
// Create a socket file that should be removed
f, err := os.Create(socketFile)
if err != nil {
return err
}
return f.Close()
},
cleanup: func() error {
return os.Remove(socketFile)
},
},
{
name: "bad key should fail",
config: ServerOptions{
Network: "tcp",
Addr: ":45987",
Keys: []gossh.PublicKey{sshBadPubKey},
},
wantErr: true,
errContains: "ssh: handshake failed",
},
{
name: "good key still good",
config: ServerOptions{
Network: "tcp",
Addr: ":45987",
Keys: []gossh.PublicKey{sshPubKey},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
err := tt.setup()
require.NoError(t, err)
}
if tt.cleanup != nil {
defer tt.cleanup()
}
agent, err := NewAgent("")
require.NoError(t, err)
// Start server in a goroutine since it blocks
errChan := make(chan error, 1)
go func() {
errChan <- agent.StartServer(tt.config)
}()
// Add a short delay to allow the server to start
time.Sleep(100 * time.Millisecond)
// Try to connect to verify server is running
var client *gossh.Client
// Choose the appropriate signer based on the test case
testSigner := signer
if tt.name == "bad key should fail" {
testSigner = badSigner
}
sshClientConfig := &gossh.ClientConfig{
User: "a",
Auth: []gossh.AuthMethod{
gossh.PublicKeys(testSigner),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Timeout: 4 * time.Second,
}
switch tt.config.Network {
case "unix":
client, err = gossh.Dial("unix", tt.config.Addr, sshClientConfig)
default:
if !strings.Contains(tt.config.Addr, ":") {
tt.config.Addr = ":" + tt.config.Addr
}
client, err = gossh.Dial("tcp", tt.config.Addr, sshClientConfig)
}
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
require.NotNil(t, client)
client.Close()
})
}
}
/////////////////////////////////////////////////////////////////
//////////////////// ParseKeys Tests ////////////////////////////
/////////////////////////////////////////////////////////////////
// Helper function to generate a temporary file with content
func createTempFile(content string) (string, error) {
tmpFile, err := os.CreateTemp("", "ssh_keys_*.txt")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
defer tmpFile.Close()
if _, err := tmpFile.WriteString(content); err != nil {
return "", fmt.Errorf("failed to write to temp file: %w", err)
}
return tmpFile.Name(), nil
}
// Test case 1: String with a single SSH key
func TestParseSingleKeyFromString(t *testing.T) {
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
keys, err := ParseKeys(input)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(keys) != 1 {
t.Fatalf("Expected 1 key, got %d keys", len(keys))
}
if keys[0].Type() != "ssh-ed25519" {
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
}
}
// Test case 2: String with multiple SSH keys
func TestParseMultipleKeysFromString(t *testing.T) {
input := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
keys, err := ParseKeys(input)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(keys) != 3 {
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
}
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
}
}
// Test case 3: File with a single SSH key
func TestParseSingleKeyFromFile(t *testing.T) {
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo"
filePath, err := createTempFile(content)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(filePath) // Clean up the file after the test
// Read the file content
fileContent, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read temp file: %v", err)
}
// Parse the keys
keys, err := ParseKeys(string(fileContent))
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(keys) != 1 {
t.Fatalf("Expected 1 key, got %d keys", len(keys))
}
if keys[0].Type() != "ssh-ed25519" {
t.Fatalf("Expected key type 'ssh-ed25519', got '%s'", keys[0].Type())
}
}
// Test case 4: File with multiple SSH keys
func TestParseMultipleKeysFromFile(t *testing.T) {
content := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKCBM91kukN7hbvFKtbpEeo2JXjCcNxXcdBH7V7ADMBo\nssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D \n #comment\n ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJDMtAOQfxDlCxe+A5lVbUY/DHxK1LAF2Z3AV0FYv36D"
filePath, err := createTempFile(content)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// defer os.Remove(filePath) // Clean up the file after the test
// Read the file content
fileContent, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read temp file: %v", err)
}
// Parse the keys
keys, err := ParseKeys(string(fileContent))
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if len(keys) != 3 {
t.Fatalf("Expected 3 keys, got %d keys", len(keys))
}
if keys[0].Type() != "ssh-ed25519" || keys[1].Type() != "ssh-ed25519" || keys[2].Type() != "ssh-ed25519" {
t.Fatalf("Unexpected key types: %s, %s, %s", keys[0].Type(), keys[1].Type(), keys[2].Type())
}
}
// Test case 5: Invalid SSH key input
func TestParseInvalidKey(t *testing.T) {
input := "invalid-key-data"
_, err := ParseKeys(input)
if err == nil {
t.Fatalf("Expected an error for invalid key, got nil")
}
expectedErrMsg := "failed to parse key"
if !strings.Contains(err.Error(), expectedErrMsg) {
t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err)
}
}
/////////////////////////////////////////////////////////////////
//////////////////// Hub Version Tests //////////////////////////
/////////////////////////////////////////////////////////////////
func TestExtractHubVersion(t *testing.T) {
tests := []struct {
name string
clientVersion string
expectedVersion string
expectError bool
}{
{
name: "valid beszel client version with underscore",
clientVersion: "SSH-2.0-beszel_0.11.1",
expectedVersion: "0.11.1",
expectError: false,
},
{
name: "valid beszel client version with beta",
clientVersion: "SSH-2.0-beszel_1.0.0-beta",
expectedVersion: "1.0.0-beta",
expectError: false,
},
{
name: "valid beszel client version with rc",
clientVersion: "SSH-2.0-beszel_0.12.0-rc1",
expectedVersion: "0.12.0-rc1",
expectError: false,
},
{
name: "different SSH client",
clientVersion: "SSH-2.0-OpenSSH_8.0",
expectedVersion: "8.0",
expectError: true,
},
{
name: "malformed version string without underscore",
clientVersion: "SSH-2.0-beszel",
expectError: true,
},
{
name: "empty version string",
clientVersion: "",
expectError: true,
},
{
name: "version string with underscore but no version",
clientVersion: "beszel_",
expectedVersion: "",
expectError: true,
},
{
name: "version with patch and build metadata",
clientVersion: "SSH-2.0-beszel_1.2.3+build.123",
expectedVersion: "1.2.3+build.123",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractHubVersion(tt.clientVersion)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedVersion, result.String())
})
}
}
/////////////////////////////////////////////////////////////////
/////////////// Hub Version Detection Tests ////////////////////
/////////////////////////////////////////////////////////////////
func TestGetHubVersion(t *testing.T) {
agent, err := NewAgent("")
require.NoError(t, err)
// Mock SSH context that implements the ssh.Context interface
mockCtx := &mockSSHContext{
sessionID: "test-session-123",
clientVersion: "SSH-2.0-beszel_0.12.0",
}
// Test first call - should extract and cache version
version := agent.getHubVersion("test-session-123", mockCtx)
assert.Equal(t, "0.12.0", version.String())
// Test second call - should return cached version
mockCtx.clientVersion = "SSH-2.0-beszel_0.11.0" // Change version but should still return cached
version = agent.getHubVersion("test-session-123", mockCtx)
assert.Equal(t, "0.12.0", version.String()) // Should still be cached version
// Test different session - should extract new version
version = agent.getHubVersion("different-session", mockCtx)
assert.Equal(t, "0.11.0", version.String())
// Test with invalid version string (non-beszel client)
mockCtx.clientVersion = "SSH-2.0-OpenSSH_8.0"
version = agent.getHubVersion("invalid-session", mockCtx)
assert.Equal(t, "0.0.0", version.String()) // Should be empty version for non-beszel clients
// Test with no client version
mockCtx.clientVersion = ""
version = agent.getHubVersion("no-version-session", mockCtx)
assert.True(t, version.EQ(semver.Version{})) // Should be empty version
}
// mockSSHContext implements ssh.Context for testing
type mockSSHContext struct {
context.Context
sync.Mutex
sessionID string
clientVersion string
}
func (m *mockSSHContext) SessionID() string {
return m.sessionID
}
func (m *mockSSHContext) ClientVersion() string {
return m.clientVersion
}
func (m *mockSSHContext) ServerVersion() string {
return "SSH-2.0-beszel_test"
}
func (m *mockSSHContext) Value(key interface{}) interface{} {
if key == ssh.ContextKeyClientVersion {
return m.clientVersion
}
return nil
}
func (m *mockSSHContext) User() string { return "test-user" }
func (m *mockSSHContext) RemoteAddr() net.Addr { return nil }
func (m *mockSSHContext) LocalAddr() net.Addr { return nil }
func (m *mockSSHContext) Permissions() *ssh.Permissions { return nil }
func (m *mockSSHContext) SetValue(key, value interface{}) {}
/////////////////////////////////////////////////////////////////
/////////////// CBOR vs JSON Encoding Tests ////////////////////
/////////////////////////////////////////////////////////////////
// TestWriteToSessionEncoding tests that writeToSession actually encodes data in the correct format
func TestWriteToSessionEncoding(t *testing.T) {
tests := []struct {
name string
hubVersion string
expectedUsesCbor bool
}{
{
name: "old hub version should use JSON",
hubVersion: "0.11.1",
expectedUsesCbor: false,
},
{
name: "non-beta release should use CBOR",
hubVersion: "0.12.0",
expectedUsesCbor: true,
},
{
name: "even newer hub version should use CBOR",
hubVersion: "0.16.4",
expectedUsesCbor: true,
},
{
name: "beta version below release threshold should use JSON",
hubVersion: "0.12.0-beta0",
expectedUsesCbor: false,
},
// {
// name: "matching beta version should use CBOR",
// hubVersion: "0.12.0-beta2",
// expectedUsesCbor: true,
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset the global hubVersions map to ensure clean state for each test
hubVersions = nil
agent, err := NewAgent("")
require.NoError(t, err)
// Parse the test version
version, err := semver.Parse(tt.hubVersion)
require.NoError(t, err)
// Create test data to encode
testData := createTestCombinedData()
var buf strings.Builder
err = agent.writeToSession(&buf, testData, version)
require.NoError(t, err)
encodedData := buf.String()
require.NotEmpty(t, encodedData)
// Verify the encoding format by attempting to decode
if tt.expectedUsesCbor {
var decodedCbor system.CombinedData
err = cbor.Unmarshal([]byte(encodedData), &decodedCbor)
assert.NoError(t, err, "Should be valid CBOR data")
var decodedJson system.CombinedData
err = json.Unmarshal([]byte(encodedData), &decodedJson)
assert.Error(t, err, "Should not be valid JSON data")
assert.Equal(t, testData.Info.Hostname, decodedCbor.Info.Hostname)
assert.Equal(t, testData.Stats.Cpu, decodedCbor.Stats.Cpu)
} else {
// Should be JSON - try to decode as JSON
var decodedJson system.CombinedData
err = json.Unmarshal([]byte(encodedData), &decodedJson)
assert.NoError(t, err, "Should be valid JSON data")
var decodedCbor system.CombinedData
err = cbor.Unmarshal([]byte(encodedData), &decodedCbor)
assert.Error(t, err, "Should not be valid CBOR data")
// Verify the decoded JSON data matches our test data
assert.Equal(t, testData.Info.Hostname, decodedJson.Info.Hostname)
assert.Equal(t, testData.Stats.Cpu, decodedJson.Stats.Cpu)
// Verify it looks like JSON (starts with '{' and contains readable field names)
assert.True(t, strings.HasPrefix(encodedData, "{"), "JSON should start with '{'")
assert.Contains(t, encodedData, `"info"`, "JSON should contain readable field names")
assert.Contains(t, encodedData, `"stats"`, "JSON should contain readable field names")
}
})
}
}
// Helper function to create test data for encoding tests
func createTestCombinedData() *system.CombinedData {
return &system.CombinedData{
Stats: system.Stats{
Cpu: 25.5,
Mem: 8589934592, // 8GB
MemUsed: 4294967296, // 4GB
MemPct: 50.0,
DiskTotal: 1099511627776, // 1TB
DiskUsed: 549755813888, // 512GB
DiskPct: 50.0,
},
Info: system.Info{
Hostname: "test-host",
Cores: 8,
CpuModel: "Test CPU Model",
Uptime: 3600,
AgentVersion: "0.12.0",
Os: system.Linux,
},
Containers: []*container.Stats{
{
Name: "test-container",
Cpu: 10.5,
Mem: 1073741824, // 1GB
},
},
}
}
func TestHubVersionCaching(t *testing.T) {
// Reset the global hubVersions map to ensure clean state
hubVersions = nil
agent, err := NewAgent("")
require.NoError(t, err)
ctx1 := &mockSSHContext{
sessionID: "session1",
clientVersion: "SSH-2.0-beszel_0.12.0",
}
ctx2 := &mockSSHContext{
sessionID: "session2",
clientVersion: "SSH-2.0-beszel_0.11.0",
}
// First calls should cache the versions
v1 := agent.getHubVersion("session1", ctx1)
v2 := agent.getHubVersion("session2", ctx2)
assert.Equal(t, "0.12.0", v1.String())
assert.Equal(t, "0.11.0", v2.String())
// Verify caching by changing context but keeping same session ID
ctx1.clientVersion = "SSH-2.0-beszel_0.10.0"
v1Cached := agent.getHubVersion("session1", ctx1)
assert.Equal(t, "0.12.0", v1Cached.String()) // Should still be cached version
// New session should get new version
ctx3 := &mockSSHContext{
sessionID: "session3",
clientVersion: "SSH-2.0-beszel_0.13.0",
}
v3 := agent.getHubVersion("session3", ctx3)
assert.Equal(t, "0.13.0", v3.String())
}