diff --git a/beszel/internal/agent/client_test.go b/beszel/internal/agent/client_test.go new file mode 100644 index 0000000..cd90038 --- /dev/null +++ b/beszel/internal/agent/client_test.go @@ -0,0 +1,391 @@ +//go:build testing +// +build testing + +package agent + +import ( + "beszel" + "beszel/internal/common" + "crypto/ed25519" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +// TestNewWebSocketClient tests WebSocket client creation +func TestNewWebSocketClient(t *testing.T) { + agent := createTestAgent(t) + + testCases := []struct { + name string + hubURL string + token string + expectError bool + errorMsg string + }{ + { + name: "valid configuration", + hubURL: "http://localhost:8080", + token: "test-token-123", + expectError: false, + }, + { + name: "valid https URL", + hubURL: "https://hub.example.com", + token: "secure-token", + expectError: false, + }, + { + name: "missing hub URL", + hubURL: "", + token: "test-token", + expectError: true, + errorMsg: "HUB_URL environment variable not set", + }, + { + name: "invalid URL", + hubURL: "ht\ttp://invalid", + token: "test-token", + expectError: true, + errorMsg: "invalid hub URL", + }, + { + name: "missing token", + hubURL: "http://localhost:8080", + token: "", + expectError: true, + errorMsg: "TOKEN environment variable not set", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up environment + if tc.hubURL != "" { + os.Setenv("BESZEL_AGENT_HUB_URL", tc.hubURL) + } else { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + } + if tc.token != "" { + os.Setenv("BESZEL_AGENT_TOKEN", tc.token) + } else { + os.Unsetenv("BESZEL_AGENT_TOKEN") + } + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + + if tc.expectError { + assert.Error(t, err) + if err != nil && tc.errorMsg != "" { + assert.Contains(t, err.Error(), tc.errorMsg) + } + assert.Nil(t, client) + } else { + require.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, agent, client.agent) + assert.Equal(t, tc.token, client.token) + assert.Equal(t, tc.hubURL, client.hubURL.String()) + assert.NotEmpty(t, client.fingerprint) + assert.NotNil(t, client.hubRequest) + } + }) + } +} + +// TestWebSocketClient_GetOptions tests WebSocket client options configuration +func TestWebSocketClient_GetOptions(t *testing.T) { + agent := createTestAgent(t) + + testCases := []struct { + name string + inputURL string + expectedScheme string + expectedPath string + }{ + { + name: "http to ws conversion", + inputURL: "http://localhost:8080", + expectedScheme: "ws", + expectedPath: "/api/beszel/agent-connect", + }, + { + name: "https to wss conversion", + inputURL: "https://hub.example.com", + expectedScheme: "wss", + expectedPath: "/api/beszel/agent-connect", + }, + { + name: "existing path preservation", + inputURL: "http://localhost:8080/custom/path", + expectedScheme: "ws", + expectedPath: "/custom/path/api/beszel/agent-connect", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up environment + os.Setenv("BESZEL_AGENT_HUB_URL", tc.inputURL) + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + require.NoError(t, err) + + options := client.getOptions() + + // Parse the WebSocket URL + wsURL, err := url.Parse(options.Addr) + require.NoError(t, err) + + assert.Equal(t, tc.expectedScheme, wsURL.Scheme) + assert.Equal(t, tc.expectedPath, wsURL.Path) + + // Check headers + assert.Equal(t, "test-token", options.RequestHeader.Get("X-Token")) + assert.Equal(t, beszel.Version, options.RequestHeader.Get("X-Beszel")) + assert.Contains(t, options.RequestHeader.Get("User-Agent"), "Mozilla/5.0") + + // Test options caching + options2 := client.getOptions() + assert.Same(t, options, options2, "Options should be cached") + }) + } +} + +// TestWebSocketClient_VerifySignature tests signature verification +func TestWebSocketClient_VerifySignature(t *testing.T) { + agent := createTestAgent(t) + + // Generate test key pairs + _, goodPrivKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + goodPubKey, err := ssh.NewPublicKey(goodPrivKey.Public().(ed25519.PublicKey)) + require.NoError(t, err) + + _, badPrivKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + badPubKey, err := ssh.NewPublicKey(badPrivKey.Public().(ed25519.PublicKey)) + require.NoError(t, err) + + // Set up environment + os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + require.NoError(t, err) + + testCases := []struct { + name string + keys []ssh.PublicKey + token string + signWith ed25519.PrivateKey + expectError bool + }{ + { + name: "valid signature with correct key", + keys: []ssh.PublicKey{goodPubKey}, + token: "test-token", + signWith: goodPrivKey, + expectError: false, + }, + { + name: "invalid signature with wrong key", + keys: []ssh.PublicKey{goodPubKey}, + token: "test-token", + signWith: badPrivKey, + expectError: true, + }, + { + name: "valid signature with multiple keys", + keys: []ssh.PublicKey{badPubKey, goodPubKey}, + token: "test-token", + signWith: goodPrivKey, + expectError: false, + }, + { + name: "no valid keys", + keys: []ssh.PublicKey{badPubKey}, + token: "test-token", + signWith: goodPrivKey, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up agent with test keys + agent.keys = tc.keys + client.token = tc.token + + // Create signature + signature := ed25519.Sign(tc.signWith, []byte(tc.token)) + + err := client.verifySignature(signature) + + if tc.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid signature") + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestWebSocketClient_HandleHubRequest tests hub request routing (basic verification logic) +func TestWebSocketClient_HandleHubRequest(t *testing.T) { + agent := createTestAgent(t) + + // Set up environment + os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + require.NoError(t, err) + + testCases := []struct { + name string + action common.WebSocketAction + hubVerified bool + expectError bool + errorMsg string + }{ + { + name: "CheckFingerprint without verification", + action: common.CheckFingerprint, + hubVerified: false, + expectError: false, // CheckFingerprint is allowed without verification + }, + { + name: "GetData without verification", + action: common.GetData, + hubVerified: false, + expectError: true, + errorMsg: "hub not verified", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client.hubVerified = tc.hubVerified + + // Create minimal request + hubRequest := &common.HubRequest[cbor.RawMessage]{ + Action: tc.action, + Data: cbor.RawMessage{}, + } + + err := client.handleHubRequest(hubRequest) + + if tc.expectError { + assert.Error(t, err) + if tc.errorMsg != "" { + assert.Contains(t, err.Error(), tc.errorMsg) + } + } else { + // For CheckFingerprint, we expect a decode error since we're not providing valid data, + // but it shouldn't be the "hub not verified" error + if err != nil && tc.errorMsg != "" { + assert.NotContains(t, err.Error(), tc.errorMsg) + } + } + }) + } +} + +// TestWebSocketClient_GetUserAgent tests user agent generation +func TestGetUserAgent(t *testing.T) { + // Run multiple times to check both variants + userAgents := make(map[string]bool) + + for range 20 { + ua := getUserAgent() + userAgents[ua] = true + + // Check that it's a valid Mozilla user agent + assert.Contains(t, ua, "Mozilla/5.0") + assert.Contains(t, ua, "AppleWebKit/537.36") + assert.Contains(t, ua, "Chrome/124.0.0.0") + assert.Contains(t, ua, "Safari/537.36") + + // Should contain either Windows or Mac + isWindows := strings.Contains(ua, "Windows NT 11.0") + isMac := strings.Contains(ua, "Macintosh; Intel Mac OS X 14_0_0") + assert.True(t, isWindows || isMac, "User agent should contain either Windows or Mac identifier") + } + + // With enough iterations, we should see both variants + // though this might occasionally fail + if len(userAgents) == 1 { + t.Log("Note: Only one user agent variant was generated in this test run") + } +} + +// TestWebSocketClient_Close tests connection closing +func TestWebSocketClient_Close(t *testing.T) { + agent := createTestAgent(t) + + // Set up environment + os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + require.NoError(t, err) + + // Test closing with nil connection (should not panic) + assert.NotPanics(t, func() { + client.Close() + }) +} + +// TestWebSocketClient_ConnectRateLimit tests connection rate limiting +func TestWebSocketClient_ConnectRateLimit(t *testing.T) { + agent := createTestAgent(t) + + // Set up environment + os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + client, err := newWebSocketClient(agent) + require.NoError(t, err) + + // Set recent connection attempt + client.lastConnectAttempt = time.Now() + + // Test that connection fails quickly due to rate limiting + // This won't actually connect but should fail fast + err = client.Connect() + assert.Error(t, err, "Connection should fail but not hang") +} diff --git a/beszel/internal/hub/ws/ws.go b/beszel/internal/hub/ws/ws.go index e198739..b960be0 100644 --- a/beszel/internal/hub/ws/ws.go +++ b/beszel/internal/hub/ws/ws.go @@ -114,6 +114,9 @@ func (ws *WsConn) Ping() error { // sendMessage encodes data to CBOR and sends it as a binary message to the agent. func (ws *WsConn) sendMessage(data common.HubRequest[any]) error { + if ws.conn == nil { + return gws.ErrConnClosed + } bytes, err := cbor.Marshal(data) if err != nil { return err diff --git a/beszel/internal/hub/ws/ws_test.go b/beszel/internal/hub/ws/ws_test.go new file mode 100644 index 0000000..935ac7c --- /dev/null +++ b/beszel/internal/hub/ws/ws_test.go @@ -0,0 +1,221 @@ +//go:build testing +// +build testing + +package ws + +import ( + "beszel/internal/common" + "crypto/ed25519" + "testing" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +// TestGetUpgrader tests the singleton upgrader +func TestGetUpgrader(t *testing.T) { + // Reset the global upgrader to test singleton behavior + upgrader = nil + + // First call should create the upgrader + upgrader1 := GetUpgrader() + assert.NotNil(t, upgrader1, "Upgrader should not be nil") + + // Second call should return the same instance + upgrader2 := GetUpgrader() + assert.Same(t, upgrader1, upgrader2, "Should return the same upgrader instance") + + // Verify it's properly configured + assert.NotNil(t, upgrader1, "Upgrader should be configured") +} + +// TestNewWsConnection tests WebSocket connection creation +func TestNewWsConnection(t *testing.T) { + // We can't easily mock gws.Conn, so we'll pass nil and test the structure + wsConn := NewWsConnection(nil) + + assert.NotNil(t, wsConn, "WebSocket connection should not be nil") + assert.Nil(t, wsConn.conn, "Connection should be nil as passed") + assert.NotNil(t, wsConn.responseChan, "Response channel should be initialized") + assert.NotNil(t, wsConn.DownChan, "Down channel should be initialized") + assert.Equal(t, 1, cap(wsConn.responseChan), "Response channel should have capacity of 1") + assert.Equal(t, 1, cap(wsConn.DownChan), "Down channel should have capacity of 1") +} + +// TestWsConn_IsConnected tests the connection status check +func TestWsConn_IsConnected(t *testing.T) { + // Test with nil connection + wsConn := NewWsConnection(nil) + assert.False(t, wsConn.IsConnected(), "Should not be connected when conn is nil") +} + +// TestWsConn_Close tests the connection closing with nil connection +func TestWsConn_Close(t *testing.T) { + wsConn := NewWsConnection(nil) + + // Should handle nil connection gracefully + assert.NotPanics(t, func() { + wsConn.Close([]byte("test message")) + }, "Should not panic when closing nil connection") +} + +// TestWsConn_SendMessage_CBOR tests CBOR encoding in sendMessage +func TestWsConn_SendMessage_CBOR(t *testing.T) { + wsConn := NewWsConnection(nil) + + testData := common.HubRequest[any]{ + Action: common.GetData, + Data: "test data", + } + + // This will fail because conn is nil, but we can test the CBOR encoding logic + // by checking that the function properly encodes to CBOR before failing + err := wsConn.sendMessage(testData) + assert.Error(t, err, "Should error with nil connection") + + // Test CBOR encoding separately + bytes, err := cbor.Marshal(testData) + assert.NoError(t, err, "Should encode to CBOR successfully") + + // Verify we can decode it back + var decodedData common.HubRequest[any] + err = cbor.Unmarshal(bytes, &decodedData) + assert.NoError(t, err, "Should decode from CBOR successfully") + assert.Equal(t, testData.Action, decodedData.Action, "Action should match") +} + +// TestWsConn_GetFingerprint_SignatureGeneration tests signature creation logic +func TestWsConn_GetFingerprint_SignatureGeneration(t *testing.T) { + // Generate test key pair + _, privKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + signer, err := ssh.NewSignerFromKey(privKey) + require.NoError(t, err) + + token := "test-token" + + // This will timeout since conn is nil, but we can verify the signature logic + // We can't test the full flow, but we can test that the signature is created properly + challenge := []byte(token) + signature, err := signer.Sign(nil, challenge) + assert.NoError(t, err, "Should create signature successfully") + assert.NotEmpty(t, signature.Blob, "Signature blob should not be empty") + assert.Equal(t, signer.PublicKey().Type(), signature.Format, "Signature format should match key type") + + // Test the fingerprint request structure + fpRequest := common.FingerprintRequest{ + Signature: signature.Blob, + NeedSysInfo: true, + } + + // Test CBOR encoding of fingerprint request + fpData, err := cbor.Marshal(fpRequest) + assert.NoError(t, err, "Should encode fingerprint request to CBOR") + + var decodedFpRequest common.FingerprintRequest + err = cbor.Unmarshal(fpData, &decodedFpRequest) + assert.NoError(t, err, "Should decode fingerprint request from CBOR") + assert.Equal(t, fpRequest.Signature, decodedFpRequest.Signature, "Signature should match") + assert.Equal(t, fpRequest.NeedSysInfo, decodedFpRequest.NeedSysInfo, "NeedSysInfo should match") + + // Test the full hub request structure + hubRequest := common.HubRequest[any]{ + Action: common.CheckFingerprint, + Data: fpRequest, + } + + hubData, err := cbor.Marshal(hubRequest) + assert.NoError(t, err, "Should encode hub request to CBOR") + + var decodedHubRequest common.HubRequest[cbor.RawMessage] + err = cbor.Unmarshal(hubData, &decodedHubRequest) + assert.NoError(t, err, "Should decode hub request from CBOR") + assert.Equal(t, common.CheckFingerprint, decodedHubRequest.Action, "Action should be CheckFingerprint") +} + +// TestWsConn_RequestSystemData_RequestFormat tests system data request format +func TestWsConn_RequestSystemData_RequestFormat(t *testing.T) { + // Test the request format that would be sent + request := common.HubRequest[any]{ + Action: common.GetData, + } + + // Test CBOR encoding + data, err := cbor.Marshal(request) + assert.NoError(t, err, "Should encode request to CBOR") + + // Test decoding + var decodedRequest common.HubRequest[any] + err = cbor.Unmarshal(data, &decodedRequest) + assert.NoError(t, err, "Should decode request from CBOR") + assert.Equal(t, common.GetData, decodedRequest.Action, "Should have GetData action") +} + +// TestFingerprintRecord tests the FingerprintRecord struct +func TestFingerprintRecord(t *testing.T) { + record := FingerprintRecord{ + Id: "test-id", + SystemId: "system-123", + Fingerprint: "test-fingerprint", + Token: "test-token", + } + + assert.Equal(t, "test-id", record.Id) + assert.Equal(t, "system-123", record.SystemId) + assert.Equal(t, "test-fingerprint", record.Fingerprint) + assert.Equal(t, "test-token", record.Token) +} + +// TestDeadlineConstant tests that the deadline constant is reasonable +func TestDeadlineConstant(t *testing.T) { + assert.Equal(t, 70*time.Second, deadline, "Deadline should be 70 seconds") +} + +// TestCommonActions tests that the common actions are properly defined +func TestCommonActions(t *testing.T) { + // Test that the actions we use exist and have expected values + assert.Equal(t, common.WebSocketAction(0), common.GetData, "GetData should be action 0") + assert.Equal(t, common.WebSocketAction(1), common.CheckFingerprint, "CheckFingerprint should be action 1") +} + +// TestHandler tests that we can create a Handler +func TestHandler(t *testing.T) { + handler := &Handler{} + assert.NotNil(t, handler, "Handler should be created successfully") + + // The Handler embeds gws.BuiltinEventHandler, so it should have the embedded type + assert.NotNil(t, handler.BuiltinEventHandler, "Should have embedded BuiltinEventHandler") +} + +// TestWsConnChannelBehavior tests channel behavior without WebSocket connections +func TestWsConnChannelBehavior(t *testing.T) { + wsConn := NewWsConnection(nil) + + // Test that channels are properly initialized and can be used + select { + case wsConn.DownChan <- struct{}{}: + // Should be able to write to channel + default: + t.Error("Should be able to write to DownChan") + } + + // Test reading from DownChan + select { + case <-wsConn.DownChan: + // Should be able to read from channel + case <-time.After(10 * time.Millisecond): + t.Error("Should be able to read from DownChan") + } + + // Response channel should be empty initially + select { + case <-wsConn.responseChan: + t.Error("Response channel should be empty initially") + default: + // Expected - channel should be empty + } +}