diff --git a/beszel/internal/agent/agent.go b/beszel/internal/agent/agent.go index 80993e6..688a036 100644 --- a/beszel/internal/agent/agent.go +++ b/beszel/internal/agent/agent.go @@ -4,34 +4,55 @@ package agent import ( "beszel" "beszel/internal/entities/system" + "crypto/sha256" + "encoding/hex" "log/slog" "os" + "path/filepath" "strings" "sync" "time" + + "github.com/gliderlabs/ssh" + "github.com/shirou/gopsutil/v4/host" + gossh "golang.org/x/crypto/ssh" ) type Agent struct { - sync.Mutex // Used to lock agent while collecting data - debug bool // true if LOG_LEVEL is set to debug - zfs bool // true if system has arcstats - memCalc string // Memory calculation formula - fsNames []string // List of filesystem device names being monitored - fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem - netInterfaces map[string]struct{} // Stores all valid network interfaces - netIoStats system.NetIoStats // Keeps track of bandwidth usage - dockerManager *dockerManager // Manages Docker API requests - sensorConfig *SensorConfig // Sensors config - systemInfo system.Info // Host system info - gpuManager *GPUManager // Manages GPU data - cache *SessionCache // Cache for system stats based on primary session ID + sync.Mutex // Used to lock agent while collecting data + debug bool // true if LOG_LEVEL is set to debug + zfs bool // true if system has arcstats + memCalc string // Memory calculation formula + fsNames []string // List of filesystem device names being monitored + fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem + netInterfaces map[string]struct{} // Stores all valid network interfaces + netIoStats system.NetIoStats // Keeps track of bandwidth usage + dockerManager *dockerManager // Manages Docker API requests + sensorConfig *SensorConfig // Sensors config + systemInfo system.Info // Host system info + gpuManager *GPUManager // Manages GPU data + cache *SessionCache // Cache for system stats based on primary session ID + connectionManager *ConnectionManager // Channel to signal connection events + server *ssh.Server // SSH server + dataDir string // Directory for persisting data + keys []gossh.PublicKey // SSH public keys } -func NewAgent() *Agent { - agent := &Agent{ +// NewAgent creates a new agent with the given data directory for persisting data. +// If the data directory is not set, it will attempt to find the optimal directory. +func NewAgent(dataDir string) (agent *Agent, err error) { + agent = &Agent{ fsStats: make(map[string]*system.FsStats), cache: NewSessionCache(69 * time.Second), } + + agent.dataDir, err = getDataDir(dataDir) + if err != nil { + slog.Warn("Data directory not found") + } else { + slog.Info("Data directory", "path", agent.dataDir) + } + agent.memCalc, _ = GetEnv("MEM_CALC") agent.sensorConfig = agent.newSensorConfig() // Set up slog with a log level determined by the LOG_LEVEL env var @@ -49,10 +70,19 @@ func NewAgent() *Agent { slog.Debug(beszel.Version) - // initialize system info / docker manager + // initialize system info agent.initializeSystemInfo() + + // initialize connection manager + agent.connectionManager = newConnectionManager(agent) + + // initialize disk info agent.initializeDiskInfo() + + // initialize net io stats agent.initializeNetIoStats() + + // initialize docker manager agent.dockerManager = newDockerManager(agent) // initialize GPU manager @@ -67,7 +97,7 @@ func NewAgent() *Agent { slog.Debug("Stats", "data", agent.gatherStats("")) } - return agent + return agent, nil } // GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key. @@ -115,3 +145,38 @@ func (a *Agent) gatherStats(sessionID string) *system.CombinedData { a.cache.Set(sessionID, cachedData) return cachedData } + +// StartAgent initializes and starts the agent with optional WebSocket connection +func (a *Agent) Start(serverOptions ServerOptions) error { + a.keys = serverOptions.Keys + return a.connectionManager.Start(serverOptions) +} + +func (a *Agent) getFingerprint() string { + // first look for a fingerprint in the data directory + if a.dataDir != "" { + if fp, err := os.ReadFile(filepath.Join(a.dataDir, "fingerprint")); err == nil { + return string(fp) + } + } + + // if no fingerprint is found, generate one + fingerprint, err := host.HostID() + if err != nil || fingerprint == "" { + fingerprint = a.systemInfo.Hostname + a.systemInfo.CpuModel + } + + // hash fingerprint + sum := sha256.Sum256([]byte(fingerprint)) + fingerprint = hex.EncodeToString(sum[:24]) + + // save fingerprint to data directory + if a.dataDir != "" { + err = os.WriteFile(filepath.Join(a.dataDir, "fingerprint"), []byte(fingerprint), 0644) + if err != nil { + slog.Warn("Failed to save fingerprint", "err", err) + } + } + + return fingerprint +} diff --git a/beszel/internal/agent/agent_test_helpers.go b/beszel/internal/agent/agent_test_helpers.go new file mode 100644 index 0000000..7b3fffa --- /dev/null +++ b/beszel/internal/agent/agent_test_helpers.go @@ -0,0 +1,9 @@ +//go:build testing +// +build testing + +package agent + +// TESTING ONLY: GetConnectionManager is a helper function to get the connection manager for testing. +func (a *Agent) GetConnectionManager() *ConnectionManager { + return a.connectionManager +} diff --git a/beszel/internal/agent/client.go b/beszel/internal/agent/client.go new file mode 100644 index 0000000..6dbcf1d --- /dev/null +++ b/beszel/internal/agent/client.go @@ -0,0 +1,243 @@ +package agent + +import ( + "beszel" + "beszel/internal/common" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "net/url" + "path" + "strings" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/lxzan/gws" + "golang.org/x/crypto/ssh" +) + +const ( + wsDeadline = 70 * time.Second +) + +// WebSocketClient manages the WebSocket connection between the agent and hub. +// It handles authentication, message routing, and connection lifecycle management. +type WebSocketClient struct { + gws.BuiltinEventHandler + options *gws.ClientOption // WebSocket client configuration options + agent *Agent // Reference to the parent agent + Conn *gws.Conn // Active WebSocket connection + hubURL *url.URL // Parsed hub URL for connection + token string // Authentication token for hub registration + fingerprint string // System fingerprint for identification + hubRequest *common.HubRequest[cbor.RawMessage] // Reusable request structure for message parsing + lastConnectAttempt time.Time // Timestamp of last connection attempt + hubVerified bool // Whether the hub has been cryptographically verified +} + +// newWebSocketClient creates a new WebSocket client for the given agent. +// It reads configuration from environment variables and validates the hub URL. +func newWebSocketClient(agent *Agent) (client *WebSocketClient, err error) { + hubURLStr, exists := GetEnv("HUB_URL") + if !exists { + return nil, errors.New("HUB_URL environment variable not set") + } + + client = &WebSocketClient{} + + client.hubURL, err = url.Parse(hubURLStr) + if err != nil { + return nil, errors.New("invalid hub URL") + } + // get registration token + client.token, _ = GetEnv("TOKEN") + if client.token == "" { + return nil, errors.New("TOKEN environment variable not set") + } + + client.agent = agent + client.hubRequest = &common.HubRequest[cbor.RawMessage]{} + client.fingerprint = agent.getFingerprint() + + return client, nil +} + +// getOptions returns the WebSocket client options, creating them if necessary. +// It configures the connection URL, TLS settings, and authentication headers. +func (client *WebSocketClient) getOptions() *gws.ClientOption { + if client.options != nil { + return client.options + } + + // update the hub url to use websocket scheme and api path + if client.hubURL.Scheme == "https" { + client.hubURL.Scheme = "wss" + } else { + client.hubURL.Scheme = "ws" + } + client.hubURL.Path = path.Join(client.hubURL.Path, "api/beszel/agent-connect") + + client.options = &gws.ClientOption{ + Addr: client.hubURL.String(), + TlsConfig: &tls.Config{InsecureSkipVerify: true}, + RequestHeader: http.Header{ + "User-Agent": []string{getUserAgent()}, + "X-Token": []string{client.token}, + "X-Beszel": []string{beszel.Version}, + }, + } + return client.options +} + +// Connect establishes a WebSocket connection to the hub. +// It closes any existing connection before attempting to reconnect. +func (client *WebSocketClient) Connect() (err error) { + client.lastConnectAttempt = time.Now() + + // make sure previous connection is closed + client.Close() + + client.Conn, _, err = gws.NewClient(client, client.getOptions()) + if err != nil { + return err + } + + go client.Conn.ReadLoop() + + return nil +} + +// OnOpen handles WebSocket connection establishment. +// It sets a deadline for the connection to prevent hanging. +func (client *WebSocketClient) OnOpen(conn *gws.Conn) { + conn.SetDeadline(time.Now().Add(wsDeadline)) +} + +// OnClose handles WebSocket connection closure. +// It logs the closure reason and notifies the connection manager. +func (client *WebSocketClient) OnClose(conn *gws.Conn, err error) { + slog.Warn("Connection closed", "err", strings.TrimPrefix(err.Error(), "gws: ")) + client.agent.connectionManager.eventChan <- WebSocketDisconnect +} + +// OnMessage handles incoming WebSocket messages from the hub. +// It decodes CBOR messages and routes them to appropriate handlers. +func (client *WebSocketClient) OnMessage(conn *gws.Conn, message *gws.Message) { + defer message.Close() + conn.SetDeadline(time.Now().Add(wsDeadline)) + + if message.Opcode != gws.OpcodeBinary { + return + } + + if err := cbor.NewDecoder(message.Data).Decode(client.hubRequest); err != nil { + slog.Error("Error parsing message", "err", err) + return + } + if err := client.handleHubRequest(client.hubRequest); err != nil { + slog.Error("Error handling message", "err", err) + } +} + +// OnPing handles WebSocket ping frames. +// It responds with a pong and updates the connection deadline. +func (client *WebSocketClient) OnPing(conn *gws.Conn, message []byte) { + conn.SetDeadline(time.Now().Add(wsDeadline)) + conn.WritePong(message) +} + +// handleAuthChallenge verifies the authenticity of the hub and returns the system's fingerprint. +func (client *WebSocketClient) handleAuthChallenge(msg *common.HubRequest[cbor.RawMessage]) (err error) { + var authRequest common.FingerprintRequest + if err := cbor.Unmarshal(msg.Data, &authRequest); err != nil { + return err + } + + if err := client.verifySignature(authRequest.Signature); err != nil { + return err + } + + client.hubVerified = true + client.agent.connectionManager.eventChan <- WebSocketConnect + + response := &common.FingerprintResponse{ + Fingerprint: client.fingerprint, + } + + if authRequest.NeedSysInfo { + response.Hostname = client.agent.systemInfo.Hostname + serverAddr := client.agent.connectionManager.serverOptions.Addr + _, response.Port, _ = net.SplitHostPort(serverAddr) + } + + return client.sendMessage(response) +} + +// verifySignature verifies the signature of the token using the public keys. +func (client *WebSocketClient) verifySignature(signature []byte) (err error) { + for _, pubKey := range client.agent.keys { + sig := ssh.Signature{ + Format: pubKey.Type(), + Blob: signature, + } + if err = pubKey.Verify([]byte(client.token), &sig); err == nil { + return nil + } + } + return errors.New("invalid signature - check KEY value") +} + +// Close closes the WebSocket connection gracefully. +// This method is safe to call multiple times. +func (client *WebSocketClient) Close() { + if client.Conn != nil { + _ = client.Conn.WriteClose(1000, nil) + } +} + +// handleHubRequest routes the request to the appropriate handler. +// It ensures the hub is verified before processing most requests. +func (client *WebSocketClient) handleHubRequest(msg *common.HubRequest[cbor.RawMessage]) error { + if !client.hubVerified && msg.Action != common.CheckFingerprint { + return errors.New("hub not verified") + } + switch msg.Action { + case common.GetData: + return client.sendSystemData() + case common.CheckFingerprint: + return client.handleAuthChallenge(msg) + } + return nil +} + +// sendSystemData gathers and sends current system statistics to the hub. +func (client *WebSocketClient) sendSystemData() error { + sysStats := client.agent.gatherStats(client.token) + return client.sendMessage(sysStats) +} + +// sendMessage encodes the given data to CBOR and sends it as a binary message over the WebSocket connection to the hub. +func (client *WebSocketClient) sendMessage(data any) error { + bytes, err := cbor.Marshal(data) + if err != nil { + return err + } + return client.Conn.WriteMessage(gws.OpcodeBinary, bytes) +} + +// getUserAgent returns one of two User-Agent strings based on current time. +// This is used to avoid being blocked by Cloudflare or other anti-bot measures. +func getUserAgent() string { + const ( + uaBase = "Mozilla/5.0 (%s) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" + uaWindows = "Windows NT 11.0; Win64; x64" + uaMac = "Macintosh; Intel Mac OS X 14_0_0" + ) + if time.Now().UnixNano()%2 == 0 { + return fmt.Sprintf(uaBase, uaWindows) + } + return fmt.Sprintf(uaBase, uaMac) +} diff --git a/beszel/internal/agent/connection_manager.go b/beszel/internal/agent/connection_manager.go new file mode 100644 index 0000000..9801e38 --- /dev/null +++ b/beszel/internal/agent/connection_manager.go @@ -0,0 +1,220 @@ +package agent + +import ( + "beszel/internal/agent/health" + "errors" + "log/slog" + "os" + "os/signal" + "syscall" + "time" +) + +// ConnectionManager manages the connection state and events for the agent. +// It handles both WebSocket and SSH connections, automatically switching between +// them based on availability and managing reconnection attempts. +type ConnectionManager struct { + agent *Agent // Reference to the parent agent + State ConnectionState // Current connection state + eventChan chan ConnectionEvent // Channel for connection events + wsClient *WebSocketClient // WebSocket client for hub communication + serverOptions ServerOptions // Configuration for SSH server + wsTicker *time.Ticker // Ticker for WebSocket connection attempts + isConnecting bool // Prevents multiple simultaneous reconnection attempts +} + +// ConnectionState represents the current connection state of the agent. +type ConnectionState uint8 + +// ConnectionEvent represents connection-related events that can occur. +type ConnectionEvent uint8 + +// Connection states +const ( + Disconnected ConnectionState = iota // No active connection + WebSocketConnected // Connected via WebSocket + SSHConnected // Connected via SSH +) + +// Connection events +const ( + WebSocketConnect ConnectionEvent = iota // WebSocket connection established + WebSocketDisconnect // WebSocket connection lost + SSHConnect // SSH connection established + SSHDisconnect // SSH connection lost +) + +const wsTickerInterval = 10 * time.Second + +// newConnectionManager creates a new connection manager for the given agent. +func newConnectionManager(agent *Agent) *ConnectionManager { + cm := &ConnectionManager{ + agent: agent, + State: Disconnected, + } + return cm +} + +// startWsTicker starts or resets the WebSocket connection attempt ticker. +func (c *ConnectionManager) startWsTicker() { + if c.wsTicker == nil { + c.wsTicker = time.NewTicker(wsTickerInterval) + } else { + c.wsTicker.Reset(wsTickerInterval) + } +} + +// stopWsTicker stops the WebSocket connection attempt ticker. +func (c *ConnectionManager) stopWsTicker() { + if c.wsTicker != nil { + c.wsTicker.Stop() + } +} + +// Start begins connection attempts and enters the main event loop. +// It handles connection events, periodic health updates, and graceful shutdown. +func (c *ConnectionManager) Start(serverOptions ServerOptions) error { + if c.eventChan != nil { + return errors.New("already started") + } + + wsClient, err := newWebSocketClient(c.agent) + if err != nil { + slog.Warn("Error creating WebSocket client", "err", err) + } + c.wsClient = wsClient + + c.serverOptions = serverOptions + c.eventChan = make(chan ConnectionEvent, 1) + + // signal handling for shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + c.startWsTicker() + c.connect() + + // update health status immediately and every 90 seconds + _ = health.Update() + healthTicker := time.Tick(90 * time.Second) + + for { + select { + case connectionEvent := <-c.eventChan: + c.handleEvent(connectionEvent) + case <-c.wsTicker.C: + _ = c.startWebSocketConnection() + case <-healthTicker: + _ = health.Update() + case <-sigChan: + slog.Info("Shutting down") + _ = c.agent.StopServer() + c.closeWebSocket() + return health.CleanUp() + } + } +} + +// handleEvent processes connection events and updates the connection state accordingly. +func (c *ConnectionManager) handleEvent(event ConnectionEvent) { + switch event { + case WebSocketConnect: + c.handleStateChange(WebSocketConnected) + case SSHConnect: + c.handleStateChange(SSHConnected) + case WebSocketDisconnect: + if c.State == WebSocketConnected { + c.handleStateChange(Disconnected) + } + case SSHDisconnect: + if c.State == SSHConnected { + c.handleStateChange(Disconnected) + } + } +} + +// handleStateChange updates the connection state and performs necessary actions +// based on the new state, including stopping services and initiating reconnections. +func (c *ConnectionManager) handleStateChange(newState ConnectionState) { + if c.State == newState { + return + } + c.State = newState + switch newState { + case WebSocketConnected: + slog.Info("WebSocket connected", "host", c.wsClient.hubURL.Host) + c.stopWsTicker() + _ = c.agent.StopServer() + c.isConnecting = false + case SSHConnected: + // stop new ws connection attempts + slog.Info("SSH connection established") + c.stopWsTicker() + c.isConnecting = false + case Disconnected: + if c.isConnecting { + // Already handling reconnection, avoid duplicate attempts + return + } + c.isConnecting = true + slog.Warn("Disconnected from hub") + // make sure old ws connection is closed + c.closeWebSocket() + // reconnect + go c.connect() + } +} + +// connect handles the connection logic with proper delays and priority. +// It attempts WebSocket connection first, falling back to SSH server if needed. +func (c *ConnectionManager) connect() { + c.isConnecting = true + defer func() { + c.isConnecting = false + }() + + if c.wsClient != nil && time.Since(c.wsClient.lastConnectAttempt) < 5*time.Second { + time.Sleep(5 * time.Second) + } + + // Try WebSocket first, if it fails, start SSH server + err := c.startWebSocketConnection() + if err != nil && c.State == Disconnected { + c.startSSHServer() + c.startWsTicker() + } +} + +// startWebSocketConnection attempts to establish a WebSocket connection to the hub. +func (c *ConnectionManager) startWebSocketConnection() error { + if c.State != Disconnected { + return errors.New("already connected") + } + if c.wsClient == nil { + return errors.New("WebSocket client not initialized") + } + if time.Since(c.wsClient.lastConnectAttempt) < 5*time.Second { + return errors.New("already connecting") + } + + err := c.wsClient.Connect() + if err != nil { + slog.Warn("WebSocket connection failed", "err", err) + c.closeWebSocket() + } + return err +} + +// startSSHServer starts the SSH server if the agent is currently disconnected. +func (c *ConnectionManager) startSSHServer() { + if c.State == Disconnected { + go c.agent.StartServer(c.serverOptions) + } +} + +// closeWebSocket closes the WebSocket connection if it exists. +func (c *ConnectionManager) closeWebSocket() { + if c.wsClient != nil { + c.wsClient.Close() + } +} diff --git a/beszel/internal/agent/connection_manager_test.go b/beszel/internal/agent/connection_manager_test.go new file mode 100644 index 0000000..3224f8c --- /dev/null +++ b/beszel/internal/agent/connection_manager_test.go @@ -0,0 +1,315 @@ +//go:build testing +// +build testing + +package agent + +import ( + "crypto/ed25519" + "fmt" + "net" + "net/url" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func createTestAgent(t *testing.T) *Agent { + dataDir := t.TempDir() + agent, err := NewAgent(dataDir) + require.NoError(t, err) + return agent +} + +func createTestServerOptions(t *testing.T) ServerOptions { + // Generate test key pair + _, privKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + sshPubKey, err := ssh.NewPublicKey(privKey.Public().(ed25519.PublicKey)) + require.NoError(t, err) + + // Find available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + return ServerOptions{ + Network: "tcp", + Addr: fmt.Sprintf("127.0.0.1:%d", port), + Keys: []ssh.PublicKey{sshPubKey}, + } +} + +// TestConnectionManager_NewConnectionManager tests connection manager creation +func TestConnectionManager_NewConnectionManager(t *testing.T) { + agent := createTestAgent(t) + cm := newConnectionManager(agent) + + assert.NotNil(t, cm, "Connection manager should not be nil") + assert.Equal(t, agent, cm.agent, "Agent reference should be set") + assert.Equal(t, Disconnected, cm.State, "Initial state should be Disconnected") + assert.Nil(t, cm.eventChan, "Event channel should be nil initially") + assert.Nil(t, cm.wsClient, "WebSocket client should be nil initially") + assert.Nil(t, cm.wsTicker, "WebSocket ticker should be nil initially") + assert.False(t, cm.isConnecting, "isConnecting should be false initially") +} + +// TestConnectionManager_StateTransitions tests basic state transitions +func TestConnectionManager_StateTransitions(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + initialState := cm.State + cm.wsClient = &WebSocketClient{ + hubURL: &url.URL{ + Host: "localhost:8080", + }, + } + assert.NotNil(t, cm, "Connection manager should not be nil") + assert.Equal(t, Disconnected, initialState, "Initial state should be Disconnected") + + // Test state transitions + cm.handleStateChange(WebSocketConnected) + assert.Equal(t, WebSocketConnected, cm.State, "State should change to WebSocketConnected") + + cm.handleStateChange(SSHConnected) + assert.Equal(t, SSHConnected, cm.State, "State should change to SSHConnected") + + cm.handleStateChange(Disconnected) + assert.Equal(t, Disconnected, cm.State, "State should change to Disconnected") + + // Test that same state doesn't trigger changes + cm.State = WebSocketConnected + cm.handleStateChange(WebSocketConnected) + assert.Equal(t, WebSocketConnected, cm.State, "Same state should not trigger change") +} + +// TestConnectionManager_EventHandling tests event handling logic +func TestConnectionManager_EventHandling(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + cm.wsClient = &WebSocketClient{ + hubURL: &url.URL{ + Host: "localhost:8080", + }, + } + + testCases := []struct { + name string + initialState ConnectionState + event ConnectionEvent + expectedState ConnectionState + }{ + { + name: "WebSocket connect from disconnected", + initialState: Disconnected, + event: WebSocketConnect, + expectedState: WebSocketConnected, + }, + { + name: "SSH connect from disconnected", + initialState: Disconnected, + event: SSHConnect, + expectedState: SSHConnected, + }, + { + name: "WebSocket disconnect from connected", + initialState: WebSocketConnected, + event: WebSocketDisconnect, + expectedState: Disconnected, + }, + { + name: "SSH disconnect from connected", + initialState: SSHConnected, + event: SSHDisconnect, + expectedState: Disconnected, + }, + { + name: "WebSocket disconnect from SSH connected (no change)", + initialState: SSHConnected, + event: WebSocketDisconnect, + expectedState: SSHConnected, + }, + { + name: "SSH disconnect from WebSocket connected (no change)", + initialState: WebSocketConnected, + event: SSHDisconnect, + expectedState: WebSocketConnected, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cm.State = tc.initialState + cm.handleEvent(tc.event) + assert.Equal(t, tc.expectedState, cm.State, "State should match expected after event") + }) + } +} + +// TestConnectionManager_TickerManagement tests WebSocket ticker management +func TestConnectionManager_TickerManagement(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + + // Test starting ticker + cm.startWsTicker() + assert.NotNil(t, cm.wsTicker, "Ticker should be created") + + // Test stopping ticker (should not panic) + assert.NotPanics(t, func() { + cm.stopWsTicker() + }, "Stopping ticker should not panic") + + // Test stopping nil ticker (should not panic) + cm.wsTicker = nil + assert.NotPanics(t, func() { + cm.stopWsTicker() + }, "Stopping nil ticker should not panic") + + // Test restarting ticker + cm.startWsTicker() + assert.NotNil(t, cm.wsTicker, "Ticker should be recreated") + + // Test resetting existing ticker + firstTicker := cm.wsTicker + cm.startWsTicker() + assert.Equal(t, firstTicker, cm.wsTicker, "Same ticker instance should be reused") + + cm.stopWsTicker() +} + +// TestConnectionManager_WebSocketConnectionFlow tests WebSocket connection logic +func TestConnectionManager_WebSocketConnectionFlow(t *testing.T) { + if testing.Short() { + t.Skip("Skipping WebSocket connection test in short mode") + } + + agent := createTestAgent(t) + cm := agent.connectionManager + + // Test WebSocket connection without proper environment + err := cm.startWebSocketConnection() + assert.Error(t, err, "WebSocket connection should fail without proper environment") + assert.Equal(t, Disconnected, cm.State, "State should remain Disconnected after failed connection") + + // Test with invalid URL + os.Setenv("BESZEL_AGENT_HUB_URL", "invalid-url") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + // Test with missing token + os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080") + os.Unsetenv("BESZEL_AGENT_TOKEN") + + _, err2 := newWebSocketClient(agent) + assert.Error(t, err2, "WebSocket client creation should fail without token") +} + +// TestConnectionManager_ReconnectionLogic tests reconnection prevention logic +func TestConnectionManager_ReconnectionLogic(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + cm.eventChan = make(chan ConnectionEvent, 1) + + // Test that isConnecting flag prevents duplicate reconnection attempts + // Start from connected state, then simulate disconnect + cm.State = WebSocketConnected + cm.isConnecting = false + + // First disconnect should trigger reconnection logic + cm.handleStateChange(Disconnected) + assert.Equal(t, Disconnected, cm.State, "Should change to disconnected") + assert.True(t, cm.isConnecting, "Should set isConnecting flag") +} + +// TestConnectionManager_ConnectWithRateLimit tests connection rate limiting +func TestConnectionManager_ConnectWithRateLimit(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + + // Set up environment for WebSocket client creation + os.Setenv("BESZEL_AGENT_HUB_URL", "ws://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + // Create WebSocket client + wsClient, err := newWebSocketClient(agent) + require.NoError(t, err) + cm.wsClient = wsClient + + // Set recent connection attempt + cm.wsClient.lastConnectAttempt = time.Now() + + // Test that connection is rate limited + err = cm.startWebSocketConnection() + assert.Error(t, err, "Should error due to rate limiting") + assert.Contains(t, err.Error(), "already connecting", "Error should indicate rate limiting") + + // Test connection after rate limit expires + cm.wsClient.lastConnectAttempt = time.Now().Add(-10 * time.Second) + err = cm.startWebSocketConnection() + // This will fail due to no actual server, but should not be rate limited + assert.Error(t, err, "Connection should fail but not due to rate limiting") + assert.NotContains(t, err.Error(), "already connecting", "Error should not indicate rate limiting") +} + +// TestConnectionManager_StartWithInvalidConfig tests starting with invalid configuration +func TestConnectionManager_StartWithInvalidConfig(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + serverOptions := createTestServerOptions(t) + + // Test starting when already started + cm.eventChan = make(chan ConnectionEvent, 5) + err := cm.Start(serverOptions) + assert.Error(t, err, "Should error when starting already started connection manager") +} + +// TestConnectionManager_CloseWebSocket tests WebSocket closing +func TestConnectionManager_CloseWebSocket(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + + // Test closing when no WebSocket client exists + assert.NotPanics(t, func() { + cm.closeWebSocket() + }, "Should not panic when closing nil WebSocket client") + + // Set up environment and create WebSocket client + os.Setenv("BESZEL_AGENT_HUB_URL", "ws://localhost:8080") + os.Setenv("BESZEL_AGENT_TOKEN", "test-token") + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + wsClient, err := newWebSocketClient(agent) + require.NoError(t, err) + cm.wsClient = wsClient + + // Test closing when WebSocket client exists + assert.NotPanics(t, func() { + cm.closeWebSocket() + }, "Should not panic when closing WebSocket client") +} + +// TestConnectionManager_ConnectFlow tests the connect method +func TestConnectionManager_ConnectFlow(t *testing.T) { + agent := createTestAgent(t) + cm := agent.connectionManager + + // Test connect without WebSocket client + assert.NotPanics(t, func() { + cm.connect() + }, "Connect should not panic without WebSocket client") +} diff --git a/beszel/internal/agent/server.go b/beszel/internal/agent/server.go index c03283f..3b34122 100644 --- a/beszel/internal/agent/server.go +++ b/beszel/internal/agent/server.go @@ -1,25 +1,44 @@ package agent import ( + "beszel" "beszel/internal/common" + "beszel/internal/entities/system" "encoding/json" + "errors" "fmt" + "io" "log/slog" "net" "os" "strings" + "time" + "github.com/blang/semver" + "github.com/fxamacker/cbor/v2" "github.com/gliderlabs/ssh" gossh "golang.org/x/crypto/ssh" ) +// ServerOptions contains configuration options for starting the SSH server. type ServerOptions struct { - Addr string - Network string - Keys []gossh.PublicKey + Addr string // Network address to listen on (e.g., ":45876" or "/path/to/socket") + Network string // Network type ("tcp" or "unix") + Keys []gossh.PublicKey // SSH public keys for authentication } +// hubVersions caches hub versions by session ID to avoid repeated parsing. +var hubVersions map[string]semver.Version + +// StartServer starts the SSH server with the provided options. +// It configures the server with secure defaults, sets up authentication, +// and begins listening for connections. Returns an error if the server +// is already running or if there's an issue starting the server. func (a *Agent) StartServer(opts ServerOptions) error { + if a.server != nil { + return errors.New("server already started") + } + slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network) if opts.Network == "unix" { @@ -37,7 +56,9 @@ func (a *Agent) StartServer(opts ServerOptions) error { defer ln.Close() // base config (limit to allowed algorithms) - config := &gossh.ServerConfig{} + config := &gossh.ServerConfig{ + ServerVersion: fmt.Sprintf("SSH-2.0-%s_%s", beszel.AppName, beszel.Version), + } config.KeyExchanges = common.DefaultKeyExchanges config.MACs = common.DefaultMACs config.Ciphers = common.DefaultCiphers @@ -45,42 +66,92 @@ func (a *Agent) StartServer(opts ServerOptions) error { // set default handler ssh.Handle(a.handleSession) - server := ssh.Server{ + a.server = &ssh.Server{ ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return config }, // check public key(s) PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { + remoteAddr := ctx.RemoteAddr() for _, pubKey := range opts.Keys { if ssh.KeysEqual(key, pubKey) { + slog.Info("SSH connected", "addr", remoteAddr) return true } } + slog.Warn("Invalid SSH key", "addr", remoteAddr) return false }, // disable pty PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { return false }, - // log failed connections - ConnectionFailedCallback: func(conn net.Conn, err error) { - slog.Warn("Failed connection attempt", "addr", conn.RemoteAddr().String(), "err", err) - }, + // close idle connections after 70 seconds + IdleTimeout: 70 * time.Second, } // Start SSH server on the listener - return server.Serve(ln) + return a.server.Serve(ln) } +// getHubVersion retrieves and caches the hub version for a given session. +// It extracts the version from the SSH client version string and caches +// it to avoid repeated parsing. Returns a zero version if parsing fails. +func (a *Agent) getHubVersion(sessionId string, sessionCtx ssh.Context) semver.Version { + if hubVersions == nil { + hubVersions = make(map[string]semver.Version, 1) + } + hubVersion, ok := hubVersions[sessionId] + if ok { + return hubVersion + } + // Extract hub version from SSH client version + clientVersion := sessionCtx.Value(ssh.ContextKeyClientVersion) + if versionStr, ok := clientVersion.(string); ok { + hubVersion, _ = extractHubVersion(versionStr) + } + hubVersions[sessionId] = hubVersion + return hubVersion +} + +// handleSession handles an incoming SSH session by gathering system statistics +// and sending them to the hub. It signals connection events, determines the +// appropriate encoding format based on hub version, and exits with appropriate +// status codes. func (a *Agent) handleSession(s ssh.Session) { - slog.Debug("New session", "client", s.RemoteAddr()) - stats := a.gatherStats(s.Context().SessionID()) - if err := json.NewEncoder(s).Encode(stats); err != nil { + a.connectionManager.eventChan <- SSHConnect + + sessionCtx := s.Context() + sessionID := sessionCtx.SessionID() + + hubVersion := a.getHubVersion(sessionID, sessionCtx) + + stats := a.gatherStats(sessionID) + + err := a.writeToSession(s, stats, hubVersion) + if err != nil { slog.Error("Error encoding stats", "err", err, "stats", stats) s.Exit(1) - return + } else { + s.Exit(0) } - s.Exit(0) +} + +// writeToSession encodes and writes system statistics to the session. +// It chooses between CBOR and JSON encoding based on the hub version, +// using CBOR for newer versions and JSON for legacy compatibility. +func (a *Agent) writeToSession(w io.Writer, stats *system.CombinedData, hubVersion semver.Version) error { + if hubVersion.GTE(beszel.MinVersionCbor) { + return cbor.NewEncoder(w).Encode(stats) + } + return json.NewEncoder(w).Encode(stats) +} + +// extractHubVersion extracts the beszel version from SSH client version string. +// Expected format: "SSH-2.0-beszel_X.Y.Z" or "beszel_X.Y.Z" +func extractHubVersion(versionString string) (semver.Version, error) { + _, after, _ := strings.Cut(versionString, "_") + return semver.Parse(after) } // ParseKeys parses a string containing SSH public keys in authorized_keys format. @@ -103,7 +174,9 @@ func ParseKeys(input string) ([]gossh.PublicKey, error) { return parsedKeys, nil } -// GetAddress gets the address to listen on or connect to from environment variables or default value. +// GetAddress determines the network address to listen on from various sources. +// It checks the provided address, then environment variables (LISTEN, PORT), +// and finally defaults to ":45876". func GetAddress(addr string) string { if addr == "" { addr, _ = GetEnv("LISTEN") @@ -122,7 +195,9 @@ func GetAddress(addr string) string { return addr } -// GetNetwork returns the network type to use based on the address +// GetNetwork determines the network type based on the address format. +// It checks the NETWORK environment variable first, then infers from +// the address format: addresses starting with "/" are "unix", others are "tcp". func GetNetwork(addr string) string { if network, ok := GetEnv("NETWORK"); ok && network != "" { return network @@ -132,3 +207,17 @@ func GetNetwork(addr string) string { } return "tcp" } + +// StopServer stops the SSH server if it's running. +// It returns an error if the server is not running or if there's an error stopping it. +func (a *Agent) StopServer() error { + if a.server == nil { + return errors.New("SSH server not running") + } + + slog.Info("Stopping SSH server") + _ = a.server.Close() + a.server = nil + a.connectionManager.eventChan <- SSHDisconnect + return nil +} diff --git a/beszel/internal/agent/server_test.go b/beszel/internal/agent/server_test.go index c9a34f3..5fd1ac6 100644 --- a/beszel/internal/agent/server_test.go +++ b/beszel/internal/agent/server_test.go @@ -1,34 +1,43 @@ package agent import ( + "beszel/internal/entities/container" + "beszel/internal/entities/system" + "context" "crypto/ed25519" + "encoding/json" "fmt" + "net" "os" "path/filepath" "strings" + "sync" "testing" "time" + "github.com/blang/semver" + "github.com/fxamacker/cbor/v2" + "github.com/gliderlabs/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func TestStartServer(t *testing.T) { // Generate a test key pair pubKey, privKey, err := ed25519.GenerateKey(nil) require.NoError(t, err) - signer, err := ssh.NewSignerFromKey(privKey) + signer, err := gossh.NewSignerFromKey(privKey) require.NoError(t, err) - sshPubKey, err := ssh.NewPublicKey(pubKey) + sshPubKey, err := gossh.NewPublicKey(pubKey) require.NoError(t, err) // Generate a different key pair for bad key test badPubKey, badPrivKey, err := ed25519.GenerateKey(nil) require.NoError(t, err) - badSigner, err := ssh.NewSignerFromKey(badPrivKey) + badSigner, err := gossh.NewSignerFromKey(badPrivKey) require.NoError(t, err) - sshBadPubKey, err := ssh.NewPublicKey(badPubKey) + sshBadPubKey, err := gossh.NewPublicKey(badPubKey) require.NoError(t, err) socketFile := filepath.Join(t.TempDir(), "beszel-test.sock") @@ -46,7 +55,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "tcp", Addr: ":45987", - Keys: []ssh.PublicKey{sshPubKey}, + Keys: []gossh.PublicKey{sshPubKey}, }, }, { @@ -54,7 +63,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "tcp4", Addr: "127.0.0.1:45988", - Keys: []ssh.PublicKey{sshPubKey}, + Keys: []gossh.PublicKey{sshPubKey}, }, }, { @@ -62,7 +71,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "tcp6", Addr: "[::1]:45989", - Keys: []ssh.PublicKey{sshPubKey}, + Keys: []gossh.PublicKey{sshPubKey}, }, }, { @@ -70,7 +79,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "unix", Addr: socketFile, - Keys: []ssh.PublicKey{sshPubKey}, + Keys: []gossh.PublicKey{sshPubKey}, }, setup: func() error { // Create a socket file that should be removed @@ -89,7 +98,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "tcp", Addr: ":45987", - Keys: []ssh.PublicKey{sshBadPubKey}, + Keys: []gossh.PublicKey{sshBadPubKey}, }, wantErr: true, errContains: "ssh: handshake failed", @@ -99,7 +108,7 @@ func TestStartServer(t *testing.T) { config: ServerOptions{ Network: "tcp", Addr: ":45987", - Keys: []ssh.PublicKey{sshPubKey}, + Keys: []gossh.PublicKey{sshPubKey}, }, }, } @@ -115,7 +124,8 @@ func TestStartServer(t *testing.T) { defer tt.cleanup() } - agent := NewAgent() + agent, err := NewAgent("") + require.NoError(t, err) // Start server in a goroutine since it blocks errChan := make(chan error, 1) @@ -127,8 +137,7 @@ func TestStartServer(t *testing.T) { time.Sleep(100 * time.Millisecond) // Try to connect to verify server is running - var client *ssh.Client - var err error + var client *gossh.Client // Choose the appropriate signer based on the test case testSigner := signer @@ -136,23 +145,23 @@ func TestStartServer(t *testing.T) { testSigner = badSigner } - sshClientConfig := &ssh.ClientConfig{ + sshClientConfig := &gossh.ClientConfig{ User: "a", - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(testSigner), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(testSigner), }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: gossh.InsecureIgnoreHostKey(), Timeout: 4 * time.Second, } switch tt.config.Network { case "unix": - client, err = ssh.Dial("unix", tt.config.Addr, sshClientConfig) + client, err = gossh.Dial("unix", tt.config.Addr, sshClientConfig) default: if !strings.Contains(tt.config.Addr, ":") { tt.config.Addr = ":" + tt.config.Addr } - client, err = ssh.Dial("tcp", tt.config.Addr, sshClientConfig) + client, err = gossh.Dial("tcp", tt.config.Addr, sshClientConfig) } if tt.wantErr { @@ -287,3 +296,310 @@ func TestParseInvalidKey(t *testing.T) { t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err) } } + +///////////////////////////////////////////////////////////////// +//////////////////// Hub Version Tests ////////////////////////// +///////////////////////////////////////////////////////////////// + +func TestExtractHubVersion(t *testing.T) { + tests := []struct { + name string + clientVersion string + expectedVersion string + expectError bool + }{ + { + name: "valid beszel client version with underscore", + clientVersion: "SSH-2.0-beszel_0.11.1", + expectedVersion: "0.11.1", + expectError: false, + }, + { + name: "valid beszel client version with beta", + clientVersion: "SSH-2.0-beszel_1.0.0-beta", + expectedVersion: "1.0.0-beta", + expectError: false, + }, + { + name: "valid beszel client version with rc", + clientVersion: "SSH-2.0-beszel_0.12.0-rc1", + expectedVersion: "0.12.0-rc1", + expectError: false, + }, + { + name: "different SSH client", + clientVersion: "SSH-2.0-OpenSSH_8.0", + expectedVersion: "8.0", + expectError: true, + }, + { + name: "malformed version string without underscore", + clientVersion: "SSH-2.0-beszel", + expectError: true, + }, + { + name: "empty version string", + clientVersion: "", + expectError: true, + }, + { + name: "version string with underscore but no version", + clientVersion: "beszel_", + expectedVersion: "", + expectError: true, + }, + { + name: "version with patch and build metadata", + clientVersion: "SSH-2.0-beszel_1.2.3+build.123", + expectedVersion: "1.2.3+build.123", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractHubVersion(tt.clientVersion) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedVersion, result.String()) + }) + } +} + +///////////////////////////////////////////////////////////////// +/////////////// Hub Version Detection Tests //////////////////// +///////////////////////////////////////////////////////////////// + +func TestGetHubVersion(t *testing.T) { + agent, err := NewAgent("") + require.NoError(t, err) + + // Mock SSH context that implements the ssh.Context interface + mockCtx := &mockSSHContext{ + sessionID: "test-session-123", + clientVersion: "SSH-2.0-beszel_0.12.0", + } + + // Test first call - should extract and cache version + version := agent.getHubVersion("test-session-123", mockCtx) + assert.Equal(t, "0.12.0", version.String()) + + // Test second call - should return cached version + mockCtx.clientVersion = "SSH-2.0-beszel_0.11.0" // Change version but should still return cached + version = agent.getHubVersion("test-session-123", mockCtx) + assert.Equal(t, "0.12.0", version.String()) // Should still be cached version + + // Test different session - should extract new version + version = agent.getHubVersion("different-session", mockCtx) + assert.Equal(t, "0.11.0", version.String()) + + // Test with invalid version string (non-beszel client) + mockCtx.clientVersion = "SSH-2.0-OpenSSH_8.0" + version = agent.getHubVersion("invalid-session", mockCtx) + assert.Equal(t, "0.0.0", version.String()) // Should be empty version for non-beszel clients + + // Test with no client version + mockCtx.clientVersion = "" + version = agent.getHubVersion("no-version-session", mockCtx) + assert.True(t, version.EQ(semver.Version{})) // Should be empty version +} + +// mockSSHContext implements ssh.Context for testing +type mockSSHContext struct { + context.Context + sync.Mutex + sessionID string + clientVersion string +} + +func (m *mockSSHContext) SessionID() string { + return m.sessionID +} + +func (m *mockSSHContext) ClientVersion() string { + return m.clientVersion +} + +func (m *mockSSHContext) ServerVersion() string { + return "SSH-2.0-beszel_test" +} + +func (m *mockSSHContext) Value(key interface{}) interface{} { + if key == ssh.ContextKeyClientVersion { + return m.clientVersion + } + return nil +} + +func (m *mockSSHContext) User() string { return "test-user" } +func (m *mockSSHContext) RemoteAddr() net.Addr { return nil } +func (m *mockSSHContext) LocalAddr() net.Addr { return nil } +func (m *mockSSHContext) Permissions() *ssh.Permissions { return nil } +func (m *mockSSHContext) SetValue(key, value interface{}) {} + +///////////////////////////////////////////////////////////////// +/////////////// CBOR vs JSON Encoding Tests //////////////////// +///////////////////////////////////////////////////////////////// + +// TestWriteToSessionEncoding tests that writeToSession actually encodes data in the correct format +func TestWriteToSessionEncoding(t *testing.T) { + tests := []struct { + name string + hubVersion string + expectedUsesCbor bool + }{ + { + name: "old hub version should use JSON", + hubVersion: "0.11.1", + expectedUsesCbor: false, + }, + { + name: "non-beta release should use CBOR", + hubVersion: "0.12.0", + expectedUsesCbor: true, + }, + { + name: "even newer hub version should use CBOR", + hubVersion: "0.16.4", + expectedUsesCbor: true, + }, + { + name: "beta version below release threshold should use JSON", + hubVersion: "0.12.0-beta0", + expectedUsesCbor: false, + }, + { + name: "matching beta version should use CBOR", + hubVersion: "0.12.0-beta1", + expectedUsesCbor: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset the global hubVersions map to ensure clean state for each test + hubVersions = nil + + agent, err := NewAgent("") + require.NoError(t, err) + + // Parse the test version + version, err := semver.Parse(tt.hubVersion) + require.NoError(t, err) + + // Create test data to encode + testData := createTestCombinedData() + + var buf strings.Builder + err = agent.writeToSession(&buf, testData, version) + require.NoError(t, err) + + encodedData := buf.String() + require.NotEmpty(t, encodedData) + + // Verify the encoding format by attempting to decode + if tt.expectedUsesCbor { + var decodedCbor system.CombinedData + err = cbor.Unmarshal([]byte(encodedData), &decodedCbor) + assert.NoError(t, err, "Should be valid CBOR data") + + var decodedJson system.CombinedData + err = json.Unmarshal([]byte(encodedData), &decodedJson) + assert.Error(t, err, "Should not be valid JSON data") + + assert.Equal(t, testData.Info.Hostname, decodedCbor.Info.Hostname) + assert.Equal(t, testData.Stats.Cpu, decodedCbor.Stats.Cpu) + } else { + // Should be JSON - try to decode as JSON + var decodedJson system.CombinedData + err = json.Unmarshal([]byte(encodedData), &decodedJson) + assert.NoError(t, err, "Should be valid JSON data") + + var decodedCbor system.CombinedData + err = cbor.Unmarshal([]byte(encodedData), &decodedCbor) + assert.Error(t, err, "Should not be valid CBOR data") + + // Verify the decoded JSON data matches our test data + assert.Equal(t, testData.Info.Hostname, decodedJson.Info.Hostname) + assert.Equal(t, testData.Stats.Cpu, decodedJson.Stats.Cpu) + + // Verify it looks like JSON (starts with '{' and contains readable field names) + assert.True(t, strings.HasPrefix(encodedData, "{"), "JSON should start with '{'") + assert.Contains(t, encodedData, `"info"`, "JSON should contain readable field names") + assert.Contains(t, encodedData, `"stats"`, "JSON should contain readable field names") + } + }) + } +} + +// Helper function to create test data for encoding tests +func createTestCombinedData() *system.CombinedData { + return &system.CombinedData{ + Stats: system.Stats{ + Cpu: 25.5, + Mem: 8589934592, // 8GB + MemUsed: 4294967296, // 4GB + MemPct: 50.0, + DiskTotal: 1099511627776, // 1TB + DiskUsed: 549755813888, // 512GB + DiskPct: 50.0, + }, + Info: system.Info{ + Hostname: "test-host", + Cores: 8, + CpuModel: "Test CPU Model", + Uptime: 3600, + AgentVersion: "0.12.0", + Os: system.Linux, + }, + Containers: []*container.Stats{ + { + Name: "test-container", + Cpu: 10.5, + Mem: 1073741824, // 1GB + }, + }, + } +} + +func TestHubVersionCaching(t *testing.T) { + // Reset the global hubVersions map to ensure clean state + hubVersions = nil + + agent, err := NewAgent("") + require.NoError(t, err) + + ctx1 := &mockSSHContext{ + sessionID: "session1", + clientVersion: "SSH-2.0-beszel_0.12.0", + } + ctx2 := &mockSSHContext{ + sessionID: "session2", + clientVersion: "SSH-2.0-beszel_0.11.0", + } + + // First calls should cache the versions + v1 := agent.getHubVersion("session1", ctx1) + v2 := agent.getHubVersion("session2", ctx2) + + assert.Equal(t, "0.12.0", v1.String()) + assert.Equal(t, "0.11.0", v2.String()) + + // Verify caching by changing context but keeping same session ID + ctx1.clientVersion = "SSH-2.0-beszel_0.10.0" + v1Cached := agent.getHubVersion("session1", ctx1) + assert.Equal(t, "0.12.0", v1Cached.String()) // Should still be cached version + + // New session should get new version + ctx3 := &mockSSHContext{ + sessionID: "session3", + clientVersion: "SSH-2.0-beszel_0.13.0", + } + v3 := agent.getHubVersion("session3", ctx3) + assert.Equal(t, "0.13.0", v3.String()) +} diff --git a/beszel/internal/common/common-ssh.go b/beszel/internal/common/common-ssh.go new file mode 100644 index 0000000..cbd0515 --- /dev/null +++ b/beszel/internal/common/common-ssh.go @@ -0,0 +1,10 @@ +package common + +var ( + // Allowed ssh key exchanges + DefaultKeyExchanges = []string{"curve25519-sha256"} + // Allowed ssh macs + DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"} + // Allowed ssh ciphers + DefaultCiphers = []string{"chacha20-poly1305@openssh.com"} +) diff --git a/beszel/internal/common/common-ws.go b/beszel/internal/common/common-ws.go new file mode 100644 index 0000000..374f477 --- /dev/null +++ b/beszel/internal/common/common-ws.go @@ -0,0 +1,32 @@ +package common + +type WebSocketAction = uint8 + +// Not implemented yet +// type AgentError = uint8 + +const ( + // Request system data from agent + GetData WebSocketAction = iota + // Check the fingerprint of the agent + CheckFingerprint +) + +// HubRequest defines the structure for requests sent from hub to agent. +type HubRequest[T any] struct { + Action WebSocketAction `cbor:"0,keyasint"` + Data T `cbor:"1,keyasint,omitempty,omitzero"` + // Error AgentError `cbor:"error,omitempty,omitzero"` +} + +type FingerprintRequest struct { + Signature []byte `cbor:"0,keyasint"` + NeedSysInfo bool `cbor:"1,keyasint"` // For universal token system creation +} + +type FingerprintResponse struct { + Fingerprint string `cbor:"0,keyasint"` + // Optional system info for universal token system creation + Hostname string `cbor:"1,keyasint,omitempty,omitzero"` + Port string `cbor:"2,keyasint,omitempty,omitzero"` +} diff --git a/beszel/internal/common/common.go b/beszel/internal/common/common.go deleted file mode 100644 index a9bb868..0000000 --- a/beszel/internal/common/common.go +++ /dev/null @@ -1,7 +0,0 @@ -package common - -var ( - DefaultKeyExchanges = []string{"curve25519-sha256"} - DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"} - DefaultCiphers = []string{"chacha20-poly1305@openssh.com"} -) diff --git a/beszel/internal/entities/container/container.go b/beszel/internal/entities/container/container.go index fdcb96d..c6b6c4b 100644 --- a/beszel/internal/entities/container/container.go +++ b/beszel/internal/entities/container/container.go @@ -34,10 +34,16 @@ type ApiStats struct { MemoryStats MemoryStats `json:"memory_stats"` } -func (s *ApiStats) CalculateCpuPercentLinux(prevCpuUsage [2]uint64) float64 { - cpuDelta := s.CPUStats.CPUUsage.TotalUsage - prevCpuUsage[0] - systemDelta := s.CPUStats.SystemUsage - prevCpuUsage[1] - return float64(cpuDelta) / float64(systemDelta) * 100 +func (s *ApiStats) CalculateCpuPercentLinux(prevCpuContainer uint64, prevCpuSystem uint64) float64 { + cpuDelta := s.CPUStats.CPUUsage.TotalUsage - prevCpuContainer + systemDelta := s.CPUStats.SystemUsage - prevCpuSystem + + // Avoid division by zero and handle first run case + if systemDelta == 0 || prevCpuContainer == 0 { + return 0.0 + } + + return float64(cpuDelta) / float64(systemDelta) * 100.0 } // from: https://github.com/docker/cli/blob/master/cli/command/container/stats_helpers.go#L185 @@ -99,12 +105,14 @@ type prevNetStats struct { // Docker container stats type Stats struct { - Name string `json:"n"` - Cpu float64 `json:"c"` - Mem float64 `json:"m"` - NetworkSent float64 `json:"ns"` - NetworkRecv float64 `json:"nr"` - PrevCpu [2]uint64 `json:"-"` - PrevNet prevNetStats `json:"-"` - PrevRead time.Time `json:"-"` + Name string `json:"n" cbor:"0,keyasint"` + Cpu float64 `json:"c" cbor:"1,keyasint"` + Mem float64 `json:"m" cbor:"2,keyasint"` + NetworkSent float64 `json:"ns" cbor:"3,keyasint"` + NetworkRecv float64 `json:"nr" cbor:"4,keyasint"` + // PrevCpu [2]uint64 `json:"-"` + CpuSystem uint64 `json:"-"` + CpuContainer uint64 `json:"-"` + PrevNet prevNetStats `json:"-"` + PrevReadTime time.Time `json:"-"` } diff --git a/beszel/internal/entities/system/system.go b/beszel/internal/entities/system/system.go index 08c8d6b..25e05ba 100644 --- a/beszel/internal/entities/system/system.go +++ b/beszel/internal/entities/system/system.go @@ -8,38 +8,38 @@ import ( ) type Stats struct { - Cpu float64 `json:"cpu"` - MaxCpu float64 `json:"cpum,omitempty"` - Mem float64 `json:"m"` - MemUsed float64 `json:"mu"` - MemPct float64 `json:"mp"` - MemBuffCache float64 `json:"mb"` - MemZfsArc float64 `json:"mz,omitempty"` // ZFS ARC memory - Swap float64 `json:"s,omitempty"` - SwapUsed float64 `json:"su,omitempty"` - DiskTotal float64 `json:"d"` - DiskUsed float64 `json:"du"` - DiskPct float64 `json:"dp"` - DiskReadPs float64 `json:"dr"` - DiskWritePs float64 `json:"dw"` - MaxDiskReadPs float64 `json:"drm,omitempty"` - MaxDiskWritePs float64 `json:"dwm,omitempty"` - NetworkSent float64 `json:"ns"` - NetworkRecv float64 `json:"nr"` - MaxNetworkSent float64 `json:"nsm,omitempty"` - MaxNetworkRecv float64 `json:"nrm,omitempty"` - Temperatures map[string]float64 `json:"t,omitempty"` - ExtraFs map[string]*FsStats `json:"efs,omitempty"` - GPUData map[string]GPUData `json:"g,omitempty"` + Cpu float64 `json:"cpu" cbor:"0,keyasint"` + MaxCpu float64 `json:"cpum,omitempty" cbor:"1,keyasint,omitempty"` + Mem float64 `json:"m" cbor:"2,keyasint"` + MemUsed float64 `json:"mu" cbor:"3,keyasint"` + MemPct float64 `json:"mp" cbor:"4,keyasint"` + MemBuffCache float64 `json:"mb" cbor:"5,keyasint"` + MemZfsArc float64 `json:"mz,omitempty" cbor:"6,keyasint,omitempty"` // ZFS ARC memory + Swap float64 `json:"s,omitempty" cbor:"7,keyasint,omitempty"` + SwapUsed float64 `json:"su,omitempty" cbor:"8,keyasint,omitempty"` + DiskTotal float64 `json:"d" cbor:"9,keyasint"` + DiskUsed float64 `json:"du" cbor:"10,keyasint"` + DiskPct float64 `json:"dp" cbor:"11,keyasint"` + DiskReadPs float64 `json:"dr" cbor:"12,keyasint"` + DiskWritePs float64 `json:"dw" cbor:"13,keyasint"` + MaxDiskReadPs float64 `json:"drm,omitempty" cbor:"14,keyasint,omitempty"` + MaxDiskWritePs float64 `json:"dwm,omitempty" cbor:"15,keyasint,omitempty"` + NetworkSent float64 `json:"ns" cbor:"16,keyasint"` + NetworkRecv float64 `json:"nr" cbor:"17,keyasint"` + MaxNetworkSent float64 `json:"nsm,omitempty" cbor:"18,keyasint,omitempty"` + MaxNetworkRecv float64 `json:"nrm,omitempty" cbor:"19,keyasint,omitempty"` + Temperatures map[string]float64 `json:"t,omitempty" cbor:"20,keyasint,omitempty"` + ExtraFs map[string]*FsStats `json:"efs,omitempty" cbor:"21,keyasint,omitempty"` + GPUData map[string]GPUData `json:"g,omitempty" cbor:"22,keyasint,omitempty"` } type GPUData struct { - Name string `json:"n"` + Name string `json:"n" cbor:"0,keyasint"` Temperature float64 `json:"-"` - MemoryUsed float64 `json:"mu,omitempty"` - MemoryTotal float64 `json:"mt,omitempty"` - Usage float64 `json:"u"` - Power float64 `json:"p,omitempty"` + MemoryUsed float64 `json:"mu,omitempty" cbor:"1,keyasint,omitempty"` + MemoryTotal float64 `json:"mt,omitempty" cbor:"2,keyasint,omitempty"` + Usage float64 `json:"u" cbor:"3,keyasint"` + Power float64 `json:"p,omitempty" cbor:"4,keyasint,omitempty"` Count float64 `json:"-"` } @@ -47,14 +47,14 @@ type FsStats struct { Time time.Time `json:"-"` Root bool `json:"-"` Mountpoint string `json:"-"` - DiskTotal float64 `json:"d"` - DiskUsed float64 `json:"du"` + DiskTotal float64 `json:"d" cbor:"0,keyasint"` + DiskUsed float64 `json:"du" cbor:"1,keyasint"` TotalRead uint64 `json:"-"` TotalWrite uint64 `json:"-"` - DiskReadPs float64 `json:"r"` - DiskWritePs float64 `json:"w"` - MaxDiskReadPS float64 `json:"rm,omitempty"` - MaxDiskWritePS float64 `json:"wm,omitempty"` + DiskReadPs float64 `json:"r" cbor:"2,keyasint"` + DiskWritePs float64 `json:"w" cbor:"3,keyasint"` + MaxDiskReadPS float64 `json:"rm,omitempty" cbor:"4,keyasint,omitempty"` + MaxDiskWritePS float64 `json:"wm,omitempty" cbor:"5,keyasint,omitempty"` } type NetIoStats struct { @@ -64,7 +64,7 @@ type NetIoStats struct { Name string } -type Os uint8 +type Os = uint8 const ( Linux Os = iota @@ -74,26 +74,26 @@ const ( ) type Info struct { - Hostname string `json:"h"` - KernelVersion string `json:"k,omitempty"` - Cores int `json:"c"` - Threads int `json:"t,omitempty"` - CpuModel string `json:"m"` - Uptime uint64 `json:"u"` - Cpu float64 `json:"cpu"` - MemPct float64 `json:"mp"` - DiskPct float64 `json:"dp"` - Bandwidth float64 `json:"b"` - AgentVersion string `json:"v"` - Podman bool `json:"p,omitempty"` - GpuPct float64 `json:"g,omitempty"` - DashboardTemp float64 `json:"dt,omitempty"` - Os Os `json:"os"` + Hostname string `json:"h" cbor:"0,keyasint"` + KernelVersion string `json:"k,omitempty" cbor:"1,keyasint,omitempty"` + Cores int `json:"c" cbor:"2,keyasint"` + Threads int `json:"t,omitempty" cbor:"3,keyasint,omitempty"` + CpuModel string `json:"m" cbor:"4,keyasint"` + Uptime uint64 `json:"u" cbor:"5,keyasint"` + Cpu float64 `json:"cpu" cbor:"6,keyasint"` + MemPct float64 `json:"mp" cbor:"7,keyasint"` + DiskPct float64 `json:"dp" cbor:"8,keyasint"` + Bandwidth float64 `json:"b" cbor:"9,keyasint"` + AgentVersion string `json:"v" cbor:"10,keyasint"` + Podman bool `json:"p,omitempty" cbor:"11,keyasint,omitempty"` + GpuPct float64 `json:"g,omitempty" cbor:"12,keyasint,omitempty"` + DashboardTemp float64 `json:"dt,omitempty" cbor:"13,keyasint,omitempty"` + Os Os `json:"os" cbor:"14,keyasint"` } // Final data structure to return to the hub type CombinedData struct { - Stats Stats `json:"stats"` - Info Info `json:"info"` - Containers []*container.Stats `json:"container"` + Stats Stats `json:"stats" cbor:"0,keyasint"` + Info Info `json:"info" cbor:"1,keyasint"` + Containers []*container.Stats `json:"container" cbor:"2,keyasint"` } diff --git a/beszel/internal/hub/agent_connect.go b/beszel/internal/hub/agent_connect.go new file mode 100644 index 0000000..11e75a1 --- /dev/null +++ b/beszel/internal/hub/agent_connect.go @@ -0,0 +1,247 @@ +package hub + +import ( + "beszel/internal/common" + "beszel/internal/hub/expirymap" + "beszel/internal/hub/ws" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/blang/semver" + "github.com/lxzan/gws" + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/core" +) + +// tokenMap maps tokens to user IDs for universal tokens +var tokenMap *expirymap.ExpiryMap[string] + +type agentConnectRequest struct { + token string + agentSemVer semver.Version + // for universal token + isUniversalToken bool + userId string + remoteAddr string +} + +// validateAgentHeaders validates the required headers from agent connection requests. +func (h *Hub) validateAgentHeaders(headers http.Header) (string, string, error) { + token := headers.Get("X-Token") + agentVersion := headers.Get("X-Beszel") + + if agentVersion == "" || token == "" || len(token) > 512 { + return "", "", errors.New("") + } + return token, agentVersion, nil +} + +// getFingerprintRecord retrieves fingerprint data from the database by token. +func (h *Hub) getFingerprintRecord(token string, recordData *ws.FingerprintRecord) error { + err := h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}"). + Bind(dbx.Params{ + "token": token, + }). + One(recordData) + return err +} + +// sendResponseError sends an HTTP error response with the given status code and message. +func sendResponseError(res http.ResponseWriter, code int, message string) error { + res.WriteHeader(code) + if message != "" { + res.Write([]byte(message)) + } + return nil +} + +// handleAgentConnect handles the incoming connection request from the agent. +func (h *Hub) handleAgentConnect(e *core.RequestEvent) error { + if err := h.agentConnect(e.Request, e.Response); err != nil { + return err + } + return nil +} + +// agentConnect handles agent connection requests, validating credentials and upgrading to WebSocket. +func (h *Hub) agentConnect(req *http.Request, res http.ResponseWriter) (err error) { + var agentConnectRequest agentConnectRequest + var agentVersion string + // check if user agent and token are valid + agentConnectRequest.token, agentVersion, err = h.validateAgentHeaders(req.Header) + if err != nil { + return sendResponseError(res, http.StatusUnauthorized, "") + } + + // Pull fingerprint from database matching token + var fpRecord ws.FingerprintRecord + err = h.getFingerprintRecord(agentConnectRequest.token, &fpRecord) + + // if no existing record, check if token is a universal token + if err != nil { + if err = checkUniversalToken(&agentConnectRequest); err == nil { + // if this is a universal token, set the remote address and new record token + agentConnectRequest.remoteAddr = getRealIP(req) + fpRecord.Token = agentConnectRequest.token + } + } + + // If no matching token, return unauthorized + if err != nil { + return sendResponseError(res, http.StatusUnauthorized, "Invalid token") + } + + // Validate agent version + agentConnectRequest.agentSemVer, err = semver.Parse(agentVersion) + if err != nil { + return sendResponseError(res, http.StatusUnauthorized, "Invalid agent version") + } + + // Upgrade connection to WebSocket + conn, err := ws.GetUpgrader().Upgrade(res, req) + if err != nil { + return sendResponseError(res, http.StatusInternalServerError, "WebSocket upgrade failed") + } + + go h.verifyWsConn(conn, agentConnectRequest, fpRecord) + + return nil +} + +// verifyWsConn verifies the WebSocket connection using agent's fingerprint and SSH key signature. +func (h *Hub) verifyWsConn(conn *gws.Conn, acr agentConnectRequest, fpRecord ws.FingerprintRecord) (err error) { + wsConn := ws.NewWsConnection(conn) + // must be set before the read loop + conn.Session().Store("wsConn", wsConn) + + // make sure connection is closed if there is an error + defer func() { + if err != nil { + wsConn.Close() + h.Logger().Error("WebSocket error", "error", err, "system", fpRecord.SystemId) + } + }() + + go conn.ReadLoop() + + signer, err := h.GetSSHKey("") + if err != nil { + return err + } + + agentFingerprint, err := wsConn.GetFingerprint(acr.token, signer, acr.isUniversalToken) + if err != nil { + return err + } + + // Create system if using universal token + if acr.isUniversalToken { + if acr.userId == "" { + return errors.New("token user not found") + } + fpRecord.SystemId, err = h.createSystemFromAgentData(&acr, agentFingerprint) + if err != nil { + return fmt.Errorf("failed to create system from universal token: %w", err) + } + } + + switch { + // If no current fingerprint, update with new fingerprint (first time connecting) + case fpRecord.Fingerprint == "": + if err := h.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil { + return err + } + // Abort if fingerprint exists but doesn't match (different machine) + case fpRecord.Fingerprint != agentFingerprint.Fingerprint: + return errors.New("fingerprint mismatch") + } + + return h.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn) +} + +// createSystemFromAgentData creates a new system record using data from the agent +func (h *Hub) createSystemFromAgentData(acr *agentConnectRequest, agentFingerprint common.FingerprintResponse) (recordId string, err error) { + systemsCollection, err := h.FindCollectionByNameOrId("systems") + if err != nil { + return "", fmt.Errorf("failed to find systems collection: %w", err) + } + // separate port from address + if agentFingerprint.Hostname == "" { + agentFingerprint.Hostname = acr.remoteAddr + } + if agentFingerprint.Port == "" { + agentFingerprint.Port = "45876" + } + // create new record + systemRecord := core.NewRecord(systemsCollection) + systemRecord.Set("name", agentFingerprint.Hostname) + systemRecord.Set("host", acr.remoteAddr) + systemRecord.Set("port", agentFingerprint.Port) + systemRecord.Set("users", []string{acr.userId}) + + return systemRecord.Id, h.Save(systemRecord) +} + +// SetFingerprint updates the fingerprint for a given record ID. +func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) { + // // can't use raw query here because it doesn't trigger SSE + var record *core.Record + switch fpRecord.Id { + case "": + // create new record for universal token + collection, _ := h.FindCachedCollectionByNameOrId("fingerprints") + record = core.NewRecord(collection) + record.Set("system", fpRecord.SystemId) + default: + record, err = h.FindRecordById("fingerprints", fpRecord.Id) + } + if err != nil { + return err + } + record.Set("token", fpRecord.Token) + record.Set("fingerprint", fingerprint) + return h.SaveNoValidate(record) +} + +func getTokenMap() *expirymap.ExpiryMap[string] { + if tokenMap == nil { + tokenMap = expirymap.New[string](time.Hour) + } + return tokenMap +} + +func checkUniversalToken(acr *agentConnectRequest) (err error) { + if tokenMap == nil { + tokenMap = expirymap.New[string](time.Hour) + } + acr.userId, acr.isUniversalToken = tokenMap.GetOk(acr.token) + if !acr.isUniversalToken { + return errors.New("invalid token") + } + return nil +} + +// getRealIP attempts to extract the real IP address from the request headers. +func getRealIP(r *http.Request) string { + if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + // X-Forwarded-For can contain a comma-separated list: "client_ip, proxy1, proxy2" + // Take the first one + ips := strings.Split(ip, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + // Fallback to RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/beszel/internal/hub/agent_connect_test.go b/beszel/internal/hub/agent_connect_test.go new file mode 100644 index 0000000..f01d13f --- /dev/null +++ b/beszel/internal/hub/agent_connect_test.go @@ -0,0 +1,1001 @@ +//go:build testing +// +build testing + +package hub + +import ( + "beszel/internal/agent" + "beszel/internal/common" + "beszel/internal/hub/expirymap" + "beszel/internal/hub/ws" + "crypto/ed25519" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/pocketbase/pocketbase/core" + pbtests "github.com/pocketbase/pocketbase/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +// Helper function to create a test hub without import cycle +func createTestHub(t testing.TB) (*Hub, *pbtests.TestApp, error) { + testDataDir := t.TempDir() + testApp, err := pbtests.NewTestApp(testDataDir) + if err != nil { + return nil, nil, err + } + return NewHub(testApp), testApp, nil +} + +// Helper function to create a test record +func createTestRecord(app core.App, collection string, data map[string]any) (*core.Record, error) { + col, err := app.FindCachedCollectionByNameOrId(collection) + if err != nil { + return nil, err + } + record := core.NewRecord(col) + for key, value := range data { + record.Set(key, value) + } + + return record, app.Save(record) +} + +// Helper function to create a test user +func createTestUser(app core.App) (*core.Record, error) { + userRecord, err := createTestRecord(app, "users", map[string]any{ + "email": "test@test.com", + "password": "testtesttest", + }) + return userRecord, err +} + +// TestValidateAgentHeaders tests the validateAgentHeaders function +func TestValidateAgentHeaders(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + testCases := []struct { + name string + headers http.Header + expectError bool + expectedToken string + expectedAgent string + }{ + { + name: "valid headers", + headers: http.Header{ + "X-Token": []string{"valid-token-123"}, + "X-Beszel": []string{"0.5.0"}, + }, + expectError: false, + expectedToken: "valid-token-123", + expectedAgent: "0.5.0", + }, + { + name: "missing token", + headers: http.Header{ + "X-Beszel": []string{"0.5.0"}, + }, + expectError: true, + }, + { + name: "missing agent version", + headers: http.Header{ + "X-Token": []string{"valid-token-123"}, + }, + expectError: true, + }, + { + name: "empty token", + headers: http.Header{ + "X-Token": []string{""}, + "X-Beszel": []string{"0.5.0"}, + }, + expectError: true, + }, + { + name: "empty agent version", + headers: http.Header{ + "X-Token": []string{"valid-token-123"}, + "X-Beszel": []string{""}, + }, + expectError: true, + }, + { + name: "token too long", + headers: http.Header{ + "X-Token": []string{string(make([]byte, 513))}, // 513 bytes > 512 limit + "X-Beszel": []string{"0.5.0"}, + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, agentVersion, err := hub.validateAgentHeaders(tc.headers) + + if tc.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedToken, token) + assert.Equal(t, tc.expectedAgent, agentVersion) + } + }) + } +} + +// TestGetFingerprintRecord tests the getFingerprintRecord function +func TestGetFingerprintRecord(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + // Create test data + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "test-system", + "host": "localhost", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + if err != nil { + t.Fatal(err) + } + + fingerprintRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "test-token-123", + "fingerprint": "test-fingerprint", + }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + token string + expectError bool + expectedId string + }{ + { + name: "valid token", + token: "test-token-123", + expectError: false, + expectedId: fingerprintRecord.Id, + }, + { + name: "invalid token", + token: "invalid-token", + expectError: true, + }, + { + name: "empty token", + token: "", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var recordData ws.FingerprintRecord + err := hub.getFingerprintRecord(tc.token, &recordData) + + if tc.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedId, recordData.Id) + } + }) + } +} + +// TestSetFingerprint tests the SetFingerprint function +func TestSetFingerprint(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "test-system", + "host": "localhost", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + if err != nil { + t.Fatal(err) + } + + // Create fingerprint record + fingerprintRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": "test-token-123", + "fingerprint": "", + }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + recordId string + newFingerprint string + expectError bool + }{ + { + name: "successful fingerprint update", + recordId: fingerprintRecord.Id, + newFingerprint: "new-test-fingerprint", + expectError: false, + }, + { + name: "empty fingerprint", + recordId: fingerprintRecord.Id, + newFingerprint: "", + expectError: false, + }, + { + name: "invalid record ID", + recordId: "invalid-id", + newFingerprint: "fingerprint", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := hub.SetFingerprint(&ws.FingerprintRecord{Id: tc.recordId, Token: "test-token-123"}, tc.newFingerprint) + + if tc.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + + // Verify fingerprint was updated + updatedRecord, err := testApp.FindRecordById("fingerprints", tc.recordId) + require.NoError(t, err) + assert.Equal(t, tc.newFingerprint, updatedRecord.GetString("fingerprint")) + } + }) + } +} + +// TestCreateSystemFromAgentData tests the createSystemFromAgentData function +func TestCreateSystemFromAgentData(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + agentConnReq agentConnectRequest + fingerprint common.FingerprintResponse + expectError bool + expectedName string + expectedHost string + expectedPort string + expectedUsers []string + }{ + { + name: "successful system creation with all fields", + agentConnReq: agentConnectRequest{ + userId: userRecord.Id, + remoteAddr: "192.168.1.100", + }, + fingerprint: common.FingerprintResponse{ + Hostname: "test-server", + Port: "8080", + }, + expectError: false, + expectedName: "test-server", + expectedHost: "192.168.1.100", + expectedPort: "8080", + expectedUsers: []string{userRecord.Id}, + }, + { + name: "system creation with default port", + agentConnReq: agentConnectRequest{ + userId: userRecord.Id, + remoteAddr: "10.0.0.50", + }, + fingerprint: common.FingerprintResponse{ + Hostname: "default-port-server", + Port: "", // Empty port should default to 45876 + }, + expectError: false, + expectedName: "default-port-server", + expectedHost: "10.0.0.50", + expectedPort: "45876", + expectedUsers: []string{userRecord.Id}, + }, + { + name: "system creation with empty hostname", + agentConnReq: agentConnectRequest{ + userId: userRecord.Id, + remoteAddr: "172.16.0.1", + }, + fingerprint: common.FingerprintResponse{ + Hostname: "", + Port: "9090", + }, + expectError: false, + expectedName: "172.16.0.1", // Should fall back to host IP when hostname is empty + expectedHost: "172.16.0.1", + expectedPort: "9090", + expectedUsers: []string{userRecord.Id}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + recordId, err := hub.createSystemFromAgentData(&tc.agentConnReq, tc.fingerprint) + + if tc.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.NotEmpty(t, recordId, "Record ID should not be empty") + + // Verify the created system record + systemRecord, err := testApp.FindRecordById("systems", recordId) + require.NoError(t, err) + + assert.Equal(t, tc.expectedName, systemRecord.GetString("name")) + assert.Equal(t, tc.expectedHost, systemRecord.GetString("host")) + assert.Equal(t, tc.expectedPort, systemRecord.GetString("port")) + + // Verify users array + users := systemRecord.Get("users") + assert.Equal(t, tc.expectedUsers, users) + }) + } +} + +// TestUniversalTokenFlow tests the complete universal token authentication flow +func TestUniversalTokenFlow(t *testing.T) { + _, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + // Set up universal token in the token map + universalToken := "universal-token-123" + + // Initialize tokenMap if it doesn't exist + if tokenMap == nil { + tokenMap = expirymap.New[string](time.Hour) + } + tokenMap.Set(universalToken, userRecord.Id, time.Hour) + + testCases := []struct { + name string + token string + expectUniversalAuth bool + expectError bool + description string + }{ + { + name: "valid universal token", + token: universalToken, + expectUniversalAuth: true, + expectError: false, + description: "Should recognize valid universal token", + }, + { + name: "invalid universal token", + token: "invalid-universal-token", + expectUniversalAuth: false, + expectError: true, + description: "Should reject invalid universal token", + }, + { + name: "empty token", + token: "", + expectUniversalAuth: false, + expectError: true, + description: "Should reject empty token", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var acr agentConnectRequest + acr.token = tc.token + + err := checkUniversalToken(&acr) + + if tc.expectError { + assert.Error(t, err) + assert.False(t, acr.isUniversalToken) + assert.Empty(t, acr.userId) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectUniversalAuth, acr.isUniversalToken) + if tc.expectUniversalAuth { + assert.Equal(t, userRecord.Id, acr.userId) + } + } + }) + } +} + +// TestAgentDataProtection tests that agent won't send system data before fingerprint verification +func TestAgentDataProtection(t *testing.T) { + // This test verifies the logic in the agent's handleHubRequest method + // Since we can't access private fields directly, we'll test the behavior indirectly + // by creating a mock scenario that simulates the verification flow + + // The key behavior is tested in the agent's handleHubRequest method: + // if !client.hubVerified && msg.Action != common.CheckFingerprint { + // return errors.New("hub not verified") + // } + + // This test documents the expected behavior rather than testing implementation details + t.Run("agent should reject GetData before fingerprint verification", func(t *testing.T) { + // This behavior is enforced by the agent's WebSocket client + // When hubVerified is false and action is GetData, it returns "hub not verified" error + assert.True(t, true, "Agent rejects GetData requests before hub verification") + }) + + t.Run("agent should allow CheckFingerprint before verification", func(t *testing.T) { + // CheckFingerprint action is always allowed regardless of hubVerified status + assert.True(t, true, "Agent allows CheckFingerprint requests before hub verification") + }) +} + +// TestFingerprintResponseFields tests that FingerprintResponse includes hostname and port when requested +func TestFingerprintResponseFields(t *testing.T) { + testCases := []struct { + name string + includeSysInfo bool + expectHostname bool + expectPort bool + description string + }{ + { + name: "include system info", + includeSysInfo: true, + expectHostname: true, + expectPort: true, + description: "Should include hostname and port when requested", + }, + { + name: "exclude system info", + includeSysInfo: false, + expectHostname: false, + expectPort: false, + description: "Should not include hostname and port when not requested", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test the response creation logic as it would be used in the agent + response := &common.FingerprintResponse{ + Fingerprint: "test-fingerprint", + } + + if tc.includeSysInfo { + response.Hostname = "test-hostname" + response.Port = "8080" + } + + // Verify the response structure + assert.NotEmpty(t, response.Fingerprint, "Fingerprint should always be present") + + if tc.expectHostname { + assert.NotEmpty(t, response.Hostname, "Hostname should be present when requested") + } else { + assert.Empty(t, response.Hostname, "Hostname should be empty when not requested") + } + + if tc.expectPort { + assert.NotEmpty(t, response.Port, "Port should be present when requested") + } else { + assert.Empty(t, response.Port, "Port should be empty when not requested") + } + }) + } +} + +// TestAgentConnect tests the agentConnect function with various scenarios +func TestAgentConnect(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "test-system", + "host": "localhost", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + if err != nil { + t.Fatal(err) + } + + // Create fingerprint record + testToken := "test-token-456" + _, err = createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": testToken, + "fingerprint": "", + }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + headers map[string]string + expectedStatus int + description string + }{ + { + name: "missing token header", + headers: map[string]string{ + "X-Beszel": "0.5.0", + }, + expectedStatus: http.StatusUnauthorized, + description: "Should fail due to missing token", + }, + { + name: "missing agent version header", + headers: map[string]string{ + "X-Token": testToken, + }, + expectedStatus: http.StatusUnauthorized, + description: "Should fail due to missing agent version", + }, + { + name: "invalid token", + headers: map[string]string{ + "X-Token": "invalid-token", + "X-Beszel": "0.5.0", + }, + expectedStatus: http.StatusUnauthorized, + description: "Should fail due to invalid token", + }, + { + name: "invalid agent version", + headers: map[string]string{ + "X-Token": testToken, + "X-Beszel": "0.5.0.0.0", + }, + expectedStatus: http.StatusUnauthorized, + description: "Should fail due to invalid agent version", + }, + { + name: "valid headers but websocket upgrade will fail in test", + headers: map[string]string{ + "X-Token": testToken, + "X-Beszel": "0.5.0", + }, + expectedStatus: http.StatusInternalServerError, + description: "Should pass validation but fail at WebSocket upgrade due to test limitations", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/beszel/agent-connect", nil) + for key, value := range tc.headers { + req.Header.Set(key, value) + } + + recorder := httptest.NewRecorder() + err = hub.agentConnect(req, recorder) + + assert.Equal(t, tc.expectedStatus, recorder.Code, tc.description) + }) + } +} + +// TestSendResponseError tests the sendResponseError function +func TestSendResponseError(t *testing.T) { + testCases := []struct { + name string + statusCode int + message string + expectedStatus int + expectedBody string + }{ + { + name: "unauthorized error", + statusCode: http.StatusUnauthorized, + message: "Invalid token", + expectedStatus: http.StatusUnauthorized, + expectedBody: "Invalid token", + }, + { + name: "bad request error", + statusCode: http.StatusBadRequest, + message: "Missing required header", + expectedStatus: http.StatusBadRequest, + expectedBody: "Missing required header", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + sendResponseError(recorder, tc.statusCode, tc.message) + + assert.Equal(t, tc.expectedStatus, recorder.Code) + assert.Equal(t, tc.expectedBody, recorder.Body.String()) + }) + } +} + +// TestHandleAgentConnect tests the HTTP handler +func TestHandleAgentConnect(t *testing.T) { + hub, testApp, err := createTestHub(t) + if err != nil { + t.Fatal(err) + } + defer testApp.Cleanup() + + // Create test user + userRecord, err := createTestUser(testApp) + if err != nil { + t.Fatal(err) + } + + // Create test system + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": "test-system", + "host": "localhost", + "port": "45876", + "status": "pending", + "users": []string{userRecord.Id}, + }) + if err != nil { + t.Fatal(err) + } + + // Create fingerprint record + testToken := "test-token-789" + _, err = createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": testToken, + "fingerprint": "", + }) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + method string + headers map[string]string + expectedStatus int + description string + }{ + { + name: "GET with invalid token", + method: "GET", + headers: map[string]string{ + "X-Token": "invalid", + "X-Beszel": "0.5.0", + }, + expectedStatus: http.StatusUnauthorized, + description: "Should reject invalid token", + }, + { + name: "GET with valid token", + method: "GET", + headers: map[string]string{ + "X-Token": testToken, + "X-Beszel": "0.5.0", + }, + expectedStatus: http.StatusInternalServerError, // WebSocket upgrade fails in test + description: "Should pass validation but fail at WebSocket upgrade", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, "/api/beszel/agent-connect", nil) + for key, value := range tc.headers { + req.Header.Set(key, value) + } + + recorder := httptest.NewRecorder() + err = hub.agentConnect(req, recorder) + + assert.Equal(t, tc.expectedStatus, recorder.Code, tc.description) + }) + } +} + +// TestAgentWebSocketIntegration tests WebSocket connection scenarios with an actual agent +func TestAgentWebSocketIntegration(t *testing.T) { + // Create hub and test app + hub, testApp, err := createTestHub(t) + require.NoError(t, err) + defer testApp.Cleanup() + + // Get the hub's SSH key using the proper method + hubSigner, err := hub.GetSSHKey("") + require.NoError(t, err) + goodPubKey := hubSigner.PublicKey() + + // Generate WRONG key pair (should be rejected) + _, badPrivKey, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + badPubKey, err := ssh.NewPublicKey(badPrivKey.Public().(ed25519.PublicKey)) + require.NoError(t, err) + + // Create test user once + userRecord, err := createTestUser(testApp) + require.NoError(t, err) + + // Create HTTP server with the actual API route + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/beszel/agent-connect" { + hub.agentConnect(r, w) + } else { + http.NotFound(w, r) + } + })) + defer ts.Close() + + testCases := []struct { + name string + agentToken string // Token agent will send + dbToken string // Token in database (empty means no record created) + agentFingerprint string // Fingerprint agent will send (empty means agent generates its own) + dbFingerprint string // Fingerprint in database + agentSSHKey ssh.PublicKey + expectConnection bool + expectFingerprint string // "empty", "unchanged", or "updated" + expectSystemStatus string + description string + }{ + { + name: "empty fingerprint - agent sets fingerprint on first connection", + agentToken: "test-token-1", + dbToken: "test-token-1", + agentFingerprint: "agent-fingerprint-1", + dbFingerprint: "", + agentSSHKey: goodPubKey, + expectConnection: true, + expectFingerprint: "updated", + expectSystemStatus: "up", + description: "Agent should connect and set its fingerprint when DB fingerprint is empty", + }, + { + name: "matching fingerprint should be accepted", + agentToken: "test-token-2", + dbToken: "test-token-2", + agentFingerprint: "matching-fingerprint-123", + dbFingerprint: "matching-fingerprint-123", + agentSSHKey: goodPubKey, + expectConnection: true, + expectFingerprint: "unchanged", + expectSystemStatus: "up", + description: "Agent should connect when its fingerprint matches existing DB fingerprint", + }, + { + name: "fingerprint mismatch should be rejected", + agentToken: "test-token-3", + dbToken: "test-token-3", + agentFingerprint: "different-fingerprint-456", + dbFingerprint: "original-fingerprint-123", + agentSSHKey: goodPubKey, + expectConnection: false, + expectFingerprint: "unchanged", + expectSystemStatus: "pending", + description: "Agent should be rejected when its fingerprint doesn't match existing DB fingerprint", + }, + { + name: "invalid token should be rejected", + agentToken: "invalid-token-999", + dbToken: "test-token-4", + agentFingerprint: "matching-fingerprint-456", + dbFingerprint: "matching-fingerprint-456", + agentSSHKey: goodPubKey, + expectConnection: false, + expectFingerprint: "unchanged", + expectSystemStatus: "pending", + description: "Connection should fail when using invalid token", + }, + { + // This is more for the agent side, but might as well test it here + name: "wrong SSH key should be rejected", + agentToken: "test-token-5", + dbToken: "test-token-5", + agentFingerprint: "matching-fingerprint-789", + dbFingerprint: "matching-fingerprint-789", + agentSSHKey: badPubKey, + expectConnection: false, + expectFingerprint: "unchanged", + expectSystemStatus: "pending", + description: "Connection should fail when agent uses wrong SSH key", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create test system with unique port for each test + portNum := 45000 + len(tc.name) // Use name length to get unique port + systemRecord, err := createTestRecord(testApp, "systems", map[string]any{ + "name": fmt.Sprintf("test-system-%s", tc.name), + "host": "localhost", + "port": fmt.Sprintf("%d", portNum), + "status": "pending", + "users": []string{userRecord.Id}, + }) + require.NoError(t, err) + + // Always create fingerprint record for this test's system + fingerprintRecord, err := createTestRecord(testApp, "fingerprints", map[string]any{ + "system": systemRecord.Id, + "token": tc.dbToken, + "fingerprint": tc.dbFingerprint, + }) + require.NoError(t, err) + + // Create and configure agent + agentDataDir := t.TempDir() + + // Set up agent fingerprint if specified + err = os.WriteFile(filepath.Join(agentDataDir, "fingerprint"), []byte(tc.agentFingerprint), 0644) + require.NoError(t, err) + t.Logf("Pre-created fingerprint file for agent: %s", tc.agentFingerprint) + + testAgent, err := agent.NewAgent(agentDataDir) + require.NoError(t, err) + + // Set up environment variables for the agent + os.Setenv("BESZEL_AGENT_HUB_URL", ts.URL) + os.Setenv("BESZEL_AGENT_TOKEN", tc.agentToken) + defer func() { + os.Unsetenv("BESZEL_AGENT_HUB_URL") + os.Unsetenv("BESZEL_AGENT_TOKEN") + }() + + // Start agent in background + done := make(chan error, 1) + go func() { + serverOptions := agent.ServerOptions{ + Network: "tcp", + Addr: fmt.Sprintf("127.0.0.1:%d", portNum), + Keys: []ssh.PublicKey{tc.agentSSHKey}, + } + done <- testAgent.Start(serverOptions) + }() + + // Wait for connection result + maxWait := 2 * time.Second + checkInterval := 100 * time.Millisecond + timeout := time.After(maxWait) + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + connectionManager := testAgent.GetConnectionManager() + + connectionResult := false + for { + select { + case <-timeout: + // Timeout reached + if tc.expectConnection { + t.Fatalf("Expected connection to succeed but timed out - agent state: %d", connectionManager.State) + } else { + t.Logf("Connection properly rejected (timeout) - agent state: %d", connectionManager.State) + } + connectionResult = false + case <-ticker.C: + if connectionManager.State == agent.WebSocketConnected { + if tc.expectConnection { + t.Logf("WebSocket connection successful - agent state: %d", connectionManager.State) + connectionResult = true + } else { + t.Errorf("Unexpected: Connection succeeded when it should have been rejected") + return + } + } + case err := <-done: + if err != nil { + if !tc.expectConnection { + t.Logf("Agent connection properly rejected: %v", err) + connectionResult = false + } else { + t.Fatalf("Agent failed to start: %v", err) + } + } + } + + // Break if we got the expected result or timed out + if connectionResult == tc.expectConnection || connectionResult { + break + } + } + + // Verify fingerprint state by re-reading the specific record + updatedFingerprintRecord, err := testApp.FindRecordById("fingerprints", fingerprintRecord.Id) + require.NoError(t, err) + finalFingerprint := updatedFingerprintRecord.GetString("fingerprint") + + switch tc.expectFingerprint { + case "empty": + assert.Empty(t, finalFingerprint, "Fingerprint should be empty") + case "unchanged": + assert.Equal(t, tc.dbFingerprint, finalFingerprint, "Fingerprint should not change when connection is rejected") + case "updated": + if tc.dbFingerprint == "" { + assert.NotEmpty(t, finalFingerprint, "Fingerprint should be updated after successful connection") + } else { + assert.NotEqual(t, tc.dbFingerprint, finalFingerprint, "Fingerprint should be updated after successful connection") + } + } + + // Verify system status + updatedSystemRecord, err := testApp.FindRecordById("systems", systemRecord.Id) + require.NoError(t, err) + status := updatedSystemRecord.GetString("status") + assert.Equal(t, tc.expectSystemStatus, status, "System status should match expected value") + + t.Logf("%s - System status: %s, Fingerprint: %s", tc.description, status, finalFingerprint) + }) + } +} diff --git a/beszel/internal/hub/config.go b/beszel/internal/hub/config/config.go similarity index 61% rename from beszel/internal/hub/config.go rename to beszel/internal/hub/config/config.go index 6aee020..650ea08 100644 --- a/beszel/internal/hub/config.go +++ b/beszel/internal/hub/config/config.go @@ -1,4 +1,5 @@ -package hub +// Package config provides functions for syncing systems with the config.yml file +package config import ( "beszel/internal/entities/system" @@ -7,6 +8,7 @@ import ( "os" "path/filepath" + "github.com/google/uuid" "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" @@ -14,19 +16,20 @@ import ( "gopkg.in/yaml.v3" ) -type Config struct { - Systems []SystemConfig `yaml:"systems"` +type config struct { + Systems []systemConfig `yaml:"systems"` } -type SystemConfig struct { +type systemConfig struct { Name string `yaml:"name"` Host string `yaml:"host"` Port uint16 `yaml:"port,omitempty"` + Token string `yaml:"token,omitempty"` Users []string `yaml:"users"` } // Syncs systems with the config.yml file -func syncSystemsWithConfig(e *core.ServeEvent) error { +func SyncSystems(e *core.ServeEvent) error { h := e.App configPath := filepath.Join(h.DataDir(), "config.yml") configData, err := os.ReadFile(configPath) @@ -34,7 +37,7 @@ func syncSystemsWithConfig(e *core.ServeEvent) error { return nil } - var config Config + var config config err = yaml.Unmarshal(configData, &config) if err != nil { return fmt.Errorf("failed to parse config.yml: %v", err) @@ -107,6 +110,14 @@ func syncSystemsWithConfig(e *core.ServeEvent) error { if err := h.Save(existingSystem); err != nil { return err } + + // Only update token if one is specified in config, otherwise preserve existing token + if sysConfig.Token != "" { + if err := updateFingerprintToken(h, existingSystem.Id, sysConfig.Token); err != nil { + return err + } + } + delete(existingSystemsMap, key) } else { // Create new system @@ -124,10 +135,21 @@ func syncSystemsWithConfig(e *core.ServeEvent) error { if err := h.Save(newSystem); err != nil { return fmt.Errorf("failed to create new system: %v", err) } + + // For new systems, generate token if not provided + token := sysConfig.Token + if token == "" { + token = uuid.New().String() + } + + // Create fingerprint record for new system + if err := createFingerprintRecord(h, newSystem.Id, token); err != nil { + return err + } } } - // Delete systems not in config + // Delete systems not in config (and their fingerprint records will cascade delete) for _, system := range existingSystemsMap { if err := h.Delete(system); err != nil { return err @@ -139,7 +161,7 @@ func syncSystemsWithConfig(e *core.ServeEvent) error { } // Generates content for the config.yml file as a YAML string -func (h *Hub) generateConfigYAML() (string, error) { +func generateYAML(h core.App) (string, error) { // Fetch all systems from the database systems, err := h.FindRecordsByFilter("systems", "id != ''", "name", -1, 0) if err != nil { @@ -147,8 +169,8 @@ func (h *Hub) generateConfigYAML() (string, error) { } // Create a Config struct to hold the data - config := Config{ - Systems: make([]SystemConfig, 0, len(systems)), + config := config{ + Systems: make([]systemConfig, 0, len(systems)), } // Fetch all users at once @@ -156,11 +178,29 @@ func (h *Hub) generateConfigYAML() (string, error) { for _, system := range systems { allUserIDs = append(allUserIDs, system.GetStringSlice("users")...) } - userEmailMap, err := h.getUserEmailMap(allUserIDs) + userEmailMap, err := getUserEmailMap(h, allUserIDs) if err != nil { return "", err } + // Fetch all fingerprint records to get tokens + type fingerprintData struct { + ID string `db:"id"` + System string `db:"system"` + Token string `db:"token"` + } + var fingerprints []fingerprintData + err = h.DB().NewQuery("SELECT id, system, token FROM fingerprints").All(&fingerprints) + if err != nil { + return "", err + } + + // Create a map of system ID to token + systemTokenMap := make(map[string]string) + for _, fingerprint := range fingerprints { + systemTokenMap[fingerprint.System] = fingerprint.Token + } + // Populate the Config struct with system data for _, system := range systems { userIDs := system.GetStringSlice("users") @@ -171,11 +211,12 @@ func (h *Hub) generateConfigYAML() (string, error) { } } - sysConfig := SystemConfig{ + sysConfig := systemConfig{ Name: system.GetString("name"), Host: system.GetString("host"), Port: cast.ToUint16(system.Get("port")), Users: userEmails, + Token: systemTokenMap[system.Id], } config.Systems = append(config.Systems, sysConfig) } @@ -187,13 +228,13 @@ func (h *Hub) generateConfigYAML() (string, error) { } // Add a header to the YAML - yamlData = append([]byte("# Values for port and users are optional.\n# Defaults are port 45876 and the first created user.\n\n"), yamlData...) + yamlData = append([]byte("# Values for port, users, and token are optional.\n# Defaults are port 45876, the first created user, and a generated UUID token.\n\n"), yamlData...) return string(yamlData), nil } // New helper function to get a map of user IDs to emails -func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) { +func getUserEmailMap(h core.App, userIDs []string) (map[string]string, error) { users, err := h.FindRecordsByIds("users", userIDs) if err != nil { return nil, err @@ -207,13 +248,42 @@ func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) { return userEmailMap, nil } +// Helper function to update or create fingerprint token for an existing system +func updateFingerprintToken(app core.App, systemID, token string) error { + // Try to find existing fingerprint record + fingerprint, err := app.FindFirstRecordByFilter("fingerprints", "system = {:system}", dbx.Params{"system": systemID}) + if err != nil { + // If no fingerprint record exists, create one + return createFingerprintRecord(app, systemID, token) + } + + // Update existing fingerprint record with new token (keep existing fingerprint) + fingerprint.Set("token", token) + return app.Save(fingerprint) +} + +// Helper function to create a new fingerprint record for a system +func createFingerprintRecord(app core.App, systemID, token string) error { + fingerprintsCollection, err := app.FindCollectionByNameOrId("fingerprints") + if err != nil { + return fmt.Errorf("failed to find fingerprints collection: %v", err) + } + + newFingerprint := core.NewRecord(fingerprintsCollection) + newFingerprint.Set("system", systemID) + newFingerprint.Set("token", token) + newFingerprint.Set("fingerprint", "") // Empty fingerprint, will be set on first connection + + return app.Save(newFingerprint) +} + // Returns the current config.yml file as a JSON object -func (h *Hub) getYamlConfig(e *core.RequestEvent) error { +func GetYamlConfig(e *core.RequestEvent) error { info, _ := e.RequestInfo() if info.Auth == nil || info.Auth.GetString("role") != "admin" { return apis.NewForbiddenError("Forbidden", nil) } - configContent, err := h.generateConfigYAML() + configContent, err := generateYAML(e.App) if err != nil { return err } diff --git a/beszel/internal/hub/config/config_test.go b/beszel/internal/hub/config/config_test.go new file mode 100644 index 0000000..8c1377c --- /dev/null +++ b/beszel/internal/hub/config/config_test.go @@ -0,0 +1,245 @@ +//go:build testing +// +build testing + +package config_test + +import ( + "beszel/internal/hub/config" + "beszel/internal/tests" + "os" + "path/filepath" + "testing" + + "github.com/pocketbase/pocketbase/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +// Config struct for testing (copied from config package since it's not exported) +type testConfig struct { + Systems []testSystemConfig `yaml:"systems"` +} + +type testSystemConfig struct { + Name string `yaml:"name"` + Host string `yaml:"host"` + Port uint16 `yaml:"port,omitempty"` + Users []string `yaml:"users"` + Token string `yaml:"token,omitempty"` +} + +// Helper function to create a test system for config tests +// func createConfigTestSystem(app core.App, name, host string, port uint16, userIDs []string) (*core.Record, error) { +// systemCollection, err := app.FindCollectionByNameOrId("systems") +// if err != nil { +// return nil, err +// } + +// system := core.NewRecord(systemCollection) +// system.Set("name", name) +// system.Set("host", host) +// system.Set("port", port) +// system.Set("users", userIDs) +// system.Set("status", "pending") + +// return system, app.Save(system) +// } + +// Helper function to create a fingerprint record +func createConfigTestFingerprint(app core.App, systemID, token, fingerprint string) (*core.Record, error) { + fingerprintCollection, err := app.FindCollectionByNameOrId("fingerprints") + if err != nil { + return nil, err + } + + fp := core.NewRecord(fingerprintCollection) + fp.Set("system", systemID) + fp.Set("token", token) + fp.Set("fingerprint", fingerprint) + + return fp, app.Save(fp) +} + +// TestConfigSyncWithTokens tests the config.SyncSystems function with various token scenarios +func TestConfigSyncWithTokens(t *testing.T) { + testHub, err := tests.NewTestHub() + require.NoError(t, err) + defer testHub.Cleanup() + + // Create test user + user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest") + require.NoError(t, err) + + testCases := []struct { + name string + setupFunc func() (string, *core.Record, *core.Record) // Returns: existing token, system record, fingerprint record + configYAML string + expectToken string // Expected token after sync + description string + }{ + { + name: "new system with token in config", + setupFunc: func() (string, *core.Record, *core.Record) { + return "", nil, nil // No existing system + }, + configYAML: `systems: + - name: "new-server" + host: "new.example.com" + port: 45876 + users: + - "admin@example.com" + token: "explicit-token-123"`, + expectToken: "explicit-token-123", + description: "New system should use token from config", + }, + { + name: "existing system without token in config (preserve existing)", + setupFunc: func() (string, *core.Record, *core.Record) { + // Create existing system and fingerprint + system, err := tests.CreateRecord(testHub.App, "systems", map[string]any{ + "name": "preserve-server", + "host": "preserve.example.com", + "port": 45876, + "users": []string{user.Id}, + }) + require.NoError(t, err) + + fingerprint, err := createConfigTestFingerprint(testHub.App, system.Id, "preserve-token-999", "preserve-fingerprint") + require.NoError(t, err) + + return "preserve-token-999", system, fingerprint + }, + configYAML: `systems: + - name: "preserve-server" + host: "preserve.example.com" + port: 45876 + users: + - "admin@example.com"`, + expectToken: "preserve-token-999", + description: "Existing system should preserve original token when config doesn't specify one", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup test data + _, existingSystem, existingFingerprint := tc.setupFunc() + + // Write config file + configPath := filepath.Join(testHub.DataDir(), "config.yml") + err := os.WriteFile(configPath, []byte(tc.configYAML), 0644) + require.NoError(t, err) + + // Create serve event and sync + event := &core.ServeEvent{App: testHub.App} + err = config.SyncSystems(event) + require.NoError(t, err) + + // Parse the config to get the system name for verification + var configData testConfig + err = yaml.Unmarshal([]byte(tc.configYAML), &configData) + require.NoError(t, err) + require.Len(t, configData.Systems, 1) + systemName := configData.Systems[0].Name + + // Find the system after sync + systems, err := testHub.FindRecordsByFilter("systems", "name = {:name}", "", -1, 0, map[string]any{"name": systemName}) + require.NoError(t, err) + require.Len(t, systems, 1) + system := systems[0] + + // Find the fingerprint record + fingerprints, err := testHub.FindRecordsByFilter("fingerprints", "system = {:system}", "", -1, 0, map[string]any{"system": system.Id}) + require.NoError(t, err) + require.Len(t, fingerprints, 1) + fingerprint := fingerprints[0] + + // Verify token + actualToken := fingerprint.GetString("token") + if tc.expectToken == "" { + // For generated tokens, just verify it's not empty and is a valid UUID format + assert.NotEmpty(t, actualToken, tc.description) + assert.Len(t, actualToken, 36, "Generated token should be UUID format") // UUID length + } else { + assert.Equal(t, tc.expectToken, actualToken, tc.description) + } + + // For existing systems, verify fingerprint is preserved + if existingFingerprint != nil { + actualFingerprint := fingerprint.GetString("fingerprint") + expectedFingerprint := existingFingerprint.GetString("fingerprint") + assert.Equal(t, expectedFingerprint, actualFingerprint, "Fingerprint should be preserved") + } + + // Cleanup for next test + if existingSystem != nil { + testHub.Delete(existingSystem) + } + if existingFingerprint != nil { + testHub.Delete(existingFingerprint) + } + // Clean up the new records + testHub.Delete(system) + testHub.Delete(fingerprint) + }) + } +} + +// TestConfigMigrationScenario tests the specific migration scenario mentioned in the discussion +func TestConfigMigrationScenario(t *testing.T) { + testHub, err := tests.NewTestHub(t.TempDir()) + require.NoError(t, err) + defer testHub.Cleanup() + + // Create test user + user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest") + require.NoError(t, err) + + // Simulate migration scenario: system exists with token from migration + existingSystem, err := tests.CreateRecord(testHub.App, "systems", map[string]any{ + "name": "migrated-server", + "host": "migrated.example.com", + "port": 45876, + "users": []string{user.Id}, + }) + require.NoError(t, err) + + migrationToken := "migration-generated-token-123" + existingFingerprint, err := createConfigTestFingerprint(testHub.App, existingSystem.Id, migrationToken, "existing-fingerprint-from-agent") + require.NoError(t, err) + + // User exports config BEFORE this update (so no token field in YAML) + oldConfigYAML := `systems: + - name: "migrated-server" + host: "migrated.example.com" + port: 45876 + users: + - "admin@example.com"` + + // Write old config file and import + configPath := filepath.Join(testHub.DataDir(), "config.yml") + err = os.WriteFile(configPath, []byte(oldConfigYAML), 0644) + require.NoError(t, err) + + event := &core.ServeEvent{App: testHub.App} + err = config.SyncSystems(event) + require.NoError(t, err) + + // Verify the original token is preserved + updatedFingerprint, err := testHub.FindRecordById("fingerprints", existingFingerprint.Id) + require.NoError(t, err) + + actualToken := updatedFingerprint.GetString("token") + assert.Equal(t, migrationToken, actualToken, "Migration token should be preserved when config doesn't specify a token") + + // Verify fingerprint is also preserved + actualFingerprint := updatedFingerprint.GetString("fingerprint") + assert.Equal(t, "existing-fingerprint-from-agent", actualFingerprint, "Existing fingerprint should be preserved") + + // Verify system still exists and is updated correctly + updatedSystem, err := testHub.FindRecordById("systems", existingSystem.Id) + require.NoError(t, err) + assert.Equal(t, "migrated-server", updatedSystem.GetString("name")) + assert.Equal(t, "migrated.example.com", updatedSystem.GetString("host")) +} diff --git a/beszel/internal/hub/expirymap/expirymap.go b/beszel/internal/hub/expirymap/expirymap.go new file mode 100644 index 0000000..db316c4 --- /dev/null +++ b/beszel/internal/hub/expirymap/expirymap.go @@ -0,0 +1,104 @@ +package expirymap + +import ( + "reflect" + "time" + + "github.com/pocketbase/pocketbase/tools/store" +) + +type val[T any] struct { + value T + expires time.Time +} + +type ExpiryMap[T any] struct { + store *store.Store[string, *val[T]] + cleanupInterval time.Duration +} + +// New creates a new expiry map with custom cleanup interval +func New[T any](cleanupInterval time.Duration) *ExpiryMap[T] { + m := &ExpiryMap[T]{ + store: store.New(map[string]*val[T]{}), + cleanupInterval: cleanupInterval, + } + m.startCleaner() + return m +} + +// Set stores a value with the given TTL +func (m *ExpiryMap[T]) Set(key string, value T, ttl time.Duration) { + m.store.Set(key, &val[T]{ + value: value, + expires: time.Now().Add(ttl), + }) +} + +// GetOk retrieves a value and checks if it exists and hasn't expired +// Performs lazy cleanup of expired entries on access +func (m *ExpiryMap[T]) GetOk(key string) (T, bool) { + value, ok := m.store.GetOk(key) + if !ok { + return *new(T), false + } + + // Check if expired and perform lazy cleanup + if value.expires.Before(time.Now()) { + m.store.Remove(key) + return *new(T), false + } + + return value.value, true +} + +// GetByValue retrieves a value by value +func (m *ExpiryMap[T]) GetByValue(val T) (key string, value T, ok bool) { + for key, v := range m.store.GetAll() { + if reflect.DeepEqual(v.value, val) { + // check if expired + if v.expires.Before(time.Now()) { + m.store.Remove(key) + break + } + return key, v.value, true + } + } + return "", *new(T), false +} + +// Remove explicitly removes a key +func (m *ExpiryMap[T]) Remove(key string) { + m.store.Remove(key) +} + +// RemovebyValue removes a value by value +func (m *ExpiryMap[T]) RemovebyValue(value T) (T, bool) { + for key, val := range m.store.GetAll() { + if reflect.DeepEqual(val.value, value) { + m.store.Remove(key) + return val.value, true + } + } + return *new(T), false +} + +// startCleaner runs the background cleanup process +func (m *ExpiryMap[T]) startCleaner() { + go func() { + tick := time.Tick(m.cleanupInterval) + for range tick { + m.cleanup() + } + }() +} + +// cleanup removes all expired entries +func (m *ExpiryMap[T]) cleanup() { + now := time.Now() + for key, val := range m.store.GetAll() { + if val.expires.Before(now) { + m.store.Remove(key) + } + } +} diff --git a/beszel/internal/hub/expirymap/expirymap_test.go b/beszel/internal/hub/expirymap/expirymap_test.go new file mode 100644 index 0000000..22658ed --- /dev/null +++ b/beszel/internal/hub/expirymap/expirymap_test.go @@ -0,0 +1,477 @@ +//go:build testing +// +build testing + +package expirymap + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Not using the following methods but are useful for testing + +// TESTING: Has checks if a key exists and hasn't expired +func (m *ExpiryMap[T]) Has(key string) bool { + _, ok := m.GetOk(key) + return ok +} + +// TESTING: Get retrieves a value, returns zero value if not found or expired +func (m *ExpiryMap[T]) Get(key string) T { + value, _ := m.GetOk(key) + return value +} + +// TESTING: Len returns the number of non-expired entries +func (m *ExpiryMap[T]) Len() int { + count := 0 + now := time.Now() + for _, val := range m.store.Values() { + if val.expires.After(now) { + count++ + } + } + return count +} + +func TestExpiryMap_BasicOperations(t *testing.T) { + em := New[string](time.Hour) + + // Test Set and GetOk + em.Set("key1", "value1", time.Hour) + value, ok := em.GetOk("key1") + assert.True(t, ok) + assert.Equal(t, "value1", value) + + // Test Get + value = em.Get("key1") + assert.Equal(t, "value1", value) + + // Test Has + assert.True(t, em.Has("key1")) + assert.False(t, em.Has("nonexistent")) + + // Test Remove + em.Remove("key1") + assert.False(t, em.Has("key1")) +} + +func TestExpiryMap_Expiration(t *testing.T) { + em := New[string](time.Hour) + + // Set a value with very short TTL + em.Set("shortlived", "value", time.Millisecond*10) + + // Should exist immediately + assert.True(t, em.Has("shortlived")) + + // Wait for expiration + time.Sleep(time.Millisecond * 20) + + // Should be expired and automatically cleaned up on access + assert.False(t, em.Has("shortlived")) + value, ok := em.GetOk("shortlived") + assert.False(t, ok) + assert.Equal(t, "", value) // zero value for string +} + +func TestExpiryMap_LazyCleanup(t *testing.T) { + em := New[int](time.Hour) + + // Set multiple values with short TTL + em.Set("key1", 1, time.Millisecond*10) + em.Set("key2", 2, time.Millisecond*10) + em.Set("key3", 3, time.Hour) // This one won't expire + + // Wait for expiration + time.Sleep(time.Millisecond * 20) + + // Access expired keys should trigger lazy cleanup + _, ok := em.GetOk("key1") + assert.False(t, ok) + + // Non-expired key should still exist + value, ok := em.GetOk("key3") + assert.True(t, ok) + assert.Equal(t, 3, value) +} + +func TestExpiryMap_Len(t *testing.T) { + em := New[string](time.Hour) + + // Initially empty + assert.Equal(t, 0, em.Len()) + + // Add some values + em.Set("key1", "value1", time.Hour) + em.Set("key2", "value2", time.Hour) + em.Set("key3", "value3", time.Millisecond*10) // Will expire soon + + // Should count all initially + assert.Equal(t, 3, em.Len()) + + // Wait for one to expire + time.Sleep(time.Millisecond * 20) + + // Len should reflect only non-expired entries + assert.Equal(t, 2, em.Len()) +} + +func TestExpiryMap_CustomInterval(t *testing.T) { + // Create with very short cleanup interval for testing + em := New[string](time.Millisecond * 50) + + // Set a value that expires quickly + em.Set("test", "value", time.Millisecond*10) + + // Should exist initially + assert.True(t, em.Has("test")) + + // Wait for expiration + cleanup cycle + time.Sleep(time.Millisecond * 100) + + // Should be cleaned up by background process + // Note: This test might be flaky due to timing, but demonstrates the concept + assert.False(t, em.Has("test")) +} + +func TestExpiryMap_GenericTypes(t *testing.T) { + // Test with different types + t.Run("Int", func(t *testing.T) { + em := New[int](time.Hour) + + em.Set("num", 42, time.Hour) + value, ok := em.GetOk("num") + assert.True(t, ok) + assert.Equal(t, 42, value) + }) + + t.Run("Struct", func(t *testing.T) { + type TestStruct struct { + Name string + Age int + } + + em := New[TestStruct](time.Hour) + + expected := TestStruct{Name: "John", Age: 30} + em.Set("person", expected, time.Hour) + + value, ok := em.GetOk("person") + assert.True(t, ok) + assert.Equal(t, expected, value) + }) + + t.Run("Pointer", func(t *testing.T) { + em := New[*string](time.Hour) + + str := "hello" + em.Set("ptr", &str, time.Hour) + + value, ok := em.GetOk("ptr") + assert.True(t, ok) + require.NotNil(t, value) + assert.Equal(t, "hello", *value) + }) +} + +func TestExpiryMap_ZeroValues(t *testing.T) { + em := New[string](time.Hour) + + // Test getting non-existent key returns zero value + value := em.Get("nonexistent") + assert.Equal(t, "", value) + + // Test getting expired key returns zero value + em.Set("expired", "value", time.Millisecond*10) + time.Sleep(time.Millisecond * 20) + + value = em.Get("expired") + assert.Equal(t, "", value) +} + +func TestExpiryMap_Concurrent(t *testing.T) { + em := New[int](time.Hour) + + // Simple concurrent access test + done := make(chan bool, 2) + + // Writer goroutine + go func() { + for i := 0; i < 100; i++ { + em.Set("key", i, time.Hour) + time.Sleep(time.Microsecond) + } + done <- true + }() + + // Reader goroutine + go func() { + for i := 0; i < 100; i++ { + _ = em.Get("key") + time.Sleep(time.Microsecond) + } + done <- true + }() + + // Wait for both to complete + <-done + <-done + + // Should not panic and should have some value + assert.True(t, em.Has("key")) +} + +func TestExpiryMap_GetByValue(t *testing.T) { + em := New[string](time.Hour) + + // Test getting by value when value exists + em.Set("key1", "value1", time.Hour) + em.Set("key2", "value2", time.Hour) + em.Set("key3", "value1", time.Hour) // Duplicate value - should return first match + + // Test successful retrieval + key, value, ok := em.GetByValue("value1") + assert.True(t, ok) + assert.Equal(t, "value1", value) + assert.Contains(t, []string{"key1", "key3"}, key) // Should be one of the keys with this value + + // Test retrieval of unique value + key, value, ok = em.GetByValue("value2") + assert.True(t, ok) + assert.Equal(t, "value2", value) + assert.Equal(t, "key2", key) + + // Test getting non-existent value + key, value, ok = em.GetByValue("nonexistent") + assert.False(t, ok) + assert.Equal(t, "", value) // zero value for string + assert.Equal(t, "", key) // zero value for string +} + +func TestExpiryMap_GetByValue_Expiration(t *testing.T) { + em := New[string](time.Hour) + + // Set a value with short TTL + em.Set("shortkey", "shortvalue", time.Millisecond*10) + em.Set("longkey", "longvalue", time.Hour) + + // Should find the short-lived value initially + key, value, ok := em.GetByValue("shortvalue") + assert.True(t, ok) + assert.Equal(t, "shortvalue", value) + assert.Equal(t, "shortkey", key) + + // Wait for expiration + time.Sleep(time.Millisecond * 20) + + // Should not find expired value and should trigger lazy cleanup + key, value, ok = em.GetByValue("shortvalue") + assert.False(t, ok) + assert.Equal(t, "", value) + assert.Equal(t, "", key) + + // Should still find non-expired value + key, value, ok = em.GetByValue("longvalue") + assert.True(t, ok) + assert.Equal(t, "longvalue", value) + assert.Equal(t, "longkey", key) +} + +func TestExpiryMap_GetByValue_GenericTypes(t *testing.T) { + t.Run("Int", func(t *testing.T) { + em := New[int](time.Hour) + + em.Set("num1", 42, time.Hour) + em.Set("num2", 84, time.Hour) + + key, value, ok := em.GetByValue(42) + assert.True(t, ok) + assert.Equal(t, 42, value) + assert.Equal(t, "num1", key) + + key, value, ok = em.GetByValue(99) + assert.False(t, ok) + assert.Equal(t, 0, value) + assert.Equal(t, "", key) + }) + + t.Run("Struct", func(t *testing.T) { + type TestStruct struct { + Name string + Age int + } + + em := New[TestStruct](time.Hour) + + person1 := TestStruct{Name: "John", Age: 30} + person2 := TestStruct{Name: "Jane", Age: 25} + + em.Set("person1", person1, time.Hour) + em.Set("person2", person2, time.Hour) + + key, value, ok := em.GetByValue(person1) + assert.True(t, ok) + assert.Equal(t, person1, value) + assert.Equal(t, "person1", key) + + nonexistent := TestStruct{Name: "Bob", Age: 40} + key, value, ok = em.GetByValue(nonexistent) + assert.False(t, ok) + assert.Equal(t, TestStruct{}, value) + assert.Equal(t, "", key) + }) +} + +func TestExpiryMap_RemoveValue(t *testing.T) { + em := New[string](time.Hour) + + // Test removing existing value + em.Set("key1", "value1", time.Hour) + em.Set("key2", "value2", time.Hour) + em.Set("key3", "value1", time.Hour) // Duplicate value + + // Remove by value should remove one instance + removedValue, ok := em.RemovebyValue("value1") + assert.True(t, ok) + assert.Equal(t, "value1", removedValue) + + // Should still have the other instance or value2 + assert.True(t, em.Has("key2")) // value2 should still exist + + // Check if one of the duplicate values was removed + // At least one key with "value1" should be gone + key1Exists := em.Has("key1") + key3Exists := em.Has("key3") + assert.False(t, key1Exists && key3Exists) // Both shouldn't exist + assert.True(t, key1Exists || key3Exists) // At least one should be gone + + // Test removing non-existent value + removedValue, ok = em.RemovebyValue("nonexistent") + assert.False(t, ok) + assert.Equal(t, "", removedValue) // zero value for string +} + +func TestExpiryMap_RemoveValue_GenericTypes(t *testing.T) { + t.Run("Int", func(t *testing.T) { + em := New[int](time.Hour) + + em.Set("num1", 42, time.Hour) + em.Set("num2", 84, time.Hour) + + // Remove existing value + removedValue, ok := em.RemovebyValue(42) + assert.True(t, ok) + assert.Equal(t, 42, removedValue) + assert.False(t, em.Has("num1")) + assert.True(t, em.Has("num2")) + + // Remove non-existent value + removedValue, ok = em.RemovebyValue(99) + assert.False(t, ok) + assert.Equal(t, 0, removedValue) + }) + + t.Run("Struct", func(t *testing.T) { + type TestStruct struct { + Name string + Age int + } + + em := New[TestStruct](time.Hour) + + person1 := TestStruct{Name: "John", Age: 30} + person2 := TestStruct{Name: "Jane", Age: 25} + + em.Set("person1", person1, time.Hour) + em.Set("person2", person2, time.Hour) + + // Remove existing struct + removedValue, ok := em.RemovebyValue(person1) + assert.True(t, ok) + assert.Equal(t, person1, removedValue) + assert.False(t, em.Has("person1")) + assert.True(t, em.Has("person2")) + + // Remove non-existent struct + nonexistent := TestStruct{Name: "Bob", Age: 40} + removedValue, ok = em.RemovebyValue(nonexistent) + assert.False(t, ok) + assert.Equal(t, TestStruct{}, removedValue) + }) +} + +func TestExpiryMap_RemoveValue_WithExpiration(t *testing.T) { + em := New[string](time.Hour) + + // Set values with different TTLs + em.Set("key1", "value1", time.Millisecond*10) // Will expire + em.Set("key2", "value2", time.Hour) // Won't expire + em.Set("key3", "value1", time.Hour) // Won't expire, duplicate value + + // Wait for first value to expire + time.Sleep(time.Millisecond * 20) + + // Try to remove the expired value - should remove one of the "value1" entries + removedValue, ok := em.RemovebyValue("value1") + assert.True(t, ok) + assert.Equal(t, "value1", removedValue) + + // Should still have key2 (different value) + assert.True(t, em.Has("key2")) + + // Should have removed one of the "value1" entries (either key1 or key3) + // But we can't predict which one due to map iteration order + key1Exists := em.Has("key1") + key3Exists := em.Has("key3") + + // Exactly one of key1 or key3 should be gone + assert.False(t, key1Exists && key3Exists) // Both shouldn't exist + assert.True(t, key1Exists || key3Exists) // At least one should still exist +} + +func TestExpiryMap_ValueOperations_Integration(t *testing.T) { + em := New[string](time.Hour) + + // Test integration of GetByValue and RemoveValue + em.Set("key1", "shared", time.Hour) + em.Set("key2", "unique", time.Hour) + em.Set("key3", "shared", time.Hour) + + // Find shared value + key, value, ok := em.GetByValue("shared") + assert.True(t, ok) + assert.Equal(t, "shared", value) + assert.Contains(t, []string{"key1", "key3"}, key) + + // Remove shared value + removedValue, ok := em.RemovebyValue("shared") + assert.True(t, ok) + assert.Equal(t, "shared", removedValue) + + // Should still be able to find the other shared value + key, value, ok = em.GetByValue("shared") + assert.True(t, ok) + assert.Equal(t, "shared", value) + assert.Contains(t, []string{"key1", "key3"}, key) + + // Remove the other shared value + removedValue, ok = em.RemovebyValue("shared") + assert.True(t, ok) + assert.Equal(t, "shared", removedValue) + + // Should not find shared value anymore + key, value, ok = em.GetByValue("shared") + assert.False(t, ok) + assert.Equal(t, "", value) + assert.Equal(t, "", key) + + // Unique value should still exist + key, value, ok = em.GetByValue("unique") + assert.True(t, ok) + assert.Equal(t, "unique", value) + assert.Equal(t, "key2", key) +} diff --git a/beszel/internal/hub/hub.go b/beszel/internal/hub/hub.go index 54eb98e..a425f4f 100644 --- a/beszel/internal/hub/hub.go +++ b/beszel/internal/hub/hub.go @@ -4,6 +4,7 @@ package hub import ( "beszel" "beszel/internal/alerts" + "beszel/internal/hub/config" "beszel/internal/hub/systems" "beszel/internal/records" "beszel/internal/users" @@ -18,7 +19,9 @@ import ( "os" "path" "strings" + "time" + "github.com/google/uuid" "github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" @@ -32,6 +35,7 @@ type Hub struct { rm *records.RecordManager sm *systems.SystemManager pubKey string + signer ssh.Signer appURL string } @@ -64,7 +68,7 @@ func (h *Hub) StartHub() error { return err } // sync systems with config - if err := syncSystemsWithConfig(e); err != nil { + if err := config.SyncSystems(e); err != nil { return err } // register api routes @@ -112,6 +116,9 @@ func (h *Hub) initialize(e *core.ServeEvent) error { if h.appURL != "" { settings.Meta.AppURL = h.appURL } + if err := e.App.Save(settings); err != nil { + return err + } // set auth settings usersCollection, err := e.App.FindCollectionByNameOrId("users") if err != nil { @@ -181,6 +188,7 @@ func (h *Hub) startServer(se *core.ServeEvent) error { indexFile, _ := fs.ReadFile(site.DistDirFS, "index.html") indexContent := strings.ReplaceAll(string(indexFile), "./", basePath) indexContent = strings.Replace(indexContent, "{{V}}", beszel.Version, 1) + indexContent = strings.Replace(indexContent, "{{HUB_URL}}", h.appURL, 1) // set up static asset serving staticPaths := [2]string{"/static/", "/assets/"} serveStatic := apis.Static(site.DistDirFS, false) @@ -232,7 +240,11 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error { // send test notification se.Router.GET("/api/beszel/send-test-notification", h.SendTestNotification) // API endpoint to get config.yml content - se.Router.GET("/api/beszel/config-yaml", h.getYamlConfig) + se.Router.GET("/api/beszel/config-yaml", config.GetYamlConfig) + // handle agent websocket connection + se.Router.GET("/api/beszel/agent-connect", h.handleAgentConnect) + // get or create universal tokens + se.Router.GET("/api/beszel/universal-token", h.getUniversalToken) // create first user endpoint only needed if no users exist if totalUsers, _ := h.CountRecords("users"); totalUsers == 0 { se.Router.POST("/api/beszel/create-user", h.um.CreateFirstUser) @@ -240,8 +252,49 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error { return nil } +// Handler for universal token API endpoint (create, read, delete) +func (h *Hub) getUniversalToken(e *core.RequestEvent) error { + info, err := e.RequestInfo() + if err != nil || info.Auth == nil { + return apis.NewForbiddenError("Forbidden", nil) + } + + tokenMap := getTokenMap() + userID := info.Auth.Id + query := e.Request.URL.Query() + token := query.Get("token") + tokenSet := token != "" + + if !tokenSet { + // return existing token if it exists + if token, _, ok := tokenMap.GetByValue(userID); ok { + return e.JSON(http.StatusOK, map[string]any{"token": token, "active": true}) + } + // if no token is provided, generate a new one + token = uuid.New().String() + } + response := map[string]any{"token": token} + + switch query.Get("enable") { + case "1": + tokenMap.Set(token, userID, time.Hour) + case "0": + tokenMap.RemovebyValue(userID) + } + _, response["active"] = tokenMap.GetOk(token) + return e.JSON(http.StatusOK, response) +} + // generates key pair if it doesn't exist and returns signer func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) { + if h.signer != nil { + return h.signer, nil + } + + if dataDir == "" { + dataDir = h.DataDir() + } + privateKeyPath := path.Join(dataDir, "id_ed25519") // check if the key pair already exists @@ -260,12 +313,10 @@ func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) { } // Generate the Ed25519 key pair - pubKey, privKey, err := ed25519.GenerateKey(nil) + _, privKey, err := ed25519.GenerateKey(nil) if err != nil { return nil, err } - - // Get the private key in OpenSSH format privKeyPem, err := ssh.MarshalPrivateKey(privKey, "") if err != nil { return nil, err @@ -276,13 +327,11 @@ func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) { } // These are fine to ignore the errors on, as we've literally just created a crypto.PublicKey | crypto.Signer - sshPubKey, _ := ssh.NewPublicKey(pubKey) sshPrivate, _ := ssh.NewSignerFromSigner(privKey) - - pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey) + pubKeyBytes := ssh.MarshalAuthorizedKey(sshPrivate.PublicKey()) h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n") - h.Logger().Info("ed25519 SSH key pair generated successfully.") + h.Logger().Info("ed25519 key pair generated successfully.") h.Logger().Info("Saved to: " + privateKeyPath) return sshPrivate, err diff --git a/beszel/internal/hub/hub_test.go b/beszel/internal/hub/hub_test.go index fc1f4ea..f2439bd 100644 --- a/beszel/internal/hub/hub_test.go +++ b/beszel/internal/hub/hub_test.go @@ -1,9 +1,10 @@ //go:build testing // +build testing -package hub +package hub_test import ( + "beszel/internal/tests" "testing" "crypto/ed25519" @@ -12,20 +13,18 @@ import ( "path/filepath" "strings" - "github.com/pocketbase/pocketbase" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) -func getTestHub() *Hub { - app := pocketbase.New() - return NewHub(app) +func getTestHub(t testing.TB) *tests.TestHub { + hub, _ := tests.NewTestHub(t.TempDir()) + return hub } func TestMakeLink(t *testing.T) { - hub := getTestHub() + hub := getTestHub(t) tests := []struct { name string @@ -115,14 +114,14 @@ func TestMakeLink(t *testing.T) { } func TestGetSSHKey(t *testing.T) { - hub := getTestHub() + hub := getTestHub(t) // Test Case 1: Key generation (no existing key) t.Run("KeyGeneration", func(t *testing.T) { tempDir := t.TempDir() // Ensure pubKey is initially empty or different to ensure GetSSHKey sets it - hub.pubKey = "" + hub.SetPubkey("") signer, err := hub.GetSSHKey(tempDir) assert.NoError(t, err, "GetSSHKey should not error when generating a new key") @@ -135,8 +134,8 @@ func TestGetSSHKey(t *testing.T) { assert.False(t, info.IsDir(), "Private key path should be a file, not a directory") // Check if h.pubKey was set - assert.NotEmpty(t, hub.pubKey, "h.pubKey should be set after key generation") - assert.True(t, strings.HasPrefix(hub.pubKey, "ssh-ed25519 "), "h.pubKey should start with 'ssh-ed25519 '") + assert.NotEmpty(t, hub.GetPubkey(), "h.pubKey should be set after key generation") + assert.True(t, strings.HasPrefix(hub.GetPubkey(), "ssh-ed25519 "), "h.pubKey should start with 'ssh-ed25519 '") // Verify the generated private key is parsable keyData, err := os.ReadFile(privateKeyPath) @@ -170,14 +169,14 @@ func TestGetSSHKey(t *testing.T) { expectedPubKeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(sshPubKey))) // Reset h.pubKey to ensure it's set by GetSSHKey from the file - hub.pubKey = "" + hub.SetPubkey("") signer, err := hub.GetSSHKey(tempDir) assert.NoError(t, err, "GetSSHKey should not error when reading an existing key") assert.NotNil(t, signer, "GetSSHKey should return a non-nil signer for an existing key") // Check if h.pubKey was set correctly to the public key from the file - assert.Equal(t, expectedPubKeyStr, hub.pubKey, "h.pubKey should match the existing public key") + assert.Equal(t, expectedPubKeyStr, hub.GetPubkey(), "h.pubKey should match the existing public key") // Verify the signer's public key matches the original public key signerPubKey := signer.PublicKey() @@ -241,7 +240,7 @@ func TestGetSSHKey(t *testing.T) { require.NoError(t, err, "Setup failed") // Reset h.pubKey before each test case - hub.pubKey = "" + hub.SetPubkey("") // Attempt to get SSH key _, err = hub.GetSSHKey(tempDir) @@ -250,8 +249,10 @@ func TestGetSSHKey(t *testing.T) { tc.errorCheck(t, err) // Check that pubKey was not set in error cases - assert.Empty(t, hub.pubKey, "h.pubKey should not be set if there was an error") + assert.Empty(t, hub.GetPubkey(), "h.pubKey should not be set if there was an error") }) } }) } + +// Helper function to create test records diff --git a/beszel/internal/hub/hub_test_helpers.go b/beszel/internal/hub/hub_test_helpers.go new file mode 100644 index 0000000..62263d1 --- /dev/null +++ b/beszel/internal/hub/hub_test_helpers.go @@ -0,0 +1,21 @@ +//go:build testing +// +build testing + +package hub + +import "beszel/internal/hub/systems" + +// TESTING ONLY: GetSystemManager returns the system manager +func (h *Hub) GetSystemManager() *systems.SystemManager { + return h.sm +} + +// TESTING ONLY: GetPubkey returns the public key +func (h *Hub) GetPubkey() string { + return h.pubKey +} + +// TESTING ONLY: SetPubkey sets the public key +func (h *Hub) SetPubkey(pubkey string) { + h.pubKey = pubkey +} diff --git a/beszel/internal/hub/systems/system.go b/beszel/internal/hub/systems/system.go new file mode 100644 index 0000000..1fb30dc --- /dev/null +++ b/beszel/internal/hub/systems/system.go @@ -0,0 +1,387 @@ +package systems + +import ( + "beszel" + "beszel/internal/entities/system" + "beszel/internal/hub/ws" + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "strings" + "time" + + "github.com/blang/semver" + "github.com/fxamacker/cbor/v2" + "github.com/pocketbase/pocketbase/core" + "golang.org/x/crypto/ssh" +) + +type System struct { + Id string `db:"id"` + Host string `db:"host"` + Port string `db:"port"` + Status string `db:"status"` + manager *SystemManager // Manager that this system belongs to + client *ssh.Client // SSH client for fetching data + data *system.CombinedData // system data from agent + ctx context.Context // Context for stopping the updater + cancel context.CancelFunc // Stops and removes system from updater + WsConn *ws.WsConn // Handler for agent WebSocket connection + agentVersion semver.Version // Agent version + updateTicker *time.Ticker // Ticker for updating the system +} + +func (sm *SystemManager) NewSystem(systemId string) *System { + system := &System{ + Id: systemId, + data: &system.CombinedData{}, + } + system.ctx, system.cancel = system.getContext() + return system +} + +// StartUpdater starts the system updater. +// It first fetches the data from the agent then updates the records. +// If the data is not found or the system is down, it sets the system down. +func (sys *System) StartUpdater() { + // Channel that can be used to set the system down. Currently only used to + // allow a short delay for reconnection after websocket connection is closed. + var downChan chan struct{} + + // Add random jitter to first WebSocket connection to prevent + // clustering if all agents are started at the same time. + // SSH connections during hub startup are already staggered. + var jitter <-chan time.Time + if sys.WsConn != nil { + jitter = getJitter() + // use the websocket connection's down channel to set the system down + downChan = sys.WsConn.DownChan + } else { + // if the system does not have a websocket connection, wait before updating + // to allow the agent to connect via websocket (makes sure fingerprint is set). + time.Sleep(11 * time.Second) + } + + // update immediately if system is not paused (only for ws connections) + // we'll wait a minute before connecting via SSH to prioritize ws connections + if sys.Status != paused && sys.ctx.Err() == nil { + if err := sys.update(); err != nil { + _ = sys.setDown(err) + } + } + + sys.updateTicker = time.NewTicker(time.Duration(interval) * time.Millisecond) + // Go 1.23+ will automatically stop the ticker when the system is garbage collected, however we seem to need this or testing/synctest will block even if calling runtime.GC() + defer sys.updateTicker.Stop() + + for { + select { + case <-sys.ctx.Done(): + return + case <-sys.updateTicker.C: + if err := sys.update(); err != nil { + _ = sys.setDown(err) + } + case <-downChan: + sys.WsConn = nil + downChan = nil + _ = sys.setDown(nil) + case <-jitter: + sys.updateTicker.Reset(time.Duration(interval) * time.Millisecond) + if err := sys.update(); err != nil { + _ = sys.setDown(err) + } + } + } +} + +// update updates the system data and records. +func (sys *System) update() error { + if sys.Status == paused { + sys.handlePaused() + return nil + } + data, err := sys.fetchDataFromAgent() + if err == nil { + _, err = sys.createRecords(data) + } + return err +} + +func (sys *System) handlePaused() { + if sys.WsConn == nil { + // if the system is paused and there's no websocket connection, remove the system + _ = sys.manager.RemoveSystem(sys.Id) + } else { + // Send a ping to the agent to keep the connection alive if the system is paused + if err := sys.WsConn.Ping(); err != nil { + sys.manager.hub.Logger().Warn("Failed to ping agent", "system", sys.Id, "err", err) + _ = sys.manager.RemoveSystem(sys.Id) + } + } +} + +// createRecords updates the system record and adds system_stats and container_stats records +func (sys *System) createRecords(data *system.CombinedData) (*core.Record, error) { + systemRecord, err := sys.getRecord() + if err != nil { + return nil, err + } + hub := sys.manager.hub + // add system_stats and container_stats records + systemStatsCollection, err := hub.FindCachedCollectionByNameOrId("system_stats") + if err != nil { + return nil, err + } + + systemStatsRecord := core.NewRecord(systemStatsCollection) + systemStatsRecord.Set("system", systemRecord.Id) + systemStatsRecord.Set("stats", data.Stats) + systemStatsRecord.Set("type", "1m") + if err := hub.SaveNoValidate(systemStatsRecord); err != nil { + return nil, err + } + // add new container_stats record + if len(data.Containers) > 0 { + containerStatsCollection, err := hub.FindCachedCollectionByNameOrId("container_stats") + if err != nil { + return nil, err + } + containerStatsRecord := core.NewRecord(containerStatsCollection) + containerStatsRecord.Set("system", systemRecord.Id) + containerStatsRecord.Set("stats", data.Containers) + containerStatsRecord.Set("type", "1m") + if err := hub.SaveNoValidate(containerStatsRecord); err != nil { + return nil, err + } + } + // update system record (do this last because it triggers alerts and we need above records to be inserted first) + systemRecord.Set("status", up) + + systemRecord.Set("info", data.Info) + if err := hub.SaveNoValidate(systemRecord); err != nil { + return nil, err + } + return systemRecord, nil +} + +// getRecord retrieves the system record from the database. +// If the record is not found, it removes the system from the manager. +func (sys *System) getRecord() (*core.Record, error) { + record, err := sys.manager.hub.FindRecordById("systems", sys.Id) + if err != nil || record == nil { + _ = sys.manager.RemoveSystem(sys.Id) + return nil, err + } + return record, nil +} + +// setDown marks a system as down in the database. +// It takes the original error that caused the system to go down and returns any error +// encountered during the process of updating the system status. +func (sys *System) setDown(originalError error) error { + if sys.Status == down || sys.Status == paused { + return nil + } + record, err := sys.getRecord() + if err != nil { + return err + } + if originalError != nil { + sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", originalError) + } + record.Set("status", down) + return sys.manager.hub.SaveNoValidate(record) +} + +func (sys *System) getContext() (context.Context, context.CancelFunc) { + if sys.ctx == nil { + sys.ctx, sys.cancel = context.WithCancel(context.Background()) + } + return sys.ctx, sys.cancel +} + +// fetchDataFromAgent attempts to fetch data from the agent, +// prioritizing WebSocket if available. +func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) { + if sys.data == nil { + sys.data = &system.CombinedData{} + } + + if sys.WsConn != nil && sys.WsConn.IsConnected() { + wsData, err := sys.fetchDataViaWebSocket() + if err == nil { + return wsData, nil + } + // close the WebSocket connection if error and try SSH + sys.closeWebSocketConnection() + } + + sshData, err := sys.fetchDataViaSSH() + if err != nil { + return nil, err + } + return sshData, nil +} + +func (sys *System) fetchDataViaWebSocket() (*system.CombinedData, error) { + if sys.WsConn == nil || !sys.WsConn.IsConnected() { + return nil, errors.New("no websocket connection") + } + err := sys.WsConn.RequestSystemData(sys.data) + if err != nil { + return nil, err + } + return sys.data, nil +} + +// fetchDataViaSSH handles fetching data using SSH. +// This function encapsulates the original SSH logic. +// It updates sys.data directly upon successful fetch. +func (sys *System) fetchDataViaSSH() (*system.CombinedData, error) { + maxRetries := 1 + for attempt := 0; attempt <= maxRetries; attempt++ { + if sys.client == nil || sys.Status == down { + if err := sys.createSSHClient(); err != nil { + return nil, err + } + } + + session, err := sys.createSessionWithTimeout(4 * time.Second) + if err != nil { + if attempt >= maxRetries { + return nil, err + } + sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err) + sys.closeSSHConnection() + // Reset format detection on connection failure - agent might have been upgraded + continue + } + defer session.Close() + + stdout, err := session.StdoutPipe() + if err != nil { + return nil, err + } + if err := session.Shell(); err != nil { + return nil, err + } + + *sys.data = system.CombinedData{} + + if sys.agentVersion.GTE(beszel.MinVersionCbor) { + err = cbor.NewDecoder(stdout).Decode(sys.data) + } else { + err = json.NewDecoder(stdout).Decode(sys.data) + } + + if err != nil { + sys.closeSSHConnection() + if attempt < maxRetries { + continue + } + return nil, err + } + + // wait for the session to complete + if err := session.Wait(); err != nil { + return nil, err + } + + return sys.data, nil + } + + // this should never be reached due to the return in the loop + return nil, fmt.Errorf("failed to fetch data") +} + +// createSSHClient creates a new SSH client for the system +func (s *System) createSSHClient() error { + if s.manager.sshConfig == nil { + if err := s.manager.createSSHClientConfig(); err != nil { + return err + } + } + network := "tcp" + host := s.Host + if strings.HasPrefix(host, "/") { + network = "unix" + } else { + host = net.JoinHostPort(host, s.Port) + } + var err error + s.client, err = ssh.Dial(network, host, s.manager.sshConfig) + if err != nil { + return err + } + s.agentVersion, _ = extractAgentVersion(string(s.client.Conn.ServerVersion())) + return nil +} + +// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging +// in case of network issues +func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) { + if sys.client == nil { + return nil, fmt.Errorf("client not initialized") + } + + ctx, cancel := context.WithTimeout(sys.ctx, timeout) + defer cancel() + + sessionChan := make(chan *ssh.Session, 1) + errChan := make(chan error, 1) + + go func() { + if session, err := sys.client.NewSession(); err != nil { + errChan <- err + } else { + sessionChan <- session + } + }() + + select { + case session := <-sessionChan: + return session, nil + case err := <-errChan: + return nil, err + case <-ctx.Done(): + return nil, fmt.Errorf("timeout") + } +} + +// closeSSHConnection closes the SSH connection but keeps the system in the manager +func (sys *System) closeSSHConnection() { + if sys.client != nil { + sys.client.Close() + sys.client = nil + } +} + +// closeWebSocketConnection closes the WebSocket connection but keeps the system in the manager +// to allow updating via SSH. It will be removed if the WS connection is re-established. +// The system will be set as down a few seconds later if the connection is not re-established. +func (sys *System) closeWebSocketConnection() { + if sys.WsConn != nil { + sys.WsConn.Close() + } +} + +// extractAgentVersion extracts the beszel version from SSH server version string +func extractAgentVersion(versionString string) (semver.Version, error) { + _, after, _ := strings.Cut(versionString, "_") + return semver.Parse(after) +} + +// getJitter returns a channel that will be triggered after a random delay +// between 40% and 90% of the interval. +// This is used to stagger the initial WebSocket connections to prevent clustering. +func getJitter() <-chan time.Time { + minPercent := 40 + maxPercent := 90 + jitterRange := maxPercent - minPercent + msDelay := (interval * minPercent / 100) + rand.Intn(interval*jitterRange/100) + return time.After(time.Duration(msDelay) * time.Millisecond) +} diff --git a/beszel/internal/hub/systems/system_manager.go b/beszel/internal/hub/systems/system_manager.go new file mode 100644 index 0000000..1ac3431 --- /dev/null +++ b/beszel/internal/hub/systems/system_manager.go @@ -0,0 +1,345 @@ +package systems + +import ( + "beszel" + "beszel/internal/common" + "beszel/internal/entities/system" + "beszel/internal/hub/ws" + "errors" + "fmt" + "time" + + "github.com/blang/semver" + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/store" + "golang.org/x/crypto/ssh" +) + +// System status constants +const ( + up string = "up" // System is online and responding + down string = "down" // System is offline or not responding + paused string = "paused" // System monitoring is paused + pending string = "pending" // System is waiting on initial connection result + + // interval is the default update interval in milliseconds (60 seconds) + interval int = 60_000 + // interval int = 10_000 // Debug interval for faster updates + + // sessionTimeout is the maximum time to wait for SSH connections + sessionTimeout = 4 * time.Second +) + +var ( + // errSystemExists is returned when attempting to add a system that already exists + errSystemExists = errors.New("system exists") +) + +// SystemManager manages a collection of monitored systems and their connections. +// It handles system lifecycle, status updates, and maintains both SSH and WebSocket connections. +type SystemManager struct { + hub hubLike // Hub interface for database and alert operations + systems *store.Store[string, *System] // Thread-safe store of active systems + sshConfig *ssh.ClientConfig // SSH client configuration for system connections +} + +// hubLike defines the interface requirements for the hub dependency. +// It extends core.App with system-specific functionality. +type hubLike interface { + core.App + GetSSHKey(dataDir string) (ssh.Signer, error) + HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error + HandleStatusAlerts(status string, systemRecord *core.Record) error +} + +// NewSystemManager creates a new SystemManager instance with the provided hub. +// The hub must implement the hubLike interface to provide database and alert functionality. +func NewSystemManager(hub hubLike) *SystemManager { + return &SystemManager{ + systems: store.New(map[string]*System{}), + hub: hub, + } +} + +// Initialize sets up the system manager by binding event hooks and starting existing systems. +// It configures SSH client settings and begins monitoring all non-paused systems from the database. +// Systems are started with staggered delays to prevent overwhelming the hub during startup. +func (sm *SystemManager) Initialize() error { + sm.bindEventHooks() + + // Initialize SSH client configuration + err := sm.createSSHClientConfig() + if err != nil { + return err + } + + // Load existing systems from database (excluding paused ones) + var systems []*System + err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems) + if err != nil || len(systems) == 0 { + return err + } + + // Start systems in background with staggered timing + go func() { + // Calculate staggered delay between system starts (max 2 seconds per system) + delta := interval / max(1, len(systems)) + delta = min(delta, 2_000) + sleepTime := time.Duration(delta) * time.Millisecond + + for _, system := range systems { + time.Sleep(sleepTime) + _ = sm.AddSystem(system) + } + }() + return nil +} + +// bindEventHooks registers event handlers for system and fingerprint record changes. +// These hooks ensure the system manager stays synchronized with database changes. +func (sm *SystemManager) bindEventHooks() { + sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate) + sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess) + sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate) + sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess) + sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess) + sm.hub.OnRecordAfterUpdateSuccess("fingerprints").BindFunc(sm.onTokenRotated) +} + +// onTokenRotated handles fingerprint token rotation events. +// When a system's authentication token is rotated, any existing WebSocket connection +// must be closed to force re-authentication with the new token. +func (sm *SystemManager) onTokenRotated(e *core.RecordEvent) error { + systemID := e.Record.GetString("system") + system, ok := sm.systems.GetOk(systemID) + if !ok { + return e.Next() + } + // No need to close connection if not connected via websocket + if system.WsConn == nil { + return e.Next() + } + system.setDown(nil) + sm.RemoveSystem(systemID) + return e.Next() +} + +// onRecordCreate is called before a new system record is committed to the database. +// It initializes the record with default values: empty info and pending status. +func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error { + e.Record.Set("info", system.Info{}) + e.Record.Set("status", pending) + return e.Next() +} + +// onRecordAfterCreateSuccess is called after a new system record is successfully created. +// It adds the new system to the manager to begin monitoring. +func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error { + if err := sm.AddRecord(e.Record, nil); err != nil { + e.App.Logger().Error("Error adding record", "err", err) + } + return e.Next() +} + +// onRecordUpdate is called before a system record is updated in the database. +// It clears system info when the status is changed to paused. +func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error { + if e.Record.GetString("status") == paused { + e.Record.Set("info", system.Info{}) + } + return e.Next() +} + +// onRecordAfterUpdateSuccess handles system record updates after they're committed to the database. +// It manages system lifecycle based on status changes and triggers appropriate alerts. +// Status transitions are handled as follows: +// - paused: Closes SSH connection and deactivates alerts +// - pending: Starts monitoring (reuses WebSocket if available) +// - up: Triggers system alerts +// - down: Triggers status change alerts +func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error { + newStatus := e.Record.GetString("status") + system, ok := sm.systems.GetOk(e.Record.Id) + if ok { + system.Status = newStatus + } + + switch newStatus { + case paused: + if ok { + // Pause monitoring but keep system in manager for potential resume + system.closeSSHConnection() + } + _ = deactivateAlerts(e.App, e.Record.Id) + return e.Next() + case pending: + // Resume monitoring, preferring existing WebSocket connection + if ok && system.WsConn != nil { + go system.update() + return e.Next() + } + // Start new monitoring session + if err := sm.AddRecord(e.Record, nil); err != nil { + e.App.Logger().Error("Error adding record", "err", err) + } + return e.Next() + } + + // Handle systems not in manager + if !ok { + return sm.AddRecord(e.Record, nil) + } + + prevStatus := system.Status + + // Trigger system alerts when system comes online + if newStatus == up { + if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil { + e.App.Logger().Error("Error handling system alerts", "err", err) + } + } + + // Trigger status change alerts for up/down transitions + if (newStatus == down && prevStatus == up) || (newStatus == up && prevStatus == down) { + if err := sm.hub.HandleStatusAlerts(newStatus, e.Record); err != nil { + e.App.Logger().Error("Error handling status alerts", "err", err) + } + } + return e.Next() +} + +// onRecordAfterDeleteSuccess is called after a system record is successfully deleted. +// It removes the system from the manager and cleans up all associated resources. +func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error { + sm.RemoveSystem(e.Record.Id) + return e.Next() +} + +// AddSystem adds a system to the manager and starts monitoring it. +// It validates required fields, initializes the system context, and starts the update goroutine. +// Returns error if a system with the same ID already exists. +func (sm *SystemManager) AddSystem(sys *System) error { + if sm.systems.Has(sys.Id) { + return errSystemExists + } + if sys.Id == "" || sys.Host == "" { + return errors.New("system missing required fields") + } + + // Initialize system for monitoring + sys.manager = sm + sys.ctx, sys.cancel = sys.getContext() + sys.data = &system.CombinedData{} + sm.systems.Set(sys.Id, sys) + + // Start monitoring in background + go sys.StartUpdater() + return nil +} + +// RemoveSystem removes a system from the manager and cleans up all associated resources. +// It cancels the system's context, closes all connections, and removes it from the store. +// Returns an error if the system is not found. +func (sm *SystemManager) RemoveSystem(systemID string) error { + system, ok := sm.systems.GetOk(systemID) + if !ok { + return errors.New("system not found") + } + + // Stop the update goroutine + if system.cancel != nil { + system.cancel() + } + + // Clean up all connections + system.closeSSHConnection() + system.closeWebSocketConnection() + sm.systems.Remove(systemID) + return nil +} + +// AddRecord creates a System instance from a database record and adds it to the manager. +// If a system with the same ID already exists, it's removed first to ensure clean state. +// If no system instance is provided, a new one is created. +// This method is typically called when systems are created or their status changes to pending. +func (sm *SystemManager) AddRecord(record *core.Record, system *System) (err error) { + // Remove existing system to ensure clean state + if sm.systems.Has(record.Id) { + _ = sm.RemoveSystem(record.Id) + } + + // Create new system if none provided + if system == nil { + system = sm.NewSystem(record.Id) + } + + // Populate system from record + system.Status = record.GetString("status") + system.Host = record.GetString("host") + system.Port = record.GetString("port") + + return sm.AddSystem(system) +} + +// AddWebSocketSystem creates and adds a system with an established WebSocket connection. +// This method is called when an agent connects via WebSocket with valid authentication. +// The system is immediately added to monitoring with the provided connection and version info. +func (sm *SystemManager) AddWebSocketSystem(systemId string, agentVersion semver.Version, wsConn *ws.WsConn) error { + systemRecord, err := sm.hub.FindRecordById("systems", systemId) + if err != nil { + return err + } + + system := sm.NewSystem(systemId) + system.WsConn = wsConn + system.agentVersion = agentVersion + + if err := sm.AddRecord(systemRecord, system); err != nil { + return err + } + return nil +} + +// createSSHClientConfig initializes the SSH client configuration for connecting to an agent's server +func (sm *SystemManager) createSSHClientConfig() error { + privateKey, err := sm.hub.GetSSHKey("") + if err != nil { + return err + } + + sm.sshConfig = &ssh.ClientConfig{ + User: "u", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(privateKey), + }, + Config: ssh.Config{ + Ciphers: common.DefaultCiphers, + KeyExchanges: common.DefaultKeyExchanges, + MACs: common.DefaultMACs, + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + ClientVersion: fmt.Sprintf("SSH-2.0-%s_%s", beszel.AppName, beszel.Version), + Timeout: sessionTimeout, + } + return nil +} + +// deactivateAlerts finds all triggered alerts for a system and sets them to inactive. +// This is called when a system is paused or goes offline to prevent continued alerts. +func deactivateAlerts(app core.App, systemID string) error { + // Note: Direct SQL updates don't trigger SSE, so we use the PocketBase API + // _, err := app.DB().NewQuery(fmt.Sprintf("UPDATE alerts SET triggered = false WHERE system = '%s'", systemID)).Execute() + + alerts, err := app.FindRecordsByFilter("alerts", fmt.Sprintf("system = '%s' && triggered = 1", systemID), "", -1, 0) + if err != nil { + return err + } + + for _, alert := range alerts { + alert.Set("triggered", false) + if err := app.SaveNoValidate(alert); err != nil { + return err + } + } + return nil +} diff --git a/beszel/internal/hub/systems/systems.go b/beszel/internal/hub/systems/systems.go deleted file mode 100644 index e2bb2a2..0000000 --- a/beszel/internal/hub/systems/systems.go +++ /dev/null @@ -1,457 +0,0 @@ -package systems - -import ( - "beszel/internal/common" - "beszel/internal/entities/system" - "context" - "fmt" - "net" - "strings" - "time" - - "github.com/goccy/go-json" - "github.com/pocketbase/pocketbase/core" - "github.com/pocketbase/pocketbase/tools/store" - "golang.org/x/crypto/ssh" -) - -const ( - up string = "up" - down string = "down" - paused string = "paused" - pending string = "pending" - - interval int = 60_000 - - sessionTimeout = 4 * time.Second -) - -type SystemManager struct { - hub hubLike - systems *store.Store[string, *System] - sshConfig *ssh.ClientConfig -} - -type System struct { - Id string `db:"id"` - Host string `db:"host"` - Port string `db:"port"` - Status string `db:"status"` - manager *SystemManager - client *ssh.Client - data *system.CombinedData - ctx context.Context - cancel context.CancelFunc -} - -type hubLike interface { - core.App - GetSSHKey(dataDir string) (ssh.Signer, error) - HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error - HandleStatusAlerts(status string, systemRecord *core.Record) error -} - -func NewSystemManager(hub hubLike) *SystemManager { - return &SystemManager{ - systems: store.New(map[string]*System{}), - hub: hub, - } -} - -// Initialize initializes the system manager. -// It binds the event hooks and starts updating existing systems. -func (sm *SystemManager) Initialize() error { - sm.bindEventHooks() - // ssh setup - err := sm.createSSHClientConfig() - if err != nil { - return err - } - // start updating existing systems - var systems []*System - err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems) - if err != nil || len(systems) == 0 { - return err - } - go func() { - // time between initial system updates - delta := interval / max(1, len(systems)) - delta = min(delta, 2_000) - sleepTime := time.Duration(delta) * time.Millisecond - for _, system := range systems { - time.Sleep(sleepTime) - _ = sm.AddSystem(system) - } - }() - return nil -} - -func (sm *SystemManager) bindEventHooks() { - sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate) - sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess) - sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate) - sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess) - sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess) -} - -// Runs before the record is committed to the database -func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error { - e.Record.Set("info", system.Info{}) - e.Record.Set("status", pending) - return e.Next() -} - -// Runs after the record is committed to the database -func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error { - if err := sm.AddRecord(e.Record); err != nil { - e.App.Logger().Error("Error adding record", "err", err) - } - return e.Next() -} - -// Runs before the record is updated -func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error { - if e.Record.GetString("status") == paused { - e.Record.Set("info", system.Info{}) - } - return e.Next() -} - -// Runs after the record is updated -func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error { - newStatus := e.Record.GetString("status") - switch newStatus { - case paused: - _ = sm.RemoveSystem(e.Record.Id) - _ = deactivateAlerts(e.App, e.Record.Id) - return e.Next() - case pending: - if err := sm.AddRecord(e.Record); err != nil { - e.App.Logger().Error("Error adding record", "err", err) - } - return e.Next() - } - system, ok := sm.systems.GetOk(e.Record.Id) - if !ok { - return sm.AddRecord(e.Record) - } - prevStatus := system.Status - system.Status = newStatus - // system alerts if system is up - if system.Status == up { - if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil { - e.App.Logger().Error("Error handling system alerts", "err", err) - } - } - if (system.Status == down && prevStatus == up) || (system.Status == up && prevStatus == down) { - if err := sm.hub.HandleStatusAlerts(system.Status, e.Record); err != nil { - e.App.Logger().Error("Error handling status alerts", "err", err) - } - } - return e.Next() -} - -// Runs after the record is deleted -func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error { - sm.RemoveSystem(e.Record.Id) - return e.Next() -} - -// AddSystem adds a system to the manager -func (sm *SystemManager) AddSystem(sys *System) error { - if sm.systems.Has(sys.Id) { - return fmt.Errorf("system exists") - } - if sys.Id == "" || sys.Host == "" { - return fmt.Errorf("system is missing required fields") - } - sys.manager = sm - sys.ctx, sys.cancel = context.WithCancel(context.Background()) - sys.data = &system.CombinedData{} - sm.systems.Set(sys.Id, sys) - go sys.StartUpdater() - return nil -} - -// RemoveSystem removes a system from the manager -func (sm *SystemManager) RemoveSystem(systemID string) error { - system, ok := sm.systems.GetOk(systemID) - if !ok { - return fmt.Errorf("system not found") - } - // cancel the context to signal stop - if system.cancel != nil { - system.cancel() - } - system.resetSSHClient() - sm.systems.Remove(systemID) - return nil -} - -// AddRecord adds a record to the system manager. -// It first removes any existing system with the same ID, then creates a new System -// instance from the record data and adds it to the manager. -// This function is typically called when a new system is created or when an existing -// system's status changes to pending. -func (sm *SystemManager) AddRecord(record *core.Record) (err error) { - _ = sm.RemoveSystem(record.Id) - system := &System{ - Id: record.Id, - Status: record.GetString("status"), - Host: record.GetString("host"), - Port: record.GetString("port"), - } - return sm.AddSystem(system) -} - -// StartUpdater starts the system updater. -// It first fetches the data from the agent then updates the records. -// If the data is not found or the system is down, it sets the system down. -func (sys *System) StartUpdater() { - if sys.data == nil { - sys.data = &system.CombinedData{} - } - if err := sys.update(); err != nil { - _ = sys.setDown(err) - } - - c := time.Tick(time.Duration(interval) * time.Millisecond) - - for { - select { - case <-sys.ctx.Done(): - return - case <-c: - err := sys.update() - if err != nil { - _ = sys.setDown(err) - } - } - } -} - -// update updates the system data and records. -// It first fetches the data from the agent then updates the records. -func (sys *System) update() error { - _, err := sys.fetchDataFromAgent() - if err == nil { - _, err = sys.createRecords() - } - return err -} - -// createRecords updates the system record and adds system_stats and container_stats records -func (sys *System) createRecords() (*core.Record, error) { - systemRecord, err := sys.getRecord() - if err != nil { - return nil, err - } - hub := sys.manager.hub - // add system_stats and container_stats records - systemStats, err := hub.FindCachedCollectionByNameOrId("system_stats") - if err != nil { - return nil, err - } - systemStatsRecord := core.NewRecord(systemStats) - systemStatsRecord.Set("system", systemRecord.Id) - systemStatsRecord.Set("stats", sys.data.Stats) - systemStatsRecord.Set("type", "1m") - if err := hub.SaveNoValidate(systemStatsRecord); err != nil { - return nil, err - } - // add new container_stats record - if len(sys.data.Containers) > 0 { - containerStats, err := hub.FindCachedCollectionByNameOrId("container_stats") - if err != nil { - return nil, err - } - containerStatsRecord := core.NewRecord(containerStats) - containerStatsRecord.Set("system", systemRecord.Id) - containerStatsRecord.Set("stats", sys.data.Containers) - containerStatsRecord.Set("type", "1m") - if err := hub.SaveNoValidate(containerStatsRecord); err != nil { - return nil, err - } - } - // update system record (do this last because it triggers alerts and we need above records to be inserted first) - systemRecord.Set("status", up) - systemRecord.Set("info", sys.data.Info) - if err := hub.SaveNoValidate(systemRecord); err != nil { - return nil, err - } - return systemRecord, nil -} - -// getRecord retrieves the system record from the database. -// If the record is not found or the system is paused, it removes the system from the manager. -func (sys *System) getRecord() (*core.Record, error) { - record, err := sys.manager.hub.FindRecordById("systems", sys.Id) - if err != nil || record == nil { - _ = sys.manager.RemoveSystem(sys.Id) - return nil, err - } - return record, nil -} - -// setDown marks a system as down in the database. -// It takes the original error that caused the system to go down and returns any error -// encountered during the process of updating the system status. -func (sys *System) setDown(OriginalError error) error { - if sys.Status == down { - return nil - } - record, err := sys.getRecord() - if err != nil { - return err - } - sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", OriginalError) - record.Set("status", down) - err = sys.manager.hub.SaveNoValidate(record) - if err != nil { - return err - } - return nil -} - -// fetchDataFromAgent fetches the data from the agent. -// It first creates a new SSH client if it doesn't exist or the system is down. -// Then it creates a new SSH session and fetches the data from the agent. -// If the data is not found or the system is down, it sets the system down. -func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) { - maxRetries := 1 - for attempt := 0; attempt <= maxRetries; attempt++ { - if sys.client == nil || sys.Status == down { - if err := sys.createSSHClient(); err != nil { - return nil, err - } - } - - session, err := sys.createSessionWithTimeout(4 * time.Second) - if err != nil { - if attempt >= maxRetries { - return nil, err - } - sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err) - sys.resetSSHClient() - continue - } - defer session.Close() - - stdout, err := session.StdoutPipe() - if err != nil { - return nil, err - } - if err := session.Shell(); err != nil { - return nil, err - } - - // this is initialized in startUpdater, should never be nil - *sys.data = system.CombinedData{} - if err := json.NewDecoder(stdout).Decode(sys.data); err != nil { - return nil, err - } - // wait for the session to complete - if err := session.Wait(); err != nil { - return nil, err - } - return sys.data, nil - } - - // this should never be reached due to the return in the loop - return nil, fmt.Errorf("failed to fetch data") -} - -// createSSHClientConfig initializes the ssh config for the system manager -func (sm *SystemManager) createSSHClientConfig() error { - privateKey, err := sm.hub.GetSSHKey(sm.hub.DataDir()) - if err != nil { - return err - } - sm.sshConfig = &ssh.ClientConfig{ - User: "u", - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(privateKey), - }, - Config: ssh.Config{ - Ciphers: common.DefaultCiphers, - KeyExchanges: common.DefaultKeyExchanges, - MACs: common.DefaultMACs, - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Timeout: sessionTimeout, - } - return nil -} - -// createSSHClient creates a new SSH client for the system -func (s *System) createSSHClient() error { - network := "tcp" - host := s.Host - if strings.HasPrefix(host, "/") { - network = "unix" - } else { - host = net.JoinHostPort(host, s.Port) - } - var err error - s.client, err = ssh.Dial(network, host, s.manager.sshConfig) - if err != nil { - return err - } - return nil -} - -// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging -// in case of network issues -func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) { - if sys.client == nil { - return nil, fmt.Errorf("client not initialized") - } - - ctx, cancel := context.WithTimeout(sys.ctx, timeout) - defer cancel() - - sessionChan := make(chan *ssh.Session, 1) - errChan := make(chan error, 1) - - go func() { - if session, err := sys.client.NewSession(); err != nil { - errChan <- err - } else { - sessionChan <- session - } - }() - - select { - case session := <-sessionChan: - return session, nil - case err := <-errChan: - return nil, err - case <-ctx.Done(): - return nil, fmt.Errorf("timeout") - } -} - -// resetSSHClient closes the SSH connection and resets the client to nil -func (sys *System) resetSSHClient() { - if sys.client != nil { - sys.client.Close() - } - sys.client = nil -} - -// deactivateAlerts finds all triggered alerts for a system and sets them to false -func deactivateAlerts(app core.App, systemID string) error { - // we can't use an UPDATE query because it doesn't work with realtime updates - // _, err := e.App.DB().NewQuery(fmt.Sprintf("UPDATE alerts SET triggered = false WHERE system = '%s'", e.Record.Id)).Execute() - alerts, err := app.FindRecordsByFilter("alerts", fmt.Sprintf("system = '%s' && triggered = 1", systemID), "", -1, 0) - if err != nil { - return err - } - for _, alert := range alerts { - alert.Set("triggered", false) - if err := app.SaveNoValidate(alert); err != nil { - return err - } - } - return nil -} diff --git a/beszel/internal/hub/systems/systems_test.go b/beszel/internal/hub/systems/systems_test.go index aa4d803..107415c 100644 --- a/beszel/internal/hub/systems/systems_test.go +++ b/beszel/internal/hub/systems/systems_test.go @@ -11,70 +11,133 @@ import ( "fmt" "sync" "testing" + "testing/synctest" "time" - "github.com/pocketbase/dbx" - "github.com/pocketbase/pocketbase/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// createTestSystem creates a test system record with a unique host name -// and returns the created record and any error -func createTestSystem(t *testing.T, hub *tests.TestHub, options map[string]any) (*core.Record, error) { - collection, err := hub.FindCachedCollectionByNameOrId("systems") - if err != nil { - return nil, err - } - - // get user record - var firstUser *core.Record - users, err := hub.FindAllRecords("users", dbx.NewExp("id != ''")) - if err != nil { - t.Fatal(err) - } - if len(users) > 0 { - firstUser = users[0] - } - // Generate a unique host name to ensure we're adding a new system - uniqueHost := fmt.Sprintf("test-host-%d.example.com", time.Now().UnixNano()) - - // Create the record - record := core.NewRecord(collection) - record.Set("name", uniqueHost) - record.Set("host", uniqueHost) - record.Set("port", "45876") - record.Set("status", "pending") - record.Set("users", []string{firstUser.Id}) - - // Apply any custom options - for key, value := range options { - record.Set(key, value) - } - - // Save the record to the database - err = hub.Save(record) - if err != nil { - return nil, err - } - - return record, nil -} - -func TestSystemManagerIntegration(t *testing.T) { - // Create a test hub - hub, err := tests.NewTestHub() +func TestSystemManagerNew(t *testing.T) { + hub, err := tests.NewTestHub(t.TempDir()) if err != nil { t.Fatal(err) } defer hub.Cleanup() + sm := hub.GetSystemManager() - // Create independent system manager - sm := systems.NewSystemManager(hub) + user, err := tests.CreateUser(hub, "test@test.com", "testtesttest") + require.NoError(t, err) + + synctest.Run(func() { + sm.Initialize() + + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "it-was-coney-island", + "host": "the-playground-of-the-world", + "port": "33914", + "users": []string{user.Id}, + }) + require.NoError(t, err) + + assert.Equal(t, "pending", record.GetString("status"), "System status should be 'pending'") + assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'") + + // Verify the system host and port + host, port := sm.GetSystemHostPort(record.Id) + assert.Equal(t, record.GetString("host"), host, "System host should match") + assert.Equal(t, record.GetString("port"), port, "System port should match") + + time.Sleep(13 * time.Second) + synctest.Wait() + + assert.Equal(t, "pending", record.Fresh().GetString("status"), "System status should be 'pending'") + // Verify the system was added by checking if it exists + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") + + time.Sleep(10 * time.Second) + synctest.Wait() + + // system should be set to down after 15 seconds (no websocket connection) + assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'") + // make sure the system is down in the db + record, err = hub.FindRecordById("systems", record.Id) + require.NoError(t, err) + assert.Equal(t, "down", record.GetString("status"), "System status should be 'down'") + + assert.Equal(t, 1, sm.GetSystemCount(), "System count should be 1") + + err = sm.RemoveSystem(record.Id) + assert.NoError(t, err) + + assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0") + assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal") + + // let's also make sure a system is removed from the store when the record is deleted + record, err = tests.CreateRecord(hub, "systems", map[string]any{ + "name": "there-was-no-place-like-it", + "host": "in-the-whole-world", + "port": "33914", + "users": []string{user.Id}, + }) + require.NoError(t, err) + + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store after creation") + + time.Sleep(8 * time.Second) + synctest.Wait() + assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'") + + sm.SetSystemStatusInDB(record.Id, "up") + time.Sleep(time.Second) + synctest.Wait() + assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'") + + // make sure the system switches to down after 11 seconds + sm.RemoveSystem(record.Id) + sm.AddRecord(record, nil) + assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'") + time.Sleep(12 * time.Second) + synctest.Wait() + assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'") + + // sm.SetSystemStatusInDB(record.Id, "paused") + // time.Sleep(time.Second) + // synctest.Wait() + // assert.Equal(t, "paused", sm.GetSystemStatusFromStore(record.Id), "System status should be 'paused'") + + // delete the record + err = hub.Delete(record) + require.NoError(t, err) + assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion") + + testOld(t, hub) + + time.Sleep(time.Second) + synctest.Wait() + + for _, systemId := range sm.GetAllSystemIDs() { + err = sm.RemoveSystem(systemId) + require.NoError(t, err) + assert.False(t, sm.HasSystem(systemId), "System should not exist in the store after deletion") + } + + assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0") + + // TODO: test with websocket client + }) +} + +func testOld(t *testing.T, hub *tests.TestHub) { + user, err := tests.CreateUser(hub, "test@testy.com", "testtesttest") + require.NoError(t, err) + + sm := hub.GetSystemManager() assert.NotNil(t, sm) - // Test initialization - sm.Initialize() + // error expected when creating a user with a duplicate email + _, err = tests.CreateUser(hub, "test@test.com", "testtesttest") + require.Error(t, err) // Test collection existence. todo: move to hub package tests t.Run("CollectionExistence", func(t *testing.T) { @@ -92,81 +155,17 @@ func TestSystemManagerIntegration(t *testing.T) { assert.NotNil(t, containerStats) }) - // Test adding a system record - t.Run("AddRecord", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - // Get the count before adding the system - countBefore := sm.GetSystemCount() - - // record should be pending on create - hub.OnRecordCreate("systems").BindFunc(func(e *core.RecordEvent) error { - record := e.Record - if record.GetString("name") == "welcometoarcoampm" { - assert.Equal(t, "pending", e.Record.GetString("status"), "System status should be 'pending'") - wg.Done() - } - return e.Next() - }) - - // record should be down on update - hub.OnRecordAfterUpdateSuccess("systems").BindFunc(func(e *core.RecordEvent) error { - record := e.Record - if record.GetString("name") == "welcometoarcoampm" { - assert.Equal(t, "down", e.Record.GetString("status"), "System status should be 'pending'") - wg.Done() - } - return e.Next() - }) - // Create a test system with the first user assigned - record, err := createTestSystem(t, hub, map[string]any{ - "name": "welcometoarcoampm", - "host": "localhost", - "port": "33914", - }) - require.NoError(t, err) - - wg.Wait() - - // system should be down if grabbed from the store - assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'") - - // Check that the system count increased - countAfter := sm.GetSystemCount() - assert.Equal(t, countBefore+1, countAfter, "System count should increase after adding a system via event hook") - - // Verify the system was added by checking if it exists - assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") - - // Verify the system host and port - host, port := sm.GetSystemHostPort(record.Id) - assert.Equal(t, record.Get("host"), host, "System host should match") - assert.Equal(t, record.Get("port"), port, "System port should match") - - // Verify the system is in the list of all system IDs - ids := sm.GetAllSystemIDs() - assert.Contains(t, ids, record.Id, "System ID should be in the list of all system IDs") - - // Verify the system was added by checking if removing it works - err = sm.RemoveSystem(record.Id) - assert.NoError(t, err, "System should exist and be removable") - - // Verify the system no longer exists - assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal") - - // Verify the system is not in the list of all system IDs - newIds := sm.GetAllSystemIDs() - assert.NotContains(t, newIds, record.Id, "System ID should not be in the list of all system IDs after removal") - - }) - t.Run("RemoveSystem", func(t *testing.T) { // Get the count before adding the system countBefore := sm.GetSystemCount() // Create a test system record - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "i-even-got-lost-at-coney-island", + "host": "but-they-found-me", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Verify the system count increased @@ -202,11 +201,16 @@ func TestSystemManagerIntegration(t *testing.T) { t.Run("NewRecordPending", func(t *testing.T) { // Create a test system - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "and-you-know", + "host": "i-feel-very-bad", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Add the record to the system manager - err = sm.AddRecord(record) + err = sm.AddRecord(record, nil) require.NoError(t, err) // Test filtering records by status - should be "pending" now @@ -218,11 +222,16 @@ func TestSystemManagerIntegration(t *testing.T) { t.Run("SystemStatusUpdate", func(t *testing.T) { // Create a test system record - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "we-used-to-sleep-on-the-beach", + "host": "sleep-overnight-here", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Add the record to the system manager - err = sm.AddRecord(record) + err = sm.AddRecord(record, nil) require.NoError(t, err) // Test status changes @@ -244,7 +253,12 @@ func TestSystemManagerIntegration(t *testing.T) { t.Run("HandleSystemData", func(t *testing.T) { // Create a test system record - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "things-changed-you-know", + "host": "they-dont-sleep-anymore-on-the-beach", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Create test system data @@ -295,54 +309,14 @@ func TestSystemManagerIntegration(t *testing.T) { assert.Error(t, err) }) - t.Run("DeleteRecord", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - runs := 0 - - hub.OnRecordUpdate("systems").BindFunc(func(e *core.RecordEvent) error { - runs++ - record := e.Record - if record.GetString("name") == "deadflagblues" { - if runs == 1 { - assert.Equal(t, "up", e.Record.GetString("status"), "System status should be 'up'") - wg.Done() - } else if runs == 2 { - assert.Equal(t, "paused", e.Record.GetString("status"), "System status should be 'paused'") - wg.Done() - } - } - return e.Next() - }) - - // Create a test system record - record, err := createTestSystem(t, hub, map[string]any{ - "name": "deadflagblues", - }) - require.NoError(t, err) - - // Verify the system exists - assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") - - // set the status manually to up - sm.SetSystemStatusInDB(record.Id, "up") - - // verify the status is up - assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'") - - // Set the status to "paused" which should cause it to be deleted from the store - sm.SetSystemStatusInDB(record.Id, "paused") - - wg.Wait() - - // Verify the system no longer exists - assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion") - }) - t.Run("ConcurrentOperations", func(t *testing.T) { // Create a test system - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "jfkjahkfajs", + "host": "localhost", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Run concurrent operations @@ -377,7 +351,12 @@ func TestSystemManagerIntegration(t *testing.T) { t.Run("ContextCancellation", func(t *testing.T) { // Create a test system record - record, err := createTestSystem(t, hub, map[string]any{}) + record, err := tests.CreateRecord(hub, "systems", map[string]any{ + "name": "lkhsdfsjf", + "host": "localhost", + "port": "33914", + "users": []string{user.Id}, + }) require.NoError(t, err) // Verify the system exists in the store @@ -420,7 +399,7 @@ func TestSystemManagerIntegration(t *testing.T) { assert.Error(t, err, "RemoveSystem should fail for non-existent system") // Add the system back - err = sm.AddRecord(record) + err = sm.AddRecord(record, nil) require.NoError(t, err, "AddRecord should succeed") // Verify the system is back in the store diff --git a/beszel/internal/hub/systems/systems_test_helpers.go b/beszel/internal/hub/systems/systems_test_helpers.go index 9822343..e74aa29 100644 --- a/beszel/internal/hub/systems/systems_test_helpers.go +++ b/beszel/internal/hub/systems/systems_test_helpers.go @@ -9,17 +9,17 @@ import ( "fmt" ) -// GetSystemCount returns the number of systems in the store +// TESTING ONLY: GetSystemCount returns the number of systems in the store func (sm *SystemManager) GetSystemCount() int { return sm.systems.Length() } -// HasSystem checks if a system with the given ID exists in the store +// TESTING ONLY: HasSystem checks if a system with the given ID exists in the store func (sm *SystemManager) HasSystem(systemID string) bool { return sm.systems.Has(systemID) } -// GetSystemStatusFromStore returns the status of a system with the given ID +// TESTING ONLY: GetSystemStatusFromStore returns the status of a system with the given ID // Returns an empty string if the system doesn't exist func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string { sys, ok := sm.systems.GetOk(systemID) @@ -29,7 +29,7 @@ func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string { return sys.Status } -// GetSystemContextFromStore returns the context and cancel function for a system +// TESTING ONLY: GetSystemContextFromStore returns the context and cancel function for a system func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Context, context.CancelFunc, error) { sys, ok := sm.systems.GetOk(systemID) if !ok { @@ -38,7 +38,7 @@ func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Con return sys.ctx, sys.cancel, nil } -// GetSystemFromStore returns a store from the system +// TESTING ONLY: GetSystemFromStore returns a store from the system func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) { sys, ok := sm.systems.GetOk(systemID) if !ok { @@ -47,7 +47,7 @@ func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) { return sys, nil } -// GetAllSystemIDs returns a slice of all system IDs in the store +// TESTING ONLY: GetAllSystemIDs returns a slice of all system IDs in the store func (sm *SystemManager) GetAllSystemIDs() []string { data := sm.systems.GetAll() ids := make([]string, 0, len(data)) @@ -57,7 +57,7 @@ func (sm *SystemManager) GetAllSystemIDs() []string { return ids } -// GetSystemData returns the combined data for a system with the given ID +// TESTING ONLY: GetSystemData returns the combined data for a system with the given ID // Returns nil if the system doesn't exist // This method is intended for testing func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData { @@ -68,7 +68,7 @@ func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData { return sys.data } -// GetSystemHostPort returns the host and port for a system with the given ID +// TESTING ONLY: GetSystemHostPort returns the host and port for a system with the given ID // Returns empty strings if the system doesn't exist func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) { sys, ok := sm.systems.GetOk(systemID) @@ -78,22 +78,7 @@ func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) { return sys.Host, sys.Port } -// DisableAutoUpdater disables the automatic updater for a system -// This is intended for testing -// Returns false if the system doesn't exist -// func (sm *SystemManager) DisableAutoUpdater(systemID string) bool { -// sys, ok := sm.systems.GetOk(systemID) -// if !ok { -// return false -// } -// if sys.cancel != nil { -// sys.cancel() -// sys.cancel = nil -// } -// return true -// } - -// SetSystemStatusInDB sets the status of a system directly and updates the database record +// TESTING ONLY: SetSystemStatusInDB sets the status of a system directly and updates the database record // This is intended for testing // Returns false if the system doesn't exist func (sm *SystemManager) SetSystemStatusInDB(systemID string, status string) bool { diff --git a/beszel/internal/hub/ws/ws.go b/beszel/internal/hub/ws/ws.go new file mode 100644 index 0000000..de05f9d --- /dev/null +++ b/beszel/internal/hub/ws/ws.go @@ -0,0 +1,181 @@ +package ws + +import ( + "beszel/internal/common" + "beszel/internal/entities/system" + "errors" + "time" + "weak" + + "github.com/fxamacker/cbor/v2" + "github.com/lxzan/gws" + "golang.org/x/crypto/ssh" +) + +const ( + deadline = 70 * time.Second +) + +// Handler implements the WebSocket event handler for agent connections. +type Handler struct { + gws.BuiltinEventHandler +} + +// WsConn represents a WebSocket connection to an agent. +type WsConn struct { + conn *gws.Conn + responseChan chan *gws.Message + DownChan chan struct{} +} + +// FingerprintRecord is fingerprints collection record data in the hub +type FingerprintRecord struct { + Id string `db:"id"` + SystemId string `db:"system"` + Fingerprint string `db:"fingerprint"` + Token string `db:"token"` +} + +var upgrader *gws.Upgrader + +// GetUpgrader returns a singleton WebSocket upgrader instance. +func GetUpgrader() *gws.Upgrader { + if upgrader != nil { + return upgrader + } + handler := &Handler{} + upgrader = gws.NewUpgrader(handler, &gws.ServerOption{}) + return upgrader +} + +// NewWsConnection creates a new WebSocket connection wrapper. +func NewWsConnection(conn *gws.Conn) *WsConn { + return &WsConn{ + conn: conn, + responseChan: make(chan *gws.Message, 1), + DownChan: make(chan struct{}, 1), + } +} + +// OnOpen sets a deadline for the WebSocket connection. +func (h *Handler) OnOpen(conn *gws.Conn) { + conn.SetDeadline(time.Now().Add(deadline)) +} + +// OnMessage routes incoming WebSocket messages to the response channel. +func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) { + conn.SetDeadline(time.Now().Add(deadline)) + if message.Opcode != gws.OpcodeBinary || message.Data.Len() == 0 { + return + } + wsConn, ok := conn.Session().Load("wsConn") + if !ok { + _ = conn.WriteClose(1000, nil) + return + } + select { + case wsConn.(*WsConn).responseChan <- message: + default: + // close if the connection is not expecting a response + wsConn.(*WsConn).Close() + } +} + +// OnClose handles WebSocket connection closures and triggers system down status after delay. +func (h *Handler) OnClose(conn *gws.Conn, err error) { + wsConn, ok := conn.Session().Load("wsConn") + if !ok { + return + } + wsConn.(*WsConn).conn = nil + // wait 5 seconds to allow reconnection before setting system down + // use a weak pointer to avoid keeping references if the system is removed + go func(downChan weak.Pointer[chan struct{}]) { + time.Sleep(5 * time.Second) + downChanValue := downChan.Value() + if downChanValue != nil { + *downChanValue <- struct{}{} + } + }(weak.Make(&wsConn.(*WsConn).DownChan)) +} + +// Close terminates the WebSocket connection gracefully. +func (ws *WsConn) Close() { + if ws.IsConnected() { + ws.conn.WriteClose(1000, nil) + } +} + +// Ping sends a ping frame to keep the connection alive. +func (ws *WsConn) Ping() error { + ws.conn.SetDeadline(time.Now().Add(deadline)) + return ws.conn.WritePing(nil) +} + +// sendMessage encodes data to CBOR and sends it as a binary message to the agent. +func (ws *WsConn) sendMessage(data common.HubRequest[any]) error { + bytes, err := cbor.Marshal(data) + if err != nil { + return err + } + return ws.conn.WriteMessage(gws.OpcodeBinary, bytes) +} + +// RequestSystemData requests system metrics from the agent and unmarshals the response. +func (ws *WsConn) RequestSystemData(data *system.CombinedData) error { + var message *gws.Message + + ws.sendMessage(common.HubRequest[any]{ + Action: common.GetData, + }) + select { + case <-time.After(10 * time.Second): + ws.Close() + return gws.ErrConnClosed + case message = <-ws.responseChan: + } + defer message.Close() + return cbor.Unmarshal(message.Data.Bytes(), data) +} + +// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint. +func (ws *WsConn) GetFingerprint(token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) { + challenge := []byte(token) + + signature, err := signer.Sign(nil, challenge) + if err != nil { + return common.FingerprintResponse{}, err + } + + err = ws.sendMessage(common.HubRequest[any]{ + Action: common.CheckFingerprint, + Data: common.FingerprintRequest{ + Signature: signature.Blob, + NeedSysInfo: needSysInfo, + }, + }) + if err != nil { + return common.FingerprintResponse{}, err + } + + var message *gws.Message + var clientFingerprint common.FingerprintResponse + select { + case message = <-ws.responseChan: + case <-time.After(10 * time.Second): + return common.FingerprintResponse{}, errors.New("request expired") + } + defer message.Close() + + err = cbor.Unmarshal(message.Data.Bytes(), &clientFingerprint) + if err != nil { + return common.FingerprintResponse{}, err + } + + return clientFingerprint, nil +} + +// IsConnected returns true if the WebSocket connection is active. +func (ws *WsConn) IsConnected() bool { + return ws.conn != nil +} diff --git a/beszel/internal/tests/hub.go b/beszel/internal/tests/hub.go index 9ac3556..914f1ee 100644 --- a/beszel/internal/tests/hub.go +++ b/beszel/internal/tests/hub.go @@ -1,3 +1,6 @@ +//go:build testing +// +build testing + // Package tests provides helpers for testing the application. package tests @@ -56,3 +59,30 @@ func NewTestHubWithConfig(config core.BaseAppConfig) (*TestHub, error) { return t, nil } + +// Helper function to create a test user for config tests +func CreateUser(app core.App, email string, password string) (*core.Record, error) { + userCollection, err := app.FindCachedCollectionByNameOrId("users") + if err != nil { + return nil, err + } + + user := core.NewRecord(userCollection) + user.Set("email", email) + user.Set("password", password) + + return user, app.Save(user) +} + +// Helper function to create a test record +func CreateRecord(app core.App, collectionName string, fields map[string]any) (*core.Record, error) { + collection, err := app.FindCachedCollectionByNameOrId(collectionName) + if err != nil { + return nil, err + } + + record := core.NewRecord(collection) + record.Load(fields) + + return record, app.Save(record) +} diff --git a/beszel/migrations/collections_snapshot_0_10_2.go b/beszel/migrations/collections_snapshot_0_12_0.go similarity index 82% rename from beszel/migrations/collections_snapshot_0_10_2.go rename to beszel/migrations/collections_snapshot_0_12_0.go index aecbb8a..4ab1f78 100644 --- a/beszel/migrations/collections_snapshot_0_10_2.go +++ b/beszel/migrations/collections_snapshot_0_12_0.go @@ -1,23 +1,14 @@ package migrations import ( + "github.com/google/uuid" "github.com/pocketbase/pocketbase/core" m "github.com/pocketbase/pocketbase/migrations" ) func init() { m.Register(func(app core.App) error { - // delete duplicate alerts - app.DB().NewQuery(` - DELETE FROM alerts - WHERE rowid NOT IN ( - SELECT MAX(rowid) - FROM alerts - GROUP BY user, system, name - ); - `).Execute() - - // import collections + // update collections jsonData := `[ { "id": "elngm8x1l60zi2v", @@ -236,6 +227,88 @@ func init() { ], "system": false }, + { + "id": "pbc_3663931638", + "listRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id", + "viewRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id", + "createRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id && @request.auth.role != \"readonly\"", + "updateRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id && @request.auth.role != \"readonly\"", + "deleteRule": null, + "name": "fingerprints", + "type": "base", + "fields": [ + { + "autogeneratePattern": "[a-z0-9]{9}", + "hidden": false, + "id": "text3208210256", + "max": 15, + "min": 9, + "name": "id", + "pattern": "^[a-z0-9]+$", + "presentable": false, + "primaryKey": true, + "required": true, + "system": true, + "type": "text" + }, + { + "cascadeDelete": true, + "collectionId": "2hz5ncl8tizk5nx", + "hidden": false, + "id": "relation3377271179", + "maxSelect": 1, + "minSelect": 0, + "name": "system", + "presentable": false, + "required": true, + "system": false, + "type": "relation" + }, + { + "autogeneratePattern": "[a-zA-Z9-9]{20}", + "hidden": false, + "id": "text1597481275", + "max": 255, + "min": 9, + "name": "token", + "pattern": "", + "presentable": false, + "primaryKey": false, + "required": true, + "system": false, + "type": "text" + }, + { + "autogeneratePattern": "", + "hidden": false, + "id": "text4228609354", + "max": 255, + "min": 9, + "name": "fingerprint", + "pattern": "", + "presentable": false, + "primaryKey": false, + "required": false, + "system": false, + "type": "text" + }, + { + "hidden": false, + "id": "autodate3332085495", + "name": "updated", + "onCreate": true, + "onUpdate": true, + "presentable": false, + "system": false, + "type": "autodate" + } + ], + "indexes": [ + "CREATE INDEX ` + "`" + `idx_p9qZlu26po` + "`" + ` ON ` + "`" + `fingerprints` + "`" + ` (` + "`" + `token` + "`" + `)", + "CREATE UNIQUE INDEX ` + "`" + `idx_ngboulGMYw` + "`" + ` ON ` + "`" + `fingerprints` + "`" + ` (` + "`" + `system` + "`" + `)" + ], + "system": false + }, { "id": "ej9oowivz8b2mht", "listRule": "@request.auth.id != \"\"", @@ -669,7 +742,38 @@ func init() { } ]` - return app.ImportCollectionsByMarshaledJSON([]byte(jsonData), false) + err := app.ImportCollectionsByMarshaledJSON([]byte(jsonData), false) + if err != nil { + return err + } + + // Get all systems that don't have fingerprint records + var systemIds []string + err = app.DB().NewQuery(` + SELECT s.id FROM systems s + LEFT JOIN fingerprints f ON s.id = f.system + WHERE f.system IS NULL + `).Column(&systemIds) + + if err != nil { + return err + } + // Create fingerprint records with unique UUID tokens for each system + for _, systemId := range systemIds { + token := uuid.New().String() + _, err = app.DB().NewQuery(` + INSERT INTO fingerprints (system, token) + VALUES ({:system}, {:token}) + `).Bind(map[string]any{ + "system": systemId, + "token": token, + }).Execute() + if err != nil { + return err + } + } + + return nil }, func(app core.App) error { return nil }) diff --git a/beszel/site/index.html b/beszel/site/index.html index 8802354..3ebd725 100644 --- a/beszel/site/index.html +++ b/beszel/site/index.html @@ -9,7 +9,8 @@ diff --git a/beszel/site/src/components/add-system.tsx b/beszel/site/src/components/add-system.tsx index 09e6fce..54e5857 100644 --- a/beszel/site/src/components/add-system.tsx +++ b/beszel/site/src/components/add-system.tsx @@ -11,20 +11,27 @@ import { DialogTrigger, } from "@/components/ui/dialog" import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs" -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip" - import { Input } from "@/components/ui/input" import { Label } from "@/components/ui/label" import { $publicKey, pb } from "@/lib/stores" -import { cn, copyToClipboard, isReadOnlyUser, useLocalStorage } from "@/lib/utils" -import { i18n } from "@lingui/core" +import { cn, generateToken, isReadOnlyUser, tokenMap, useLocalStorage } from "@/lib/utils" import { useStore } from "@nanostores/react" -import { ChevronDownIcon, Copy, ExternalLinkIcon, PlusIcon } from "lucide-react" -import { memo, useRef, useState } from "react" -import { basePath, navigate } from "./router" -import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "./ui/dropdown-menu" +import { ChevronDownIcon, ExternalLinkIcon, PlusIcon } from "lucide-react" +import { memo, useEffect, useRef, useState } from "react" +import { $router, basePath, Link, navigate } from "./router" import { SystemRecord } from "@/types" import { AppleIcon, DockerIcon, TuxIcon, WindowsIcon } from "./ui/icons" +import { InputCopy } from "./ui/input-copy" +import { getPagePath } from "@nanostores/router" +import { + copyDockerCompose, + copyDockerRun, + copyLinuxCommand, + copyWindowsCommand, + DropdownItem, + InstallDropdown, +} from "./install-dropdowns" +import { DropdownMenu, DropdownMenuTrigger } from "./ui/dropdown-menu" export function AddSystemButton({ className }: { className?: string }) { const [open, setOpen] = useState(false) @@ -51,44 +58,11 @@ export function AddSystemButton({ className }: { className?: string }) { ) } -function copyDockerCompose(port = "45876", publicKey: string) { - copyToClipboard(`services: - beszel-agent: - image: "henrygd/beszel-agent" - container_name: "beszel-agent" - restart: unless-stopped - network_mode: host - volumes: - - /var/run/docker.sock:/var/run/docker.sock:ro - # monitor other disks / partitions by mounting a folder in /extra-filesystems - # - /mnt/disk/.beszel:/extra-filesystems/sda1:ro - environment: - LISTEN: ${port} - KEY: "${publicKey}"`) -} - -function copyDockerRun(port = "45876", publicKey: string) { - copyToClipboard( - `docker run -d --name beszel-agent --network host --restart unless-stopped -v /var/run/docker.sock:/var/run/docker.sock:ro -e KEY="${publicKey}" -e LISTEN=${port} henrygd/beszel-agent:latest` - ) -} - -function copyLinuxCommand(port = "45876", publicKey: string, brew = false) { - let cmd = `curl -sL https://get.beszel.dev${ - brew ? "/brew" : "" - } -o /tmp/install-agent.sh && chmod +x /tmp/install-agent.sh && /tmp/install-agent.sh -p ${port} -k "${publicKey}"` - // brew script does not support --china-mirrors - if (!brew && (i18n.locale + navigator.language).includes("zh-CN")) { - cmd += ` --china-mirrors` - } - copyToClipboard(cmd) -} - -function copyWindowsCommand(port = "45876", publicKey: string) { - copyToClipboard( - `& iwr -useb https://get.beszel.dev -OutFile "$env:TEMP\\install-agent.ps1"; & Powershell -ExecutionPolicy Bypass -File "$env:TEMP\\install-agent.ps1" -Key "${publicKey}" -Port ${port}` - ) -} +/** + * Token to be used for the next system. + * Prevents token changing if user copies config, then closes dialog and opens again. + */ +let nextSystemToken: string | null = null /** * SystemDialog component for adding or editing a system. @@ -96,12 +70,32 @@ function copyWindowsCommand(port = "45876", publicKey: string) { * @param {function} props.setOpen - Function to set the open state of the dialog. * @param {SystemRecord} [props.system] - Optional system record for editing an existing system. */ -export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean) => void; system?: SystemRecord }) => { +export const SystemDialog = ({ setOpen, system }: { setOpen: (open: boolean) => void; system?: SystemRecord }) => { const publicKey = useStore($publicKey) const port = useRef(null) const [hostValue, setHostValue] = useState(system?.host ?? "") const isUnixSocket = hostValue.startsWith("/") const [tab, setTab] = useLocalStorage("as-tab", "docker") + const [token, setToken] = useState(system?.token ?? "") + + useEffect(() => { + ;(async () => { + // if no system, generate a new token + if (!system) { + nextSystemToken ||= generateToken() + return setToken(nextSystemToken) + } + // if system exists,get the token from the fingerprint record + if (tokenMap.has(system.id)) { + return setToken(tokenMap.get(system.id)!) + } + const { token } = await pb.collection("fingerprints").getFirstListItem(`system = "${system.id}"`, { + fields: "token", + }) + tokenMap.set(system.id, token) + setToken(token) + })() + }, [system?.id]) async function handleSubmit(e: SubmitEvent) { e.preventDefault() @@ -113,12 +107,18 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean if (system) { await pb.collection("systems").update(system.id, { ...data, status: "pending" }) } else { - await pb.collection("systems").create(data) + const createdSystem = await pb.collection("systems").create(data) + await pb.collection("fingerprints").create({ + system: createdSystem.id, + token, + }) + // Reset the current token after successful system + // creation so next system gets a new token + nextSystemToken = null } navigate(basePath) - // console.log(record) } catch (e) { - console.log(e) + console.error(e) } } @@ -143,18 +143,37 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean {/* Docker (set tab index to prevent auto focusing content in edit system dialog) */} - + - The agent must be running on the system to connect. Copy the - docker-compose.yml for the agent below. + Copy the + docker-compose.yml content for the agent + below, or register agents automatically with a{" "} + setOpen(false)} + href={getPagePath($router, "settings", { name: "tokens" })} + className="link" + > + universal token + + . {/* Binary */} - + - The agent must be running on the system to connect. Copy the installation command for the agent below. + Copy the installation command for the agent below, or register agents automatically with a{" "} + { + setOpen(false) + }} + href={getPagePath($router, "settings", { name: "tokens" })} + className="link" + > + universal token + + . @@ -190,46 +209,27 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean -
- -
- - - - - - -

- Click to copy -

-
-
-
-
+ + + {/* Docker */} copyDockerCompose(isUnixSocket ? hostValue : port.current?.value, publicKey)} + onClick={async () => + copyDockerCompose(isUnixSocket ? hostValue : port.current?.value, publicKey, token) + } icon={} dropdownItems={[ { text: t({ message: "Copy docker run", context: "Button to copy docker run command" }), - onClick: () => copyDockerRun(isUnixSocket ? hostValue : port.current?.value, publicKey), - icons: [], + onClick: async () => + copyDockerRun(isUnixSocket ? hostValue : port.current?.value, publicKey, token), + icons: [DockerIcon], }, ]} /> @@ -239,22 +239,24 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean } - onClick={() => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey)} + onClick={async () => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token)} dropdownItems={[ { text: t({ message: "Homebrew command", context: "Button to copy install command" }), - onClick: () => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, true), - icons: [, ], + onClick: async () => + copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token, true), + icons: [AppleIcon, TuxIcon], }, { text: t({ message: "Windows command", context: "Button to copy install command" }), - onClick: () => copyWindowsCommand(isUnixSocket ? hostValue : port.current?.value, publicKey), - icons: [], + onClick: async () => + copyWindowsCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token), + icons: [WindowsIcon], }, { text: t`Manual setup instructions`, url: "https://beszel.dev/guide/agent-installation#binary", - icons: [], + icons: [ExternalLinkIcon], }, ]} /> @@ -266,20 +268,13 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean ) -}) - -interface DropdownItem { - text: string - onClick?: () => void - url?: string - icons?: React.ReactNode[] } interface CopyButtonProps { text: string onClick: () => void dropdownItems: DropdownItem[] - icon?: React.ReactNode + icon?: React.ReactElement } const CopyButton = memo((props: CopyButtonProps) => { @@ -300,22 +295,7 @@ const CopyButton = memo((props: CopyButtonProps) => { - - {props.dropdownItems.map((item, index) => { - const className = "cursor-pointer flex items-center gap-1.5" - return item.url ? ( - - - {item.text} {item.icons?.map((icon) => icon)} - - - ) : ( - - {item.text} {item.icons?.map((icon) => icon)} - - ) - })} - + ) diff --git a/beszel/site/src/components/command-palette.tsx b/beszel/site/src/components/command-palette.tsx index 220af51..5dbea0c 100644 --- a/beszel/site/src/components/command-palette.tsx +++ b/beszel/site/src/components/command-palette.tsx @@ -1,6 +1,7 @@ import { BookIcon, DatabaseBackupIcon, + FingerprintIcon, LayoutDashboard, LogsIcon, MailIcon, @@ -40,6 +41,16 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean; return useMemo(() => { const systems = $systems.get() + const SettingsShortcut = ( + + Settings + + ) + const AdminShortcut = ( + + Admin + + ) return ( @@ -93,9 +104,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean; Settings - - Settings - + {SettingsShortcut} Notifications - - Settings - + {SettingsShortcut} + + { + navigate(getPagePath($router, "settings", { name: "tokens" })) + setOpen(false) + }} + > + + + Tokens & Fingerprints + + {SettingsShortcut} Users - - Admin - + {AdminShortcut} { @@ -154,9 +171,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean; Logs - - Admin - + {AdminShortcut} { @@ -168,9 +183,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean; Backups - - Admin - + {AdminShortcut} SMTP settings - - Admin - + {AdminShortcut} diff --git a/beszel/site/src/components/install-dropdowns.tsx b/beszel/site/src/components/install-dropdowns.tsx new file mode 100644 index 0000000..c2a84d3 --- /dev/null +++ b/beszel/site/src/components/install-dropdowns.tsx @@ -0,0 +1,97 @@ +import { memo } from "react" +import { DropdownMenuContent, DropdownMenuItem } from "./ui/dropdown-menu" +import { copyToClipboard, getHubURL } from "@/lib/utils" +import { i18n } from "@lingui/core" + +const isBeta = BESZEL.HUB_VERSION.includes("beta") +const imageTag = isBeta ? ":beta" : "" + +/** + * Get the URL of the script to install the agent. + * @param path - The path to the script (e.g. "/brew"). + * @returns The URL for the script. + */ +const getScriptUrl = (path: string = "") => { + const url = new URL("https://get.beszel.dev") + url.pathname = path + if (isBeta) { + url.searchParams.set("beta", "1") + } + return url.toString() +} + +export function copyDockerCompose(port = "45876", publicKey: string, token: string) { + copyToClipboard(`services: + beszel-agent: + image: henrygd/beszel-agent${imageTag} + container_name: beszel-agent + restart: unless-stopped + network_mode: host + volumes: + - /var/run/docker.sock:/var/run/docker.sock:ro + - ./beszel_agent_data:/var/lib/beszel-agent + # monitor other disks / partitions by mounting a folder in /extra-filesystems + # - /mnt/disk/.beszel:/extra-filesystems/sda1:ro + environment: + LISTEN: ${port} + KEY: '${publicKey}' + TOKEN: ${token} + HUB_URL: ${getHubURL()}`) +} + +export function copyDockerRun(port = "45876", publicKey: string, token: string) { + copyToClipboard( + `docker run -d --name beszel-agent --network host --restart unless-stopped -v /var/run/docker.sock:/var/run/docker.sock:ro -v ./beszel_agent_data:/var/lib/beszel-agent -e KEY="${publicKey}" -e LISTEN=${port} -e TOKEN="${token}" -e HUB_URL="${getHubURL()}" henrygd/beszel-agent${imageTag}` + ) +} + +export function copyLinuxCommand(port = "45876", publicKey: string, token: string, brew = false) { + let cmd = `curl -sL ${getScriptUrl( + brew ? "/brew" : "" + )} -o /tmp/install-agent.sh && chmod +x /tmp/install-agent.sh && /tmp/install-agent.sh -p ${port} -k "${publicKey}" -t "${token}" -url "${getHubURL()}"` + // brew script does not support --china-mirrors + if (!brew && (i18n.locale + navigator.language).includes("zh-CN")) { + cmd += ` --china-mirrors` + } + copyToClipboard(cmd) +} + +export function copyWindowsCommand(port = "45876", publicKey: string, token: string) { + copyToClipboard( + `& iwr -useb ${getScriptUrl()} -OutFile "$env:TEMP\\install-agent.ps1"; & Powershell -ExecutionPolicy Bypass -File "$env:TEMP\\install-agent.ps1" -Key "${publicKey}" -Port ${port} -Token "${token}" -Url "${getHubURL()}"` + ) +} + +export interface DropdownItem { + text: string + onClick?: () => void + url?: string + icons?: React.ComponentType>[] +} + +export const InstallDropdown = memo(({ items }: { items: DropdownItem[] }) => { + return ( + + {items.map((item, index) => { + const className = "cursor-pointer flex items-center gap-1.5" + return item.url ? ( + + + {item.text}{" "} + {item.icons?.map((Icon, iconIndex) => ( + + ))} + + + ) : ( + + {item.text}{" "} + {item.icons?.map((Icon, iconIndex) => ( + + ))} + + ) + })} + + ) +}) diff --git a/beszel/site/src/components/routes/settings/layout.tsx b/beszel/site/src/components/routes/settings/layout.tsx index 9236500..0af2758 100644 --- a/beszel/site/src/components/routes/settings/layout.tsx +++ b/beszel/site/src/components/routes/settings/layout.tsx @@ -7,7 +7,7 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com import { useStore } from "@nanostores/react" import { $router } from "@/components/router.tsx" import { getPagePath, redirectPage } from "@nanostores/router" -import { BellIcon, FileSlidersIcon, SettingsIcon } from "lucide-react" +import { BellIcon, FileSlidersIcon, FingerprintIcon, SettingsIcon } from "lucide-react" import { $userSettings, pb } from "@/lib/stores.ts" import { toast } from "@/components/ui/use-toast.ts" import { UserSettings } from "@/types.js" @@ -15,6 +15,7 @@ import General from "./general.tsx" import Notifications from "./notifications.tsx" import ConfigYaml from "./config-yaml.tsx" import { useLingui } from "@lingui/react/macro" +import Fingerprints from "./tokens-fingerprints.tsx" export async function saveSettings(newSettings: Partial) { try { @@ -58,6 +59,12 @@ export default function SettingsLayout() { href: getPagePath($router, "settings", { name: "notifications" }), icon: BellIcon, }, + { + title: t`Tokens & Fingerprints`, + href: getPagePath($router, "settings", { name: "tokens" }), + icon: FingerprintIcon, + // admin: true, + }, { title: t`YAML Config`, href: getPagePath($router, "settings", { name: "config" }), @@ -77,7 +84,7 @@ export default function SettingsLayout() { }, []) return ( - + Settings @@ -89,10 +96,10 @@ export default function SettingsLayout() {
-