diff --git a/beszel/internal/hub/agent_connect.go b/beszel/internal/hub/agent_connect.go index 11e75a1..b42978e 100644 --- a/beszel/internal/hub/agent_connect.go +++ b/beszel/internal/hub/agent_connect.go @@ -5,10 +5,10 @@ import ( "beszel/internal/hub/expirymap" "beszel/internal/hub/ws" "errors" - "fmt" "net" "net/http" "strings" + "sync" "time" "github.com/blang/semver" @@ -17,118 +17,97 @@ import ( "github.com/pocketbase/pocketbase/core" ) -// tokenMap maps tokens to user IDs for universal tokens -var tokenMap *expirymap.ExpiryMap[string] - +// agentConnectRequest holds information related to an agent's connection attempt. type agentConnectRequest struct { + hub *Hub + req *http.Request + res http.ResponseWriter token string agentSemVer semver.Version - // for universal token + // isUniversalToken is true if the token is a universal token. isUniversalToken bool - userId string - remoteAddr string + // userId is the user ID associated with the universal token. + userId string } -// validateAgentHeaders validates the required headers from agent connection requests. -func (h *Hub) validateAgentHeaders(headers http.Header) (string, string, error) { - token := headers.Get("X-Token") - agentVersion := headers.Get("X-Beszel") +// universalTokenMap stores active universal tokens and their associated user IDs. +var universalTokenMap tokenMap - if agentVersion == "" || token == "" || len(token) > 512 { - return "", "", errors.New("") - } - return token, agentVersion, nil +type tokenMap struct { + store *expirymap.ExpiryMap[string] + once sync.Once } -// getFingerprintRecord retrieves fingerprint data from the database by token. -func (h *Hub) getFingerprintRecord(token string, recordData *ws.FingerprintRecord) error { - err := h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}"). - Bind(dbx.Params{ - "token": token, - }). - One(recordData) - return err +// getMap returns the expirymap, creating it if necessary. +func (tm *tokenMap) GetMap() *expirymap.ExpiryMap[string] { + tm.once.Do(func() { + tm.store = expirymap.New[string](time.Hour) + }) + return tm.store } -// sendResponseError sends an HTTP error response with the given status code and message. -func sendResponseError(res http.ResponseWriter, code int, message string) error { - res.WriteHeader(code) - if message != "" { - res.Write([]byte(message)) - } - return nil -} - -// handleAgentConnect handles the incoming connection request from the agent. +// handleAgentConnect is the HTTP handler for an agent's connection request. func (h *Hub) handleAgentConnect(e *core.RequestEvent) error { - if err := h.agentConnect(e.Request, e.Response); err != nil { - return err - } + agentRequest := agentConnectRequest{req: e.Request, res: e.Response, hub: h} + _ = agentRequest.agentConnect() return nil } -// agentConnect handles agent connection requests, validating credentials and upgrading to WebSocket. -func (h *Hub) agentConnect(req *http.Request, res http.ResponseWriter) (err error) { - var agentConnectRequest agentConnectRequest +// agentConnect validates agent credentials and upgrades the connection to a WebSocket. +func (acr *agentConnectRequest) agentConnect() (err error) { var agentVersion string - // check if user agent and token are valid - agentConnectRequest.token, agentVersion, err = h.validateAgentHeaders(req.Header) + + acr.token, agentVersion, err = acr.validateAgentHeaders(acr.req.Header) if err != nil { - return sendResponseError(res, http.StatusUnauthorized, "") + return acr.sendResponseError(acr.res, http.StatusBadRequest, "") } - // Pull fingerprint from database matching token - var fpRecord ws.FingerprintRecord - err = h.getFingerprintRecord(agentConnectRequest.token, &fpRecord) + // Check if token is an active universal token + acr.userId, acr.isUniversalToken = universalTokenMap.GetMap().GetOk(acr.token) - // if no existing record, check if token is a universal token - if err != nil { - if err = checkUniversalToken(&agentConnectRequest); err == nil { - // if this is a universal token, set the remote address and new record token - agentConnectRequest.remoteAddr = getRealIP(req) - fpRecord.Token = agentConnectRequest.token - } - } - - // If no matching token, return unauthorized - if err != nil { - return sendResponseError(res, http.StatusUnauthorized, "Invalid token") + // Find matching fingerprint records for this token + fpRecords := getFingerprintRecordsByToken(acr.token, acr.hub) + if len(fpRecords) == 0 && !acr.isUniversalToken { + // Invalid token - no records found and not a universal token + return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid token") } // Validate agent version - agentConnectRequest.agentSemVer, err = semver.Parse(agentVersion) + acr.agentSemVer, err = semver.Parse(agentVersion) if err != nil { - return sendResponseError(res, http.StatusUnauthorized, "Invalid agent version") + return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid agent version") } // Upgrade connection to WebSocket - conn, err := ws.GetUpgrader().Upgrade(res, req) + conn, err := ws.GetUpgrader().Upgrade(acr.res, acr.req) if err != nil { - return sendResponseError(res, http.StatusInternalServerError, "WebSocket upgrade failed") + return acr.sendResponseError(acr.res, http.StatusInternalServerError, "WebSocket upgrade failed") } - go h.verifyWsConn(conn, agentConnectRequest, fpRecord) + go acr.verifyWsConn(conn, fpRecords) return nil } -// verifyWsConn verifies the WebSocket connection using agent's fingerprint and SSH key signature. -func (h *Hub) verifyWsConn(conn *gws.Conn, acr agentConnectRequest, fpRecord ws.FingerprintRecord) (err error) { +// verifyWsConn verifies the WebSocket connection using the agent's fingerprint and +// SSH key signature, then adds the system to the system manager. +func (acr *agentConnectRequest) verifyWsConn(conn *gws.Conn, fpRecords []ws.FingerprintRecord) (err error) { wsConn := ws.NewWsConnection(conn) - // must be set before the read loop + + // must set wsConn in connection store before the read loop conn.Session().Store("wsConn", wsConn) // make sure connection is closed if there is an error defer func() { if err != nil { wsConn.Close() - h.Logger().Error("WebSocket error", "error", err, "system", fpRecord.SystemId) + acr.hub.Logger().Error("WebSocket error", "error", err, "systems", fpRecords) } }() go conn.ReadLoop() - signer, err := h.GetSSHKey("") + signer, err := acr.hub.GetSSHKey("") if err != nil { return err } @@ -138,40 +117,152 @@ func (h *Hub) verifyWsConn(conn *gws.Conn, acr agentConnectRequest, fpRecord ws. return err } - // Create system if using universal token - if acr.isUniversalToken { - if acr.userId == "" { - return errors.New("token user not found") - } - fpRecord.SystemId, err = h.createSystemFromAgentData(&acr, agentFingerprint) - if err != nil { - return fmt.Errorf("failed to create system from universal token: %w", err) - } + // Find or create the appropriate system for this token and fingerprint + fpRecord, err := acr.findOrCreateSystemForToken(fpRecords, agentFingerprint) + if err != nil { + return err } - switch { - // If no current fingerprint, update with new fingerprint (first time connecting) - case fpRecord.Fingerprint == "": - if err := h.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil { - return err - } - // Abort if fingerprint exists but doesn't match (different machine) - case fpRecord.Fingerprint != agentFingerprint.Fingerprint: - return errors.New("fingerprint mismatch") - } - - return h.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn) + return acr.hub.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn) } -// createSystemFromAgentData creates a new system record using data from the agent -func (h *Hub) createSystemFromAgentData(acr *agentConnectRequest, agentFingerprint common.FingerprintResponse) (recordId string, err error) { - systemsCollection, err := h.FindCollectionByNameOrId("systems") - if err != nil { - return "", fmt.Errorf("failed to find systems collection: %w", err) +// validateAgentHeaders extracts and validates the token and agent version from HTTP headers. +func (acr *agentConnectRequest) validateAgentHeaders(headers http.Header) (string, string, error) { + token := headers.Get("X-Token") + agentVersion := headers.Get("X-Beszel") + + if agentVersion == "" || token == "" || len(token) > 64 { + return "", "", errors.New("") } + return token, agentVersion, nil +} + +// sendResponseError writes an HTTP error response. +func (acr *agentConnectRequest) sendResponseError(res http.ResponseWriter, code int, message string) error { + res.WriteHeader(code) + if message != "" { + res.Write([]byte(message)) + } + return nil +} + +// getFingerprintRecordsByToken retrieves all fingerprint records associated with a given token. +func getFingerprintRecordsByToken(token string, h *Hub) []ws.FingerprintRecord { + var records []ws.FingerprintRecord + // All will populate empty slice even on error + _ = h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}"). + Bind(dbx.Params{ + "token": token, + }). + All(&records) + return records +} + +// findOrCreateSystemForToken finds an existing system matching the token and fingerprint, +// or creates a new one for a universal token. +func (acr *agentConnectRequest) findOrCreateSystemForToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) { + // No records - only valid for active universal tokens + if len(fpRecords) == 0 { + return acr.handleNoRecords(agentFingerprint) + } + + // Single record - handle as regular token + if len(fpRecords) == 1 && !acr.isUniversalToken { + return acr.handleSingleRecord(fpRecords[0], agentFingerprint) + } + + // Multiple records or universal token - look for matching fingerprint + return acr.handleMultipleRecordsOrUniversalToken(fpRecords, agentFingerprint) +} + +// handleNoRecords handles the case where no fingerprint records are found for a token. +// A new system is created if the token is a valid universal token. +func (acr *agentConnectRequest) handleNoRecords(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) { + var fpRecord ws.FingerprintRecord + + if !acr.isUniversalToken || acr.userId == "" { + return fpRecord, errors.New("no matching fingerprints") + } + + return acr.createNewSystemForUniversalToken(agentFingerprint) +} + +// handleSingleRecord handles the case with a single fingerprint record. It validates +// the agent's fingerprint against the stored one, or sets it on first connect. +func (acr *agentConnectRequest) handleSingleRecord(fpRecord ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) { + // If no current fingerprint, update with new fingerprint (first time connecting) + if fpRecord.Fingerprint == "" { + if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil { + return fpRecord, err + } + // Update the record with the fingerprint that was set + fpRecord.Fingerprint = agentFingerprint.Fingerprint + return fpRecord, nil + } + + // Abort if fingerprint exists but doesn't match (different machine) + if fpRecord.Fingerprint != agentFingerprint.Fingerprint { + return fpRecord, errors.New("fingerprint mismatch") + } + + return fpRecord, nil +} + +// handleMultipleRecordsOrUniversalToken finds a matching fingerprint from multiple records. +// If no match is found and the token is a universal token, a new system is created. +func (acr *agentConnectRequest) handleMultipleRecordsOrUniversalToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) { + // Return existing record with matching fingerprint if found + for i := range fpRecords { + if fpRecords[i].Fingerprint == agentFingerprint.Fingerprint { + return fpRecords[i], nil + } + } + + // No matching fingerprint record found, but it's + // an active universal token so create a new system + if acr.isUniversalToken { + return acr.createNewSystemForUniversalToken(agentFingerprint) + } + + return ws.FingerprintRecord{}, errors.New("fingerprint mismatch") +} + +// createNewSystemForUniversalToken creates a new system and fingerprint record for a universal token. +func (acr *agentConnectRequest) createNewSystemForUniversalToken(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) { + var fpRecord ws.FingerprintRecord + if !acr.isUniversalToken || acr.userId == "" { + return fpRecord, errors.New("invalid token") + } + + fpRecord.Token = acr.token + + systemId, err := acr.createSystem(agentFingerprint) + if err != nil { + return fpRecord, err + } + fpRecord.SystemId = systemId + + // Set the fingerprint for the new system + if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil { + return fpRecord, err + } + + // Update the record with the fingerprint that was set + fpRecord.Fingerprint = agentFingerprint.Fingerprint + + return fpRecord, nil +} + +// createSystem creates a new system record in the database using details from the agent. +func (acr *agentConnectRequest) createSystem(agentFingerprint common.FingerprintResponse) (recordId string, err error) { + systemsCollection, err := acr.hub.FindCachedCollectionByNameOrId("systems") + if err != nil { + return "", err + } + remoteAddr := getRealIP(acr.req) // separate port from address if agentFingerprint.Hostname == "" { - agentFingerprint.Hostname = acr.remoteAddr + agentFingerprint.Hostname = remoteAddr } if agentFingerprint.Port == "" { agentFingerprint.Port = "45876" @@ -179,14 +270,14 @@ func (h *Hub) createSystemFromAgentData(acr *agentConnectRequest, agentFingerpri // create new record systemRecord := core.NewRecord(systemsCollection) systemRecord.Set("name", agentFingerprint.Hostname) - systemRecord.Set("host", acr.remoteAddr) + systemRecord.Set("host", remoteAddr) systemRecord.Set("port", agentFingerprint.Port) systemRecord.Set("users", []string{acr.userId}) - return systemRecord.Id, h.Save(systemRecord) + return systemRecord.Id, acr.hub.Save(systemRecord) } -// SetFingerprint updates the fingerprint for a given record ID. +// SetFingerprint creates or updates a fingerprint record in the database. func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) { // // can't use raw query here because it doesn't trigger SSE var record *core.Record @@ -207,25 +298,8 @@ func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) return h.SaveNoValidate(record) } -func getTokenMap() *expirymap.ExpiryMap[string] { - if tokenMap == nil { - tokenMap = expirymap.New[string](time.Hour) - } - return tokenMap -} - -func checkUniversalToken(acr *agentConnectRequest) (err error) { - if tokenMap == nil { - tokenMap = expirymap.New[string](time.Hour) - } - acr.userId, acr.isUniversalToken = tokenMap.GetOk(acr.token) - if !acr.isUniversalToken { - return errors.New("invalid token") - } - return nil -} - -// getRealIP attempts to extract the real IP address from the request headers. +// getRealIP extracts the client's real IP address from request headers, +// checking common proxy headers before falling back to the remote address. func getRealIP(r *http.Request) string { if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { return ip diff --git a/beszel/internal/hub/agent_connect_test.go b/beszel/internal/hub/agent_connect_test.go index f01d13f..b84bb8a 100644 --- a/beszel/internal/hub/agent_connect_test.go +++ b/beszel/internal/hub/agent_connect_test.go @@ -6,7 +6,6 @@ package hub import ( "beszel/internal/agent" "beszel/internal/common" - "beszel/internal/hub/expirymap" "beszel/internal/hub/ws" "crypto/ed25519" "fmt" @@ -14,6 +13,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" @@ -115,7 +115,7 @@ func TestValidateAgentHeaders(t *testing.T) { { name: "token too long", headers: http.Header{ - "X-Token": []string{string(make([]byte, 513))}, // 513 bytes > 512 limit + "X-Token": []string{strings.Repeat("a", 65)}, "X-Beszel": []string{"0.5.0"}, }, expectError: true, @@ -124,7 +124,8 @@ func TestValidateAgentHeaders(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - token, agentVersion, err := hub.validateAgentHeaders(tc.headers) + acr := &agentConnectRequest{hub: hub} + token, agentVersion, err := acr.validateAgentHeaders(tc.headers) if tc.expectError { assert.Error(t, err) @@ -137,8 +138,8 @@ func TestValidateAgentHeaders(t *testing.T) { } } -// TestGetFingerprintRecord tests the getFingerprintRecord function -func TestGetFingerprintRecord(t *testing.T) { +// TestGetAllFingerprintRecordsByToken tests the getAllFingerprintRecordsByToken function +func TestGetAllFingerprintRecordsByToken(t *testing.T) { hub, testApp, err := createTestHub(t) if err != nil { t.Fatal(err) @@ -168,44 +169,60 @@ func TestGetFingerprintRecord(t *testing.T) { "token": "test-token-123", "fingerprint": "test-fingerprint", }) + for i := range 3 { + systemRecord, _ := createTestRecord(testApp, "systems", map[string]any{ + "name": fmt.Sprintf("test-system-%d", i), + "host": "localhost", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "duplicate-token", + "fingerprint": fmt.Sprintf("test-fingerprint-%d", i), + }) + } if err != nil { t.Fatal(err) } testCases := []struct { - name string - token string - expectError bool - expectedId string + name string + token string + expectedId string + expectLen int }{ { - name: "valid token", - token: "test-token-123", - expectError: false, - expectedId: fingerprintRecord.Id, + name: "valid token", + token: "test-token-123", + expectLen: 1, + expectedId: fingerprintRecord.Id, }, { - name: "invalid token", - token: "invalid-token", - expectError: true, + name: "invalid token", + token: "invalid-token", + expectLen: 0, }, { - name: "empty token", - token: "", - expectError: true, + name: "empty token", + token: "", + expectLen: 0, + }, + { + name: "duplicate token", + token: "duplicate-token", + expectLen: 3, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var recordData ws.FingerprintRecord - err := hub.getFingerprintRecord(tc.token, &recordData) + records := getFingerprintRecordsByToken(tc.token, hub) - if tc.expectError { - assert.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tc.expectedId, recordData.Id) + require.Len(t, records, tc.expectLen) + if tc.expectedId != "" { + assert.Equal(t, tc.expectedId, records[0].Id) } }) } @@ -318,8 +335,11 @@ func TestCreateSystemFromAgentData(t *testing.T) { { name: "successful system creation with all fields", agentConnReq: agentConnectRequest{ - userId: userRecord.Id, - remoteAddr: "192.168.1.100", + hub: hub, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.0.1", + }, }, fingerprint: common.FingerprintResponse{ Hostname: "test-server", @@ -327,15 +347,18 @@ func TestCreateSystemFromAgentData(t *testing.T) { }, expectError: false, expectedName: "test-server", - expectedHost: "192.168.1.100", + expectedHost: "192.168.0.1", // This will be the parsed IP from the mock request expectedPort: "8080", expectedUsers: []string{userRecord.Id}, }, { name: "system creation with default port", agentConnReq: agentConnectRequest{ - userId: userRecord.Id, - remoteAddr: "10.0.0.50", + hub: hub, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.0.1", + }, }, fingerprint: common.FingerprintResponse{ Hostname: "default-port-server", @@ -343,23 +366,26 @@ func TestCreateSystemFromAgentData(t *testing.T) { }, expectError: false, expectedName: "default-port-server", - expectedHost: "10.0.0.50", + expectedHost: "192.168.0.1", // This will be the parsed IP from the mock request expectedPort: "45876", expectedUsers: []string{userRecord.Id}, }, { name: "system creation with empty hostname", agentConnReq: agentConnectRequest{ - userId: userRecord.Id, - remoteAddr: "172.16.0.1", + hub: hub, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.0.1", + }, }, fingerprint: common.FingerprintResponse{ Hostname: "", Port: "9090", }, expectError: false, - expectedName: "172.16.0.1", // Should fall back to host IP when hostname is empty - expectedHost: "172.16.0.1", + expectedName: "192.168.0.1", // Should fall back to host IP when hostname is empty + expectedHost: "192.168.0.1", // This will be the parsed IP from the mock request expectedPort: "9090", expectedUsers: []string{userRecord.Id}, }, @@ -367,7 +393,7 @@ func TestCreateSystemFromAgentData(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - recordId, err := hub.createSystemFromAgentData(&tc.agentConnReq, tc.fingerprint) + recordId, err := tc.agentConnReq.createSystem(tc.fingerprint) if tc.expectError { assert.Error(t, err) @@ -409,11 +435,7 @@ func TestUniversalTokenFlow(t *testing.T) { // Set up universal token in the token map universalToken := "universal-token-123" - // Initialize tokenMap if it doesn't exist - if tokenMap == nil { - tokenMap = expirymap.New[string](time.Hour) - } - tokenMap.Set(universalToken, userRecord.Id, time.Hour) + universalTokenMap.GetMap().Set(universalToken, userRecord.Id, time.Hour) testCases := []struct { name string @@ -447,17 +469,14 @@ func TestUniversalTokenFlow(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var acr agentConnectRequest - acr.token = tc.token + acr := &agentConnectRequest{} - err := checkUniversalToken(&acr) + acr.userId, acr.isUniversalToken = universalTokenMap.GetMap().GetOk(tc.token) if tc.expectError { - assert.Error(t, err) assert.False(t, acr.isUniversalToken) assert.Empty(t, acr.userId) } else { - require.NoError(t, err) assert.Equal(t, tc.expectUniversalAuth, acr.isUniversalToken) if tc.expectUniversalAuth { assert.Equal(t, userRecord.Id, acr.userId) @@ -467,85 +486,6 @@ func TestUniversalTokenFlow(t *testing.T) { } } -// TestAgentDataProtection tests that agent won't send system data before fingerprint verification -func TestAgentDataProtection(t *testing.T) { - // This test verifies the logic in the agent's handleHubRequest method - // Since we can't access private fields directly, we'll test the behavior indirectly - // by creating a mock scenario that simulates the verification flow - - // The key behavior is tested in the agent's handleHubRequest method: - // if !client.hubVerified && msg.Action != common.CheckFingerprint { - // return errors.New("hub not verified") - // } - - // This test documents the expected behavior rather than testing implementation details - t.Run("agent should reject GetData before fingerprint verification", func(t *testing.T) { - // This behavior is enforced by the agent's WebSocket client - // When hubVerified is false and action is GetData, it returns "hub not verified" error - assert.True(t, true, "Agent rejects GetData requests before hub verification") - }) - - t.Run("agent should allow CheckFingerprint before verification", func(t *testing.T) { - // CheckFingerprint action is always allowed regardless of hubVerified status - assert.True(t, true, "Agent allows CheckFingerprint requests before hub verification") - }) -} - -// TestFingerprintResponseFields tests that FingerprintResponse includes hostname and port when requested -func TestFingerprintResponseFields(t *testing.T) { - testCases := []struct { - name string - includeSysInfo bool - expectHostname bool - expectPort bool - description string - }{ - { - name: "include system info", - includeSysInfo: true, - expectHostname: true, - expectPort: true, - description: "Should include hostname and port when requested", - }, - { - name: "exclude system info", - includeSysInfo: false, - expectHostname: false, - expectPort: false, - description: "Should not include hostname and port when not requested", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test the response creation logic as it would be used in the agent - response := &common.FingerprintResponse{ - Fingerprint: "test-fingerprint", - } - - if tc.includeSysInfo { - response.Hostname = "test-hostname" - response.Port = "8080" - } - - // Verify the response structure - assert.NotEmpty(t, response.Fingerprint, "Fingerprint should always be present") - - if tc.expectHostname { - assert.NotEmpty(t, response.Hostname, "Hostname should be present when requested") - } else { - assert.Empty(t, response.Hostname, "Hostname should be empty when not requested") - } - - if tc.expectPort { - assert.NotEmpty(t, response.Port, "Port should be present when requested") - } else { - assert.Empty(t, response.Port, "Port should be empty when not requested") - } - }) - } -} - // TestAgentConnect tests the agentConnect function with various scenarios func TestAgentConnect(t *testing.T) { hub, testApp, err := createTestHub(t) @@ -588,22 +528,25 @@ func TestAgentConnect(t *testing.T) { headers map[string]string expectedStatus int description string + errorMessage string }{ { name: "missing token header", headers: map[string]string{ "X-Beszel": "0.5.0", }, - expectedStatus: http.StatusUnauthorized, + expectedStatus: http.StatusBadRequest, description: "Should fail due to missing token", + errorMessage: "", }, { name: "missing agent version header", headers: map[string]string{ "X-Token": testToken, }, - expectedStatus: http.StatusUnauthorized, + expectedStatus: http.StatusBadRequest, description: "Should fail due to missing agent version", + errorMessage: "", }, { name: "invalid token", @@ -613,6 +556,7 @@ func TestAgentConnect(t *testing.T) { }, expectedStatus: http.StatusUnauthorized, description: "Should fail due to invalid token", + errorMessage: "Invalid token", }, { name: "invalid agent version", @@ -622,6 +566,7 @@ func TestAgentConnect(t *testing.T) { }, expectedStatus: http.StatusUnauthorized, description: "Should fail due to invalid agent version", + errorMessage: "Invalid agent version", }, { name: "valid headers but websocket upgrade will fail in test", @@ -631,6 +576,14 @@ func TestAgentConnect(t *testing.T) { }, expectedStatus: http.StatusInternalServerError, description: "Should pass validation but fail at WebSocket upgrade due to test limitations", + errorMessage: "WebSocket upgrade failed", + }, + { + name: "Token too long", + headers: map[string]string{"X-Token": strings.Repeat("a", 65), "X-Beszel": "0.5.0"}, + expectedStatus: http.StatusBadRequest, + description: "Should reject token exceeding 64 characters", + errorMessage: "", }, } @@ -642,9 +595,15 @@ func TestAgentConnect(t *testing.T) { } recorder := httptest.NewRecorder() - err = hub.agentConnect(req, recorder) + acr := &agentConnectRequest{ + hub: hub, + req: req, + res: recorder, + } + err = acr.agentConnect() assert.Equal(t, tc.expectedStatus, recorder.Code, tc.description) + assert.Equal(t, tc.errorMessage, recorder.Body.String(), tc.description) }) } } @@ -677,7 +636,8 @@ func TestSendResponseError(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() - sendResponseError(recorder, tc.statusCode, tc.message) + acr := &agentConnectRequest{} + acr.sendResponseError(recorder, tc.statusCode, tc.message) assert.Equal(t, tc.expectedStatus, recorder.Code) assert.Equal(t, tc.expectedBody, recorder.Body.String()) @@ -759,7 +719,12 @@ func TestHandleAgentConnect(t *testing.T) { } recorder := httptest.NewRecorder() - err = hub.agentConnect(req, recorder) + acr := &agentConnectRequest{ + hub: hub, + req: req, + res: recorder, + } + err = acr.agentConnect() assert.Equal(t, tc.expectedStatus, recorder.Code, tc.description) }) @@ -773,25 +738,30 @@ func TestAgentWebSocketIntegration(t *testing.T) { require.NoError(t, err) defer testApp.Cleanup() - // Get the hub's SSH key using the proper method + // Get the hub's SSH key hubSigner, err := hub.GetSSHKey("") require.NoError(t, err) goodPubKey := hubSigner.PublicKey() - // Generate WRONG key pair (should be rejected) + // Generate bad key pair (should be rejected) _, badPrivKey, err := ed25519.GenerateKey(nil) require.NoError(t, err) badPubKey, err := ssh.NewPublicKey(badPrivKey.Public().(ed25519.PublicKey)) require.NoError(t, err) - // Create test user once + // Create test user userRecord, err := createTestUser(testApp) require.NoError(t, err) // Create HTTP server with the actual API route ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/beszel/agent-connect" { - hub.agentConnect(r, w) + acr := &agentConnectRequest{ + hub: hub, + req: r, + res: w, + } + acr.agentConnect() } else { http.NotFound(w, r) } @@ -926,10 +896,10 @@ func TestAgentWebSocketIntegration(t *testing.T) { // Wait for connection result maxWait := 2 * time.Second - checkInterval := 100 * time.Millisecond + time.Sleep(20 * time.Millisecond) + checkInterval := 20 * time.Millisecond timeout := time.After(maxWait) - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() + ticker := time.Tick(checkInterval) connectionManager := testAgent.GetConnectionManager() @@ -944,7 +914,7 @@ func TestAgentWebSocketIntegration(t *testing.T) { t.Logf("Connection properly rejected (timeout) - agent state: %d", connectionManager.State) } connectionResult = false - case <-ticker.C: + case <-ticker: if connectionManager.State == agent.WebSocketConnected { if tc.expectConnection { t.Logf("WebSocket connection successful - agent state: %d", connectionManager.State) @@ -999,3 +969,732 @@ func TestAgentWebSocketIntegration(t *testing.T) { }) } } + +// TestMultipleSystemsWithSameUniversalToken tests that multiple systems can share the same universal token +func TestMultipleSystemsWithSameUniversalToken(t *testing.T) { + // Create hub and test app + hub, testApp, err := createTestHub(t) + require.NoError(t, err) + defer testApp.Cleanup() + + // Get the hub's SSH key + hubSigner, err := hub.GetSSHKey("") + require.NoError(t, err) + goodPubKey := hubSigner.PublicKey() + + // Create test user + userRecord, err := createTestUser(testApp) + require.NoError(t, err) + + // Set up universal token in the token map + universalToken := "shared-universal-token-123" + universalTokenMap.GetMap().Set(universalToken, userRecord.Id, time.Hour) + + // Create HTTP server with the actual API route + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/beszel/agent-connect" { + acr := &agentConnectRequest{ + hub: hub, + req: r, + res: w, + } + acr.agentConnect() + } else { + http.NotFound(w, r) + } + })) + defer ts.Close() + + // Test scenarios for universal tokens + testCases := []struct { + name string + agentFingerprint string + expectConnection bool + expectSystemStatus string + expectNewSystem bool // Whether we expect a new system to be created + description string + }{ + { + name: "first system with universal token", + agentFingerprint: "system-1-fingerprint", + expectConnection: true, + expectSystemStatus: "up", + expectNewSystem: true, + description: "First system should create a new system", + }, + { + name: "same system reconnecting with same fingerprint", + agentFingerprint: "system-1-fingerprint", // Same fingerprint as first + expectConnection: true, + expectSystemStatus: "up", + expectNewSystem: false, // Should reuse existing system + description: "Same system should reuse existing system record", + }, + { + name: "different system with same universal token", + agentFingerprint: "system-2-fingerprint", // Different fingerprint + expectConnection: true, + expectSystemStatus: "up", + expectNewSystem: true, // Should create new system + description: "Different system should create a new system record", + }, + } + + var systemCount int + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create unique port for each test + portNum := 46000 + i + + // Create and configure agent + agentDataDir := t.TempDir() + + // Set up agent fingerprint + err = os.WriteFile(filepath.Join(agentDataDir, "fingerprint"), []byte(tc.agentFingerprint), 0644) + require.NoError(t, err) + + testAgent, err := agent.NewAgent(agentDataDir) + require.NoError(t, err) + + // Set up environment variables for the agent + os.Setenv("BESZEL_AGENT_HUB_URL", ts.URL) + os.Setenv("BESZEL_AGENT_TOKEN", universalToken) + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + // Count systems before connection + systemsBefore, err := testApp.FindRecordsByFilter("systems", "users ~ {:userId}", "", -1, 0, map[string]any{"userId": userRecord.Id}) + require.NoError(t, err) + systemsBeforeCount := len(systemsBefore) + + // Start agent in background + done := make(chan error, 1) + go func() { + serverOptions := agent.ServerOptions{ + Network: "tcp", + Addr: fmt.Sprintf("127.0.0.1:%d", portNum), + Keys: []ssh.PublicKey{goodPubKey}, + } + done <- testAgent.Start(serverOptions) + }() + + // Wait for connection result + maxWait := 2 * time.Second + time.Sleep(20 * time.Millisecond) + checkInterval := 20 * time.Millisecond + timeout := time.After(maxWait) + ticker := time.Tick(checkInterval) + + connectionManager := testAgent.GetConnectionManager() + connectionResult := false + + for { + select { + case <-timeout: + if tc.expectConnection { + t.Fatalf("Expected connection to succeed but timed out - agent state: %d", connectionManager.State) + } else { + t.Logf("Connection properly rejected (timeout) - agent state: %d", connectionManager.State) + } + connectionResult = false + case <-ticker: + if connectionManager.State == agent.WebSocketConnected { + if tc.expectConnection { + t.Logf("WebSocket connection successful - agent state: %d", connectionManager.State) + connectionResult = true + } else { + t.Errorf("Unexpected: Connection succeeded when it should have been rejected") + return + } + } + case err := <-done: + if err != nil { + if !tc.expectConnection { + t.Logf("Agent connection properly rejected: %v", err) + connectionResult = false + } else { + t.Fatalf("Agent failed to start: %v", err) + } + } + } + + if connectionResult == tc.expectConnection || connectionResult { + break + } + } + + // Verify system creation/reuse behavior + if tc.expectConnection { + // Count systems after connection + systemsAfter, err := testApp.FindRecordsByFilter("systems", "users ~ {:userId}", "", -1, 0, map[string]any{"userId": userRecord.Id}) + require.NoError(t, err) + systemsAfterCount := len(systemsAfter) + + if tc.expectNewSystem { + // Should have created a new system + systemCount++ + assert.Equal(t, systemsBeforeCount+1, systemsAfterCount, "Should have created a new system") + assert.Equal(t, systemCount, systemsAfterCount, "Total system count should match expected") + } else { + // Should have reused existing system + assert.Equal(t, systemsBeforeCount, systemsAfterCount, "Should not have created a new system") + assert.Equal(t, systemCount, systemsAfterCount, "Total system count should remain the same") + } + + // Verify that a fingerprint record exists for this fingerprint + fingerprints, err := testApp.FindRecordsByFilter("fingerprints", "token = {:token} && fingerprint = {:fingerprint}", "", -1, 0, map[string]any{ + "token": universalToken, + "fingerprint": tc.agentFingerprint, + }) + require.NoError(t, err) + require.Len(t, fingerprints, 1, "Should have exactly one fingerprint record for this token+fingerprint combination") + + fingerprint := fingerprints[0] + assert.Equal(t, universalToken, fingerprint.GetString("token"), "Fingerprint should have the universal token") + assert.Equal(t, tc.agentFingerprint, fingerprint.GetString("fingerprint"), "Fingerprint should match agent's fingerprint") + + // Verify system status + systemId := fingerprint.GetString("system") + system, err := testApp.FindRecordById("systems", systemId) + require.NoError(t, err) + status := system.GetString("status") + assert.Equal(t, tc.expectSystemStatus, status, "System status should match expected value") + + t.Logf("%s - System ID: %s, Status: %s, New System: %v", tc.description, systemId, status, tc.expectNewSystem) + } + }) + } +} + +// TestFindOrCreateSystemForToken tests the findOrCreateSystemForToken function +func TestFindOrCreateSystemForToken(t *testing.T) { + hub, testApp, err := createTestHub(t) + require.NoError(t, err) + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + require.NoError(t, err) + + type testCase struct { + name string + setup func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) + agentFingerprint common.FingerprintResponse + expectError bool + expectNewSystem bool + expectedFingerprint string + description string + } + + testCases := []testCase{ + { + name: "universal token - existing fingerprint match", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "existing-system", + "host": "192.168.1.100", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint record + fpRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "universal-token-123", + "fingerprint": "existing-fingerprint", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "universal-token-123", + isUniversalToken: true, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.1.100", + }, + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord.Id, + SystemId: systemRecord.Id, + Fingerprint: "existing-fingerprint", + Token: "universal-token-123", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "existing-fingerprint", + Hostname: "test-host", + Port: "8080", + }, + expectError: false, + expectNewSystem: false, + expectedFingerprint: "existing-fingerprint", + description: "Should reuse existing system with matching fingerprint", + }, + { + name: "universal token - new fingerprint", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "existing-system-2", + "host": "192.168.1.101", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint record + fpRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "universal-token-123", + "fingerprint": "existing-fingerprint", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "universal-token-123", + isUniversalToken: true, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.1.200", + }, + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord.Id, + SystemId: systemRecord.Id, + Fingerprint: "existing-fingerprint", + Token: "universal-token-123", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "new-fingerprint", + Hostname: "new-host", + Port: "9090", + }, + expectError: false, + expectNewSystem: true, + expectedFingerprint: "new-fingerprint", + description: "Should create new system with different fingerprint", + }, + { + name: "universal token - no existing records", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + acr := agentConnectRequest{ + hub: hub, + token: "universal-token-456", + isUniversalToken: true, + userId: userRecord.Id, + req: &http.Request{ + RemoteAddr: "192.168.1.300", + }, + } + + fpRecords := []ws.FingerprintRecord{} + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "first-fingerprint", + Hostname: "first-host", + Port: "7070", + }, + expectError: false, + expectNewSystem: true, + expectedFingerprint: "first-fingerprint", + description: "Should create new system when no existing records", + }, + { + name: "regular token - empty fingerprint", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "regular-system", + "host": "192.168.1.200", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint record with empty fingerprint + fpRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "regular-token-123", + "fingerprint": "", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "regular-token-123", + isUniversalToken: false, + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord.Id, + SystemId: systemRecord.Id, + Fingerprint: "", + Token: "regular-token-123", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "agent-fingerprint", + Hostname: "agent-host", + Port: "6060", + }, + expectError: false, + expectNewSystem: false, + expectedFingerprint: "agent-fingerprint", + description: "Should update empty fingerprint for regular token", + }, + { + name: "regular token - fingerprint mismatch", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "regular-system-2", + "host": "192.168.1.250", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint record with different fingerprint + fpRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "regular-token-456", + "fingerprint": "different-fingerprint", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "regular-token-456", + isUniversalToken: false, + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord.Id, + SystemId: systemRecord.Id, + Fingerprint: "different-fingerprint", + Token: "regular-token-456", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "agent-fingerprint", + Hostname: "agent-host", + Port: "5050", + }, + expectError: true, + description: "Should reject fingerprint mismatch for regular token", + }, + { + name: "universal token - missing user ID", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + acr := agentConnectRequest{ + hub: hub, + token: "universal-token-789", + isUniversalToken: true, + userId: "", // Missing user ID + req: &http.Request{ + RemoteAddr: "192.168.1.400", + }, + } + + fpRecords := []ws.FingerprintRecord{} + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "some-fingerprint", + Hostname: "some-host", + Port: "4040", + }, + expectError: true, + description: "Should reject universal token without user ID", + }, + { + name: "expired universal token - matching fingerprint", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test systems + systemRecord1, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "expired-system-1", + "host": "192.168.1.500", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + systemRecord2, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "expired-system-2", + "host": "192.168.1.501", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint records + fpRecord1, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord1.Id, + "token": "expired-universal-token-123", + "fingerprint": "expired-fingerprint-1", + }) + require.NoError(t, err) + + fpRecord2, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord2.Id, + "token": "expired-universal-token-123", + "fingerprint": "expired-fingerprint-2", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "expired-universal-token-123", + isUniversalToken: false, // Token is no longer active + userId: "", // No user ID since token is expired + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord1.Id, + SystemId: systemRecord1.Id, + Fingerprint: "expired-fingerprint-1", + Token: "expired-universal-token-123", + }, + { + Id: fpRecord2.Id, + SystemId: systemRecord2.Id, + Fingerprint: "expired-fingerprint-2", + Token: "expired-universal-token-123", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "expired-fingerprint-1", // Matches first record + Hostname: "expired-host", + Port: "3030", + }, + expectError: false, + expectNewSystem: false, + expectedFingerprint: "expired-fingerprint-1", + description: "Should allow connection with expired universal token if fingerprint matches", + }, + { + name: "expired universal token - no matching fingerprint", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "expired-system-3", + "host": "192.168.1.600", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Create fingerprint record + fpRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "expired-universal-token-456", + "fingerprint": "expired-fingerprint-3", + }) + require.NoError(t, err) + + acr := agentConnectRequest{ + hub: hub, + token: "expired-universal-token-456", + isUniversalToken: false, // Token is no longer active + userId: "", // No user ID since token is expired + req: &http.Request{ + RemoteAddr: "192.168.1.600", + }, + } + + fpRecords := []ws.FingerprintRecord{ + { + Id: fpRecord.Id, + SystemId: systemRecord.Id, + Fingerprint: "expired-fingerprint-3", + Token: "expired-universal-token-456", + }, + } + + return acr, fpRecords + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "different-fingerprint", // Doesn't match any existing record + Hostname: "different-host", + Port: "2020", + }, + expectError: true, + description: "Should reject connection with expired universal token if no fingerprint matches", + }, + { + name: "regular token - no existing records", + setup: func(t *testing.T, hub *Hub, testApp *pbtests.TestApp, userRecord *core.Record) (agentConnectRequest, []ws.FingerprintRecord) { + acr := agentConnectRequest{ + hub: hub, + token: "regular-token-no-record", + isUniversalToken: false, + } + return acr, []ws.FingerprintRecord{} + }, + agentFingerprint: common.FingerprintResponse{ + Fingerprint: "some-fingerprint", + }, + expectError: true, + description: "Should reject regular token with no fingerprint record", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + acr, fpRecords := tc.setup(t, hub, testApp, userRecord) + result, err := acr.findOrCreateSystemForToken(fpRecords, tc.agentFingerprint) + + if tc.expectError { + assert.Error(t, err, tc.description) + return + } + + require.NoError(t, err, tc.description) + + // Verify expected fingerprint + if tc.expectedFingerprint != "" { + assert.Equal(t, tc.expectedFingerprint, result.Fingerprint, "Fingerprint should match expected") + } + + // For new systems, verify they were actually created + if tc.expectNewSystem { + assert.NotEmpty(t, result.SystemId, "New system should have a system ID") + + // Verify system was created in database + system, err := testApp.FindRecordById("systems", result.SystemId) + require.NoError(t, err, "New system should exist in database") + + // Verify system properties + assert.Equal(t, tc.agentFingerprint.Hostname, system.GetString("name"), "System name should match hostname") + assert.Equal(t, getRealIP(acr.req), system.GetString("host"), "System host should match remote address") + assert.Equal(t, tc.agentFingerprint.Port, system.GetString("port"), "System port should match agent port") + assert.Equal(t, []string{acr.userId}, system.Get("users"), "System users should match") + } + + t.Logf("%s - Result: SystemId=%s, Fingerprint=%s", tc.description, result.SystemId, result.Fingerprint) + }) + } +} + +// TestGetRealIP tests the getRealIP function +func TestGetRealIP(t *testing.T) { + testCases := []struct { + name string + headers map[string]string + remoteAddr string + expectedIP string + }{ + { + name: "CF-Connecting-IP header", + headers: map[string]string{"CF-Connecting-IP": "192.168.1.1"}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.1", + }, + { + name: "X-Forwarded-For header with single IP", + headers: map[string]string{"X-Forwarded-For": "192.168.1.2"}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.2", + }, + { + name: "X-Forwarded-For header with multiple IPs", + headers: map[string]string{"X-Forwarded-For": "192.168.1.3, 10.0.0.1, 172.16.0.1"}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.3", + }, + { + name: "X-Forwarded-For header with spaces", + headers: map[string]string{"X-Forwarded-For": " 192.168.1.4 "}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.4", + }, + { + name: "No headers, fallback to RemoteAddr with port", + headers: map[string]string{}, + remoteAddr: "192.168.1.5:54321", + expectedIP: "192.168.1.5", + }, + { + name: "No headers, fallback to RemoteAddr without port", + headers: map[string]string{}, + remoteAddr: "192.168.1.6", + expectedIP: "192.168.1.6", + }, + { + name: "Both headers present, CF takes precedence", + headers: map[string]string{"CF-Connecting-IP": "192.168.1.1", "X-Forwarded-For": "192.168.1.2"}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.1", + }, + { + name: "X-Forwarded-For present, takes precedence over RemoteAddr", + headers: map[string]string{"X-Forwarded-For": "192.168.1.2"}, + remoteAddr: "192.168.1.5:54321", + expectedIP: "192.168.1.2", + }, + { + name: "Empty X-Forwarded-For, fallback to RemoteAddr", + headers: map[string]string{"X-Forwarded-For": ""}, + remoteAddr: "192.168.1.7:12345", + expectedIP: "192.168.1.7", + }, + { + name: "Empty CF-Connecting-IP, fallback to X-Forwarded-For", + headers: map[string]string{"CF-Connecting-IP": "", "X-Forwarded-For": "192.168.1.8"}, + remoteAddr: "127.0.0.1:12345", + expectedIP: "192.168.1.8", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + for key, value := range tc.headers { + req.Header.Set(key, value) + } + req.RemoteAddr = tc.remoteAddr + + ip := getRealIP(req) + assert.Equal(t, tc.expectedIP, ip) + }) + } +} diff --git a/beszel/internal/hub/hub.go b/beszel/internal/hub/hub.go index a425f4f..9ad4942 100644 --- a/beszel/internal/hub/hub.go +++ b/beszel/internal/hub/hub.go @@ -259,7 +259,7 @@ func (h *Hub) getUniversalToken(e *core.RequestEvent) error { return apis.NewForbiddenError("Forbidden", nil) } - tokenMap := getTokenMap() + tokenMap := universalTokenMap.GetMap() userID := info.Auth.Id query := e.Request.URL.Query() token := query.Get("token") diff --git a/beszel/internal/hub/hub_test.go b/beszel/internal/hub/hub_test.go index f2439bd..02618ce 100644 --- a/beszel/internal/hub/hub_test.go +++ b/beszel/internal/hub/hub_test.go @@ -254,5 +254,3 @@ func TestGetSSHKey(t *testing.T) { } }) } - -// Helper function to create test records diff --git a/beszel/internal/hub/ws/ws.go b/beszel/internal/hub/ws/ws.go index de05f9d..cdb7216 100644 --- a/beszel/internal/hub/ws/ws.go +++ b/beszel/internal/hub/ws/ws.go @@ -140,11 +140,12 @@ func (ws *WsConn) RequestSystemData(data *system.CombinedData) error { // GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint. func (ws *WsConn) GetFingerprint(token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) { + var clientFingerprint common.FingerprintResponse challenge := []byte(token) signature, err := signer.Sign(nil, challenge) if err != nil { - return common.FingerprintResponse{}, err + return clientFingerprint, err } err = ws.sendMessage(common.HubRequest[any]{ @@ -155,24 +156,19 @@ func (ws *WsConn) GetFingerprint(token string, signer ssh.Signer, needSysInfo bo }, }) if err != nil { - return common.FingerprintResponse{}, err + return clientFingerprint, err } var message *gws.Message - var clientFingerprint common.FingerprintResponse select { case message = <-ws.responseChan: case <-time.After(10 * time.Second): - return common.FingerprintResponse{}, errors.New("request expired") + return clientFingerprint, errors.New("request expired") } defer message.Close() err = cbor.Unmarshal(message.Data.Bytes(), &clientFingerprint) - if err != nil { - return common.FingerprintResponse{}, err - } - - return clientFingerprint, nil + return clientFingerprint, err } // IsConnected returns true if the WebSocket connection is active.