mirror of
https://github.com/fankes/beszel.git
synced 2025-10-18 17:29:28 +08:00
- Add version exchange between hub and agent. - Introduce ConnectionManager for managing WebSocket and SSH connections. - Implement fingerprint generation and storage in agent. - Create expiry map package to store universal tokens. - Update config.yml configuration to include tokens. - Enhance system management with new methods for handling system states and alerts. - Update front-end components to support token / fingerprint management features. - Introduce utility functions for token generation and hub URL retrieval. Co-authored-by: nhas <jordanatararimu@gmail.com>
This commit is contained in:
@@ -4,34 +4,55 @@ package agent
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/entities/system"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/shirou/gopsutil/v4/host"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
sync.Mutex // Used to lock agent while collecting data
|
||||
debug bool // true if LOG_LEVEL is set to debug
|
||||
zfs bool // true if system has arcstats
|
||||
memCalc string // Memory calculation formula
|
||||
fsNames []string // List of filesystem device names being monitored
|
||||
fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem
|
||||
netInterfaces map[string]struct{} // Stores all valid network interfaces
|
||||
netIoStats system.NetIoStats // Keeps track of bandwidth usage
|
||||
dockerManager *dockerManager // Manages Docker API requests
|
||||
sensorConfig *SensorConfig // Sensors config
|
||||
systemInfo system.Info // Host system info
|
||||
gpuManager *GPUManager // Manages GPU data
|
||||
cache *SessionCache // Cache for system stats based on primary session ID
|
||||
sync.Mutex // Used to lock agent while collecting data
|
||||
debug bool // true if LOG_LEVEL is set to debug
|
||||
zfs bool // true if system has arcstats
|
||||
memCalc string // Memory calculation formula
|
||||
fsNames []string // List of filesystem device names being monitored
|
||||
fsStats map[string]*system.FsStats // Keeps track of disk stats for each filesystem
|
||||
netInterfaces map[string]struct{} // Stores all valid network interfaces
|
||||
netIoStats system.NetIoStats // Keeps track of bandwidth usage
|
||||
dockerManager *dockerManager // Manages Docker API requests
|
||||
sensorConfig *SensorConfig // Sensors config
|
||||
systemInfo system.Info // Host system info
|
||||
gpuManager *GPUManager // Manages GPU data
|
||||
cache *SessionCache // Cache for system stats based on primary session ID
|
||||
connectionManager *ConnectionManager // Channel to signal connection events
|
||||
server *ssh.Server // SSH server
|
||||
dataDir string // Directory for persisting data
|
||||
keys []gossh.PublicKey // SSH public keys
|
||||
}
|
||||
|
||||
func NewAgent() *Agent {
|
||||
agent := &Agent{
|
||||
// NewAgent creates a new agent with the given data directory for persisting data.
|
||||
// If the data directory is not set, it will attempt to find the optimal directory.
|
||||
func NewAgent(dataDir string) (agent *Agent, err error) {
|
||||
agent = &Agent{
|
||||
fsStats: make(map[string]*system.FsStats),
|
||||
cache: NewSessionCache(69 * time.Second),
|
||||
}
|
||||
|
||||
agent.dataDir, err = getDataDir(dataDir)
|
||||
if err != nil {
|
||||
slog.Warn("Data directory not found")
|
||||
} else {
|
||||
slog.Info("Data directory", "path", agent.dataDir)
|
||||
}
|
||||
|
||||
agent.memCalc, _ = GetEnv("MEM_CALC")
|
||||
agent.sensorConfig = agent.newSensorConfig()
|
||||
// Set up slog with a log level determined by the LOG_LEVEL env var
|
||||
@@ -49,10 +70,19 @@ func NewAgent() *Agent {
|
||||
|
||||
slog.Debug(beszel.Version)
|
||||
|
||||
// initialize system info / docker manager
|
||||
// initialize system info
|
||||
agent.initializeSystemInfo()
|
||||
|
||||
// initialize connection manager
|
||||
agent.connectionManager = newConnectionManager(agent)
|
||||
|
||||
// initialize disk info
|
||||
agent.initializeDiskInfo()
|
||||
|
||||
// initialize net io stats
|
||||
agent.initializeNetIoStats()
|
||||
|
||||
// initialize docker manager
|
||||
agent.dockerManager = newDockerManager(agent)
|
||||
|
||||
// initialize GPU manager
|
||||
@@ -67,7 +97,7 @@ func NewAgent() *Agent {
|
||||
slog.Debug("Stats", "data", agent.gatherStats(""))
|
||||
}
|
||||
|
||||
return agent
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
// GetEnv retrieves an environment variable with a "BESZEL_AGENT_" prefix, or falls back to the unprefixed key.
|
||||
@@ -115,3 +145,38 @@ func (a *Agent) gatherStats(sessionID string) *system.CombinedData {
|
||||
a.cache.Set(sessionID, cachedData)
|
||||
return cachedData
|
||||
}
|
||||
|
||||
// StartAgent initializes and starts the agent with optional WebSocket connection
|
||||
func (a *Agent) Start(serverOptions ServerOptions) error {
|
||||
a.keys = serverOptions.Keys
|
||||
return a.connectionManager.Start(serverOptions)
|
||||
}
|
||||
|
||||
func (a *Agent) getFingerprint() string {
|
||||
// first look for a fingerprint in the data directory
|
||||
if a.dataDir != "" {
|
||||
if fp, err := os.ReadFile(filepath.Join(a.dataDir, "fingerprint")); err == nil {
|
||||
return string(fp)
|
||||
}
|
||||
}
|
||||
|
||||
// if no fingerprint is found, generate one
|
||||
fingerprint, err := host.HostID()
|
||||
if err != nil || fingerprint == "" {
|
||||
fingerprint = a.systemInfo.Hostname + a.systemInfo.CpuModel
|
||||
}
|
||||
|
||||
// hash fingerprint
|
||||
sum := sha256.Sum256([]byte(fingerprint))
|
||||
fingerprint = hex.EncodeToString(sum[:24])
|
||||
|
||||
// save fingerprint to data directory
|
||||
if a.dataDir != "" {
|
||||
err = os.WriteFile(filepath.Join(a.dataDir, "fingerprint"), []byte(fingerprint), 0644)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to save fingerprint", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
return fingerprint
|
||||
}
|
||||
|
9
beszel/internal/agent/agent_test_helpers.go
Normal file
9
beszel/internal/agent/agent_test_helpers.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package agent
|
||||
|
||||
// TESTING ONLY: GetConnectionManager is a helper function to get the connection manager for testing.
|
||||
func (a *Agent) GetConnectionManager() *ConnectionManager {
|
||||
return a.connectionManager
|
||||
}
|
243
beszel/internal/agent/client.go
Normal file
243
beszel/internal/agent/client.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/common"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/lxzan/gws"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
wsDeadline = 70 * time.Second
|
||||
)
|
||||
|
||||
// WebSocketClient manages the WebSocket connection between the agent and hub.
|
||||
// It handles authentication, message routing, and connection lifecycle management.
|
||||
type WebSocketClient struct {
|
||||
gws.BuiltinEventHandler
|
||||
options *gws.ClientOption // WebSocket client configuration options
|
||||
agent *Agent // Reference to the parent agent
|
||||
Conn *gws.Conn // Active WebSocket connection
|
||||
hubURL *url.URL // Parsed hub URL for connection
|
||||
token string // Authentication token for hub registration
|
||||
fingerprint string // System fingerprint for identification
|
||||
hubRequest *common.HubRequest[cbor.RawMessage] // Reusable request structure for message parsing
|
||||
lastConnectAttempt time.Time // Timestamp of last connection attempt
|
||||
hubVerified bool // Whether the hub has been cryptographically verified
|
||||
}
|
||||
|
||||
// newWebSocketClient creates a new WebSocket client for the given agent.
|
||||
// It reads configuration from environment variables and validates the hub URL.
|
||||
func newWebSocketClient(agent *Agent) (client *WebSocketClient, err error) {
|
||||
hubURLStr, exists := GetEnv("HUB_URL")
|
||||
if !exists {
|
||||
return nil, errors.New("HUB_URL environment variable not set")
|
||||
}
|
||||
|
||||
client = &WebSocketClient{}
|
||||
|
||||
client.hubURL, err = url.Parse(hubURLStr)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid hub URL")
|
||||
}
|
||||
// get registration token
|
||||
client.token, _ = GetEnv("TOKEN")
|
||||
if client.token == "" {
|
||||
return nil, errors.New("TOKEN environment variable not set")
|
||||
}
|
||||
|
||||
client.agent = agent
|
||||
client.hubRequest = &common.HubRequest[cbor.RawMessage]{}
|
||||
client.fingerprint = agent.getFingerprint()
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// getOptions returns the WebSocket client options, creating them if necessary.
|
||||
// It configures the connection URL, TLS settings, and authentication headers.
|
||||
func (client *WebSocketClient) getOptions() *gws.ClientOption {
|
||||
if client.options != nil {
|
||||
return client.options
|
||||
}
|
||||
|
||||
// update the hub url to use websocket scheme and api path
|
||||
if client.hubURL.Scheme == "https" {
|
||||
client.hubURL.Scheme = "wss"
|
||||
} else {
|
||||
client.hubURL.Scheme = "ws"
|
||||
}
|
||||
client.hubURL.Path = path.Join(client.hubURL.Path, "api/beszel/agent-connect")
|
||||
|
||||
client.options = &gws.ClientOption{
|
||||
Addr: client.hubURL.String(),
|
||||
TlsConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
RequestHeader: http.Header{
|
||||
"User-Agent": []string{getUserAgent()},
|
||||
"X-Token": []string{client.token},
|
||||
"X-Beszel": []string{beszel.Version},
|
||||
},
|
||||
}
|
||||
return client.options
|
||||
}
|
||||
|
||||
// Connect establishes a WebSocket connection to the hub.
|
||||
// It closes any existing connection before attempting to reconnect.
|
||||
func (client *WebSocketClient) Connect() (err error) {
|
||||
client.lastConnectAttempt = time.Now()
|
||||
|
||||
// make sure previous connection is closed
|
||||
client.Close()
|
||||
|
||||
client.Conn, _, err = gws.NewClient(client, client.getOptions())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go client.Conn.ReadLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnOpen handles WebSocket connection establishment.
|
||||
// It sets a deadline for the connection to prevent hanging.
|
||||
func (client *WebSocketClient) OnOpen(conn *gws.Conn) {
|
||||
conn.SetDeadline(time.Now().Add(wsDeadline))
|
||||
}
|
||||
|
||||
// OnClose handles WebSocket connection closure.
|
||||
// It logs the closure reason and notifies the connection manager.
|
||||
func (client *WebSocketClient) OnClose(conn *gws.Conn, err error) {
|
||||
slog.Warn("Connection closed", "err", strings.TrimPrefix(err.Error(), "gws: "))
|
||||
client.agent.connectionManager.eventChan <- WebSocketDisconnect
|
||||
}
|
||||
|
||||
// OnMessage handles incoming WebSocket messages from the hub.
|
||||
// It decodes CBOR messages and routes them to appropriate handlers.
|
||||
func (client *WebSocketClient) OnMessage(conn *gws.Conn, message *gws.Message) {
|
||||
defer message.Close()
|
||||
conn.SetDeadline(time.Now().Add(wsDeadline))
|
||||
|
||||
if message.Opcode != gws.OpcodeBinary {
|
||||
return
|
||||
}
|
||||
|
||||
if err := cbor.NewDecoder(message.Data).Decode(client.hubRequest); err != nil {
|
||||
slog.Error("Error parsing message", "err", err)
|
||||
return
|
||||
}
|
||||
if err := client.handleHubRequest(client.hubRequest); err != nil {
|
||||
slog.Error("Error handling message", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// OnPing handles WebSocket ping frames.
|
||||
// It responds with a pong and updates the connection deadline.
|
||||
func (client *WebSocketClient) OnPing(conn *gws.Conn, message []byte) {
|
||||
conn.SetDeadline(time.Now().Add(wsDeadline))
|
||||
conn.WritePong(message)
|
||||
}
|
||||
|
||||
// handleAuthChallenge verifies the authenticity of the hub and returns the system's fingerprint.
|
||||
func (client *WebSocketClient) handleAuthChallenge(msg *common.HubRequest[cbor.RawMessage]) (err error) {
|
||||
var authRequest common.FingerprintRequest
|
||||
if err := cbor.Unmarshal(msg.Data, &authRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := client.verifySignature(authRequest.Signature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.hubVerified = true
|
||||
client.agent.connectionManager.eventChan <- WebSocketConnect
|
||||
|
||||
response := &common.FingerprintResponse{
|
||||
Fingerprint: client.fingerprint,
|
||||
}
|
||||
|
||||
if authRequest.NeedSysInfo {
|
||||
response.Hostname = client.agent.systemInfo.Hostname
|
||||
serverAddr := client.agent.connectionManager.serverOptions.Addr
|
||||
_, response.Port, _ = net.SplitHostPort(serverAddr)
|
||||
}
|
||||
|
||||
return client.sendMessage(response)
|
||||
}
|
||||
|
||||
// verifySignature verifies the signature of the token using the public keys.
|
||||
func (client *WebSocketClient) verifySignature(signature []byte) (err error) {
|
||||
for _, pubKey := range client.agent.keys {
|
||||
sig := ssh.Signature{
|
||||
Format: pubKey.Type(),
|
||||
Blob: signature,
|
||||
}
|
||||
if err = pubKey.Verify([]byte(client.token), &sig); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("invalid signature - check KEY value")
|
||||
}
|
||||
|
||||
// Close closes the WebSocket connection gracefully.
|
||||
// This method is safe to call multiple times.
|
||||
func (client *WebSocketClient) Close() {
|
||||
if client.Conn != nil {
|
||||
_ = client.Conn.WriteClose(1000, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// handleHubRequest routes the request to the appropriate handler.
|
||||
// It ensures the hub is verified before processing most requests.
|
||||
func (client *WebSocketClient) handleHubRequest(msg *common.HubRequest[cbor.RawMessage]) error {
|
||||
if !client.hubVerified && msg.Action != common.CheckFingerprint {
|
||||
return errors.New("hub not verified")
|
||||
}
|
||||
switch msg.Action {
|
||||
case common.GetData:
|
||||
return client.sendSystemData()
|
||||
case common.CheckFingerprint:
|
||||
return client.handleAuthChallenge(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendSystemData gathers and sends current system statistics to the hub.
|
||||
func (client *WebSocketClient) sendSystemData() error {
|
||||
sysStats := client.agent.gatherStats(client.token)
|
||||
return client.sendMessage(sysStats)
|
||||
}
|
||||
|
||||
// sendMessage encodes the given data to CBOR and sends it as a binary message over the WebSocket connection to the hub.
|
||||
func (client *WebSocketClient) sendMessage(data any) error {
|
||||
bytes, err := cbor.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Conn.WriteMessage(gws.OpcodeBinary, bytes)
|
||||
}
|
||||
|
||||
// getUserAgent returns one of two User-Agent strings based on current time.
|
||||
// This is used to avoid being blocked by Cloudflare or other anti-bot measures.
|
||||
func getUserAgent() string {
|
||||
const (
|
||||
uaBase = "Mozilla/5.0 (%s) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
|
||||
uaWindows = "Windows NT 11.0; Win64; x64"
|
||||
uaMac = "Macintosh; Intel Mac OS X 14_0_0"
|
||||
)
|
||||
if time.Now().UnixNano()%2 == 0 {
|
||||
return fmt.Sprintf(uaBase, uaWindows)
|
||||
}
|
||||
return fmt.Sprintf(uaBase, uaMac)
|
||||
}
|
220
beszel/internal/agent/connection_manager.go
Normal file
220
beszel/internal/agent/connection_manager.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"beszel/internal/agent/health"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionManager manages the connection state and events for the agent.
|
||||
// It handles both WebSocket and SSH connections, automatically switching between
|
||||
// them based on availability and managing reconnection attempts.
|
||||
type ConnectionManager struct {
|
||||
agent *Agent // Reference to the parent agent
|
||||
State ConnectionState // Current connection state
|
||||
eventChan chan ConnectionEvent // Channel for connection events
|
||||
wsClient *WebSocketClient // WebSocket client for hub communication
|
||||
serverOptions ServerOptions // Configuration for SSH server
|
||||
wsTicker *time.Ticker // Ticker for WebSocket connection attempts
|
||||
isConnecting bool // Prevents multiple simultaneous reconnection attempts
|
||||
}
|
||||
|
||||
// ConnectionState represents the current connection state of the agent.
|
||||
type ConnectionState uint8
|
||||
|
||||
// ConnectionEvent represents connection-related events that can occur.
|
||||
type ConnectionEvent uint8
|
||||
|
||||
// Connection states
|
||||
const (
|
||||
Disconnected ConnectionState = iota // No active connection
|
||||
WebSocketConnected // Connected via WebSocket
|
||||
SSHConnected // Connected via SSH
|
||||
)
|
||||
|
||||
// Connection events
|
||||
const (
|
||||
WebSocketConnect ConnectionEvent = iota // WebSocket connection established
|
||||
WebSocketDisconnect // WebSocket connection lost
|
||||
SSHConnect // SSH connection established
|
||||
SSHDisconnect // SSH connection lost
|
||||
)
|
||||
|
||||
const wsTickerInterval = 10 * time.Second
|
||||
|
||||
// newConnectionManager creates a new connection manager for the given agent.
|
||||
func newConnectionManager(agent *Agent) *ConnectionManager {
|
||||
cm := &ConnectionManager{
|
||||
agent: agent,
|
||||
State: Disconnected,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
// startWsTicker starts or resets the WebSocket connection attempt ticker.
|
||||
func (c *ConnectionManager) startWsTicker() {
|
||||
if c.wsTicker == nil {
|
||||
c.wsTicker = time.NewTicker(wsTickerInterval)
|
||||
} else {
|
||||
c.wsTicker.Reset(wsTickerInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// stopWsTicker stops the WebSocket connection attempt ticker.
|
||||
func (c *ConnectionManager) stopWsTicker() {
|
||||
if c.wsTicker != nil {
|
||||
c.wsTicker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins connection attempts and enters the main event loop.
|
||||
// It handles connection events, periodic health updates, and graceful shutdown.
|
||||
func (c *ConnectionManager) Start(serverOptions ServerOptions) error {
|
||||
if c.eventChan != nil {
|
||||
return errors.New("already started")
|
||||
}
|
||||
|
||||
wsClient, err := newWebSocketClient(c.agent)
|
||||
if err != nil {
|
||||
slog.Warn("Error creating WebSocket client", "err", err)
|
||||
}
|
||||
c.wsClient = wsClient
|
||||
|
||||
c.serverOptions = serverOptions
|
||||
c.eventChan = make(chan ConnectionEvent, 1)
|
||||
|
||||
// signal handling for shutdown
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
c.startWsTicker()
|
||||
c.connect()
|
||||
|
||||
// update health status immediately and every 90 seconds
|
||||
_ = health.Update()
|
||||
healthTicker := time.Tick(90 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case connectionEvent := <-c.eventChan:
|
||||
c.handleEvent(connectionEvent)
|
||||
case <-c.wsTicker.C:
|
||||
_ = c.startWebSocketConnection()
|
||||
case <-healthTicker:
|
||||
_ = health.Update()
|
||||
case <-sigChan:
|
||||
slog.Info("Shutting down")
|
||||
_ = c.agent.StopServer()
|
||||
c.closeWebSocket()
|
||||
return health.CleanUp()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleEvent processes connection events and updates the connection state accordingly.
|
||||
func (c *ConnectionManager) handleEvent(event ConnectionEvent) {
|
||||
switch event {
|
||||
case WebSocketConnect:
|
||||
c.handleStateChange(WebSocketConnected)
|
||||
case SSHConnect:
|
||||
c.handleStateChange(SSHConnected)
|
||||
case WebSocketDisconnect:
|
||||
if c.State == WebSocketConnected {
|
||||
c.handleStateChange(Disconnected)
|
||||
}
|
||||
case SSHDisconnect:
|
||||
if c.State == SSHConnected {
|
||||
c.handleStateChange(Disconnected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStateChange updates the connection state and performs necessary actions
|
||||
// based on the new state, including stopping services and initiating reconnections.
|
||||
func (c *ConnectionManager) handleStateChange(newState ConnectionState) {
|
||||
if c.State == newState {
|
||||
return
|
||||
}
|
||||
c.State = newState
|
||||
switch newState {
|
||||
case WebSocketConnected:
|
||||
slog.Info("WebSocket connected", "host", c.wsClient.hubURL.Host)
|
||||
c.stopWsTicker()
|
||||
_ = c.agent.StopServer()
|
||||
c.isConnecting = false
|
||||
case SSHConnected:
|
||||
// stop new ws connection attempts
|
||||
slog.Info("SSH connection established")
|
||||
c.stopWsTicker()
|
||||
c.isConnecting = false
|
||||
case Disconnected:
|
||||
if c.isConnecting {
|
||||
// Already handling reconnection, avoid duplicate attempts
|
||||
return
|
||||
}
|
||||
c.isConnecting = true
|
||||
slog.Warn("Disconnected from hub")
|
||||
// make sure old ws connection is closed
|
||||
c.closeWebSocket()
|
||||
// reconnect
|
||||
go c.connect()
|
||||
}
|
||||
}
|
||||
|
||||
// connect handles the connection logic with proper delays and priority.
|
||||
// It attempts WebSocket connection first, falling back to SSH server if needed.
|
||||
func (c *ConnectionManager) connect() {
|
||||
c.isConnecting = true
|
||||
defer func() {
|
||||
c.isConnecting = false
|
||||
}()
|
||||
|
||||
if c.wsClient != nil && time.Since(c.wsClient.lastConnectAttempt) < 5*time.Second {
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
// Try WebSocket first, if it fails, start SSH server
|
||||
err := c.startWebSocketConnection()
|
||||
if err != nil && c.State == Disconnected {
|
||||
c.startSSHServer()
|
||||
c.startWsTicker()
|
||||
}
|
||||
}
|
||||
|
||||
// startWebSocketConnection attempts to establish a WebSocket connection to the hub.
|
||||
func (c *ConnectionManager) startWebSocketConnection() error {
|
||||
if c.State != Disconnected {
|
||||
return errors.New("already connected")
|
||||
}
|
||||
if c.wsClient == nil {
|
||||
return errors.New("WebSocket client not initialized")
|
||||
}
|
||||
if time.Since(c.wsClient.lastConnectAttempt) < 5*time.Second {
|
||||
return errors.New("already connecting")
|
||||
}
|
||||
|
||||
err := c.wsClient.Connect()
|
||||
if err != nil {
|
||||
slog.Warn("WebSocket connection failed", "err", err)
|
||||
c.closeWebSocket()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// startSSHServer starts the SSH server if the agent is currently disconnected.
|
||||
func (c *ConnectionManager) startSSHServer() {
|
||||
if c.State == Disconnected {
|
||||
go c.agent.StartServer(c.serverOptions)
|
||||
}
|
||||
}
|
||||
|
||||
// closeWebSocket closes the WebSocket connection if it exists.
|
||||
func (c *ConnectionManager) closeWebSocket() {
|
||||
if c.wsClient != nil {
|
||||
c.wsClient.Close()
|
||||
}
|
||||
}
|
315
beszel/internal/agent/connection_manager_test.go
Normal file
315
beszel/internal/agent/connection_manager_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func createTestAgent(t *testing.T) *Agent {
|
||||
dataDir := t.TempDir()
|
||||
agent, err := NewAgent(dataDir)
|
||||
require.NoError(t, err)
|
||||
return agent
|
||||
}
|
||||
|
||||
func createTestServerOptions(t *testing.T) ServerOptions {
|
||||
// Generate test key pair
|
||||
_, privKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
sshPubKey, err := ssh.NewPublicKey(privKey.Public().(ed25519.PublicKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find available port
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
return ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: fmt.Sprintf("127.0.0.1:%d", port),
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionManager_NewConnectionManager tests connection manager creation
|
||||
func TestConnectionManager_NewConnectionManager(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := newConnectionManager(agent)
|
||||
|
||||
assert.NotNil(t, cm, "Connection manager should not be nil")
|
||||
assert.Equal(t, agent, cm.agent, "Agent reference should be set")
|
||||
assert.Equal(t, Disconnected, cm.State, "Initial state should be Disconnected")
|
||||
assert.Nil(t, cm.eventChan, "Event channel should be nil initially")
|
||||
assert.Nil(t, cm.wsClient, "WebSocket client should be nil initially")
|
||||
assert.Nil(t, cm.wsTicker, "WebSocket ticker should be nil initially")
|
||||
assert.False(t, cm.isConnecting, "isConnecting should be false initially")
|
||||
}
|
||||
|
||||
// TestConnectionManager_StateTransitions tests basic state transitions
|
||||
func TestConnectionManager_StateTransitions(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
initialState := cm.State
|
||||
cm.wsClient = &WebSocketClient{
|
||||
hubURL: &url.URL{
|
||||
Host: "localhost:8080",
|
||||
},
|
||||
}
|
||||
assert.NotNil(t, cm, "Connection manager should not be nil")
|
||||
assert.Equal(t, Disconnected, initialState, "Initial state should be Disconnected")
|
||||
|
||||
// Test state transitions
|
||||
cm.handleStateChange(WebSocketConnected)
|
||||
assert.Equal(t, WebSocketConnected, cm.State, "State should change to WebSocketConnected")
|
||||
|
||||
cm.handleStateChange(SSHConnected)
|
||||
assert.Equal(t, SSHConnected, cm.State, "State should change to SSHConnected")
|
||||
|
||||
cm.handleStateChange(Disconnected)
|
||||
assert.Equal(t, Disconnected, cm.State, "State should change to Disconnected")
|
||||
|
||||
// Test that same state doesn't trigger changes
|
||||
cm.State = WebSocketConnected
|
||||
cm.handleStateChange(WebSocketConnected)
|
||||
assert.Equal(t, WebSocketConnected, cm.State, "Same state should not trigger change")
|
||||
}
|
||||
|
||||
// TestConnectionManager_EventHandling tests event handling logic
|
||||
func TestConnectionManager_EventHandling(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
cm.wsClient = &WebSocketClient{
|
||||
hubURL: &url.URL{
|
||||
Host: "localhost:8080",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState ConnectionState
|
||||
event ConnectionEvent
|
||||
expectedState ConnectionState
|
||||
}{
|
||||
{
|
||||
name: "WebSocket connect from disconnected",
|
||||
initialState: Disconnected,
|
||||
event: WebSocketConnect,
|
||||
expectedState: WebSocketConnected,
|
||||
},
|
||||
{
|
||||
name: "SSH connect from disconnected",
|
||||
initialState: Disconnected,
|
||||
event: SSHConnect,
|
||||
expectedState: SSHConnected,
|
||||
},
|
||||
{
|
||||
name: "WebSocket disconnect from connected",
|
||||
initialState: WebSocketConnected,
|
||||
event: WebSocketDisconnect,
|
||||
expectedState: Disconnected,
|
||||
},
|
||||
{
|
||||
name: "SSH disconnect from connected",
|
||||
initialState: SSHConnected,
|
||||
event: SSHDisconnect,
|
||||
expectedState: Disconnected,
|
||||
},
|
||||
{
|
||||
name: "WebSocket disconnect from SSH connected (no change)",
|
||||
initialState: SSHConnected,
|
||||
event: WebSocketDisconnect,
|
||||
expectedState: SSHConnected,
|
||||
},
|
||||
{
|
||||
name: "SSH disconnect from WebSocket connected (no change)",
|
||||
initialState: WebSocketConnected,
|
||||
event: SSHDisconnect,
|
||||
expectedState: WebSocketConnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cm.State = tc.initialState
|
||||
cm.handleEvent(tc.event)
|
||||
assert.Equal(t, tc.expectedState, cm.State, "State should match expected after event")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionManager_TickerManagement tests WebSocket ticker management
|
||||
func TestConnectionManager_TickerManagement(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
|
||||
// Test starting ticker
|
||||
cm.startWsTicker()
|
||||
assert.NotNil(t, cm.wsTicker, "Ticker should be created")
|
||||
|
||||
// Test stopping ticker (should not panic)
|
||||
assert.NotPanics(t, func() {
|
||||
cm.stopWsTicker()
|
||||
}, "Stopping ticker should not panic")
|
||||
|
||||
// Test stopping nil ticker (should not panic)
|
||||
cm.wsTicker = nil
|
||||
assert.NotPanics(t, func() {
|
||||
cm.stopWsTicker()
|
||||
}, "Stopping nil ticker should not panic")
|
||||
|
||||
// Test restarting ticker
|
||||
cm.startWsTicker()
|
||||
assert.NotNil(t, cm.wsTicker, "Ticker should be recreated")
|
||||
|
||||
// Test resetting existing ticker
|
||||
firstTicker := cm.wsTicker
|
||||
cm.startWsTicker()
|
||||
assert.Equal(t, firstTicker, cm.wsTicker, "Same ticker instance should be reused")
|
||||
|
||||
cm.stopWsTicker()
|
||||
}
|
||||
|
||||
// TestConnectionManager_WebSocketConnectionFlow tests WebSocket connection logic
|
||||
func TestConnectionManager_WebSocketConnectionFlow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping WebSocket connection test in short mode")
|
||||
}
|
||||
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
|
||||
// Test WebSocket connection without proper environment
|
||||
err := cm.startWebSocketConnection()
|
||||
assert.Error(t, err, "WebSocket connection should fail without proper environment")
|
||||
assert.Equal(t, Disconnected, cm.State, "State should remain Disconnected after failed connection")
|
||||
|
||||
// Test with invalid URL
|
||||
os.Setenv("BESZEL_AGENT_HUB_URL", "invalid-url")
|
||||
os.Setenv("BESZEL_AGENT_TOKEN", "test-token")
|
||||
defer func() {
|
||||
os.Unsetenv("BESZEL_AGENT_HUB_URL")
|
||||
os.Unsetenv("BESZEL_AGENT_TOKEN")
|
||||
}()
|
||||
|
||||
// Test with missing token
|
||||
os.Setenv("BESZEL_AGENT_HUB_URL", "http://localhost:8080")
|
||||
os.Unsetenv("BESZEL_AGENT_TOKEN")
|
||||
|
||||
_, err2 := newWebSocketClient(agent)
|
||||
assert.Error(t, err2, "WebSocket client creation should fail without token")
|
||||
}
|
||||
|
||||
// TestConnectionManager_ReconnectionLogic tests reconnection prevention logic
|
||||
func TestConnectionManager_ReconnectionLogic(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
cm.eventChan = make(chan ConnectionEvent, 1)
|
||||
|
||||
// Test that isConnecting flag prevents duplicate reconnection attempts
|
||||
// Start from connected state, then simulate disconnect
|
||||
cm.State = WebSocketConnected
|
||||
cm.isConnecting = false
|
||||
|
||||
// First disconnect should trigger reconnection logic
|
||||
cm.handleStateChange(Disconnected)
|
||||
assert.Equal(t, Disconnected, cm.State, "Should change to disconnected")
|
||||
assert.True(t, cm.isConnecting, "Should set isConnecting flag")
|
||||
}
|
||||
|
||||
// TestConnectionManager_ConnectWithRateLimit tests connection rate limiting
|
||||
func TestConnectionManager_ConnectWithRateLimit(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
|
||||
// Set up environment for WebSocket client creation
|
||||
os.Setenv("BESZEL_AGENT_HUB_URL", "ws://localhost:8080")
|
||||
os.Setenv("BESZEL_AGENT_TOKEN", "test-token")
|
||||
defer func() {
|
||||
os.Unsetenv("BESZEL_AGENT_HUB_URL")
|
||||
os.Unsetenv("BESZEL_AGENT_TOKEN")
|
||||
}()
|
||||
|
||||
// Create WebSocket client
|
||||
wsClient, err := newWebSocketClient(agent)
|
||||
require.NoError(t, err)
|
||||
cm.wsClient = wsClient
|
||||
|
||||
// Set recent connection attempt
|
||||
cm.wsClient.lastConnectAttempt = time.Now()
|
||||
|
||||
// Test that connection is rate limited
|
||||
err = cm.startWebSocketConnection()
|
||||
assert.Error(t, err, "Should error due to rate limiting")
|
||||
assert.Contains(t, err.Error(), "already connecting", "Error should indicate rate limiting")
|
||||
|
||||
// Test connection after rate limit expires
|
||||
cm.wsClient.lastConnectAttempt = time.Now().Add(-10 * time.Second)
|
||||
err = cm.startWebSocketConnection()
|
||||
// This will fail due to no actual server, but should not be rate limited
|
||||
assert.Error(t, err, "Connection should fail but not due to rate limiting")
|
||||
assert.NotContains(t, err.Error(), "already connecting", "Error should not indicate rate limiting")
|
||||
}
|
||||
|
||||
// TestConnectionManager_StartWithInvalidConfig tests starting with invalid configuration
|
||||
func TestConnectionManager_StartWithInvalidConfig(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
serverOptions := createTestServerOptions(t)
|
||||
|
||||
// Test starting when already started
|
||||
cm.eventChan = make(chan ConnectionEvent, 5)
|
||||
err := cm.Start(serverOptions)
|
||||
assert.Error(t, err, "Should error when starting already started connection manager")
|
||||
}
|
||||
|
||||
// TestConnectionManager_CloseWebSocket tests WebSocket closing
|
||||
func TestConnectionManager_CloseWebSocket(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
|
||||
// Test closing when no WebSocket client exists
|
||||
assert.NotPanics(t, func() {
|
||||
cm.closeWebSocket()
|
||||
}, "Should not panic when closing nil WebSocket client")
|
||||
|
||||
// Set up environment and create WebSocket client
|
||||
os.Setenv("BESZEL_AGENT_HUB_URL", "ws://localhost:8080")
|
||||
os.Setenv("BESZEL_AGENT_TOKEN", "test-token")
|
||||
defer func() {
|
||||
os.Unsetenv("BESZEL_AGENT_HUB_URL")
|
||||
os.Unsetenv("BESZEL_AGENT_TOKEN")
|
||||
}()
|
||||
|
||||
wsClient, err := newWebSocketClient(agent)
|
||||
require.NoError(t, err)
|
||||
cm.wsClient = wsClient
|
||||
|
||||
// Test closing when WebSocket client exists
|
||||
assert.NotPanics(t, func() {
|
||||
cm.closeWebSocket()
|
||||
}, "Should not panic when closing WebSocket client")
|
||||
}
|
||||
|
||||
// TestConnectionManager_ConnectFlow tests the connect method
|
||||
func TestConnectionManager_ConnectFlow(t *testing.T) {
|
||||
agent := createTestAgent(t)
|
||||
cm := agent.connectionManager
|
||||
|
||||
// Test connect without WebSocket client
|
||||
assert.NotPanics(t, func() {
|
||||
cm.connect()
|
||||
}, "Connect should not panic without WebSocket client")
|
||||
}
|
@@ -1,25 +1,44 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/common"
|
||||
"beszel/internal/entities/system"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gliderlabs/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// ServerOptions contains configuration options for starting the SSH server.
|
||||
type ServerOptions struct {
|
||||
Addr string
|
||||
Network string
|
||||
Keys []gossh.PublicKey
|
||||
Addr string // Network address to listen on (e.g., ":45876" or "/path/to/socket")
|
||||
Network string // Network type ("tcp" or "unix")
|
||||
Keys []gossh.PublicKey // SSH public keys for authentication
|
||||
}
|
||||
|
||||
// hubVersions caches hub versions by session ID to avoid repeated parsing.
|
||||
var hubVersions map[string]semver.Version
|
||||
|
||||
// StartServer starts the SSH server with the provided options.
|
||||
// It configures the server with secure defaults, sets up authentication,
|
||||
// and begins listening for connections. Returns an error if the server
|
||||
// is already running or if there's an issue starting the server.
|
||||
func (a *Agent) StartServer(opts ServerOptions) error {
|
||||
if a.server != nil {
|
||||
return errors.New("server already started")
|
||||
}
|
||||
|
||||
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
|
||||
|
||||
if opts.Network == "unix" {
|
||||
@@ -37,7 +56,9 @@ func (a *Agent) StartServer(opts ServerOptions) error {
|
||||
defer ln.Close()
|
||||
|
||||
// base config (limit to allowed algorithms)
|
||||
config := &gossh.ServerConfig{}
|
||||
config := &gossh.ServerConfig{
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-%s_%s", beszel.AppName, beszel.Version),
|
||||
}
|
||||
config.KeyExchanges = common.DefaultKeyExchanges
|
||||
config.MACs = common.DefaultMACs
|
||||
config.Ciphers = common.DefaultCiphers
|
||||
@@ -45,42 +66,92 @@ func (a *Agent) StartServer(opts ServerOptions) error {
|
||||
// set default handler
|
||||
ssh.Handle(a.handleSession)
|
||||
|
||||
server := ssh.Server{
|
||||
a.server = &ssh.Server{
|
||||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
|
||||
return config
|
||||
},
|
||||
// check public key(s)
|
||||
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
remoteAddr := ctx.RemoteAddr()
|
||||
for _, pubKey := range opts.Keys {
|
||||
if ssh.KeysEqual(key, pubKey) {
|
||||
slog.Info("SSH connected", "addr", remoteAddr)
|
||||
return true
|
||||
}
|
||||
}
|
||||
slog.Warn("Invalid SSH key", "addr", remoteAddr)
|
||||
return false
|
||||
},
|
||||
// disable pty
|
||||
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
|
||||
return false
|
||||
},
|
||||
// log failed connections
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
slog.Warn("Failed connection attempt", "addr", conn.RemoteAddr().String(), "err", err)
|
||||
},
|
||||
// close idle connections after 70 seconds
|
||||
IdleTimeout: 70 * time.Second,
|
||||
}
|
||||
|
||||
// Start SSH server on the listener
|
||||
return server.Serve(ln)
|
||||
return a.server.Serve(ln)
|
||||
}
|
||||
|
||||
// getHubVersion retrieves and caches the hub version for a given session.
|
||||
// It extracts the version from the SSH client version string and caches
|
||||
// it to avoid repeated parsing. Returns a zero version if parsing fails.
|
||||
func (a *Agent) getHubVersion(sessionId string, sessionCtx ssh.Context) semver.Version {
|
||||
if hubVersions == nil {
|
||||
hubVersions = make(map[string]semver.Version, 1)
|
||||
}
|
||||
hubVersion, ok := hubVersions[sessionId]
|
||||
if ok {
|
||||
return hubVersion
|
||||
}
|
||||
// Extract hub version from SSH client version
|
||||
clientVersion := sessionCtx.Value(ssh.ContextKeyClientVersion)
|
||||
if versionStr, ok := clientVersion.(string); ok {
|
||||
hubVersion, _ = extractHubVersion(versionStr)
|
||||
}
|
||||
hubVersions[sessionId] = hubVersion
|
||||
return hubVersion
|
||||
}
|
||||
|
||||
// handleSession handles an incoming SSH session by gathering system statistics
|
||||
// and sending them to the hub. It signals connection events, determines the
|
||||
// appropriate encoding format based on hub version, and exits with appropriate
|
||||
// status codes.
|
||||
func (a *Agent) handleSession(s ssh.Session) {
|
||||
slog.Debug("New session", "client", s.RemoteAddr())
|
||||
stats := a.gatherStats(s.Context().SessionID())
|
||||
if err := json.NewEncoder(s).Encode(stats); err != nil {
|
||||
a.connectionManager.eventChan <- SSHConnect
|
||||
|
||||
sessionCtx := s.Context()
|
||||
sessionID := sessionCtx.SessionID()
|
||||
|
||||
hubVersion := a.getHubVersion(sessionID, sessionCtx)
|
||||
|
||||
stats := a.gatherStats(sessionID)
|
||||
|
||||
err := a.writeToSession(s, stats, hubVersion)
|
||||
if err != nil {
|
||||
slog.Error("Error encoding stats", "err", err, "stats", stats)
|
||||
s.Exit(1)
|
||||
return
|
||||
} else {
|
||||
s.Exit(0)
|
||||
}
|
||||
s.Exit(0)
|
||||
}
|
||||
|
||||
// writeToSession encodes and writes system statistics to the session.
|
||||
// It chooses between CBOR and JSON encoding based on the hub version,
|
||||
// using CBOR for newer versions and JSON for legacy compatibility.
|
||||
func (a *Agent) writeToSession(w io.Writer, stats *system.CombinedData, hubVersion semver.Version) error {
|
||||
if hubVersion.GTE(beszel.MinVersionCbor) {
|
||||
return cbor.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
return json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
|
||||
// extractHubVersion extracts the beszel version from SSH client version string.
|
||||
// Expected format: "SSH-2.0-beszel_X.Y.Z" or "beszel_X.Y.Z"
|
||||
func extractHubVersion(versionString string) (semver.Version, error) {
|
||||
_, after, _ := strings.Cut(versionString, "_")
|
||||
return semver.Parse(after)
|
||||
}
|
||||
|
||||
// ParseKeys parses a string containing SSH public keys in authorized_keys format.
|
||||
@@ -103,7 +174,9 @@ func ParseKeys(input string) ([]gossh.PublicKey, error) {
|
||||
return parsedKeys, nil
|
||||
}
|
||||
|
||||
// GetAddress gets the address to listen on or connect to from environment variables or default value.
|
||||
// GetAddress determines the network address to listen on from various sources.
|
||||
// It checks the provided address, then environment variables (LISTEN, PORT),
|
||||
// and finally defaults to ":45876".
|
||||
func GetAddress(addr string) string {
|
||||
if addr == "" {
|
||||
addr, _ = GetEnv("LISTEN")
|
||||
@@ -122,7 +195,9 @@ func GetAddress(addr string) string {
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type to use based on the address
|
||||
// GetNetwork determines the network type based on the address format.
|
||||
// It checks the NETWORK environment variable first, then infers from
|
||||
// the address format: addresses starting with "/" are "unix", others are "tcp".
|
||||
func GetNetwork(addr string) string {
|
||||
if network, ok := GetEnv("NETWORK"); ok && network != "" {
|
||||
return network
|
||||
@@ -132,3 +207,17 @@ func GetNetwork(addr string) string {
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
// StopServer stops the SSH server if it's running.
|
||||
// It returns an error if the server is not running or if there's an error stopping it.
|
||||
func (a *Agent) StopServer() error {
|
||||
if a.server == nil {
|
||||
return errors.New("SSH server not running")
|
||||
}
|
||||
|
||||
slog.Info("Stopping SSH server")
|
||||
_ = a.server.Close()
|
||||
a.server = nil
|
||||
a.connectionManager.eventChan <- SSHDisconnect
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,34 +1,43 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"beszel/internal/entities/container"
|
||||
"beszel/internal/entities/system"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestStartServer(t *testing.T) {
|
||||
// Generate a test key pair
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
signer, err := ssh.NewSignerFromKey(privKey)
|
||||
signer, err := gossh.NewSignerFromKey(privKey)
|
||||
require.NoError(t, err)
|
||||
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||
sshPubKey, err := gossh.NewPublicKey(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a different key pair for bad key test
|
||||
badPubKey, badPrivKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
badSigner, err := ssh.NewSignerFromKey(badPrivKey)
|
||||
badSigner, err := gossh.NewSignerFromKey(badPrivKey)
|
||||
require.NoError(t, err)
|
||||
sshBadPubKey, err := ssh.NewPublicKey(badPubKey)
|
||||
sshBadPubKey, err := gossh.NewPublicKey(badPubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
socketFile := filepath.Join(t.TempDir(), "beszel-test.sock")
|
||||
@@ -46,7 +55,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: ":45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
Keys: []gossh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -54,7 +63,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "tcp4",
|
||||
Addr: "127.0.0.1:45988",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
Keys: []gossh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -62,7 +71,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "tcp6",
|
||||
Addr: "[::1]:45989",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
Keys: []gossh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -70,7 +79,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "unix",
|
||||
Addr: socketFile,
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
Keys: []gossh.PublicKey{sshPubKey},
|
||||
},
|
||||
setup: func() error {
|
||||
// Create a socket file that should be removed
|
||||
@@ -89,7 +98,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: ":45987",
|
||||
Keys: []ssh.PublicKey{sshBadPubKey},
|
||||
Keys: []gossh.PublicKey{sshBadPubKey},
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "ssh: handshake failed",
|
||||
@@ -99,7 +108,7 @@ func TestStartServer(t *testing.T) {
|
||||
config: ServerOptions{
|
||||
Network: "tcp",
|
||||
Addr: ":45987",
|
||||
Keys: []ssh.PublicKey{sshPubKey},
|
||||
Keys: []gossh.PublicKey{sshPubKey},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -115,7 +124,8 @@ func TestStartServer(t *testing.T) {
|
||||
defer tt.cleanup()
|
||||
}
|
||||
|
||||
agent := NewAgent()
|
||||
agent, err := NewAgent("")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start server in a goroutine since it blocks
|
||||
errChan := make(chan error, 1)
|
||||
@@ -127,8 +137,7 @@ func TestStartServer(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to connect to verify server is running
|
||||
var client *ssh.Client
|
||||
var err error
|
||||
var client *gossh.Client
|
||||
|
||||
// Choose the appropriate signer based on the test case
|
||||
testSigner := signer
|
||||
@@ -136,23 +145,23 @@ func TestStartServer(t *testing.T) {
|
||||
testSigner = badSigner
|
||||
}
|
||||
|
||||
sshClientConfig := &ssh.ClientConfig{
|
||||
sshClientConfig := &gossh.ClientConfig{
|
||||
User: "a",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(testSigner),
|
||||
Auth: []gossh.AuthMethod{
|
||||
gossh.PublicKeys(testSigner),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 4 * time.Second,
|
||||
}
|
||||
|
||||
switch tt.config.Network {
|
||||
case "unix":
|
||||
client, err = ssh.Dial("unix", tt.config.Addr, sshClientConfig)
|
||||
client, err = gossh.Dial("unix", tt.config.Addr, sshClientConfig)
|
||||
default:
|
||||
if !strings.Contains(tt.config.Addr, ":") {
|
||||
tt.config.Addr = ":" + tt.config.Addr
|
||||
}
|
||||
client, err = ssh.Dial("tcp", tt.config.Addr, sshClientConfig)
|
||||
client, err = gossh.Dial("tcp", tt.config.Addr, sshClientConfig)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
@@ -287,3 +296,310 @@ func TestParseInvalidKey(t *testing.T) {
|
||||
t.Fatalf("Expected error message to contain '%s', got: %v", expectedErrMsg, err)
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
//////////////////// Hub Version Tests //////////////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
func TestExtractHubVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientVersion string
|
||||
expectedVersion string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid beszel client version with underscore",
|
||||
clientVersion: "SSH-2.0-beszel_0.11.1",
|
||||
expectedVersion: "0.11.1",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid beszel client version with beta",
|
||||
clientVersion: "SSH-2.0-beszel_1.0.0-beta",
|
||||
expectedVersion: "1.0.0-beta",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid beszel client version with rc",
|
||||
clientVersion: "SSH-2.0-beszel_0.12.0-rc1",
|
||||
expectedVersion: "0.12.0-rc1",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "different SSH client",
|
||||
clientVersion: "SSH-2.0-OpenSSH_8.0",
|
||||
expectedVersion: "8.0",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "malformed version string without underscore",
|
||||
clientVersion: "SSH-2.0-beszel",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty version string",
|
||||
clientVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "version string with underscore but no version",
|
||||
clientVersion: "beszel_",
|
||||
expectedVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "version with patch and build metadata",
|
||||
clientVersion: "SSH-2.0-beszel_1.2.3+build.123",
|
||||
expectedVersion: "1.2.3+build.123",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := extractHubVersion(tt.clientVersion)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedVersion, result.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/////////////// Hub Version Detection Tests ////////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
func TestGetHubVersion(t *testing.T) {
|
||||
agent, err := NewAgent("")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mock SSH context that implements the ssh.Context interface
|
||||
mockCtx := &mockSSHContext{
|
||||
sessionID: "test-session-123",
|
||||
clientVersion: "SSH-2.0-beszel_0.12.0",
|
||||
}
|
||||
|
||||
// Test first call - should extract and cache version
|
||||
version := agent.getHubVersion("test-session-123", mockCtx)
|
||||
assert.Equal(t, "0.12.0", version.String())
|
||||
|
||||
// Test second call - should return cached version
|
||||
mockCtx.clientVersion = "SSH-2.0-beszel_0.11.0" // Change version but should still return cached
|
||||
version = agent.getHubVersion("test-session-123", mockCtx)
|
||||
assert.Equal(t, "0.12.0", version.String()) // Should still be cached version
|
||||
|
||||
// Test different session - should extract new version
|
||||
version = agent.getHubVersion("different-session", mockCtx)
|
||||
assert.Equal(t, "0.11.0", version.String())
|
||||
|
||||
// Test with invalid version string (non-beszel client)
|
||||
mockCtx.clientVersion = "SSH-2.0-OpenSSH_8.0"
|
||||
version = agent.getHubVersion("invalid-session", mockCtx)
|
||||
assert.Equal(t, "0.0.0", version.String()) // Should be empty version for non-beszel clients
|
||||
|
||||
// Test with no client version
|
||||
mockCtx.clientVersion = ""
|
||||
version = agent.getHubVersion("no-version-session", mockCtx)
|
||||
assert.True(t, version.EQ(semver.Version{})) // Should be empty version
|
||||
}
|
||||
|
||||
// mockSSHContext implements ssh.Context for testing
|
||||
type mockSSHContext struct {
|
||||
context.Context
|
||||
sync.Mutex
|
||||
sessionID string
|
||||
clientVersion string
|
||||
}
|
||||
|
||||
func (m *mockSSHContext) SessionID() string {
|
||||
return m.sessionID
|
||||
}
|
||||
|
||||
func (m *mockSSHContext) ClientVersion() string {
|
||||
return m.clientVersion
|
||||
}
|
||||
|
||||
func (m *mockSSHContext) ServerVersion() string {
|
||||
return "SSH-2.0-beszel_test"
|
||||
}
|
||||
|
||||
func (m *mockSSHContext) Value(key interface{}) interface{} {
|
||||
if key == ssh.ContextKeyClientVersion {
|
||||
return m.clientVersion
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSSHContext) User() string { return "test-user" }
|
||||
func (m *mockSSHContext) RemoteAddr() net.Addr { return nil }
|
||||
func (m *mockSSHContext) LocalAddr() net.Addr { return nil }
|
||||
func (m *mockSSHContext) Permissions() *ssh.Permissions { return nil }
|
||||
func (m *mockSSHContext) SetValue(key, value interface{}) {}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/////////////// CBOR vs JSON Encoding Tests ////////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
// TestWriteToSessionEncoding tests that writeToSession actually encodes data in the correct format
|
||||
func TestWriteToSessionEncoding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hubVersion string
|
||||
expectedUsesCbor bool
|
||||
}{
|
||||
{
|
||||
name: "old hub version should use JSON",
|
||||
hubVersion: "0.11.1",
|
||||
expectedUsesCbor: false,
|
||||
},
|
||||
{
|
||||
name: "non-beta release should use CBOR",
|
||||
hubVersion: "0.12.0",
|
||||
expectedUsesCbor: true,
|
||||
},
|
||||
{
|
||||
name: "even newer hub version should use CBOR",
|
||||
hubVersion: "0.16.4",
|
||||
expectedUsesCbor: true,
|
||||
},
|
||||
{
|
||||
name: "beta version below release threshold should use JSON",
|
||||
hubVersion: "0.12.0-beta0",
|
||||
expectedUsesCbor: false,
|
||||
},
|
||||
{
|
||||
name: "matching beta version should use CBOR",
|
||||
hubVersion: "0.12.0-beta1",
|
||||
expectedUsesCbor: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset the global hubVersions map to ensure clean state for each test
|
||||
hubVersions = nil
|
||||
|
||||
agent, err := NewAgent("")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the test version
|
||||
version, err := semver.Parse(tt.hubVersion)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test data to encode
|
||||
testData := createTestCombinedData()
|
||||
|
||||
var buf strings.Builder
|
||||
err = agent.writeToSession(&buf, testData, version)
|
||||
require.NoError(t, err)
|
||||
|
||||
encodedData := buf.String()
|
||||
require.NotEmpty(t, encodedData)
|
||||
|
||||
// Verify the encoding format by attempting to decode
|
||||
if tt.expectedUsesCbor {
|
||||
var decodedCbor system.CombinedData
|
||||
err = cbor.Unmarshal([]byte(encodedData), &decodedCbor)
|
||||
assert.NoError(t, err, "Should be valid CBOR data")
|
||||
|
||||
var decodedJson system.CombinedData
|
||||
err = json.Unmarshal([]byte(encodedData), &decodedJson)
|
||||
assert.Error(t, err, "Should not be valid JSON data")
|
||||
|
||||
assert.Equal(t, testData.Info.Hostname, decodedCbor.Info.Hostname)
|
||||
assert.Equal(t, testData.Stats.Cpu, decodedCbor.Stats.Cpu)
|
||||
} else {
|
||||
// Should be JSON - try to decode as JSON
|
||||
var decodedJson system.CombinedData
|
||||
err = json.Unmarshal([]byte(encodedData), &decodedJson)
|
||||
assert.NoError(t, err, "Should be valid JSON data")
|
||||
|
||||
var decodedCbor system.CombinedData
|
||||
err = cbor.Unmarshal([]byte(encodedData), &decodedCbor)
|
||||
assert.Error(t, err, "Should not be valid CBOR data")
|
||||
|
||||
// Verify the decoded JSON data matches our test data
|
||||
assert.Equal(t, testData.Info.Hostname, decodedJson.Info.Hostname)
|
||||
assert.Equal(t, testData.Stats.Cpu, decodedJson.Stats.Cpu)
|
||||
|
||||
// Verify it looks like JSON (starts with '{' and contains readable field names)
|
||||
assert.True(t, strings.HasPrefix(encodedData, "{"), "JSON should start with '{'")
|
||||
assert.Contains(t, encodedData, `"info"`, "JSON should contain readable field names")
|
||||
assert.Contains(t, encodedData, `"stats"`, "JSON should contain readable field names")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create test data for encoding tests
|
||||
func createTestCombinedData() *system.CombinedData {
|
||||
return &system.CombinedData{
|
||||
Stats: system.Stats{
|
||||
Cpu: 25.5,
|
||||
Mem: 8589934592, // 8GB
|
||||
MemUsed: 4294967296, // 4GB
|
||||
MemPct: 50.0,
|
||||
DiskTotal: 1099511627776, // 1TB
|
||||
DiskUsed: 549755813888, // 512GB
|
||||
DiskPct: 50.0,
|
||||
},
|
||||
Info: system.Info{
|
||||
Hostname: "test-host",
|
||||
Cores: 8,
|
||||
CpuModel: "Test CPU Model",
|
||||
Uptime: 3600,
|
||||
AgentVersion: "0.12.0",
|
||||
Os: system.Linux,
|
||||
},
|
||||
Containers: []*container.Stats{
|
||||
{
|
||||
Name: "test-container",
|
||||
Cpu: 10.5,
|
||||
Mem: 1073741824, // 1GB
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestHubVersionCaching(t *testing.T) {
|
||||
// Reset the global hubVersions map to ensure clean state
|
||||
hubVersions = nil
|
||||
|
||||
agent, err := NewAgent("")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx1 := &mockSSHContext{
|
||||
sessionID: "session1",
|
||||
clientVersion: "SSH-2.0-beszel_0.12.0",
|
||||
}
|
||||
ctx2 := &mockSSHContext{
|
||||
sessionID: "session2",
|
||||
clientVersion: "SSH-2.0-beszel_0.11.0",
|
||||
}
|
||||
|
||||
// First calls should cache the versions
|
||||
v1 := agent.getHubVersion("session1", ctx1)
|
||||
v2 := agent.getHubVersion("session2", ctx2)
|
||||
|
||||
assert.Equal(t, "0.12.0", v1.String())
|
||||
assert.Equal(t, "0.11.0", v2.String())
|
||||
|
||||
// Verify caching by changing context but keeping same session ID
|
||||
ctx1.clientVersion = "SSH-2.0-beszel_0.10.0"
|
||||
v1Cached := agent.getHubVersion("session1", ctx1)
|
||||
assert.Equal(t, "0.12.0", v1Cached.String()) // Should still be cached version
|
||||
|
||||
// New session should get new version
|
||||
ctx3 := &mockSSHContext{
|
||||
sessionID: "session3",
|
||||
clientVersion: "SSH-2.0-beszel_0.13.0",
|
||||
}
|
||||
v3 := agent.getHubVersion("session3", ctx3)
|
||||
assert.Equal(t, "0.13.0", v3.String())
|
||||
}
|
||||
|
10
beszel/internal/common/common-ssh.go
Normal file
10
beszel/internal/common/common-ssh.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package common
|
||||
|
||||
var (
|
||||
// Allowed ssh key exchanges
|
||||
DefaultKeyExchanges = []string{"curve25519-sha256"}
|
||||
// Allowed ssh macs
|
||||
DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"}
|
||||
// Allowed ssh ciphers
|
||||
DefaultCiphers = []string{"chacha20-poly1305@openssh.com"}
|
||||
)
|
32
beszel/internal/common/common-ws.go
Normal file
32
beszel/internal/common/common-ws.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package common
|
||||
|
||||
type WebSocketAction = uint8
|
||||
|
||||
// Not implemented yet
|
||||
// type AgentError = uint8
|
||||
|
||||
const (
|
||||
// Request system data from agent
|
||||
GetData WebSocketAction = iota
|
||||
// Check the fingerprint of the agent
|
||||
CheckFingerprint
|
||||
)
|
||||
|
||||
// HubRequest defines the structure for requests sent from hub to agent.
|
||||
type HubRequest[T any] struct {
|
||||
Action WebSocketAction `cbor:"0,keyasint"`
|
||||
Data T `cbor:"1,keyasint,omitempty,omitzero"`
|
||||
// Error AgentError `cbor:"error,omitempty,omitzero"`
|
||||
}
|
||||
|
||||
type FingerprintRequest struct {
|
||||
Signature []byte `cbor:"0,keyasint"`
|
||||
NeedSysInfo bool `cbor:"1,keyasint"` // For universal token system creation
|
||||
}
|
||||
|
||||
type FingerprintResponse struct {
|
||||
Fingerprint string `cbor:"0,keyasint"`
|
||||
// Optional system info for universal token system creation
|
||||
Hostname string `cbor:"1,keyasint,omitempty,omitzero"`
|
||||
Port string `cbor:"2,keyasint,omitempty,omitzero"`
|
||||
}
|
@@ -1,7 +0,0 @@
|
||||
package common
|
||||
|
||||
var (
|
||||
DefaultKeyExchanges = []string{"curve25519-sha256"}
|
||||
DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"}
|
||||
DefaultCiphers = []string{"chacha20-poly1305@openssh.com"}
|
||||
)
|
@@ -34,10 +34,16 @@ type ApiStats struct {
|
||||
MemoryStats MemoryStats `json:"memory_stats"`
|
||||
}
|
||||
|
||||
func (s *ApiStats) CalculateCpuPercentLinux(prevCpuUsage [2]uint64) float64 {
|
||||
cpuDelta := s.CPUStats.CPUUsage.TotalUsage - prevCpuUsage[0]
|
||||
systemDelta := s.CPUStats.SystemUsage - prevCpuUsage[1]
|
||||
return float64(cpuDelta) / float64(systemDelta) * 100
|
||||
func (s *ApiStats) CalculateCpuPercentLinux(prevCpuContainer uint64, prevCpuSystem uint64) float64 {
|
||||
cpuDelta := s.CPUStats.CPUUsage.TotalUsage - prevCpuContainer
|
||||
systemDelta := s.CPUStats.SystemUsage - prevCpuSystem
|
||||
|
||||
// Avoid division by zero and handle first run case
|
||||
if systemDelta == 0 || prevCpuContainer == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return float64(cpuDelta) / float64(systemDelta) * 100.0
|
||||
}
|
||||
|
||||
// from: https://github.com/docker/cli/blob/master/cli/command/container/stats_helpers.go#L185
|
||||
@@ -99,12 +105,14 @@ type prevNetStats struct {
|
||||
|
||||
// Docker container stats
|
||||
type Stats struct {
|
||||
Name string `json:"n"`
|
||||
Cpu float64 `json:"c"`
|
||||
Mem float64 `json:"m"`
|
||||
NetworkSent float64 `json:"ns"`
|
||||
NetworkRecv float64 `json:"nr"`
|
||||
PrevCpu [2]uint64 `json:"-"`
|
||||
PrevNet prevNetStats `json:"-"`
|
||||
PrevRead time.Time `json:"-"`
|
||||
Name string `json:"n" cbor:"0,keyasint"`
|
||||
Cpu float64 `json:"c" cbor:"1,keyasint"`
|
||||
Mem float64 `json:"m" cbor:"2,keyasint"`
|
||||
NetworkSent float64 `json:"ns" cbor:"3,keyasint"`
|
||||
NetworkRecv float64 `json:"nr" cbor:"4,keyasint"`
|
||||
// PrevCpu [2]uint64 `json:"-"`
|
||||
CpuSystem uint64 `json:"-"`
|
||||
CpuContainer uint64 `json:"-"`
|
||||
PrevNet prevNetStats `json:"-"`
|
||||
PrevReadTime time.Time `json:"-"`
|
||||
}
|
||||
|
@@ -8,38 +8,38 @@ import (
|
||||
)
|
||||
|
||||
type Stats struct {
|
||||
Cpu float64 `json:"cpu"`
|
||||
MaxCpu float64 `json:"cpum,omitempty"`
|
||||
Mem float64 `json:"m"`
|
||||
MemUsed float64 `json:"mu"`
|
||||
MemPct float64 `json:"mp"`
|
||||
MemBuffCache float64 `json:"mb"`
|
||||
MemZfsArc float64 `json:"mz,omitempty"` // ZFS ARC memory
|
||||
Swap float64 `json:"s,omitempty"`
|
||||
SwapUsed float64 `json:"su,omitempty"`
|
||||
DiskTotal float64 `json:"d"`
|
||||
DiskUsed float64 `json:"du"`
|
||||
DiskPct float64 `json:"dp"`
|
||||
DiskReadPs float64 `json:"dr"`
|
||||
DiskWritePs float64 `json:"dw"`
|
||||
MaxDiskReadPs float64 `json:"drm,omitempty"`
|
||||
MaxDiskWritePs float64 `json:"dwm,omitempty"`
|
||||
NetworkSent float64 `json:"ns"`
|
||||
NetworkRecv float64 `json:"nr"`
|
||||
MaxNetworkSent float64 `json:"nsm,omitempty"`
|
||||
MaxNetworkRecv float64 `json:"nrm,omitempty"`
|
||||
Temperatures map[string]float64 `json:"t,omitempty"`
|
||||
ExtraFs map[string]*FsStats `json:"efs,omitempty"`
|
||||
GPUData map[string]GPUData `json:"g,omitempty"`
|
||||
Cpu float64 `json:"cpu" cbor:"0,keyasint"`
|
||||
MaxCpu float64 `json:"cpum,omitempty" cbor:"1,keyasint,omitempty"`
|
||||
Mem float64 `json:"m" cbor:"2,keyasint"`
|
||||
MemUsed float64 `json:"mu" cbor:"3,keyasint"`
|
||||
MemPct float64 `json:"mp" cbor:"4,keyasint"`
|
||||
MemBuffCache float64 `json:"mb" cbor:"5,keyasint"`
|
||||
MemZfsArc float64 `json:"mz,omitempty" cbor:"6,keyasint,omitempty"` // ZFS ARC memory
|
||||
Swap float64 `json:"s,omitempty" cbor:"7,keyasint,omitempty"`
|
||||
SwapUsed float64 `json:"su,omitempty" cbor:"8,keyasint,omitempty"`
|
||||
DiskTotal float64 `json:"d" cbor:"9,keyasint"`
|
||||
DiskUsed float64 `json:"du" cbor:"10,keyasint"`
|
||||
DiskPct float64 `json:"dp" cbor:"11,keyasint"`
|
||||
DiskReadPs float64 `json:"dr" cbor:"12,keyasint"`
|
||||
DiskWritePs float64 `json:"dw" cbor:"13,keyasint"`
|
||||
MaxDiskReadPs float64 `json:"drm,omitempty" cbor:"14,keyasint,omitempty"`
|
||||
MaxDiskWritePs float64 `json:"dwm,omitempty" cbor:"15,keyasint,omitempty"`
|
||||
NetworkSent float64 `json:"ns" cbor:"16,keyasint"`
|
||||
NetworkRecv float64 `json:"nr" cbor:"17,keyasint"`
|
||||
MaxNetworkSent float64 `json:"nsm,omitempty" cbor:"18,keyasint,omitempty"`
|
||||
MaxNetworkRecv float64 `json:"nrm,omitempty" cbor:"19,keyasint,omitempty"`
|
||||
Temperatures map[string]float64 `json:"t,omitempty" cbor:"20,keyasint,omitempty"`
|
||||
ExtraFs map[string]*FsStats `json:"efs,omitempty" cbor:"21,keyasint,omitempty"`
|
||||
GPUData map[string]GPUData `json:"g,omitempty" cbor:"22,keyasint,omitempty"`
|
||||
}
|
||||
|
||||
type GPUData struct {
|
||||
Name string `json:"n"`
|
||||
Name string `json:"n" cbor:"0,keyasint"`
|
||||
Temperature float64 `json:"-"`
|
||||
MemoryUsed float64 `json:"mu,omitempty"`
|
||||
MemoryTotal float64 `json:"mt,omitempty"`
|
||||
Usage float64 `json:"u"`
|
||||
Power float64 `json:"p,omitempty"`
|
||||
MemoryUsed float64 `json:"mu,omitempty" cbor:"1,keyasint,omitempty"`
|
||||
MemoryTotal float64 `json:"mt,omitempty" cbor:"2,keyasint,omitempty"`
|
||||
Usage float64 `json:"u" cbor:"3,keyasint"`
|
||||
Power float64 `json:"p,omitempty" cbor:"4,keyasint,omitempty"`
|
||||
Count float64 `json:"-"`
|
||||
}
|
||||
|
||||
@@ -47,14 +47,14 @@ type FsStats struct {
|
||||
Time time.Time `json:"-"`
|
||||
Root bool `json:"-"`
|
||||
Mountpoint string `json:"-"`
|
||||
DiskTotal float64 `json:"d"`
|
||||
DiskUsed float64 `json:"du"`
|
||||
DiskTotal float64 `json:"d" cbor:"0,keyasint"`
|
||||
DiskUsed float64 `json:"du" cbor:"1,keyasint"`
|
||||
TotalRead uint64 `json:"-"`
|
||||
TotalWrite uint64 `json:"-"`
|
||||
DiskReadPs float64 `json:"r"`
|
||||
DiskWritePs float64 `json:"w"`
|
||||
MaxDiskReadPS float64 `json:"rm,omitempty"`
|
||||
MaxDiskWritePS float64 `json:"wm,omitempty"`
|
||||
DiskReadPs float64 `json:"r" cbor:"2,keyasint"`
|
||||
DiskWritePs float64 `json:"w" cbor:"3,keyasint"`
|
||||
MaxDiskReadPS float64 `json:"rm,omitempty" cbor:"4,keyasint,omitempty"`
|
||||
MaxDiskWritePS float64 `json:"wm,omitempty" cbor:"5,keyasint,omitempty"`
|
||||
}
|
||||
|
||||
type NetIoStats struct {
|
||||
@@ -64,7 +64,7 @@ type NetIoStats struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Os uint8
|
||||
type Os = uint8
|
||||
|
||||
const (
|
||||
Linux Os = iota
|
||||
@@ -74,26 +74,26 @@ const (
|
||||
)
|
||||
|
||||
type Info struct {
|
||||
Hostname string `json:"h"`
|
||||
KernelVersion string `json:"k,omitempty"`
|
||||
Cores int `json:"c"`
|
||||
Threads int `json:"t,omitempty"`
|
||||
CpuModel string `json:"m"`
|
||||
Uptime uint64 `json:"u"`
|
||||
Cpu float64 `json:"cpu"`
|
||||
MemPct float64 `json:"mp"`
|
||||
DiskPct float64 `json:"dp"`
|
||||
Bandwidth float64 `json:"b"`
|
||||
AgentVersion string `json:"v"`
|
||||
Podman bool `json:"p,omitempty"`
|
||||
GpuPct float64 `json:"g,omitempty"`
|
||||
DashboardTemp float64 `json:"dt,omitempty"`
|
||||
Os Os `json:"os"`
|
||||
Hostname string `json:"h" cbor:"0,keyasint"`
|
||||
KernelVersion string `json:"k,omitempty" cbor:"1,keyasint,omitempty"`
|
||||
Cores int `json:"c" cbor:"2,keyasint"`
|
||||
Threads int `json:"t,omitempty" cbor:"3,keyasint,omitempty"`
|
||||
CpuModel string `json:"m" cbor:"4,keyasint"`
|
||||
Uptime uint64 `json:"u" cbor:"5,keyasint"`
|
||||
Cpu float64 `json:"cpu" cbor:"6,keyasint"`
|
||||
MemPct float64 `json:"mp" cbor:"7,keyasint"`
|
||||
DiskPct float64 `json:"dp" cbor:"8,keyasint"`
|
||||
Bandwidth float64 `json:"b" cbor:"9,keyasint"`
|
||||
AgentVersion string `json:"v" cbor:"10,keyasint"`
|
||||
Podman bool `json:"p,omitempty" cbor:"11,keyasint,omitempty"`
|
||||
GpuPct float64 `json:"g,omitempty" cbor:"12,keyasint,omitempty"`
|
||||
DashboardTemp float64 `json:"dt,omitempty" cbor:"13,keyasint,omitempty"`
|
||||
Os Os `json:"os" cbor:"14,keyasint"`
|
||||
}
|
||||
|
||||
// Final data structure to return to the hub
|
||||
type CombinedData struct {
|
||||
Stats Stats `json:"stats"`
|
||||
Info Info `json:"info"`
|
||||
Containers []*container.Stats `json:"container"`
|
||||
Stats Stats `json:"stats" cbor:"0,keyasint"`
|
||||
Info Info `json:"info" cbor:"1,keyasint"`
|
||||
Containers []*container.Stats `json:"container" cbor:"2,keyasint"`
|
||||
}
|
||||
|
247
beszel/internal/hub/agent_connect.go
Normal file
247
beszel/internal/hub/agent_connect.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"beszel/internal/common"
|
||||
"beszel/internal/hub/expirymap"
|
||||
"beszel/internal/hub/ws"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/lxzan/gws"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// tokenMap maps tokens to user IDs for universal tokens
|
||||
var tokenMap *expirymap.ExpiryMap[string]
|
||||
|
||||
type agentConnectRequest struct {
|
||||
token string
|
||||
agentSemVer semver.Version
|
||||
// for universal token
|
||||
isUniversalToken bool
|
||||
userId string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
// validateAgentHeaders validates the required headers from agent connection requests.
|
||||
func (h *Hub) validateAgentHeaders(headers http.Header) (string, string, error) {
|
||||
token := headers.Get("X-Token")
|
||||
agentVersion := headers.Get("X-Beszel")
|
||||
|
||||
if agentVersion == "" || token == "" || len(token) > 512 {
|
||||
return "", "", errors.New("")
|
||||
}
|
||||
return token, agentVersion, nil
|
||||
}
|
||||
|
||||
// getFingerprintRecord retrieves fingerprint data from the database by token.
|
||||
func (h *Hub) getFingerprintRecord(token string, recordData *ws.FingerprintRecord) error {
|
||||
err := h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}").
|
||||
Bind(dbx.Params{
|
||||
"token": token,
|
||||
}).
|
||||
One(recordData)
|
||||
return err
|
||||
}
|
||||
|
||||
// sendResponseError sends an HTTP error response with the given status code and message.
|
||||
func sendResponseError(res http.ResponseWriter, code int, message string) error {
|
||||
res.WriteHeader(code)
|
||||
if message != "" {
|
||||
res.Write([]byte(message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAgentConnect handles the incoming connection request from the agent.
|
||||
func (h *Hub) handleAgentConnect(e *core.RequestEvent) error {
|
||||
if err := h.agentConnect(e.Request, e.Response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// agentConnect handles agent connection requests, validating credentials and upgrading to WebSocket.
|
||||
func (h *Hub) agentConnect(req *http.Request, res http.ResponseWriter) (err error) {
|
||||
var agentConnectRequest agentConnectRequest
|
||||
var agentVersion string
|
||||
// check if user agent and token are valid
|
||||
agentConnectRequest.token, agentVersion, err = h.validateAgentHeaders(req.Header)
|
||||
if err != nil {
|
||||
return sendResponseError(res, http.StatusUnauthorized, "")
|
||||
}
|
||||
|
||||
// Pull fingerprint from database matching token
|
||||
var fpRecord ws.FingerprintRecord
|
||||
err = h.getFingerprintRecord(agentConnectRequest.token, &fpRecord)
|
||||
|
||||
// if no existing record, check if token is a universal token
|
||||
if err != nil {
|
||||
if err = checkUniversalToken(&agentConnectRequest); err == nil {
|
||||
// if this is a universal token, set the remote address and new record token
|
||||
agentConnectRequest.remoteAddr = getRealIP(req)
|
||||
fpRecord.Token = agentConnectRequest.token
|
||||
}
|
||||
}
|
||||
|
||||
// If no matching token, return unauthorized
|
||||
if err != nil {
|
||||
return sendResponseError(res, http.StatusUnauthorized, "Invalid token")
|
||||
}
|
||||
|
||||
// Validate agent version
|
||||
agentConnectRequest.agentSemVer, err = semver.Parse(agentVersion)
|
||||
if err != nil {
|
||||
return sendResponseError(res, http.StatusUnauthorized, "Invalid agent version")
|
||||
}
|
||||
|
||||
// Upgrade connection to WebSocket
|
||||
conn, err := ws.GetUpgrader().Upgrade(res, req)
|
||||
if err != nil {
|
||||
return sendResponseError(res, http.StatusInternalServerError, "WebSocket upgrade failed")
|
||||
}
|
||||
|
||||
go h.verifyWsConn(conn, agentConnectRequest, fpRecord)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyWsConn verifies the WebSocket connection using agent's fingerprint and SSH key signature.
|
||||
func (h *Hub) verifyWsConn(conn *gws.Conn, acr agentConnectRequest, fpRecord ws.FingerprintRecord) (err error) {
|
||||
wsConn := ws.NewWsConnection(conn)
|
||||
// must be set before the read loop
|
||||
conn.Session().Store("wsConn", wsConn)
|
||||
|
||||
// make sure connection is closed if there is an error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
wsConn.Close()
|
||||
h.Logger().Error("WebSocket error", "error", err, "system", fpRecord.SystemId)
|
||||
}
|
||||
}()
|
||||
|
||||
go conn.ReadLoop()
|
||||
|
||||
signer, err := h.GetSSHKey("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentFingerprint, err := wsConn.GetFingerprint(acr.token, signer, acr.isUniversalToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create system if using universal token
|
||||
if acr.isUniversalToken {
|
||||
if acr.userId == "" {
|
||||
return errors.New("token user not found")
|
||||
}
|
||||
fpRecord.SystemId, err = h.createSystemFromAgentData(&acr, agentFingerprint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create system from universal token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
// If no current fingerprint, update with new fingerprint (first time connecting)
|
||||
case fpRecord.Fingerprint == "":
|
||||
if err := h.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
||||
return err
|
||||
}
|
||||
// Abort if fingerprint exists but doesn't match (different machine)
|
||||
case fpRecord.Fingerprint != agentFingerprint.Fingerprint:
|
||||
return errors.New("fingerprint mismatch")
|
||||
}
|
||||
|
||||
return h.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn)
|
||||
}
|
||||
|
||||
// createSystemFromAgentData creates a new system record using data from the agent
|
||||
func (h *Hub) createSystemFromAgentData(acr *agentConnectRequest, agentFingerprint common.FingerprintResponse) (recordId string, err error) {
|
||||
systemsCollection, err := h.FindCollectionByNameOrId("systems")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to find systems collection: %w", err)
|
||||
}
|
||||
// separate port from address
|
||||
if agentFingerprint.Hostname == "" {
|
||||
agentFingerprint.Hostname = acr.remoteAddr
|
||||
}
|
||||
if agentFingerprint.Port == "" {
|
||||
agentFingerprint.Port = "45876"
|
||||
}
|
||||
// create new record
|
||||
systemRecord := core.NewRecord(systemsCollection)
|
||||
systemRecord.Set("name", agentFingerprint.Hostname)
|
||||
systemRecord.Set("host", acr.remoteAddr)
|
||||
systemRecord.Set("port", agentFingerprint.Port)
|
||||
systemRecord.Set("users", []string{acr.userId})
|
||||
|
||||
return systemRecord.Id, h.Save(systemRecord)
|
||||
}
|
||||
|
||||
// SetFingerprint updates the fingerprint for a given record ID.
|
||||
func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) {
|
||||
// // can't use raw query here because it doesn't trigger SSE
|
||||
var record *core.Record
|
||||
switch fpRecord.Id {
|
||||
case "":
|
||||
// create new record for universal token
|
||||
collection, _ := h.FindCachedCollectionByNameOrId("fingerprints")
|
||||
record = core.NewRecord(collection)
|
||||
record.Set("system", fpRecord.SystemId)
|
||||
default:
|
||||
record, err = h.FindRecordById("fingerprints", fpRecord.Id)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record.Set("token", fpRecord.Token)
|
||||
record.Set("fingerprint", fingerprint)
|
||||
return h.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
func getTokenMap() *expirymap.ExpiryMap[string] {
|
||||
if tokenMap == nil {
|
||||
tokenMap = expirymap.New[string](time.Hour)
|
||||
}
|
||||
return tokenMap
|
||||
}
|
||||
|
||||
func checkUniversalToken(acr *agentConnectRequest) (err error) {
|
||||
if tokenMap == nil {
|
||||
tokenMap = expirymap.New[string](time.Hour)
|
||||
}
|
||||
acr.userId, acr.isUniversalToken = tokenMap.GetOk(acr.token)
|
||||
if !acr.isUniversalToken {
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRealIP attempts to extract the real IP address from the request headers.
|
||||
func getRealIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// X-Forwarded-For can contain a comma-separated list: "client_ip, proxy1, proxy2"
|
||||
// Take the first one
|
||||
ips := strings.Split(ip, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
// Fallback to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
1001
beszel/internal/hub/agent_connect_test.go
Normal file
1001
beszel/internal/hub/agent_connect_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
||||
package hub
|
||||
// Package config provides functions for syncing systems with the config.yml file
|
||||
package config
|
||||
|
||||
import (
|
||||
"beszel/internal/entities/system"
|
||||
@@ -7,6 +8,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
@@ -14,19 +16,20 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Systems []SystemConfig `yaml:"systems"`
|
||||
type config struct {
|
||||
Systems []systemConfig `yaml:"systems"`
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
type systemConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port,omitempty"`
|
||||
Token string `yaml:"token,omitempty"`
|
||||
Users []string `yaml:"users"`
|
||||
}
|
||||
|
||||
// Syncs systems with the config.yml file
|
||||
func syncSystemsWithConfig(e *core.ServeEvent) error {
|
||||
func SyncSystems(e *core.ServeEvent) error {
|
||||
h := e.App
|
||||
configPath := filepath.Join(h.DataDir(), "config.yml")
|
||||
configData, err := os.ReadFile(configPath)
|
||||
@@ -34,7 +37,7 @@ func syncSystemsWithConfig(e *core.ServeEvent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var config Config
|
||||
var config config
|
||||
err = yaml.Unmarshal(configData, &config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse config.yml: %v", err)
|
||||
@@ -107,6 +110,14 @@ func syncSystemsWithConfig(e *core.ServeEvent) error {
|
||||
if err := h.Save(existingSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only update token if one is specified in config, otherwise preserve existing token
|
||||
if sysConfig.Token != "" {
|
||||
if err := updateFingerprintToken(h, existingSystem.Id, sysConfig.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
delete(existingSystemsMap, key)
|
||||
} else {
|
||||
// Create new system
|
||||
@@ -124,10 +135,21 @@ func syncSystemsWithConfig(e *core.ServeEvent) error {
|
||||
if err := h.Save(newSystem); err != nil {
|
||||
return fmt.Errorf("failed to create new system: %v", err)
|
||||
}
|
||||
|
||||
// For new systems, generate token if not provided
|
||||
token := sysConfig.Token
|
||||
if token == "" {
|
||||
token = uuid.New().String()
|
||||
}
|
||||
|
||||
// Create fingerprint record for new system
|
||||
if err := createFingerprintRecord(h, newSystem.Id, token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete systems not in config
|
||||
// Delete systems not in config (and their fingerprint records will cascade delete)
|
||||
for _, system := range existingSystemsMap {
|
||||
if err := h.Delete(system); err != nil {
|
||||
return err
|
||||
@@ -139,7 +161,7 @@ func syncSystemsWithConfig(e *core.ServeEvent) error {
|
||||
}
|
||||
|
||||
// Generates content for the config.yml file as a YAML string
|
||||
func (h *Hub) generateConfigYAML() (string, error) {
|
||||
func generateYAML(h core.App) (string, error) {
|
||||
// Fetch all systems from the database
|
||||
systems, err := h.FindRecordsByFilter("systems", "id != ''", "name", -1, 0)
|
||||
if err != nil {
|
||||
@@ -147,8 +169,8 @@ func (h *Hub) generateConfigYAML() (string, error) {
|
||||
}
|
||||
|
||||
// Create a Config struct to hold the data
|
||||
config := Config{
|
||||
Systems: make([]SystemConfig, 0, len(systems)),
|
||||
config := config{
|
||||
Systems: make([]systemConfig, 0, len(systems)),
|
||||
}
|
||||
|
||||
// Fetch all users at once
|
||||
@@ -156,11 +178,29 @@ func (h *Hub) generateConfigYAML() (string, error) {
|
||||
for _, system := range systems {
|
||||
allUserIDs = append(allUserIDs, system.GetStringSlice("users")...)
|
||||
}
|
||||
userEmailMap, err := h.getUserEmailMap(allUserIDs)
|
||||
userEmailMap, err := getUserEmailMap(h, allUserIDs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Fetch all fingerprint records to get tokens
|
||||
type fingerprintData struct {
|
||||
ID string `db:"id"`
|
||||
System string `db:"system"`
|
||||
Token string `db:"token"`
|
||||
}
|
||||
var fingerprints []fingerprintData
|
||||
err = h.DB().NewQuery("SELECT id, system, token FROM fingerprints").All(&fingerprints)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a map of system ID to token
|
||||
systemTokenMap := make(map[string]string)
|
||||
for _, fingerprint := range fingerprints {
|
||||
systemTokenMap[fingerprint.System] = fingerprint.Token
|
||||
}
|
||||
|
||||
// Populate the Config struct with system data
|
||||
for _, system := range systems {
|
||||
userIDs := system.GetStringSlice("users")
|
||||
@@ -171,11 +211,12 @@ func (h *Hub) generateConfigYAML() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
sysConfig := SystemConfig{
|
||||
sysConfig := systemConfig{
|
||||
Name: system.GetString("name"),
|
||||
Host: system.GetString("host"),
|
||||
Port: cast.ToUint16(system.Get("port")),
|
||||
Users: userEmails,
|
||||
Token: systemTokenMap[system.Id],
|
||||
}
|
||||
config.Systems = append(config.Systems, sysConfig)
|
||||
}
|
||||
@@ -187,13 +228,13 @@ func (h *Hub) generateConfigYAML() (string, error) {
|
||||
}
|
||||
|
||||
// Add a header to the YAML
|
||||
yamlData = append([]byte("# Values for port and users are optional.\n# Defaults are port 45876 and the first created user.\n\n"), yamlData...)
|
||||
yamlData = append([]byte("# Values for port, users, and token are optional.\n# Defaults are port 45876, the first created user, and a generated UUID token.\n\n"), yamlData...)
|
||||
|
||||
return string(yamlData), nil
|
||||
}
|
||||
|
||||
// New helper function to get a map of user IDs to emails
|
||||
func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) {
|
||||
func getUserEmailMap(h core.App, userIDs []string) (map[string]string, error) {
|
||||
users, err := h.FindRecordsByIds("users", userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -207,13 +248,42 @@ func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) {
|
||||
return userEmailMap, nil
|
||||
}
|
||||
|
||||
// Helper function to update or create fingerprint token for an existing system
|
||||
func updateFingerprintToken(app core.App, systemID, token string) error {
|
||||
// Try to find existing fingerprint record
|
||||
fingerprint, err := app.FindFirstRecordByFilter("fingerprints", "system = {:system}", dbx.Params{"system": systemID})
|
||||
if err != nil {
|
||||
// If no fingerprint record exists, create one
|
||||
return createFingerprintRecord(app, systemID, token)
|
||||
}
|
||||
|
||||
// Update existing fingerprint record with new token (keep existing fingerprint)
|
||||
fingerprint.Set("token", token)
|
||||
return app.Save(fingerprint)
|
||||
}
|
||||
|
||||
// Helper function to create a new fingerprint record for a system
|
||||
func createFingerprintRecord(app core.App, systemID, token string) error {
|
||||
fingerprintsCollection, err := app.FindCollectionByNameOrId("fingerprints")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find fingerprints collection: %v", err)
|
||||
}
|
||||
|
||||
newFingerprint := core.NewRecord(fingerprintsCollection)
|
||||
newFingerprint.Set("system", systemID)
|
||||
newFingerprint.Set("token", token)
|
||||
newFingerprint.Set("fingerprint", "") // Empty fingerprint, will be set on first connection
|
||||
|
||||
return app.Save(newFingerprint)
|
||||
}
|
||||
|
||||
// Returns the current config.yml file as a JSON object
|
||||
func (h *Hub) getYamlConfig(e *core.RequestEvent) error {
|
||||
func GetYamlConfig(e *core.RequestEvent) error {
|
||||
info, _ := e.RequestInfo()
|
||||
if info.Auth == nil || info.Auth.GetString("role") != "admin" {
|
||||
return apis.NewForbiddenError("Forbidden", nil)
|
||||
}
|
||||
configContent, err := h.generateConfigYAML()
|
||||
configContent, err := generateYAML(e.App)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
245
beszel/internal/hub/config/config_test.go
Normal file
245
beszel/internal/hub/config/config_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package config_test
|
||||
|
||||
import (
|
||||
"beszel/internal/hub/config"
|
||||
"beszel/internal/tests"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config struct for testing (copied from config package since it's not exported)
|
||||
type testConfig struct {
|
||||
Systems []testSystemConfig `yaml:"systems"`
|
||||
}
|
||||
|
||||
type testSystemConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Host string `yaml:"host"`
|
||||
Port uint16 `yaml:"port,omitempty"`
|
||||
Users []string `yaml:"users"`
|
||||
Token string `yaml:"token,omitempty"`
|
||||
}
|
||||
|
||||
// Helper function to create a test system for config tests
|
||||
// func createConfigTestSystem(app core.App, name, host string, port uint16, userIDs []string) (*core.Record, error) {
|
||||
// systemCollection, err := app.FindCollectionByNameOrId("systems")
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// system := core.NewRecord(systemCollection)
|
||||
// system.Set("name", name)
|
||||
// system.Set("host", host)
|
||||
// system.Set("port", port)
|
||||
// system.Set("users", userIDs)
|
||||
// system.Set("status", "pending")
|
||||
|
||||
// return system, app.Save(system)
|
||||
// }
|
||||
|
||||
// Helper function to create a fingerprint record
|
||||
func createConfigTestFingerprint(app core.App, systemID, token, fingerprint string) (*core.Record, error) {
|
||||
fingerprintCollection, err := app.FindCollectionByNameOrId("fingerprints")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fp := core.NewRecord(fingerprintCollection)
|
||||
fp.Set("system", systemID)
|
||||
fp.Set("token", token)
|
||||
fp.Set("fingerprint", fingerprint)
|
||||
|
||||
return fp, app.Save(fp)
|
||||
}
|
||||
|
||||
// TestConfigSyncWithTokens tests the config.SyncSystems function with various token scenarios
|
||||
func TestConfigSyncWithTokens(t *testing.T) {
|
||||
testHub, err := tests.NewTestHub()
|
||||
require.NoError(t, err)
|
||||
defer testHub.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupFunc func() (string, *core.Record, *core.Record) // Returns: existing token, system record, fingerprint record
|
||||
configYAML string
|
||||
expectToken string // Expected token after sync
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "new system with token in config",
|
||||
setupFunc: func() (string, *core.Record, *core.Record) {
|
||||
return "", nil, nil // No existing system
|
||||
},
|
||||
configYAML: `systems:
|
||||
- name: "new-server"
|
||||
host: "new.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"
|
||||
token: "explicit-token-123"`,
|
||||
expectToken: "explicit-token-123",
|
||||
description: "New system should use token from config",
|
||||
},
|
||||
{
|
||||
name: "existing system without token in config (preserve existing)",
|
||||
setupFunc: func() (string, *core.Record, *core.Record) {
|
||||
// Create existing system and fingerprint
|
||||
system, err := tests.CreateRecord(testHub.App, "systems", map[string]any{
|
||||
"name": "preserve-server",
|
||||
"host": "preserve.example.com",
|
||||
"port": 45876,
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
fingerprint, err := createConfigTestFingerprint(testHub.App, system.Id, "preserve-token-999", "preserve-fingerprint")
|
||||
require.NoError(t, err)
|
||||
|
||||
return "preserve-token-999", system, fingerprint
|
||||
},
|
||||
configYAML: `systems:
|
||||
- name: "preserve-server"
|
||||
host: "preserve.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"`,
|
||||
expectToken: "preserve-token-999",
|
||||
description: "Existing system should preserve original token when config doesn't specify one",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Setup test data
|
||||
_, existingSystem, existingFingerprint := tc.setupFunc()
|
||||
|
||||
// Write config file
|
||||
configPath := filepath.Join(testHub.DataDir(), "config.yml")
|
||||
err := os.WriteFile(configPath, []byte(tc.configYAML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create serve event and sync
|
||||
event := &core.ServeEvent{App: testHub.App}
|
||||
err = config.SyncSystems(event)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the config to get the system name for verification
|
||||
var configData testConfig
|
||||
err = yaml.Unmarshal([]byte(tc.configYAML), &configData)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, configData.Systems, 1)
|
||||
systemName := configData.Systems[0].Name
|
||||
|
||||
// Find the system after sync
|
||||
systems, err := testHub.FindRecordsByFilter("systems", "name = {:name}", "", -1, 0, map[string]any{"name": systemName})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, systems, 1)
|
||||
system := systems[0]
|
||||
|
||||
// Find the fingerprint record
|
||||
fingerprints, err := testHub.FindRecordsByFilter("fingerprints", "system = {:system}", "", -1, 0, map[string]any{"system": system.Id})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, fingerprints, 1)
|
||||
fingerprint := fingerprints[0]
|
||||
|
||||
// Verify token
|
||||
actualToken := fingerprint.GetString("token")
|
||||
if tc.expectToken == "" {
|
||||
// For generated tokens, just verify it's not empty and is a valid UUID format
|
||||
assert.NotEmpty(t, actualToken, tc.description)
|
||||
assert.Len(t, actualToken, 36, "Generated token should be UUID format") // UUID length
|
||||
} else {
|
||||
assert.Equal(t, tc.expectToken, actualToken, tc.description)
|
||||
}
|
||||
|
||||
// For existing systems, verify fingerprint is preserved
|
||||
if existingFingerprint != nil {
|
||||
actualFingerprint := fingerprint.GetString("fingerprint")
|
||||
expectedFingerprint := existingFingerprint.GetString("fingerprint")
|
||||
assert.Equal(t, expectedFingerprint, actualFingerprint, "Fingerprint should be preserved")
|
||||
}
|
||||
|
||||
// Cleanup for next test
|
||||
if existingSystem != nil {
|
||||
testHub.Delete(existingSystem)
|
||||
}
|
||||
if existingFingerprint != nil {
|
||||
testHub.Delete(existingFingerprint)
|
||||
}
|
||||
// Clean up the new records
|
||||
testHub.Delete(system)
|
||||
testHub.Delete(fingerprint)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigMigrationScenario tests the specific migration scenario mentioned in the discussion
|
||||
func TestConfigMigrationScenario(t *testing.T) {
|
||||
testHub, err := tests.NewTestHub(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
defer testHub.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user, err := tests.CreateUser(testHub.App, "admin@example.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate migration scenario: system exists with token from migration
|
||||
existingSystem, err := tests.CreateRecord(testHub.App, "systems", map[string]any{
|
||||
"name": "migrated-server",
|
||||
"host": "migrated.example.com",
|
||||
"port": 45876,
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
migrationToken := "migration-generated-token-123"
|
||||
existingFingerprint, err := createConfigTestFingerprint(testHub.App, existingSystem.Id, migrationToken, "existing-fingerprint-from-agent")
|
||||
require.NoError(t, err)
|
||||
|
||||
// User exports config BEFORE this update (so no token field in YAML)
|
||||
oldConfigYAML := `systems:
|
||||
- name: "migrated-server"
|
||||
host: "migrated.example.com"
|
||||
port: 45876
|
||||
users:
|
||||
- "admin@example.com"`
|
||||
|
||||
// Write old config file and import
|
||||
configPath := filepath.Join(testHub.DataDir(), "config.yml")
|
||||
err = os.WriteFile(configPath, []byte(oldConfigYAML), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
event := &core.ServeEvent{App: testHub.App}
|
||||
err = config.SyncSystems(event)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the original token is preserved
|
||||
updatedFingerprint, err := testHub.FindRecordById("fingerprints", existingFingerprint.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualToken := updatedFingerprint.GetString("token")
|
||||
assert.Equal(t, migrationToken, actualToken, "Migration token should be preserved when config doesn't specify a token")
|
||||
|
||||
// Verify fingerprint is also preserved
|
||||
actualFingerprint := updatedFingerprint.GetString("fingerprint")
|
||||
assert.Equal(t, "existing-fingerprint-from-agent", actualFingerprint, "Existing fingerprint should be preserved")
|
||||
|
||||
// Verify system still exists and is updated correctly
|
||||
updatedSystem, err := testHub.FindRecordById("systems", existingSystem.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "migrated-server", updatedSystem.GetString("name"))
|
||||
assert.Equal(t, "migrated.example.com", updatedSystem.GetString("host"))
|
||||
}
|
104
beszel/internal/hub/expirymap/expirymap.go
Normal file
104
beszel/internal/hub/expirymap/expirymap.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package expirymap
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
)
|
||||
|
||||
type val[T any] struct {
|
||||
value T
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
type ExpiryMap[T any] struct {
|
||||
store *store.Store[string, *val[T]]
|
||||
cleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// New creates a new expiry map with custom cleanup interval
|
||||
func New[T any](cleanupInterval time.Duration) *ExpiryMap[T] {
|
||||
m := &ExpiryMap[T]{
|
||||
store: store.New(map[string]*val[T]{}),
|
||||
cleanupInterval: cleanupInterval,
|
||||
}
|
||||
m.startCleaner()
|
||||
return m
|
||||
}
|
||||
|
||||
// Set stores a value with the given TTL
|
||||
func (m *ExpiryMap[T]) Set(key string, value T, ttl time.Duration) {
|
||||
m.store.Set(key, &val[T]{
|
||||
value: value,
|
||||
expires: time.Now().Add(ttl),
|
||||
})
|
||||
}
|
||||
|
||||
// GetOk retrieves a value and checks if it exists and hasn't expired
|
||||
// Performs lazy cleanup of expired entries on access
|
||||
func (m *ExpiryMap[T]) GetOk(key string) (T, bool) {
|
||||
value, ok := m.store.GetOk(key)
|
||||
if !ok {
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
// Check if expired and perform lazy cleanup
|
||||
if value.expires.Before(time.Now()) {
|
||||
m.store.Remove(key)
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
return value.value, true
|
||||
}
|
||||
|
||||
// GetByValue retrieves a value by value
|
||||
func (m *ExpiryMap[T]) GetByValue(val T) (key string, value T, ok bool) {
|
||||
for key, v := range m.store.GetAll() {
|
||||
if reflect.DeepEqual(v.value, val) {
|
||||
// check if expired
|
||||
if v.expires.Before(time.Now()) {
|
||||
m.store.Remove(key)
|
||||
break
|
||||
}
|
||||
return key, v.value, true
|
||||
}
|
||||
}
|
||||
return "", *new(T), false
|
||||
}
|
||||
|
||||
// Remove explicitly removes a key
|
||||
func (m *ExpiryMap[T]) Remove(key string) {
|
||||
m.store.Remove(key)
|
||||
}
|
||||
|
||||
// RemovebyValue removes a value by value
|
||||
func (m *ExpiryMap[T]) RemovebyValue(value T) (T, bool) {
|
||||
for key, val := range m.store.GetAll() {
|
||||
if reflect.DeepEqual(val.value, value) {
|
||||
m.store.Remove(key)
|
||||
return val.value, true
|
||||
}
|
||||
}
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
// startCleaner runs the background cleanup process
|
||||
func (m *ExpiryMap[T]) startCleaner() {
|
||||
go func() {
|
||||
tick := time.Tick(m.cleanupInterval)
|
||||
for range tick {
|
||||
m.cleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanup removes all expired entries
|
||||
func (m *ExpiryMap[T]) cleanup() {
|
||||
now := time.Now()
|
||||
for key, val := range m.store.GetAll() {
|
||||
if val.expires.Before(now) {
|
||||
m.store.Remove(key)
|
||||
}
|
||||
}
|
||||
}
|
477
beszel/internal/hub/expirymap/expirymap_test.go
Normal file
477
beszel/internal/hub/expirymap/expirymap_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package expirymap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Not using the following methods but are useful for testing
|
||||
|
||||
// TESTING: Has checks if a key exists and hasn't expired
|
||||
func (m *ExpiryMap[T]) Has(key string) bool {
|
||||
_, ok := m.GetOk(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// TESTING: Get retrieves a value, returns zero value if not found or expired
|
||||
func (m *ExpiryMap[T]) Get(key string) T {
|
||||
value, _ := m.GetOk(key)
|
||||
return value
|
||||
}
|
||||
|
||||
// TESTING: Len returns the number of non-expired entries
|
||||
func (m *ExpiryMap[T]) Len() int {
|
||||
count := 0
|
||||
now := time.Now()
|
||||
for _, val := range m.store.Values() {
|
||||
if val.expires.After(now) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func TestExpiryMap_BasicOperations(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test Set and GetOk
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
value, ok := em.GetOk("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// Test Get
|
||||
value = em.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// Test Has
|
||||
assert.True(t, em.Has("key1"))
|
||||
assert.False(t, em.Has("nonexistent"))
|
||||
|
||||
// Test Remove
|
||||
em.Remove("key1")
|
||||
assert.False(t, em.Has("key1"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_Expiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set a value with very short TTL
|
||||
em.Set("shortlived", "value", time.Millisecond*10)
|
||||
|
||||
// Should exist immediately
|
||||
assert.True(t, em.Has("shortlived"))
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Should be expired and automatically cleaned up on access
|
||||
assert.False(t, em.Has("shortlived"))
|
||||
value, ok := em.GetOk("shortlived")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_LazyCleanup(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
// Set multiple values with short TTL
|
||||
em.Set("key1", 1, time.Millisecond*10)
|
||||
em.Set("key2", 2, time.Millisecond*10)
|
||||
em.Set("key3", 3, time.Hour) // This one won't expire
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Access expired keys should trigger lazy cleanup
|
||||
_, ok := em.GetOk("key1")
|
||||
assert.False(t, ok)
|
||||
|
||||
// Non-expired key should still exist
|
||||
value, ok := em.GetOk("key3")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 3, value)
|
||||
}
|
||||
|
||||
func TestExpiryMap_Len(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Initially empty
|
||||
assert.Equal(t, 0, em.Len())
|
||||
|
||||
// Add some values
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value3", time.Millisecond*10) // Will expire soon
|
||||
|
||||
// Should count all initially
|
||||
assert.Equal(t, 3, em.Len())
|
||||
|
||||
// Wait for one to expire
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Len should reflect only non-expired entries
|
||||
assert.Equal(t, 2, em.Len())
|
||||
}
|
||||
|
||||
func TestExpiryMap_CustomInterval(t *testing.T) {
|
||||
// Create with very short cleanup interval for testing
|
||||
em := New[string](time.Millisecond * 50)
|
||||
|
||||
// Set a value that expires quickly
|
||||
em.Set("test", "value", time.Millisecond*10)
|
||||
|
||||
// Should exist initially
|
||||
assert.True(t, em.Has("test"))
|
||||
|
||||
// Wait for expiration + cleanup cycle
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
// Should be cleaned up by background process
|
||||
// Note: This test might be flaky due to timing, but demonstrates the concept
|
||||
assert.False(t, em.Has("test"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_GenericTypes(t *testing.T) {
|
||||
// Test with different types
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num", 42, time.Hour)
|
||||
value, ok := em.GetOk("num")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
expected := TestStruct{Name: "John", Age: 30}
|
||||
em.Set("person", expected, time.Hour)
|
||||
|
||||
value, ok := em.GetOk("person")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, expected, value)
|
||||
})
|
||||
|
||||
t.Run("Pointer", func(t *testing.T) {
|
||||
em := New[*string](time.Hour)
|
||||
|
||||
str := "hello"
|
||||
em.Set("ptr", &str, time.Hour)
|
||||
|
||||
value, ok := em.GetOk("ptr")
|
||||
assert.True(t, ok)
|
||||
require.NotNil(t, value)
|
||||
assert.Equal(t, "hello", *value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_ZeroValues(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test getting non-existent key returns zero value
|
||||
value := em.Get("nonexistent")
|
||||
assert.Equal(t, "", value)
|
||||
|
||||
// Test getting expired key returns zero value
|
||||
em.Set("expired", "value", time.Millisecond*10)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
value = em.Get("expired")
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestExpiryMap_Concurrent(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
// Simple concurrent access test
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Writer goroutine
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
em.Set("key", i, time.Hour)
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Reader goroutine
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = em.Get("key")
|
||||
time.Sleep(time.Microsecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for both to complete
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// Should not panic and should have some value
|
||||
assert.True(t, em.Has("key"))
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test getting by value when value exists
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value1", time.Hour) // Duplicate value - should return first match
|
||||
|
||||
// Test successful retrieval
|
||||
key, value, ok := em.GetByValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key) // Should be one of the keys with this value
|
||||
|
||||
// Test retrieval of unique value
|
||||
key, value, ok = em.GetByValue("value2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value2", value)
|
||||
assert.Equal(t, "key2", key)
|
||||
|
||||
// Test getting non-existent value
|
||||
key, value, ok = em.GetByValue("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value) // zero value for string
|
||||
assert.Equal(t, "", key) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue_Expiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set a value with short TTL
|
||||
em.Set("shortkey", "shortvalue", time.Millisecond*10)
|
||||
em.Set("longkey", "longvalue", time.Hour)
|
||||
|
||||
// Should find the short-lived value initially
|
||||
key, value, ok := em.GetByValue("shortvalue")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shortvalue", value)
|
||||
assert.Equal(t, "shortkey", key)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Should not find expired value and should trigger lazy cleanup
|
||||
key, value, ok = em.GetByValue("shortvalue")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value)
|
||||
assert.Equal(t, "", key)
|
||||
|
||||
// Should still find non-expired value
|
||||
key, value, ok = em.GetByValue("longvalue")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "longvalue", value)
|
||||
assert.Equal(t, "longkey", key)
|
||||
}
|
||||
|
||||
func TestExpiryMap_GetByValue_GenericTypes(t *testing.T) {
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num1", 42, time.Hour)
|
||||
em.Set("num2", 84, time.Hour)
|
||||
|
||||
key, value, ok := em.GetByValue(42)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, value)
|
||||
assert.Equal(t, "num1", key)
|
||||
|
||||
key, value, ok = em.GetByValue(99)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, value)
|
||||
assert.Equal(t, "", key)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
person1 := TestStruct{Name: "John", Age: 30}
|
||||
person2 := TestStruct{Name: "Jane", Age: 25}
|
||||
|
||||
em.Set("person1", person1, time.Hour)
|
||||
em.Set("person2", person2, time.Hour)
|
||||
|
||||
key, value, ok := em.GetByValue(person1)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, person1, value)
|
||||
assert.Equal(t, "person1", key)
|
||||
|
||||
nonexistent := TestStruct{Name: "Bob", Age: 40}
|
||||
key, value, ok = em.GetByValue(nonexistent)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, TestStruct{}, value)
|
||||
assert.Equal(t, "", key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test removing existing value
|
||||
em.Set("key1", "value1", time.Hour)
|
||||
em.Set("key2", "value2", time.Hour)
|
||||
em.Set("key3", "value1", time.Hour) // Duplicate value
|
||||
|
||||
// Remove by value should remove one instance
|
||||
removedValue, ok := em.RemovebyValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", removedValue)
|
||||
|
||||
// Should still have the other instance or value2
|
||||
assert.True(t, em.Has("key2")) // value2 should still exist
|
||||
|
||||
// Check if one of the duplicate values was removed
|
||||
// At least one key with "value1" should be gone
|
||||
key1Exists := em.Has("key1")
|
||||
key3Exists := em.Has("key3")
|
||||
assert.False(t, key1Exists && key3Exists) // Both shouldn't exist
|
||||
assert.True(t, key1Exists || key3Exists) // At least one should be gone
|
||||
|
||||
// Test removing non-existent value
|
||||
removedValue, ok = em.RemovebyValue("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", removedValue) // zero value for string
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue_GenericTypes(t *testing.T) {
|
||||
t.Run("Int", func(t *testing.T) {
|
||||
em := New[int](time.Hour)
|
||||
|
||||
em.Set("num1", 42, time.Hour)
|
||||
em.Set("num2", 84, time.Hour)
|
||||
|
||||
// Remove existing value
|
||||
removedValue, ok := em.RemovebyValue(42)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, removedValue)
|
||||
assert.False(t, em.Has("num1"))
|
||||
assert.True(t, em.Has("num2"))
|
||||
|
||||
// Remove non-existent value
|
||||
removedValue, ok = em.RemovebyValue(99)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, removedValue)
|
||||
})
|
||||
|
||||
t.Run("Struct", func(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
em := New[TestStruct](time.Hour)
|
||||
|
||||
person1 := TestStruct{Name: "John", Age: 30}
|
||||
person2 := TestStruct{Name: "Jane", Age: 25}
|
||||
|
||||
em.Set("person1", person1, time.Hour)
|
||||
em.Set("person2", person2, time.Hour)
|
||||
|
||||
// Remove existing struct
|
||||
removedValue, ok := em.RemovebyValue(person1)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, person1, removedValue)
|
||||
assert.False(t, em.Has("person1"))
|
||||
assert.True(t, em.Has("person2"))
|
||||
|
||||
// Remove non-existent struct
|
||||
nonexistent := TestStruct{Name: "Bob", Age: 40}
|
||||
removedValue, ok = em.RemovebyValue(nonexistent)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, TestStruct{}, removedValue)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiryMap_RemoveValue_WithExpiration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Set values with different TTLs
|
||||
em.Set("key1", "value1", time.Millisecond*10) // Will expire
|
||||
em.Set("key2", "value2", time.Hour) // Won't expire
|
||||
em.Set("key3", "value1", time.Hour) // Won't expire, duplicate value
|
||||
|
||||
// Wait for first value to expire
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
// Try to remove the expired value - should remove one of the "value1" entries
|
||||
removedValue, ok := em.RemovebyValue("value1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", removedValue)
|
||||
|
||||
// Should still have key2 (different value)
|
||||
assert.True(t, em.Has("key2"))
|
||||
|
||||
// Should have removed one of the "value1" entries (either key1 or key3)
|
||||
// But we can't predict which one due to map iteration order
|
||||
key1Exists := em.Has("key1")
|
||||
key3Exists := em.Has("key3")
|
||||
|
||||
// Exactly one of key1 or key3 should be gone
|
||||
assert.False(t, key1Exists && key3Exists) // Both shouldn't exist
|
||||
assert.True(t, key1Exists || key3Exists) // At least one should still exist
|
||||
}
|
||||
|
||||
func TestExpiryMap_ValueOperations_Integration(t *testing.T) {
|
||||
em := New[string](time.Hour)
|
||||
|
||||
// Test integration of GetByValue and RemoveValue
|
||||
em.Set("key1", "shared", time.Hour)
|
||||
em.Set("key2", "unique", time.Hour)
|
||||
em.Set("key3", "shared", time.Hour)
|
||||
|
||||
// Find shared value
|
||||
key, value, ok := em.GetByValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key)
|
||||
|
||||
// Remove shared value
|
||||
removedValue, ok := em.RemovebyValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", removedValue)
|
||||
|
||||
// Should still be able to find the other shared value
|
||||
key, value, ok = em.GetByValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", value)
|
||||
assert.Contains(t, []string{"key1", "key3"}, key)
|
||||
|
||||
// Remove the other shared value
|
||||
removedValue, ok = em.RemovebyValue("shared")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "shared", removedValue)
|
||||
|
||||
// Should not find shared value anymore
|
||||
key, value, ok = em.GetByValue("shared")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value)
|
||||
assert.Equal(t, "", key)
|
||||
|
||||
// Unique value should still exist
|
||||
key, value, ok = em.GetByValue("unique")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "unique", value)
|
||||
assert.Equal(t, "key2", key)
|
||||
}
|
@@ -4,6 +4,7 @@ package hub
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/alerts"
|
||||
"beszel/internal/hub/config"
|
||||
"beszel/internal/hub/systems"
|
||||
"beszel/internal/records"
|
||||
"beszel/internal/users"
|
||||
@@ -18,7 +19,9 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/pocketbase/pocketbase/apis"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
@@ -32,6 +35,7 @@ type Hub struct {
|
||||
rm *records.RecordManager
|
||||
sm *systems.SystemManager
|
||||
pubKey string
|
||||
signer ssh.Signer
|
||||
appURL string
|
||||
}
|
||||
|
||||
@@ -64,7 +68,7 @@ func (h *Hub) StartHub() error {
|
||||
return err
|
||||
}
|
||||
// sync systems with config
|
||||
if err := syncSystemsWithConfig(e); err != nil {
|
||||
if err := config.SyncSystems(e); err != nil {
|
||||
return err
|
||||
}
|
||||
// register api routes
|
||||
@@ -112,6 +116,9 @@ func (h *Hub) initialize(e *core.ServeEvent) error {
|
||||
if h.appURL != "" {
|
||||
settings.Meta.AppURL = h.appURL
|
||||
}
|
||||
if err := e.App.Save(settings); err != nil {
|
||||
return err
|
||||
}
|
||||
// set auth settings
|
||||
usersCollection, err := e.App.FindCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
@@ -181,6 +188,7 @@ func (h *Hub) startServer(se *core.ServeEvent) error {
|
||||
indexFile, _ := fs.ReadFile(site.DistDirFS, "index.html")
|
||||
indexContent := strings.ReplaceAll(string(indexFile), "./", basePath)
|
||||
indexContent = strings.Replace(indexContent, "{{V}}", beszel.Version, 1)
|
||||
indexContent = strings.Replace(indexContent, "{{HUB_URL}}", h.appURL, 1)
|
||||
// set up static asset serving
|
||||
staticPaths := [2]string{"/static/", "/assets/"}
|
||||
serveStatic := apis.Static(site.DistDirFS, false)
|
||||
@@ -232,7 +240,11 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error {
|
||||
// send test notification
|
||||
se.Router.GET("/api/beszel/send-test-notification", h.SendTestNotification)
|
||||
// API endpoint to get config.yml content
|
||||
se.Router.GET("/api/beszel/config-yaml", h.getYamlConfig)
|
||||
se.Router.GET("/api/beszel/config-yaml", config.GetYamlConfig)
|
||||
// handle agent websocket connection
|
||||
se.Router.GET("/api/beszel/agent-connect", h.handleAgentConnect)
|
||||
// get or create universal tokens
|
||||
se.Router.GET("/api/beszel/universal-token", h.getUniversalToken)
|
||||
// create first user endpoint only needed if no users exist
|
||||
if totalUsers, _ := h.CountRecords("users"); totalUsers == 0 {
|
||||
se.Router.POST("/api/beszel/create-user", h.um.CreateFirstUser)
|
||||
@@ -240,8 +252,49 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handler for universal token API endpoint (create, read, delete)
|
||||
func (h *Hub) getUniversalToken(e *core.RequestEvent) error {
|
||||
info, err := e.RequestInfo()
|
||||
if err != nil || info.Auth == nil {
|
||||
return apis.NewForbiddenError("Forbidden", nil)
|
||||
}
|
||||
|
||||
tokenMap := getTokenMap()
|
||||
userID := info.Auth.Id
|
||||
query := e.Request.URL.Query()
|
||||
token := query.Get("token")
|
||||
tokenSet := token != ""
|
||||
|
||||
if !tokenSet {
|
||||
// return existing token if it exists
|
||||
if token, _, ok := tokenMap.GetByValue(userID); ok {
|
||||
return e.JSON(http.StatusOK, map[string]any{"token": token, "active": true})
|
||||
}
|
||||
// if no token is provided, generate a new one
|
||||
token = uuid.New().String()
|
||||
}
|
||||
response := map[string]any{"token": token}
|
||||
|
||||
switch query.Get("enable") {
|
||||
case "1":
|
||||
tokenMap.Set(token, userID, time.Hour)
|
||||
case "0":
|
||||
tokenMap.RemovebyValue(userID)
|
||||
}
|
||||
_, response["active"] = tokenMap.GetOk(token)
|
||||
return e.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// generates key pair if it doesn't exist and returns signer
|
||||
func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) {
|
||||
if h.signer != nil {
|
||||
return h.signer, nil
|
||||
}
|
||||
|
||||
if dataDir == "" {
|
||||
dataDir = h.DataDir()
|
||||
}
|
||||
|
||||
privateKeyPath := path.Join(dataDir, "id_ed25519")
|
||||
|
||||
// check if the key pair already exists
|
||||
@@ -260,12 +313,10 @@ func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) {
|
||||
}
|
||||
|
||||
// Generate the Ed25519 key pair
|
||||
pubKey, privKey, err := ed25519.GenerateKey(nil)
|
||||
_, privKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the private key in OpenSSH format
|
||||
privKeyPem, err := ssh.MarshalPrivateKey(privKey, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -276,13 +327,11 @@ func (h *Hub) GetSSHKey(dataDir string) (ssh.Signer, error) {
|
||||
}
|
||||
|
||||
// These are fine to ignore the errors on, as we've literally just created a crypto.PublicKey | crypto.Signer
|
||||
sshPubKey, _ := ssh.NewPublicKey(pubKey)
|
||||
sshPrivate, _ := ssh.NewSignerFromSigner(privKey)
|
||||
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPrivate.PublicKey())
|
||||
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
|
||||
|
||||
h.Logger().Info("ed25519 SSH key pair generated successfully.")
|
||||
h.Logger().Info("ed25519 key pair generated successfully.")
|
||||
h.Logger().Info("Saved to: " + privateKeyPath)
|
||||
|
||||
return sshPrivate, err
|
||||
|
@@ -1,9 +1,10 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package hub
|
||||
package hub_test
|
||||
|
||||
import (
|
||||
"beszel/internal/tests"
|
||||
"testing"
|
||||
|
||||
"crypto/ed25519"
|
||||
@@ -12,20 +13,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func getTestHub() *Hub {
|
||||
app := pocketbase.New()
|
||||
return NewHub(app)
|
||||
func getTestHub(t testing.TB) *tests.TestHub {
|
||||
hub, _ := tests.NewTestHub(t.TempDir())
|
||||
return hub
|
||||
}
|
||||
|
||||
func TestMakeLink(t *testing.T) {
|
||||
hub := getTestHub()
|
||||
hub := getTestHub(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -115,14 +114,14 @@ func TestMakeLink(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetSSHKey(t *testing.T) {
|
||||
hub := getTestHub()
|
||||
hub := getTestHub(t)
|
||||
|
||||
// Test Case 1: Key generation (no existing key)
|
||||
t.Run("KeyGeneration", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Ensure pubKey is initially empty or different to ensure GetSSHKey sets it
|
||||
hub.pubKey = ""
|
||||
hub.SetPubkey("")
|
||||
|
||||
signer, err := hub.GetSSHKey(tempDir)
|
||||
assert.NoError(t, err, "GetSSHKey should not error when generating a new key")
|
||||
@@ -135,8 +134,8 @@ func TestGetSSHKey(t *testing.T) {
|
||||
assert.False(t, info.IsDir(), "Private key path should be a file, not a directory")
|
||||
|
||||
// Check if h.pubKey was set
|
||||
assert.NotEmpty(t, hub.pubKey, "h.pubKey should be set after key generation")
|
||||
assert.True(t, strings.HasPrefix(hub.pubKey, "ssh-ed25519 "), "h.pubKey should start with 'ssh-ed25519 '")
|
||||
assert.NotEmpty(t, hub.GetPubkey(), "h.pubKey should be set after key generation")
|
||||
assert.True(t, strings.HasPrefix(hub.GetPubkey(), "ssh-ed25519 "), "h.pubKey should start with 'ssh-ed25519 '")
|
||||
|
||||
// Verify the generated private key is parsable
|
||||
keyData, err := os.ReadFile(privateKeyPath)
|
||||
@@ -170,14 +169,14 @@ func TestGetSSHKey(t *testing.T) {
|
||||
expectedPubKeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(sshPubKey)))
|
||||
|
||||
// Reset h.pubKey to ensure it's set by GetSSHKey from the file
|
||||
hub.pubKey = ""
|
||||
hub.SetPubkey("")
|
||||
|
||||
signer, err := hub.GetSSHKey(tempDir)
|
||||
assert.NoError(t, err, "GetSSHKey should not error when reading an existing key")
|
||||
assert.NotNil(t, signer, "GetSSHKey should return a non-nil signer for an existing key")
|
||||
|
||||
// Check if h.pubKey was set correctly to the public key from the file
|
||||
assert.Equal(t, expectedPubKeyStr, hub.pubKey, "h.pubKey should match the existing public key")
|
||||
assert.Equal(t, expectedPubKeyStr, hub.GetPubkey(), "h.pubKey should match the existing public key")
|
||||
|
||||
// Verify the signer's public key matches the original public key
|
||||
signerPubKey := signer.PublicKey()
|
||||
@@ -241,7 +240,7 @@ func TestGetSSHKey(t *testing.T) {
|
||||
require.NoError(t, err, "Setup failed")
|
||||
|
||||
// Reset h.pubKey before each test case
|
||||
hub.pubKey = ""
|
||||
hub.SetPubkey("")
|
||||
|
||||
// Attempt to get SSH key
|
||||
_, err = hub.GetSSHKey(tempDir)
|
||||
@@ -250,8 +249,10 @@ func TestGetSSHKey(t *testing.T) {
|
||||
tc.errorCheck(t, err)
|
||||
|
||||
// Check that pubKey was not set in error cases
|
||||
assert.Empty(t, hub.pubKey, "h.pubKey should not be set if there was an error")
|
||||
assert.Empty(t, hub.GetPubkey(), "h.pubKey should not be set if there was an error")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to create test records
|
||||
|
21
beszel/internal/hub/hub_test_helpers.go
Normal file
21
beszel/internal/hub/hub_test_helpers.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
package hub
|
||||
|
||||
import "beszel/internal/hub/systems"
|
||||
|
||||
// TESTING ONLY: GetSystemManager returns the system manager
|
||||
func (h *Hub) GetSystemManager() *systems.SystemManager {
|
||||
return h.sm
|
||||
}
|
||||
|
||||
// TESTING ONLY: GetPubkey returns the public key
|
||||
func (h *Hub) GetPubkey() string {
|
||||
return h.pubKey
|
||||
}
|
||||
|
||||
// TESTING ONLY: SetPubkey sets the public key
|
||||
func (h *Hub) SetPubkey(pubkey string) {
|
||||
h.pubKey = pubkey
|
||||
}
|
387
beszel/internal/hub/systems/system.go
Normal file
387
beszel/internal/hub/systems/system.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/entities/system"
|
||||
"beszel/internal/hub/ws"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type System struct {
|
||||
Id string `db:"id"`
|
||||
Host string `db:"host"`
|
||||
Port string `db:"port"`
|
||||
Status string `db:"status"`
|
||||
manager *SystemManager // Manager that this system belongs to
|
||||
client *ssh.Client // SSH client for fetching data
|
||||
data *system.CombinedData // system data from agent
|
||||
ctx context.Context // Context for stopping the updater
|
||||
cancel context.CancelFunc // Stops and removes system from updater
|
||||
WsConn *ws.WsConn // Handler for agent WebSocket connection
|
||||
agentVersion semver.Version // Agent version
|
||||
updateTicker *time.Ticker // Ticker for updating the system
|
||||
}
|
||||
|
||||
func (sm *SystemManager) NewSystem(systemId string) *System {
|
||||
system := &System{
|
||||
Id: systemId,
|
||||
data: &system.CombinedData{},
|
||||
}
|
||||
system.ctx, system.cancel = system.getContext()
|
||||
return system
|
||||
}
|
||||
|
||||
// StartUpdater starts the system updater.
|
||||
// It first fetches the data from the agent then updates the records.
|
||||
// If the data is not found or the system is down, it sets the system down.
|
||||
func (sys *System) StartUpdater() {
|
||||
// Channel that can be used to set the system down. Currently only used to
|
||||
// allow a short delay for reconnection after websocket connection is closed.
|
||||
var downChan chan struct{}
|
||||
|
||||
// Add random jitter to first WebSocket connection to prevent
|
||||
// clustering if all agents are started at the same time.
|
||||
// SSH connections during hub startup are already staggered.
|
||||
var jitter <-chan time.Time
|
||||
if sys.WsConn != nil {
|
||||
jitter = getJitter()
|
||||
// use the websocket connection's down channel to set the system down
|
||||
downChan = sys.WsConn.DownChan
|
||||
} else {
|
||||
// if the system does not have a websocket connection, wait before updating
|
||||
// to allow the agent to connect via websocket (makes sure fingerprint is set).
|
||||
time.Sleep(11 * time.Second)
|
||||
}
|
||||
|
||||
// update immediately if system is not paused (only for ws connections)
|
||||
// we'll wait a minute before connecting via SSH to prioritize ws connections
|
||||
if sys.Status != paused && sys.ctx.Err() == nil {
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
}
|
||||
|
||||
sys.updateTicker = time.NewTicker(time.Duration(interval) * time.Millisecond)
|
||||
// Go 1.23+ will automatically stop the ticker when the system is garbage collected, however we seem to need this or testing/synctest will block even if calling runtime.GC()
|
||||
defer sys.updateTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sys.ctx.Done():
|
||||
return
|
||||
case <-sys.updateTicker.C:
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
case <-downChan:
|
||||
sys.WsConn = nil
|
||||
downChan = nil
|
||||
_ = sys.setDown(nil)
|
||||
case <-jitter:
|
||||
sys.updateTicker.Reset(time.Duration(interval) * time.Millisecond)
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the system data and records.
|
||||
func (sys *System) update() error {
|
||||
if sys.Status == paused {
|
||||
sys.handlePaused()
|
||||
return nil
|
||||
}
|
||||
data, err := sys.fetchDataFromAgent()
|
||||
if err == nil {
|
||||
_, err = sys.createRecords(data)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (sys *System) handlePaused() {
|
||||
if sys.WsConn == nil {
|
||||
// if the system is paused and there's no websocket connection, remove the system
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
} else {
|
||||
// Send a ping to the agent to keep the connection alive if the system is paused
|
||||
if err := sys.WsConn.Ping(); err != nil {
|
||||
sys.manager.hub.Logger().Warn("Failed to ping agent", "system", sys.Id, "err", err)
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createRecords updates the system record and adds system_stats and container_stats records
|
||||
func (sys *System) createRecords(data *system.CombinedData) (*core.Record, error) {
|
||||
systemRecord, err := sys.getRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hub := sys.manager.hub
|
||||
// add system_stats and container_stats records
|
||||
systemStatsCollection, err := hub.FindCachedCollectionByNameOrId("system_stats")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
systemStatsRecord := core.NewRecord(systemStatsCollection)
|
||||
systemStatsRecord.Set("system", systemRecord.Id)
|
||||
systemStatsRecord.Set("stats", data.Stats)
|
||||
systemStatsRecord.Set("type", "1m")
|
||||
if err := hub.SaveNoValidate(systemStatsRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// add new container_stats record
|
||||
if len(data.Containers) > 0 {
|
||||
containerStatsCollection, err := hub.FindCachedCollectionByNameOrId("container_stats")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
containerStatsRecord := core.NewRecord(containerStatsCollection)
|
||||
containerStatsRecord.Set("system", systemRecord.Id)
|
||||
containerStatsRecord.Set("stats", data.Containers)
|
||||
containerStatsRecord.Set("type", "1m")
|
||||
if err := hub.SaveNoValidate(containerStatsRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// update system record (do this last because it triggers alerts and we need above records to be inserted first)
|
||||
systemRecord.Set("status", up)
|
||||
|
||||
systemRecord.Set("info", data.Info)
|
||||
if err := hub.SaveNoValidate(systemRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return systemRecord, nil
|
||||
}
|
||||
|
||||
// getRecord retrieves the system record from the database.
|
||||
// If the record is not found, it removes the system from the manager.
|
||||
func (sys *System) getRecord() (*core.Record, error) {
|
||||
record, err := sys.manager.hub.FindRecordById("systems", sys.Id)
|
||||
if err != nil || record == nil {
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// setDown marks a system as down in the database.
|
||||
// It takes the original error that caused the system to go down and returns any error
|
||||
// encountered during the process of updating the system status.
|
||||
func (sys *System) setDown(originalError error) error {
|
||||
if sys.Status == down || sys.Status == paused {
|
||||
return nil
|
||||
}
|
||||
record, err := sys.getRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if originalError != nil {
|
||||
sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", originalError)
|
||||
}
|
||||
record.Set("status", down)
|
||||
return sys.manager.hub.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
func (sys *System) getContext() (context.Context, context.CancelFunc) {
|
||||
if sys.ctx == nil {
|
||||
sys.ctx, sys.cancel = context.WithCancel(context.Background())
|
||||
}
|
||||
return sys.ctx, sys.cancel
|
||||
}
|
||||
|
||||
// fetchDataFromAgent attempts to fetch data from the agent,
|
||||
// prioritizing WebSocket if available.
|
||||
func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) {
|
||||
if sys.data == nil {
|
||||
sys.data = &system.CombinedData{}
|
||||
}
|
||||
|
||||
if sys.WsConn != nil && sys.WsConn.IsConnected() {
|
||||
wsData, err := sys.fetchDataViaWebSocket()
|
||||
if err == nil {
|
||||
return wsData, nil
|
||||
}
|
||||
// close the WebSocket connection if error and try SSH
|
||||
sys.closeWebSocketConnection()
|
||||
}
|
||||
|
||||
sshData, err := sys.fetchDataViaSSH()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshData, nil
|
||||
}
|
||||
|
||||
func (sys *System) fetchDataViaWebSocket() (*system.CombinedData, error) {
|
||||
if sys.WsConn == nil || !sys.WsConn.IsConnected() {
|
||||
return nil, errors.New("no websocket connection")
|
||||
}
|
||||
err := sys.WsConn.RequestSystemData(sys.data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sys.data, nil
|
||||
}
|
||||
|
||||
// fetchDataViaSSH handles fetching data using SSH.
|
||||
// This function encapsulates the original SSH logic.
|
||||
// It updates sys.data directly upon successful fetch.
|
||||
func (sys *System) fetchDataViaSSH() (*system.CombinedData, error) {
|
||||
maxRetries := 1
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if sys.client == nil || sys.Status == down {
|
||||
if err := sys.createSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
session, err := sys.createSessionWithTimeout(4 * time.Second)
|
||||
if err != nil {
|
||||
if attempt >= maxRetries {
|
||||
return nil, err
|
||||
}
|
||||
sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err)
|
||||
sys.closeSSHConnection()
|
||||
// Reset format detection on connection failure - agent might have been upgraded
|
||||
continue
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := session.Shell(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
*sys.data = system.CombinedData{}
|
||||
|
||||
if sys.agentVersion.GTE(beszel.MinVersionCbor) {
|
||||
err = cbor.NewDecoder(stdout).Decode(sys.data)
|
||||
} else {
|
||||
err = json.NewDecoder(stdout).Decode(sys.data)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
sys.closeSSHConnection()
|
||||
if attempt < maxRetries {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// wait for the session to complete
|
||||
if err := session.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sys.data, nil
|
||||
}
|
||||
|
||||
// this should never be reached due to the return in the loop
|
||||
return nil, fmt.Errorf("failed to fetch data")
|
||||
}
|
||||
|
||||
// createSSHClient creates a new SSH client for the system
|
||||
func (s *System) createSSHClient() error {
|
||||
if s.manager.sshConfig == nil {
|
||||
if err := s.manager.createSSHClientConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
network := "tcp"
|
||||
host := s.Host
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
} else {
|
||||
host = net.JoinHostPort(host, s.Port)
|
||||
}
|
||||
var err error
|
||||
s.client, err = ssh.Dial(network, host, s.manager.sshConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.agentVersion, _ = extractAgentVersion(string(s.client.Conn.ServerVersion()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging
|
||||
// in case of network issues
|
||||
func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) {
|
||||
if sys.client == nil {
|
||||
return nil, fmt.Errorf("client not initialized")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(sys.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
sessionChan := make(chan *ssh.Session, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if session, err := sys.client.NewSession(); err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
sessionChan <- session
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case session := <-sessionChan:
|
||||
return session, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// closeSSHConnection closes the SSH connection but keeps the system in the manager
|
||||
func (sys *System) closeSSHConnection() {
|
||||
if sys.client != nil {
|
||||
sys.client.Close()
|
||||
sys.client = nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeWebSocketConnection closes the WebSocket connection but keeps the system in the manager
|
||||
// to allow updating via SSH. It will be removed if the WS connection is re-established.
|
||||
// The system will be set as down a few seconds later if the connection is not re-established.
|
||||
func (sys *System) closeWebSocketConnection() {
|
||||
if sys.WsConn != nil {
|
||||
sys.WsConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// extractAgentVersion extracts the beszel version from SSH server version string
|
||||
func extractAgentVersion(versionString string) (semver.Version, error) {
|
||||
_, after, _ := strings.Cut(versionString, "_")
|
||||
return semver.Parse(after)
|
||||
}
|
||||
|
||||
// getJitter returns a channel that will be triggered after a random delay
|
||||
// between 40% and 90% of the interval.
|
||||
// This is used to stagger the initial WebSocket connections to prevent clustering.
|
||||
func getJitter() <-chan time.Time {
|
||||
minPercent := 40
|
||||
maxPercent := 90
|
||||
jitterRange := maxPercent - minPercent
|
||||
msDelay := (interval * minPercent / 100) + rand.Intn(interval*jitterRange/100)
|
||||
return time.After(time.Duration(msDelay) * time.Millisecond)
|
||||
}
|
345
beszel/internal/hub/systems/system_manager.go
Normal file
345
beszel/internal/hub/systems/system_manager.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"beszel"
|
||||
"beszel/internal/common"
|
||||
"beszel/internal/entities/system"
|
||||
"beszel/internal/hub/ws"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// System status constants
|
||||
const (
|
||||
up string = "up" // System is online and responding
|
||||
down string = "down" // System is offline or not responding
|
||||
paused string = "paused" // System monitoring is paused
|
||||
pending string = "pending" // System is waiting on initial connection result
|
||||
|
||||
// interval is the default update interval in milliseconds (60 seconds)
|
||||
interval int = 60_000
|
||||
// interval int = 10_000 // Debug interval for faster updates
|
||||
|
||||
// sessionTimeout is the maximum time to wait for SSH connections
|
||||
sessionTimeout = 4 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// errSystemExists is returned when attempting to add a system that already exists
|
||||
errSystemExists = errors.New("system exists")
|
||||
)
|
||||
|
||||
// SystemManager manages a collection of monitored systems and their connections.
|
||||
// It handles system lifecycle, status updates, and maintains both SSH and WebSocket connections.
|
||||
type SystemManager struct {
|
||||
hub hubLike // Hub interface for database and alert operations
|
||||
systems *store.Store[string, *System] // Thread-safe store of active systems
|
||||
sshConfig *ssh.ClientConfig // SSH client configuration for system connections
|
||||
}
|
||||
|
||||
// hubLike defines the interface requirements for the hub dependency.
|
||||
// It extends core.App with system-specific functionality.
|
||||
type hubLike interface {
|
||||
core.App
|
||||
GetSSHKey(dataDir string) (ssh.Signer, error)
|
||||
HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error
|
||||
HandleStatusAlerts(status string, systemRecord *core.Record) error
|
||||
}
|
||||
|
||||
// NewSystemManager creates a new SystemManager instance with the provided hub.
|
||||
// The hub must implement the hubLike interface to provide database and alert functionality.
|
||||
func NewSystemManager(hub hubLike) *SystemManager {
|
||||
return &SystemManager{
|
||||
systems: store.New(map[string]*System{}),
|
||||
hub: hub,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sets up the system manager by binding event hooks and starting existing systems.
|
||||
// It configures SSH client settings and begins monitoring all non-paused systems from the database.
|
||||
// Systems are started with staggered delays to prevent overwhelming the hub during startup.
|
||||
func (sm *SystemManager) Initialize() error {
|
||||
sm.bindEventHooks()
|
||||
|
||||
// Initialize SSH client configuration
|
||||
err := sm.createSSHClientConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load existing systems from database (excluding paused ones)
|
||||
var systems []*System
|
||||
err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems)
|
||||
if err != nil || len(systems) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start systems in background with staggered timing
|
||||
go func() {
|
||||
// Calculate staggered delay between system starts (max 2 seconds per system)
|
||||
delta := interval / max(1, len(systems))
|
||||
delta = min(delta, 2_000)
|
||||
sleepTime := time.Duration(delta) * time.Millisecond
|
||||
|
||||
for _, system := range systems {
|
||||
time.Sleep(sleepTime)
|
||||
_ = sm.AddSystem(system)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindEventHooks registers event handlers for system and fingerprint record changes.
|
||||
// These hooks ensure the system manager stays synchronized with database changes.
|
||||
func (sm *SystemManager) bindEventHooks() {
|
||||
sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate)
|
||||
sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess)
|
||||
sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate)
|
||||
sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess)
|
||||
sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess)
|
||||
sm.hub.OnRecordAfterUpdateSuccess("fingerprints").BindFunc(sm.onTokenRotated)
|
||||
}
|
||||
|
||||
// onTokenRotated handles fingerprint token rotation events.
|
||||
// When a system's authentication token is rotated, any existing WebSocket connection
|
||||
// must be closed to force re-authentication with the new token.
|
||||
func (sm *SystemManager) onTokenRotated(e *core.RecordEvent) error {
|
||||
systemID := e.Record.GetString("system")
|
||||
system, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return e.Next()
|
||||
}
|
||||
// No need to close connection if not connected via websocket
|
||||
if system.WsConn == nil {
|
||||
return e.Next()
|
||||
}
|
||||
system.setDown(nil)
|
||||
sm.RemoveSystem(systemID)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordCreate is called before a new system record is committed to the database.
|
||||
// It initializes the record with default values: empty info and pending status.
|
||||
func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error {
|
||||
e.Record.Set("info", system.Info{})
|
||||
e.Record.Set("status", pending)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterCreateSuccess is called after a new system record is successfully created.
|
||||
// It adds the new system to the manager to begin monitoring.
|
||||
func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error {
|
||||
if err := sm.AddRecord(e.Record, nil); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordUpdate is called before a system record is updated in the database.
|
||||
// It clears system info when the status is changed to paused.
|
||||
func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error {
|
||||
if e.Record.GetString("status") == paused {
|
||||
e.Record.Set("info", system.Info{})
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterUpdateSuccess handles system record updates after they're committed to the database.
|
||||
// It manages system lifecycle based on status changes and triggers appropriate alerts.
|
||||
// Status transitions are handled as follows:
|
||||
// - paused: Closes SSH connection and deactivates alerts
|
||||
// - pending: Starts monitoring (reuses WebSocket if available)
|
||||
// - up: Triggers system alerts
|
||||
// - down: Triggers status change alerts
|
||||
func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error {
|
||||
newStatus := e.Record.GetString("status")
|
||||
system, ok := sm.systems.GetOk(e.Record.Id)
|
||||
if ok {
|
||||
system.Status = newStatus
|
||||
}
|
||||
|
||||
switch newStatus {
|
||||
case paused:
|
||||
if ok {
|
||||
// Pause monitoring but keep system in manager for potential resume
|
||||
system.closeSSHConnection()
|
||||
}
|
||||
_ = deactivateAlerts(e.App, e.Record.Id)
|
||||
return e.Next()
|
||||
case pending:
|
||||
// Resume monitoring, preferring existing WebSocket connection
|
||||
if ok && system.WsConn != nil {
|
||||
go system.update()
|
||||
return e.Next()
|
||||
}
|
||||
// Start new monitoring session
|
||||
if err := sm.AddRecord(e.Record, nil); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Handle systems not in manager
|
||||
if !ok {
|
||||
return sm.AddRecord(e.Record, nil)
|
||||
}
|
||||
|
||||
prevStatus := system.Status
|
||||
|
||||
// Trigger system alerts when system comes online
|
||||
if newStatus == up {
|
||||
if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil {
|
||||
e.App.Logger().Error("Error handling system alerts", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger status change alerts for up/down transitions
|
||||
if (newStatus == down && prevStatus == up) || (newStatus == up && prevStatus == down) {
|
||||
if err := sm.hub.HandleStatusAlerts(newStatus, e.Record); err != nil {
|
||||
e.App.Logger().Error("Error handling status alerts", "err", err)
|
||||
}
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// onRecordAfterDeleteSuccess is called after a system record is successfully deleted.
|
||||
// It removes the system from the manager and cleans up all associated resources.
|
||||
func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error {
|
||||
sm.RemoveSystem(e.Record.Id)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// AddSystem adds a system to the manager and starts monitoring it.
|
||||
// It validates required fields, initializes the system context, and starts the update goroutine.
|
||||
// Returns error if a system with the same ID already exists.
|
||||
func (sm *SystemManager) AddSystem(sys *System) error {
|
||||
if sm.systems.Has(sys.Id) {
|
||||
return errSystemExists
|
||||
}
|
||||
if sys.Id == "" || sys.Host == "" {
|
||||
return errors.New("system missing required fields")
|
||||
}
|
||||
|
||||
// Initialize system for monitoring
|
||||
sys.manager = sm
|
||||
sys.ctx, sys.cancel = sys.getContext()
|
||||
sys.data = &system.CombinedData{}
|
||||
sm.systems.Set(sys.Id, sys)
|
||||
|
||||
// Start monitoring in background
|
||||
go sys.StartUpdater()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSystem removes a system from the manager and cleans up all associated resources.
|
||||
// It cancels the system's context, closes all connections, and removes it from the store.
|
||||
// Returns an error if the system is not found.
|
||||
func (sm *SystemManager) RemoveSystem(systemID string) error {
|
||||
system, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return errors.New("system not found")
|
||||
}
|
||||
|
||||
// Stop the update goroutine
|
||||
if system.cancel != nil {
|
||||
system.cancel()
|
||||
}
|
||||
|
||||
// Clean up all connections
|
||||
system.closeSSHConnection()
|
||||
system.closeWebSocketConnection()
|
||||
sm.systems.Remove(systemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRecord creates a System instance from a database record and adds it to the manager.
|
||||
// If a system with the same ID already exists, it's removed first to ensure clean state.
|
||||
// If no system instance is provided, a new one is created.
|
||||
// This method is typically called when systems are created or their status changes to pending.
|
||||
func (sm *SystemManager) AddRecord(record *core.Record, system *System) (err error) {
|
||||
// Remove existing system to ensure clean state
|
||||
if sm.systems.Has(record.Id) {
|
||||
_ = sm.RemoveSystem(record.Id)
|
||||
}
|
||||
|
||||
// Create new system if none provided
|
||||
if system == nil {
|
||||
system = sm.NewSystem(record.Id)
|
||||
}
|
||||
|
||||
// Populate system from record
|
||||
system.Status = record.GetString("status")
|
||||
system.Host = record.GetString("host")
|
||||
system.Port = record.GetString("port")
|
||||
|
||||
return sm.AddSystem(system)
|
||||
}
|
||||
|
||||
// AddWebSocketSystem creates and adds a system with an established WebSocket connection.
|
||||
// This method is called when an agent connects via WebSocket with valid authentication.
|
||||
// The system is immediately added to monitoring with the provided connection and version info.
|
||||
func (sm *SystemManager) AddWebSocketSystem(systemId string, agentVersion semver.Version, wsConn *ws.WsConn) error {
|
||||
systemRecord, err := sm.hub.FindRecordById("systems", systemId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
system := sm.NewSystem(systemId)
|
||||
system.WsConn = wsConn
|
||||
system.agentVersion = agentVersion
|
||||
|
||||
if err := sm.AddRecord(systemRecord, system); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSSHClientConfig initializes the SSH client configuration for connecting to an agent's server
|
||||
func (sm *SystemManager) createSSHClientConfig() error {
|
||||
privateKey, err := sm.hub.GetSSHKey("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sm.sshConfig = &ssh.ClientConfig{
|
||||
User: "u",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(privateKey),
|
||||
},
|
||||
Config: ssh.Config{
|
||||
Ciphers: common.DefaultCiphers,
|
||||
KeyExchanges: common.DefaultKeyExchanges,
|
||||
MACs: common.DefaultMACs,
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
ClientVersion: fmt.Sprintf("SSH-2.0-%s_%s", beszel.AppName, beszel.Version),
|
||||
Timeout: sessionTimeout,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deactivateAlerts finds all triggered alerts for a system and sets them to inactive.
|
||||
// This is called when a system is paused or goes offline to prevent continued alerts.
|
||||
func deactivateAlerts(app core.App, systemID string) error {
|
||||
// Note: Direct SQL updates don't trigger SSE, so we use the PocketBase API
|
||||
// _, err := app.DB().NewQuery(fmt.Sprintf("UPDATE alerts SET triggered = false WHERE system = '%s'", systemID)).Execute()
|
||||
|
||||
alerts, err := app.FindRecordsByFilter("alerts", fmt.Sprintf("system = '%s' && triggered = 1", systemID), "", -1, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, alert := range alerts {
|
||||
alert.Set("triggered", false)
|
||||
if err := app.SaveNoValidate(alert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,457 +0,0 @@
|
||||
package systems
|
||||
|
||||
import (
|
||||
"beszel/internal/common"
|
||||
"beszel/internal/entities/system"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/tools/store"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
up string = "up"
|
||||
down string = "down"
|
||||
paused string = "paused"
|
||||
pending string = "pending"
|
||||
|
||||
interval int = 60_000
|
||||
|
||||
sessionTimeout = 4 * time.Second
|
||||
)
|
||||
|
||||
type SystemManager struct {
|
||||
hub hubLike
|
||||
systems *store.Store[string, *System]
|
||||
sshConfig *ssh.ClientConfig
|
||||
}
|
||||
|
||||
type System struct {
|
||||
Id string `db:"id"`
|
||||
Host string `db:"host"`
|
||||
Port string `db:"port"`
|
||||
Status string `db:"status"`
|
||||
manager *SystemManager
|
||||
client *ssh.Client
|
||||
data *system.CombinedData
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type hubLike interface {
|
||||
core.App
|
||||
GetSSHKey(dataDir string) (ssh.Signer, error)
|
||||
HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error
|
||||
HandleStatusAlerts(status string, systemRecord *core.Record) error
|
||||
}
|
||||
|
||||
func NewSystemManager(hub hubLike) *SystemManager {
|
||||
return &SystemManager{
|
||||
systems: store.New(map[string]*System{}),
|
||||
hub: hub,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize initializes the system manager.
|
||||
// It binds the event hooks and starts updating existing systems.
|
||||
func (sm *SystemManager) Initialize() error {
|
||||
sm.bindEventHooks()
|
||||
// ssh setup
|
||||
err := sm.createSSHClientConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// start updating existing systems
|
||||
var systems []*System
|
||||
err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems)
|
||||
if err != nil || len(systems) == 0 {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
// time between initial system updates
|
||||
delta := interval / max(1, len(systems))
|
||||
delta = min(delta, 2_000)
|
||||
sleepTime := time.Duration(delta) * time.Millisecond
|
||||
for _, system := range systems {
|
||||
time.Sleep(sleepTime)
|
||||
_ = sm.AddSystem(system)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SystemManager) bindEventHooks() {
|
||||
sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate)
|
||||
sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess)
|
||||
sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate)
|
||||
sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess)
|
||||
sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess)
|
||||
}
|
||||
|
||||
// Runs before the record is committed to the database
|
||||
func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error {
|
||||
e.Record.Set("info", system.Info{})
|
||||
e.Record.Set("status", pending)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Runs after the record is committed to the database
|
||||
func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error {
|
||||
if err := sm.AddRecord(e.Record); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Runs before the record is updated
|
||||
func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error {
|
||||
if e.Record.GetString("status") == paused {
|
||||
e.Record.Set("info", system.Info{})
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Runs after the record is updated
|
||||
func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error {
|
||||
newStatus := e.Record.GetString("status")
|
||||
switch newStatus {
|
||||
case paused:
|
||||
_ = sm.RemoveSystem(e.Record.Id)
|
||||
_ = deactivateAlerts(e.App, e.Record.Id)
|
||||
return e.Next()
|
||||
case pending:
|
||||
if err := sm.AddRecord(e.Record); err != nil {
|
||||
e.App.Logger().Error("Error adding record", "err", err)
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
system, ok := sm.systems.GetOk(e.Record.Id)
|
||||
if !ok {
|
||||
return sm.AddRecord(e.Record)
|
||||
}
|
||||
prevStatus := system.Status
|
||||
system.Status = newStatus
|
||||
// system alerts if system is up
|
||||
if system.Status == up {
|
||||
if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil {
|
||||
e.App.Logger().Error("Error handling system alerts", "err", err)
|
||||
}
|
||||
}
|
||||
if (system.Status == down && prevStatus == up) || (system.Status == up && prevStatus == down) {
|
||||
if err := sm.hub.HandleStatusAlerts(system.Status, e.Record); err != nil {
|
||||
e.App.Logger().Error("Error handling status alerts", "err", err)
|
||||
}
|
||||
}
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// Runs after the record is deleted
|
||||
func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error {
|
||||
sm.RemoveSystem(e.Record.Id)
|
||||
return e.Next()
|
||||
}
|
||||
|
||||
// AddSystem adds a system to the manager
|
||||
func (sm *SystemManager) AddSystem(sys *System) error {
|
||||
if sm.systems.Has(sys.Id) {
|
||||
return fmt.Errorf("system exists")
|
||||
}
|
||||
if sys.Id == "" || sys.Host == "" {
|
||||
return fmt.Errorf("system is missing required fields")
|
||||
}
|
||||
sys.manager = sm
|
||||
sys.ctx, sys.cancel = context.WithCancel(context.Background())
|
||||
sys.data = &system.CombinedData{}
|
||||
sm.systems.Set(sys.Id, sys)
|
||||
go sys.StartUpdater()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSystem removes a system from the manager
|
||||
func (sm *SystemManager) RemoveSystem(systemID string) error {
|
||||
system, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
return fmt.Errorf("system not found")
|
||||
}
|
||||
// cancel the context to signal stop
|
||||
if system.cancel != nil {
|
||||
system.cancel()
|
||||
}
|
||||
system.resetSSHClient()
|
||||
sm.systems.Remove(systemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRecord adds a record to the system manager.
|
||||
// It first removes any existing system with the same ID, then creates a new System
|
||||
// instance from the record data and adds it to the manager.
|
||||
// This function is typically called when a new system is created or when an existing
|
||||
// system's status changes to pending.
|
||||
func (sm *SystemManager) AddRecord(record *core.Record) (err error) {
|
||||
_ = sm.RemoveSystem(record.Id)
|
||||
system := &System{
|
||||
Id: record.Id,
|
||||
Status: record.GetString("status"),
|
||||
Host: record.GetString("host"),
|
||||
Port: record.GetString("port"),
|
||||
}
|
||||
return sm.AddSystem(system)
|
||||
}
|
||||
|
||||
// StartUpdater starts the system updater.
|
||||
// It first fetches the data from the agent then updates the records.
|
||||
// If the data is not found or the system is down, it sets the system down.
|
||||
func (sys *System) StartUpdater() {
|
||||
if sys.data == nil {
|
||||
sys.data = &system.CombinedData{}
|
||||
}
|
||||
if err := sys.update(); err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
|
||||
c := time.Tick(time.Duration(interval) * time.Millisecond)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sys.ctx.Done():
|
||||
return
|
||||
case <-c:
|
||||
err := sys.update()
|
||||
if err != nil {
|
||||
_ = sys.setDown(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update updates the system data and records.
|
||||
// It first fetches the data from the agent then updates the records.
|
||||
func (sys *System) update() error {
|
||||
_, err := sys.fetchDataFromAgent()
|
||||
if err == nil {
|
||||
_, err = sys.createRecords()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// createRecords updates the system record and adds system_stats and container_stats records
|
||||
func (sys *System) createRecords() (*core.Record, error) {
|
||||
systemRecord, err := sys.getRecord()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hub := sys.manager.hub
|
||||
// add system_stats and container_stats records
|
||||
systemStats, err := hub.FindCachedCollectionByNameOrId("system_stats")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
systemStatsRecord := core.NewRecord(systemStats)
|
||||
systemStatsRecord.Set("system", systemRecord.Id)
|
||||
systemStatsRecord.Set("stats", sys.data.Stats)
|
||||
systemStatsRecord.Set("type", "1m")
|
||||
if err := hub.SaveNoValidate(systemStatsRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// add new container_stats record
|
||||
if len(sys.data.Containers) > 0 {
|
||||
containerStats, err := hub.FindCachedCollectionByNameOrId("container_stats")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
containerStatsRecord := core.NewRecord(containerStats)
|
||||
containerStatsRecord.Set("system", systemRecord.Id)
|
||||
containerStatsRecord.Set("stats", sys.data.Containers)
|
||||
containerStatsRecord.Set("type", "1m")
|
||||
if err := hub.SaveNoValidate(containerStatsRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// update system record (do this last because it triggers alerts and we need above records to be inserted first)
|
||||
systemRecord.Set("status", up)
|
||||
systemRecord.Set("info", sys.data.Info)
|
||||
if err := hub.SaveNoValidate(systemRecord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return systemRecord, nil
|
||||
}
|
||||
|
||||
// getRecord retrieves the system record from the database.
|
||||
// If the record is not found or the system is paused, it removes the system from the manager.
|
||||
func (sys *System) getRecord() (*core.Record, error) {
|
||||
record, err := sys.manager.hub.FindRecordById("systems", sys.Id)
|
||||
if err != nil || record == nil {
|
||||
_ = sys.manager.RemoveSystem(sys.Id)
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// setDown marks a system as down in the database.
|
||||
// It takes the original error that caused the system to go down and returns any error
|
||||
// encountered during the process of updating the system status.
|
||||
func (sys *System) setDown(OriginalError error) error {
|
||||
if sys.Status == down {
|
||||
return nil
|
||||
}
|
||||
record, err := sys.getRecord()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", OriginalError)
|
||||
record.Set("status", down)
|
||||
err = sys.manager.hub.SaveNoValidate(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchDataFromAgent fetches the data from the agent.
|
||||
// It first creates a new SSH client if it doesn't exist or the system is down.
|
||||
// Then it creates a new SSH session and fetches the data from the agent.
|
||||
// If the data is not found or the system is down, it sets the system down.
|
||||
func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) {
|
||||
maxRetries := 1
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if sys.client == nil || sys.Status == down {
|
||||
if err := sys.createSSHClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
session, err := sys.createSessionWithTimeout(4 * time.Second)
|
||||
if err != nil {
|
||||
if attempt >= maxRetries {
|
||||
return nil, err
|
||||
}
|
||||
sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err)
|
||||
sys.resetSSHClient()
|
||||
continue
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := session.Shell(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this is initialized in startUpdater, should never be nil
|
||||
*sys.data = system.CombinedData{}
|
||||
if err := json.NewDecoder(stdout).Decode(sys.data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// wait for the session to complete
|
||||
if err := session.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sys.data, nil
|
||||
}
|
||||
|
||||
// this should never be reached due to the return in the loop
|
||||
return nil, fmt.Errorf("failed to fetch data")
|
||||
}
|
||||
|
||||
// createSSHClientConfig initializes the ssh config for the system manager
|
||||
func (sm *SystemManager) createSSHClientConfig() error {
|
||||
privateKey, err := sm.hub.GetSSHKey(sm.hub.DataDir())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sm.sshConfig = &ssh.ClientConfig{
|
||||
User: "u",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(privateKey),
|
||||
},
|
||||
Config: ssh.Config{
|
||||
Ciphers: common.DefaultCiphers,
|
||||
KeyExchanges: common.DefaultKeyExchanges,
|
||||
MACs: common.DefaultMACs,
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: sessionTimeout,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSSHClient creates a new SSH client for the system
|
||||
func (s *System) createSSHClient() error {
|
||||
network := "tcp"
|
||||
host := s.Host
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
} else {
|
||||
host = net.JoinHostPort(host, s.Port)
|
||||
}
|
||||
var err error
|
||||
s.client, err = ssh.Dial(network, host, s.manager.sshConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging
|
||||
// in case of network issues
|
||||
func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) {
|
||||
if sys.client == nil {
|
||||
return nil, fmt.Errorf("client not initialized")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(sys.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
sessionChan := make(chan *ssh.Session, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if session, err := sys.client.NewSession(); err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
sessionChan <- session
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case session := <-sessionChan:
|
||||
return session, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// resetSSHClient closes the SSH connection and resets the client to nil
|
||||
func (sys *System) resetSSHClient() {
|
||||
if sys.client != nil {
|
||||
sys.client.Close()
|
||||
}
|
||||
sys.client = nil
|
||||
}
|
||||
|
||||
// deactivateAlerts finds all triggered alerts for a system and sets them to false
|
||||
func deactivateAlerts(app core.App, systemID string) error {
|
||||
// we can't use an UPDATE query because it doesn't work with realtime updates
|
||||
// _, err := e.App.DB().NewQuery(fmt.Sprintf("UPDATE alerts SET triggered = false WHERE system = '%s'", e.Record.Id)).Execute()
|
||||
alerts, err := app.FindRecordsByFilter("alerts", fmt.Sprintf("system = '%s' && triggered = 1", systemID), "", -1, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, alert := range alerts {
|
||||
alert.Set("triggered", false)
|
||||
if err := app.SaveNoValidate(alert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -11,70 +11,133 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createTestSystem creates a test system record with a unique host name
|
||||
// and returns the created record and any error
|
||||
func createTestSystem(t *testing.T, hub *tests.TestHub, options map[string]any) (*core.Record, error) {
|
||||
collection, err := hub.FindCachedCollectionByNameOrId("systems")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get user record
|
||||
var firstUser *core.Record
|
||||
users, err := hub.FindAllRecords("users", dbx.NewExp("id != ''"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(users) > 0 {
|
||||
firstUser = users[0]
|
||||
}
|
||||
// Generate a unique host name to ensure we're adding a new system
|
||||
uniqueHost := fmt.Sprintf("test-host-%d.example.com", time.Now().UnixNano())
|
||||
|
||||
// Create the record
|
||||
record := core.NewRecord(collection)
|
||||
record.Set("name", uniqueHost)
|
||||
record.Set("host", uniqueHost)
|
||||
record.Set("port", "45876")
|
||||
record.Set("status", "pending")
|
||||
record.Set("users", []string{firstUser.Id})
|
||||
|
||||
// Apply any custom options
|
||||
for key, value := range options {
|
||||
record.Set(key, value)
|
||||
}
|
||||
|
||||
// Save the record to the database
|
||||
err = hub.Save(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func TestSystemManagerIntegration(t *testing.T) {
|
||||
// Create a test hub
|
||||
hub, err := tests.NewTestHub()
|
||||
func TestSystemManagerNew(t *testing.T) {
|
||||
hub, err := tests.NewTestHub(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer hub.Cleanup()
|
||||
sm := hub.GetSystemManager()
|
||||
|
||||
// Create independent system manager
|
||||
sm := systems.NewSystemManager(hub)
|
||||
user, err := tests.CreateUser(hub, "test@test.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
synctest.Run(func() {
|
||||
sm.Initialize()
|
||||
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "it-was-coney-island",
|
||||
"host": "the-playground-of-the-world",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "pending", record.GetString("status"), "System status should be 'pending'")
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
|
||||
// Verify the system host and port
|
||||
host, port := sm.GetSystemHostPort(record.Id)
|
||||
assert.Equal(t, record.GetString("host"), host, "System host should match")
|
||||
assert.Equal(t, record.GetString("port"), port, "System port should match")
|
||||
|
||||
time.Sleep(13 * time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
assert.Equal(t, "pending", record.Fresh().GetString("status"), "System status should be 'pending'")
|
||||
// Verify the system was added by checking if it exists
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
// system should be set to down after 15 seconds (no websocket connection)
|
||||
assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'")
|
||||
// make sure the system is down in the db
|
||||
record, err = hub.FindRecordById("systems", record.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "down", record.GetString("status"), "System status should be 'down'")
|
||||
|
||||
assert.Equal(t, 1, sm.GetSystemCount(), "System count should be 1")
|
||||
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0")
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal")
|
||||
|
||||
// let's also make sure a system is removed from the store when the record is deleted
|
||||
record, err = tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "there-was-no-place-like-it",
|
||||
"host": "in-the-whole-world",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store after creation")
|
||||
|
||||
time.Sleep(8 * time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
|
||||
sm.SetSystemStatusInDB(record.Id, "up")
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'")
|
||||
|
||||
// make sure the system switches to down after 11 seconds
|
||||
sm.RemoveSystem(record.Id)
|
||||
sm.AddRecord(record, nil)
|
||||
assert.Equal(t, "pending", sm.GetSystemStatusFromStore(record.Id), "System status should be 'pending'")
|
||||
time.Sleep(12 * time.Second)
|
||||
synctest.Wait()
|
||||
assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'")
|
||||
|
||||
// sm.SetSystemStatusInDB(record.Id, "paused")
|
||||
// time.Sleep(time.Second)
|
||||
// synctest.Wait()
|
||||
// assert.Equal(t, "paused", sm.GetSystemStatusFromStore(record.Id), "System status should be 'paused'")
|
||||
|
||||
// delete the record
|
||||
err = hub.Delete(record)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion")
|
||||
|
||||
testOld(t, hub)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
synctest.Wait()
|
||||
|
||||
for _, systemId := range sm.GetAllSystemIDs() {
|
||||
err = sm.RemoveSystem(systemId)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sm.HasSystem(systemId), "System should not exist in the store after deletion")
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, sm.GetSystemCount(), "System count should be 0")
|
||||
|
||||
// TODO: test with websocket client
|
||||
})
|
||||
}
|
||||
|
||||
func testOld(t *testing.T, hub *tests.TestHub) {
|
||||
user, err := tests.CreateUser(hub, "test@testy.com", "testtesttest")
|
||||
require.NoError(t, err)
|
||||
|
||||
sm := hub.GetSystemManager()
|
||||
assert.NotNil(t, sm)
|
||||
|
||||
// Test initialization
|
||||
sm.Initialize()
|
||||
// error expected when creating a user with a duplicate email
|
||||
_, err = tests.CreateUser(hub, "test@test.com", "testtesttest")
|
||||
require.Error(t, err)
|
||||
|
||||
// Test collection existence. todo: move to hub package tests
|
||||
t.Run("CollectionExistence", func(t *testing.T) {
|
||||
@@ -92,81 +155,17 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
assert.NotNil(t, containerStats)
|
||||
})
|
||||
|
||||
// Test adding a system record
|
||||
t.Run("AddRecord", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Get the count before adding the system
|
||||
countBefore := sm.GetSystemCount()
|
||||
|
||||
// record should be pending on create
|
||||
hub.OnRecordCreate("systems").BindFunc(func(e *core.RecordEvent) error {
|
||||
record := e.Record
|
||||
if record.GetString("name") == "welcometoarcoampm" {
|
||||
assert.Equal(t, "pending", e.Record.GetString("status"), "System status should be 'pending'")
|
||||
wg.Done()
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// record should be down on update
|
||||
hub.OnRecordAfterUpdateSuccess("systems").BindFunc(func(e *core.RecordEvent) error {
|
||||
record := e.Record
|
||||
if record.GetString("name") == "welcometoarcoampm" {
|
||||
assert.Equal(t, "down", e.Record.GetString("status"), "System status should be 'pending'")
|
||||
wg.Done()
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
// Create a test system with the first user assigned
|
||||
record, err := createTestSystem(t, hub, map[string]any{
|
||||
"name": "welcometoarcoampm",
|
||||
"host": "localhost",
|
||||
"port": "33914",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// system should be down if grabbed from the store
|
||||
assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'")
|
||||
|
||||
// Check that the system count increased
|
||||
countAfter := sm.GetSystemCount()
|
||||
assert.Equal(t, countBefore+1, countAfter, "System count should increase after adding a system via event hook")
|
||||
|
||||
// Verify the system was added by checking if it exists
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
// Verify the system host and port
|
||||
host, port := sm.GetSystemHostPort(record.Id)
|
||||
assert.Equal(t, record.Get("host"), host, "System host should match")
|
||||
assert.Equal(t, record.Get("port"), port, "System port should match")
|
||||
|
||||
// Verify the system is in the list of all system IDs
|
||||
ids := sm.GetAllSystemIDs()
|
||||
assert.Contains(t, ids, record.Id, "System ID should be in the list of all system IDs")
|
||||
|
||||
// Verify the system was added by checking if removing it works
|
||||
err = sm.RemoveSystem(record.Id)
|
||||
assert.NoError(t, err, "System should exist and be removable")
|
||||
|
||||
// Verify the system no longer exists
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal")
|
||||
|
||||
// Verify the system is not in the list of all system IDs
|
||||
newIds := sm.GetAllSystemIDs()
|
||||
assert.NotContains(t, newIds, record.Id, "System ID should not be in the list of all system IDs after removal")
|
||||
|
||||
})
|
||||
|
||||
t.Run("RemoveSystem", func(t *testing.T) {
|
||||
// Get the count before adding the system
|
||||
countBefore := sm.GetSystemCount()
|
||||
|
||||
// Create a test system record
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "i-even-got-lost-at-coney-island",
|
||||
"host": "but-they-found-me",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the system count increased
|
||||
@@ -202,11 +201,16 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
|
||||
t.Run("NewRecordPending", func(t *testing.T) {
|
||||
// Create a test system
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "and-you-know",
|
||||
"host": "i-feel-very-bad",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add the record to the system manager
|
||||
err = sm.AddRecord(record)
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test filtering records by status - should be "pending" now
|
||||
@@ -218,11 +222,16 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
|
||||
t.Run("SystemStatusUpdate", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "we-used-to-sleep-on-the-beach",
|
||||
"host": "sleep-overnight-here",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add the record to the system manager
|
||||
err = sm.AddRecord(record)
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test status changes
|
||||
@@ -244,7 +253,12 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
|
||||
t.Run("HandleSystemData", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "things-changed-you-know",
|
||||
"host": "they-dont-sleep-anymore-on-the-beach",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test system data
|
||||
@@ -295,54 +309,14 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("DeleteRecord", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
runs := 0
|
||||
|
||||
hub.OnRecordUpdate("systems").BindFunc(func(e *core.RecordEvent) error {
|
||||
runs++
|
||||
record := e.Record
|
||||
if record.GetString("name") == "deadflagblues" {
|
||||
if runs == 1 {
|
||||
assert.Equal(t, "up", e.Record.GetString("status"), "System status should be 'up'")
|
||||
wg.Done()
|
||||
} else if runs == 2 {
|
||||
assert.Equal(t, "paused", e.Record.GetString("status"), "System status should be 'paused'")
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
return e.Next()
|
||||
})
|
||||
|
||||
// Create a test system record
|
||||
record, err := createTestSystem(t, hub, map[string]any{
|
||||
"name": "deadflagblues",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the system exists
|
||||
assert.True(t, sm.HasSystem(record.Id), "System should exist in the store")
|
||||
|
||||
// set the status manually to up
|
||||
sm.SetSystemStatusInDB(record.Id, "up")
|
||||
|
||||
// verify the status is up
|
||||
assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'")
|
||||
|
||||
// Set the status to "paused" which should cause it to be deleted from the store
|
||||
sm.SetSystemStatusInDB(record.Id, "paused")
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify the system no longer exists
|
||||
assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion")
|
||||
})
|
||||
|
||||
t.Run("ConcurrentOperations", func(t *testing.T) {
|
||||
// Create a test system
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "jfkjahkfajs",
|
||||
"host": "localhost",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run concurrent operations
|
||||
@@ -377,7 +351,12 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
|
||||
t.Run("ContextCancellation", func(t *testing.T) {
|
||||
// Create a test system record
|
||||
record, err := createTestSystem(t, hub, map[string]any{})
|
||||
record, err := tests.CreateRecord(hub, "systems", map[string]any{
|
||||
"name": "lkhsdfsjf",
|
||||
"host": "localhost",
|
||||
"port": "33914",
|
||||
"users": []string{user.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the system exists in the store
|
||||
@@ -420,7 +399,7 @@ func TestSystemManagerIntegration(t *testing.T) {
|
||||
assert.Error(t, err, "RemoveSystem should fail for non-existent system")
|
||||
|
||||
// Add the system back
|
||||
err = sm.AddRecord(record)
|
||||
err = sm.AddRecord(record, nil)
|
||||
require.NoError(t, err, "AddRecord should succeed")
|
||||
|
||||
// Verify the system is back in the store
|
||||
|
@@ -9,17 +9,17 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GetSystemCount returns the number of systems in the store
|
||||
// TESTING ONLY: GetSystemCount returns the number of systems in the store
|
||||
func (sm *SystemManager) GetSystemCount() int {
|
||||
return sm.systems.Length()
|
||||
}
|
||||
|
||||
// HasSystem checks if a system with the given ID exists in the store
|
||||
// TESTING ONLY: HasSystem checks if a system with the given ID exists in the store
|
||||
func (sm *SystemManager) HasSystem(systemID string) bool {
|
||||
return sm.systems.Has(systemID)
|
||||
}
|
||||
|
||||
// GetSystemStatusFromStore returns the status of a system with the given ID
|
||||
// TESTING ONLY: GetSystemStatusFromStore returns the status of a system with the given ID
|
||||
// Returns an empty string if the system doesn't exist
|
||||
func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
@@ -29,7 +29,7 @@ func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string {
|
||||
return sys.Status
|
||||
}
|
||||
|
||||
// GetSystemContextFromStore returns the context and cancel function for a system
|
||||
// TESTING ONLY: GetSystemContextFromStore returns the context and cancel function for a system
|
||||
func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Context, context.CancelFunc, error) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
@@ -38,7 +38,7 @@ func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Con
|
||||
return sys.ctx, sys.cancel, nil
|
||||
}
|
||||
|
||||
// GetSystemFromStore returns a store from the system
|
||||
// TESTING ONLY: GetSystemFromStore returns a store from the system
|
||||
func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
if !ok {
|
||||
@@ -47,7 +47,7 @@ func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) {
|
||||
return sys, nil
|
||||
}
|
||||
|
||||
// GetAllSystemIDs returns a slice of all system IDs in the store
|
||||
// TESTING ONLY: GetAllSystemIDs returns a slice of all system IDs in the store
|
||||
func (sm *SystemManager) GetAllSystemIDs() []string {
|
||||
data := sm.systems.GetAll()
|
||||
ids := make([]string, 0, len(data))
|
||||
@@ -57,7 +57,7 @@ func (sm *SystemManager) GetAllSystemIDs() []string {
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetSystemData returns the combined data for a system with the given ID
|
||||
// TESTING ONLY: GetSystemData returns the combined data for a system with the given ID
|
||||
// Returns nil if the system doesn't exist
|
||||
// This method is intended for testing
|
||||
func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData {
|
||||
@@ -68,7 +68,7 @@ func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData {
|
||||
return sys.data
|
||||
}
|
||||
|
||||
// GetSystemHostPort returns the host and port for a system with the given ID
|
||||
// TESTING ONLY: GetSystemHostPort returns the host and port for a system with the given ID
|
||||
// Returns empty strings if the system doesn't exist
|
||||
func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) {
|
||||
sys, ok := sm.systems.GetOk(systemID)
|
||||
@@ -78,22 +78,7 @@ func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) {
|
||||
return sys.Host, sys.Port
|
||||
}
|
||||
|
||||
// DisableAutoUpdater disables the automatic updater for a system
|
||||
// This is intended for testing
|
||||
// Returns false if the system doesn't exist
|
||||
// func (sm *SystemManager) DisableAutoUpdater(systemID string) bool {
|
||||
// sys, ok := sm.systems.GetOk(systemID)
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
// if sys.cancel != nil {
|
||||
// sys.cancel()
|
||||
// sys.cancel = nil
|
||||
// }
|
||||
// return true
|
||||
// }
|
||||
|
||||
// SetSystemStatusInDB sets the status of a system directly and updates the database record
|
||||
// TESTING ONLY: SetSystemStatusInDB sets the status of a system directly and updates the database record
|
||||
// This is intended for testing
|
||||
// Returns false if the system doesn't exist
|
||||
func (sm *SystemManager) SetSystemStatusInDB(systemID string, status string) bool {
|
||||
|
181
beszel/internal/hub/ws/ws.go
Normal file
181
beszel/internal/hub/ws/ws.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"beszel/internal/common"
|
||||
"beszel/internal/entities/system"
|
||||
"errors"
|
||||
"time"
|
||||
"weak"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/lxzan/gws"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
deadline = 70 * time.Second
|
||||
)
|
||||
|
||||
// Handler implements the WebSocket event handler for agent connections.
|
||||
type Handler struct {
|
||||
gws.BuiltinEventHandler
|
||||
}
|
||||
|
||||
// WsConn represents a WebSocket connection to an agent.
|
||||
type WsConn struct {
|
||||
conn *gws.Conn
|
||||
responseChan chan *gws.Message
|
||||
DownChan chan struct{}
|
||||
}
|
||||
|
||||
// FingerprintRecord is fingerprints collection record data in the hub
|
||||
type FingerprintRecord struct {
|
||||
Id string `db:"id"`
|
||||
SystemId string `db:"system"`
|
||||
Fingerprint string `db:"fingerprint"`
|
||||
Token string `db:"token"`
|
||||
}
|
||||
|
||||
var upgrader *gws.Upgrader
|
||||
|
||||
// GetUpgrader returns a singleton WebSocket upgrader instance.
|
||||
func GetUpgrader() *gws.Upgrader {
|
||||
if upgrader != nil {
|
||||
return upgrader
|
||||
}
|
||||
handler := &Handler{}
|
||||
upgrader = gws.NewUpgrader(handler, &gws.ServerOption{})
|
||||
return upgrader
|
||||
}
|
||||
|
||||
// NewWsConnection creates a new WebSocket connection wrapper.
|
||||
func NewWsConnection(conn *gws.Conn) *WsConn {
|
||||
return &WsConn{
|
||||
conn: conn,
|
||||
responseChan: make(chan *gws.Message, 1),
|
||||
DownChan: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// OnOpen sets a deadline for the WebSocket connection.
|
||||
func (h *Handler) OnOpen(conn *gws.Conn) {
|
||||
conn.SetDeadline(time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
// OnMessage routes incoming WebSocket messages to the response channel.
|
||||
func (h *Handler) OnMessage(conn *gws.Conn, message *gws.Message) {
|
||||
conn.SetDeadline(time.Now().Add(deadline))
|
||||
if message.Opcode != gws.OpcodeBinary || message.Data.Len() == 0 {
|
||||
return
|
||||
}
|
||||
wsConn, ok := conn.Session().Load("wsConn")
|
||||
if !ok {
|
||||
_ = conn.WriteClose(1000, nil)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case wsConn.(*WsConn).responseChan <- message:
|
||||
default:
|
||||
// close if the connection is not expecting a response
|
||||
wsConn.(*WsConn).Close()
|
||||
}
|
||||
}
|
||||
|
||||
// OnClose handles WebSocket connection closures and triggers system down status after delay.
|
||||
func (h *Handler) OnClose(conn *gws.Conn, err error) {
|
||||
wsConn, ok := conn.Session().Load("wsConn")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
wsConn.(*WsConn).conn = nil
|
||||
// wait 5 seconds to allow reconnection before setting system down
|
||||
// use a weak pointer to avoid keeping references if the system is removed
|
||||
go func(downChan weak.Pointer[chan struct{}]) {
|
||||
time.Sleep(5 * time.Second)
|
||||
downChanValue := downChan.Value()
|
||||
if downChanValue != nil {
|
||||
*downChanValue <- struct{}{}
|
||||
}
|
||||
}(weak.Make(&wsConn.(*WsConn).DownChan))
|
||||
}
|
||||
|
||||
// Close terminates the WebSocket connection gracefully.
|
||||
func (ws *WsConn) Close() {
|
||||
if ws.IsConnected() {
|
||||
ws.conn.WriteClose(1000, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Ping sends a ping frame to keep the connection alive.
|
||||
func (ws *WsConn) Ping() error {
|
||||
ws.conn.SetDeadline(time.Now().Add(deadline))
|
||||
return ws.conn.WritePing(nil)
|
||||
}
|
||||
|
||||
// sendMessage encodes data to CBOR and sends it as a binary message to the agent.
|
||||
func (ws *WsConn) sendMessage(data common.HubRequest[any]) error {
|
||||
bytes, err := cbor.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ws.conn.WriteMessage(gws.OpcodeBinary, bytes)
|
||||
}
|
||||
|
||||
// RequestSystemData requests system metrics from the agent and unmarshals the response.
|
||||
func (ws *WsConn) RequestSystemData(data *system.CombinedData) error {
|
||||
var message *gws.Message
|
||||
|
||||
ws.sendMessage(common.HubRequest[any]{
|
||||
Action: common.GetData,
|
||||
})
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
ws.Close()
|
||||
return gws.ErrConnClosed
|
||||
case message = <-ws.responseChan:
|
||||
}
|
||||
defer message.Close()
|
||||
return cbor.Unmarshal(message.Data.Bytes(), data)
|
||||
}
|
||||
|
||||
// GetFingerprint authenticates with the agent using SSH signature and returns the agent's fingerprint.
|
||||
func (ws *WsConn) GetFingerprint(token string, signer ssh.Signer, needSysInfo bool) (common.FingerprintResponse, error) {
|
||||
challenge := []byte(token)
|
||||
|
||||
signature, err := signer.Sign(nil, challenge)
|
||||
if err != nil {
|
||||
return common.FingerprintResponse{}, err
|
||||
}
|
||||
|
||||
err = ws.sendMessage(common.HubRequest[any]{
|
||||
Action: common.CheckFingerprint,
|
||||
Data: common.FingerprintRequest{
|
||||
Signature: signature.Blob,
|
||||
NeedSysInfo: needSysInfo,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return common.FingerprintResponse{}, err
|
||||
}
|
||||
|
||||
var message *gws.Message
|
||||
var clientFingerprint common.FingerprintResponse
|
||||
select {
|
||||
case message = <-ws.responseChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return common.FingerprintResponse{}, errors.New("request expired")
|
||||
}
|
||||
defer message.Close()
|
||||
|
||||
err = cbor.Unmarshal(message.Data.Bytes(), &clientFingerprint)
|
||||
if err != nil {
|
||||
return common.FingerprintResponse{}, err
|
||||
}
|
||||
|
||||
return clientFingerprint, nil
|
||||
}
|
||||
|
||||
// IsConnected returns true if the WebSocket connection is active.
|
||||
func (ws *WsConn) IsConnected() bool {
|
||||
return ws.conn != nil
|
||||
}
|
@@ -1,3 +1,6 @@
|
||||
//go:build testing
|
||||
// +build testing
|
||||
|
||||
// Package tests provides helpers for testing the application.
|
||||
package tests
|
||||
|
||||
@@ -56,3 +59,30 @@ func NewTestHubWithConfig(config core.BaseAppConfig) (*TestHub, error) {
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Helper function to create a test user for config tests
|
||||
func CreateUser(app core.App, email string, password string) (*core.Record, error) {
|
||||
userCollection, err := app.FindCachedCollectionByNameOrId("users")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := core.NewRecord(userCollection)
|
||||
user.Set("email", email)
|
||||
user.Set("password", password)
|
||||
|
||||
return user, app.Save(user)
|
||||
}
|
||||
|
||||
// Helper function to create a test record
|
||||
func CreateRecord(app core.App, collectionName string, fields map[string]any) (*core.Record, error) {
|
||||
collection, err := app.FindCachedCollectionByNameOrId(collectionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record := core.NewRecord(collection)
|
||||
record.Load(fields)
|
||||
|
||||
return record, app.Save(record)
|
||||
}
|
||||
|
@@ -1,23 +1,14 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
m "github.com/pocketbase/pocketbase/migrations"
|
||||
)
|
||||
|
||||
func init() {
|
||||
m.Register(func(app core.App) error {
|
||||
// delete duplicate alerts
|
||||
app.DB().NewQuery(`
|
||||
DELETE FROM alerts
|
||||
WHERE rowid NOT IN (
|
||||
SELECT MAX(rowid)
|
||||
FROM alerts
|
||||
GROUP BY user, system, name
|
||||
);
|
||||
`).Execute()
|
||||
|
||||
// import collections
|
||||
// update collections
|
||||
jsonData := `[
|
||||
{
|
||||
"id": "elngm8x1l60zi2v",
|
||||
@@ -236,6 +227,88 @@ func init() {
|
||||
],
|
||||
"system": false
|
||||
},
|
||||
{
|
||||
"id": "pbc_3663931638",
|
||||
"listRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id",
|
||||
"viewRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id",
|
||||
"createRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id && @request.auth.role != \"readonly\"",
|
||||
"updateRule": "@request.auth.id != \"\" && system.users.id ?= @request.auth.id && @request.auth.role != \"readonly\"",
|
||||
"deleteRule": null,
|
||||
"name": "fingerprints",
|
||||
"type": "base",
|
||||
"fields": [
|
||||
{
|
||||
"autogeneratePattern": "[a-z0-9]{9}",
|
||||
"hidden": false,
|
||||
"id": "text3208210256",
|
||||
"max": 15,
|
||||
"min": 9,
|
||||
"name": "id",
|
||||
"pattern": "^[a-z0-9]+$",
|
||||
"presentable": false,
|
||||
"primaryKey": true,
|
||||
"required": true,
|
||||
"system": true,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"cascadeDelete": true,
|
||||
"collectionId": "2hz5ncl8tizk5nx",
|
||||
"hidden": false,
|
||||
"id": "relation3377271179",
|
||||
"maxSelect": 1,
|
||||
"minSelect": 0,
|
||||
"name": "system",
|
||||
"presentable": false,
|
||||
"required": true,
|
||||
"system": false,
|
||||
"type": "relation"
|
||||
},
|
||||
{
|
||||
"autogeneratePattern": "[a-zA-Z9-9]{20}",
|
||||
"hidden": false,
|
||||
"id": "text1597481275",
|
||||
"max": 255,
|
||||
"min": 9,
|
||||
"name": "token",
|
||||
"pattern": "",
|
||||
"presentable": false,
|
||||
"primaryKey": false,
|
||||
"required": true,
|
||||
"system": false,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"autogeneratePattern": "",
|
||||
"hidden": false,
|
||||
"id": "text4228609354",
|
||||
"max": 255,
|
||||
"min": 9,
|
||||
"name": "fingerprint",
|
||||
"pattern": "",
|
||||
"presentable": false,
|
||||
"primaryKey": false,
|
||||
"required": false,
|
||||
"system": false,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"hidden": false,
|
||||
"id": "autodate3332085495",
|
||||
"name": "updated",
|
||||
"onCreate": true,
|
||||
"onUpdate": true,
|
||||
"presentable": false,
|
||||
"system": false,
|
||||
"type": "autodate"
|
||||
}
|
||||
],
|
||||
"indexes": [
|
||||
"CREATE INDEX ` + "`" + `idx_p9qZlu26po` + "`" + ` ON ` + "`" + `fingerprints` + "`" + ` (` + "`" + `token` + "`" + `)",
|
||||
"CREATE UNIQUE INDEX ` + "`" + `idx_ngboulGMYw` + "`" + ` ON ` + "`" + `fingerprints` + "`" + ` (` + "`" + `system` + "`" + `)"
|
||||
],
|
||||
"system": false
|
||||
},
|
||||
{
|
||||
"id": "ej9oowivz8b2mht",
|
||||
"listRule": "@request.auth.id != \"\"",
|
||||
@@ -669,7 +742,38 @@ func init() {
|
||||
}
|
||||
]`
|
||||
|
||||
return app.ImportCollectionsByMarshaledJSON([]byte(jsonData), false)
|
||||
err := app.ImportCollectionsByMarshaledJSON([]byte(jsonData), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get all systems that don't have fingerprint records
|
||||
var systemIds []string
|
||||
err = app.DB().NewQuery(`
|
||||
SELECT s.id FROM systems s
|
||||
LEFT JOIN fingerprints f ON s.id = f.system
|
||||
WHERE f.system IS NULL
|
||||
`).Column(&systemIds)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Create fingerprint records with unique UUID tokens for each system
|
||||
for _, systemId := range systemIds {
|
||||
token := uuid.New().String()
|
||||
_, err = app.DB().NewQuery(`
|
||||
INSERT INTO fingerprints (system, token)
|
||||
VALUES ({:system}, {:token})
|
||||
`).Bind(map[string]any{
|
||||
"system": systemId,
|
||||
"token": token,
|
||||
}).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func(app core.App) error {
|
||||
return nil
|
||||
})
|
@@ -9,7 +9,8 @@
|
||||
<script>
|
||||
globalThis.BESZEL = {
|
||||
BASE_PATH: "%BASE_URL%",
|
||||
HUB_VERSION: "{{V}}"
|
||||
HUB_VERSION: "{{V}}",
|
||||
HUB_URL: "{{HUB_URL}}"
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
|
@@ -11,20 +11,27 @@ import {
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog"
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"
|
||||
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import { $publicKey, pb } from "@/lib/stores"
|
||||
import { cn, copyToClipboard, isReadOnlyUser, useLocalStorage } from "@/lib/utils"
|
||||
import { i18n } from "@lingui/core"
|
||||
import { cn, generateToken, isReadOnlyUser, tokenMap, useLocalStorage } from "@/lib/utils"
|
||||
import { useStore } from "@nanostores/react"
|
||||
import { ChevronDownIcon, Copy, ExternalLinkIcon, PlusIcon } from "lucide-react"
|
||||
import { memo, useRef, useState } from "react"
|
||||
import { basePath, navigate } from "./router"
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "./ui/dropdown-menu"
|
||||
import { ChevronDownIcon, ExternalLinkIcon, PlusIcon } from "lucide-react"
|
||||
import { memo, useEffect, useRef, useState } from "react"
|
||||
import { $router, basePath, Link, navigate } from "./router"
|
||||
import { SystemRecord } from "@/types"
|
||||
import { AppleIcon, DockerIcon, TuxIcon, WindowsIcon } from "./ui/icons"
|
||||
import { InputCopy } from "./ui/input-copy"
|
||||
import { getPagePath } from "@nanostores/router"
|
||||
import {
|
||||
copyDockerCompose,
|
||||
copyDockerRun,
|
||||
copyLinuxCommand,
|
||||
copyWindowsCommand,
|
||||
DropdownItem,
|
||||
InstallDropdown,
|
||||
} from "./install-dropdowns"
|
||||
import { DropdownMenu, DropdownMenuTrigger } from "./ui/dropdown-menu"
|
||||
|
||||
export function AddSystemButton({ className }: { className?: string }) {
|
||||
const [open, setOpen] = useState(false)
|
||||
@@ -51,44 +58,11 @@ export function AddSystemButton({ className }: { className?: string }) {
|
||||
)
|
||||
}
|
||||
|
||||
function copyDockerCompose(port = "45876", publicKey: string) {
|
||||
copyToClipboard(`services:
|
||||
beszel-agent:
|
||||
image: "henrygd/beszel-agent"
|
||||
container_name: "beszel-agent"
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
# monitor other disks / partitions by mounting a folder in /extra-filesystems
|
||||
# - /mnt/disk/.beszel:/extra-filesystems/sda1:ro
|
||||
environment:
|
||||
LISTEN: ${port}
|
||||
KEY: "${publicKey}"`)
|
||||
}
|
||||
|
||||
function copyDockerRun(port = "45876", publicKey: string) {
|
||||
copyToClipboard(
|
||||
`docker run -d --name beszel-agent --network host --restart unless-stopped -v /var/run/docker.sock:/var/run/docker.sock:ro -e KEY="${publicKey}" -e LISTEN=${port} henrygd/beszel-agent:latest`
|
||||
)
|
||||
}
|
||||
|
||||
function copyLinuxCommand(port = "45876", publicKey: string, brew = false) {
|
||||
let cmd = `curl -sL https://get.beszel.dev${
|
||||
brew ? "/brew" : ""
|
||||
} -o /tmp/install-agent.sh && chmod +x /tmp/install-agent.sh && /tmp/install-agent.sh -p ${port} -k "${publicKey}"`
|
||||
// brew script does not support --china-mirrors
|
||||
if (!brew && (i18n.locale + navigator.language).includes("zh-CN")) {
|
||||
cmd += ` --china-mirrors`
|
||||
}
|
||||
copyToClipboard(cmd)
|
||||
}
|
||||
|
||||
function copyWindowsCommand(port = "45876", publicKey: string) {
|
||||
copyToClipboard(
|
||||
`& iwr -useb https://get.beszel.dev -OutFile "$env:TEMP\\install-agent.ps1"; & Powershell -ExecutionPolicy Bypass -File "$env:TEMP\\install-agent.ps1" -Key "${publicKey}" -Port ${port}`
|
||||
)
|
||||
}
|
||||
/**
|
||||
* Token to be used for the next system.
|
||||
* Prevents token changing if user copies config, then closes dialog and opens again.
|
||||
*/
|
||||
let nextSystemToken: string | null = null
|
||||
|
||||
/**
|
||||
* SystemDialog component for adding or editing a system.
|
||||
@@ -96,12 +70,32 @@ function copyWindowsCommand(port = "45876", publicKey: string) {
|
||||
* @param {function} props.setOpen - Function to set the open state of the dialog.
|
||||
* @param {SystemRecord} [props.system] - Optional system record for editing an existing system.
|
||||
*/
|
||||
export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean) => void; system?: SystemRecord }) => {
|
||||
export const SystemDialog = ({ setOpen, system }: { setOpen: (open: boolean) => void; system?: SystemRecord }) => {
|
||||
const publicKey = useStore($publicKey)
|
||||
const port = useRef<HTMLInputElement>(null)
|
||||
const [hostValue, setHostValue] = useState(system?.host ?? "")
|
||||
const isUnixSocket = hostValue.startsWith("/")
|
||||
const [tab, setTab] = useLocalStorage("as-tab", "docker")
|
||||
const [token, setToken] = useState(system?.token ?? "")
|
||||
|
||||
useEffect(() => {
|
||||
;(async () => {
|
||||
// if no system, generate a new token
|
||||
if (!system) {
|
||||
nextSystemToken ||= generateToken()
|
||||
return setToken(nextSystemToken)
|
||||
}
|
||||
// if system exists,get the token from the fingerprint record
|
||||
if (tokenMap.has(system.id)) {
|
||||
return setToken(tokenMap.get(system.id)!)
|
||||
}
|
||||
const { token } = await pb.collection("fingerprints").getFirstListItem(`system = "${system.id}"`, {
|
||||
fields: "token",
|
||||
})
|
||||
tokenMap.set(system.id, token)
|
||||
setToken(token)
|
||||
})()
|
||||
}, [system?.id])
|
||||
|
||||
async function handleSubmit(e: SubmitEvent) {
|
||||
e.preventDefault()
|
||||
@@ -113,12 +107,18 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean
|
||||
if (system) {
|
||||
await pb.collection("systems").update(system.id, { ...data, status: "pending" })
|
||||
} else {
|
||||
await pb.collection("systems").create(data)
|
||||
const createdSystem = await pb.collection("systems").create(data)
|
||||
await pb.collection("fingerprints").create({
|
||||
system: createdSystem.id,
|
||||
token,
|
||||
})
|
||||
// Reset the current token after successful system
|
||||
// creation so next system gets a new token
|
||||
nextSystemToken = null
|
||||
}
|
||||
navigate(basePath)
|
||||
// console.log(record)
|
||||
} catch (e) {
|
||||
console.log(e)
|
||||
console.error(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,18 +143,37 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean
|
||||
</DialogHeader>
|
||||
{/* Docker (set tab index to prevent auto focusing content in edit system dialog) */}
|
||||
<TabsContent value="docker" tabIndex={-1}>
|
||||
<DialogDescription className="mb-4 leading-normal w-0 min-w-full">
|
||||
<DialogDescription className="mb-3 leading-relaxed w-0 min-w-full">
|
||||
<Trans>
|
||||
The agent must be running on the system to connect. Copy the
|
||||
<code className="bg-muted px-1 rounded-sm leading-3">docker-compose.yml</code> for the agent below.
|
||||
Copy the
|
||||
<code className="bg-muted px-1 rounded-sm leading-3">docker-compose.yml</code> content for the agent
|
||||
below, or register agents automatically with a{" "}
|
||||
<Link
|
||||
onClick={() => setOpen(false)}
|
||||
href={getPagePath($router, "settings", { name: "tokens" })}
|
||||
className="link"
|
||||
>
|
||||
universal token
|
||||
</Link>
|
||||
.
|
||||
</Trans>
|
||||
</DialogDescription>
|
||||
</TabsContent>
|
||||
{/* Binary */}
|
||||
<TabsContent value="binary" tabIndex={-1}>
|
||||
<DialogDescription className="mb-4 leading-normal w-0 min-w-full">
|
||||
<DialogDescription className="mb-3 leading-relaxed w-0 min-w-full">
|
||||
<Trans>
|
||||
The agent must be running on the system to connect. Copy the installation command for the agent below.
|
||||
Copy the installation command for the agent below, or register agents automatically with a{" "}
|
||||
<Link
|
||||
onClick={() => {
|
||||
setOpen(false)
|
||||
}}
|
||||
href={getPagePath($router, "settings", { name: "tokens" })}
|
||||
className="link"
|
||||
>
|
||||
universal token
|
||||
</Link>
|
||||
.
|
||||
</Trans>
|
||||
</DialogDescription>
|
||||
</TabsContent>
|
||||
@@ -190,46 +209,27 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean
|
||||
<Label htmlFor="pkey" className="xs:text-end whitespace-pre">
|
||||
<Trans comment="Use 'Key' if your language requires many more characters">Public Key</Trans>
|
||||
</Label>
|
||||
<div className="relative">
|
||||
<Input readOnly id="pkey" value={publicKey} required></Input>
|
||||
<div
|
||||
className={
|
||||
"h-6 w-24 bg-gradient-to-r rtl:bg-gradient-to-l from-transparent to-background to-65% absolute top-2 end-1 pointer-events-none"
|
||||
}
|
||||
></div>
|
||||
<TooltipProvider delayDuration={100}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant={"link"}
|
||||
className="absolute end-0 top-0"
|
||||
onClick={() => copyToClipboard(publicKey)}
|
||||
>
|
||||
<Copy className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>
|
||||
<Trans>Click to copy</Trans>
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<InputCopy value={publicKey} id="pkey" name="pkey" />
|
||||
<Label htmlFor="tkn" className="xs:text-end whitespace-pre">
|
||||
<Trans>Token</Trans>
|
||||
</Label>
|
||||
<InputCopy value={token} id="tkn" name="tkn" />
|
||||
</div>
|
||||
<DialogFooter className="flex justify-end gap-x-2 gap-y-3 flex-col mt-5">
|
||||
{/* Docker */}
|
||||
<TabsContent value="docker" className="contents">
|
||||
<CopyButton
|
||||
text={t({ message: "Copy docker compose", context: "Button to copy docker compose file content" })}
|
||||
onClick={() => copyDockerCompose(isUnixSocket ? hostValue : port.current?.value, publicKey)}
|
||||
onClick={async () =>
|
||||
copyDockerCompose(isUnixSocket ? hostValue : port.current?.value, publicKey, token)
|
||||
}
|
||||
icon={<DockerIcon className="size-4 -me-0.5" />}
|
||||
dropdownItems={[
|
||||
{
|
||||
text: t({ message: "Copy docker run", context: "Button to copy docker run command" }),
|
||||
onClick: () => copyDockerRun(isUnixSocket ? hostValue : port.current?.value, publicKey),
|
||||
icons: [<DockerIcon className="size-4" />],
|
||||
onClick: async () =>
|
||||
copyDockerRun(isUnixSocket ? hostValue : port.current?.value, publicKey, token),
|
||||
icons: [DockerIcon],
|
||||
},
|
||||
]}
|
||||
/>
|
||||
@@ -239,22 +239,24 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean
|
||||
<CopyButton
|
||||
text={t`Copy Linux command`}
|
||||
icon={<TuxIcon className="size-4" />}
|
||||
onClick={() => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey)}
|
||||
onClick={async () => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token)}
|
||||
dropdownItems={[
|
||||
{
|
||||
text: t({ message: "Homebrew command", context: "Button to copy install command" }),
|
||||
onClick: () => copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, true),
|
||||
icons: [<AppleIcon className="size-4" />, <TuxIcon className="w-4 h-4" />],
|
||||
onClick: async () =>
|
||||
copyLinuxCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token, true),
|
||||
icons: [AppleIcon, TuxIcon],
|
||||
},
|
||||
{
|
||||
text: t({ message: "Windows command", context: "Button to copy install command" }),
|
||||
onClick: () => copyWindowsCommand(isUnixSocket ? hostValue : port.current?.value, publicKey),
|
||||
icons: [<WindowsIcon className="size-4" />],
|
||||
onClick: async () =>
|
||||
copyWindowsCommand(isUnixSocket ? hostValue : port.current?.value, publicKey, token),
|
||||
icons: [WindowsIcon],
|
||||
},
|
||||
{
|
||||
text: t`Manual setup instructions`,
|
||||
url: "https://beszel.dev/guide/agent-installation#binary",
|
||||
icons: [<ExternalLinkIcon className="size-4" />],
|
||||
icons: [ExternalLinkIcon],
|
||||
},
|
||||
]}
|
||||
/>
|
||||
@@ -266,20 +268,13 @@ export const SystemDialog = memo(({ setOpen, system }: { setOpen: (open: boolean
|
||||
</Tabs>
|
||||
</DialogContent>
|
||||
)
|
||||
})
|
||||
|
||||
interface DropdownItem {
|
||||
text: string
|
||||
onClick?: () => void
|
||||
url?: string
|
||||
icons?: React.ReactNode[]
|
||||
}
|
||||
|
||||
interface CopyButtonProps {
|
||||
text: string
|
||||
onClick: () => void
|
||||
dropdownItems: DropdownItem[]
|
||||
icon?: React.ReactNode
|
||||
icon?: React.ReactElement
|
||||
}
|
||||
|
||||
const CopyButton = memo((props: CopyButtonProps) => {
|
||||
@@ -300,22 +295,7 @@ const CopyButton = memo((props: CopyButtonProps) => {
|
||||
<ChevronDownIcon />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
{props.dropdownItems.map((item, index) => {
|
||||
const className = "cursor-pointer flex items-center gap-1.5"
|
||||
return item.url ? (
|
||||
<DropdownMenuItem key={index} asChild>
|
||||
<a href={item.url} className={className} target="_blank" rel="noopener noreferrer">
|
||||
{item.text} {item.icons?.map((icon) => icon)}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
) : (
|
||||
<DropdownMenuItem key={index} onClick={item.onClick} className={className}>
|
||||
{item.text} {item.icons?.map((icon) => icon)}
|
||||
</DropdownMenuItem>
|
||||
)
|
||||
})}
|
||||
</DropdownMenuContent>
|
||||
<InstallDropdown items={props.dropdownItems} />
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import {
|
||||
BookIcon,
|
||||
DatabaseBackupIcon,
|
||||
FingerprintIcon,
|
||||
LayoutDashboard,
|
||||
LogsIcon,
|
||||
MailIcon,
|
||||
@@ -40,6 +41,16 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
|
||||
return useMemo(() => {
|
||||
const systems = $systems.get()
|
||||
const SettingsShortcut = (
|
||||
<CommandShortcut>
|
||||
<Trans>Settings</Trans>
|
||||
</CommandShortcut>
|
||||
)
|
||||
const AdminShortcut = (
|
||||
<CommandShortcut>
|
||||
<Trans>Admin</Trans>
|
||||
</CommandShortcut>
|
||||
)
|
||||
return (
|
||||
<CommandDialog open={open} onOpenChange={setOpen}>
|
||||
<CommandInput placeholder={t`Search for systems or settings...`} />
|
||||
@@ -93,9 +104,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>Settings</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Settings</Trans>
|
||||
</CommandShortcut>
|
||||
{SettingsShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
keywords={["alerts"]}
|
||||
@@ -108,9 +117,19 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>Notifications</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Settings</Trans>
|
||||
</CommandShortcut>
|
||||
{SettingsShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
onSelect={() => {
|
||||
navigate(getPagePath($router, "settings", { name: "tokens" }))
|
||||
setOpen(false)
|
||||
}}
|
||||
>
|
||||
<FingerprintIcon className="me-2 h-4 w-4" />
|
||||
<span>
|
||||
<Trans>Tokens & Fingerprints</Trans>
|
||||
</span>
|
||||
{SettingsShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
keywords={["help", "oauth", "oidc"]}
|
||||
@@ -140,9 +159,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>Users</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Admin</Trans>
|
||||
</CommandShortcut>
|
||||
{AdminShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
onSelect={() => {
|
||||
@@ -154,9 +171,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>Logs</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Admin</Trans>
|
||||
</CommandShortcut>
|
||||
{AdminShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
onSelect={() => {
|
||||
@@ -168,9 +183,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>Backups</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Admin</Trans>
|
||||
</CommandShortcut>
|
||||
{AdminShortcut}
|
||||
</CommandItem>
|
||||
<CommandItem
|
||||
keywords={["email"]}
|
||||
@@ -183,9 +196,7 @@ export default memo(function CommandPalette({ open, setOpen }: { open: boolean;
|
||||
<span>
|
||||
<Trans>SMTP settings</Trans>
|
||||
</span>
|
||||
<CommandShortcut>
|
||||
<Trans>Admin</Trans>
|
||||
</CommandShortcut>
|
||||
{AdminShortcut}
|
||||
</CommandItem>
|
||||
</CommandGroup>
|
||||
</>
|
||||
|
97
beszel/site/src/components/install-dropdowns.tsx
Normal file
97
beszel/site/src/components/install-dropdowns.tsx
Normal file
@@ -0,0 +1,97 @@
|
||||
import { memo } from "react"
|
||||
import { DropdownMenuContent, DropdownMenuItem } from "./ui/dropdown-menu"
|
||||
import { copyToClipboard, getHubURL } from "@/lib/utils"
|
||||
import { i18n } from "@lingui/core"
|
||||
|
||||
const isBeta = BESZEL.HUB_VERSION.includes("beta")
|
||||
const imageTag = isBeta ? ":beta" : ""
|
||||
|
||||
/**
|
||||
* Get the URL of the script to install the agent.
|
||||
* @param path - The path to the script (e.g. "/brew").
|
||||
* @returns The URL for the script.
|
||||
*/
|
||||
const getScriptUrl = (path: string = "") => {
|
||||
const url = new URL("https://get.beszel.dev")
|
||||
url.pathname = path
|
||||
if (isBeta) {
|
||||
url.searchParams.set("beta", "1")
|
||||
}
|
||||
return url.toString()
|
||||
}
|
||||
|
||||
export function copyDockerCompose(port = "45876", publicKey: string, token: string) {
|
||||
copyToClipboard(`services:
|
||||
beszel-agent:
|
||||
image: henrygd/beszel-agent${imageTag}
|
||||
container_name: beszel-agent
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./beszel_agent_data:/var/lib/beszel-agent
|
||||
# monitor other disks / partitions by mounting a folder in /extra-filesystems
|
||||
# - /mnt/disk/.beszel:/extra-filesystems/sda1:ro
|
||||
environment:
|
||||
LISTEN: ${port}
|
||||
KEY: '${publicKey}'
|
||||
TOKEN: ${token}
|
||||
HUB_URL: ${getHubURL()}`)
|
||||
}
|
||||
|
||||
export function copyDockerRun(port = "45876", publicKey: string, token: string) {
|
||||
copyToClipboard(
|
||||
`docker run -d --name beszel-agent --network host --restart unless-stopped -v /var/run/docker.sock:/var/run/docker.sock:ro -v ./beszel_agent_data:/var/lib/beszel-agent -e KEY="${publicKey}" -e LISTEN=${port} -e TOKEN="${token}" -e HUB_URL="${getHubURL()}" henrygd/beszel-agent${imageTag}`
|
||||
)
|
||||
}
|
||||
|
||||
export function copyLinuxCommand(port = "45876", publicKey: string, token: string, brew = false) {
|
||||
let cmd = `curl -sL ${getScriptUrl(
|
||||
brew ? "/brew" : ""
|
||||
)} -o /tmp/install-agent.sh && chmod +x /tmp/install-agent.sh && /tmp/install-agent.sh -p ${port} -k "${publicKey}" -t "${token}" -url "${getHubURL()}"`
|
||||
// brew script does not support --china-mirrors
|
||||
if (!brew && (i18n.locale + navigator.language).includes("zh-CN")) {
|
||||
cmd += ` --china-mirrors`
|
||||
}
|
||||
copyToClipboard(cmd)
|
||||
}
|
||||
|
||||
export function copyWindowsCommand(port = "45876", publicKey: string, token: string) {
|
||||
copyToClipboard(
|
||||
`& iwr -useb ${getScriptUrl()} -OutFile "$env:TEMP\\install-agent.ps1"; & Powershell -ExecutionPolicy Bypass -File "$env:TEMP\\install-agent.ps1" -Key "${publicKey}" -Port ${port} -Token "${token}" -Url "${getHubURL()}"`
|
||||
)
|
||||
}
|
||||
|
||||
export interface DropdownItem {
|
||||
text: string
|
||||
onClick?: () => void
|
||||
url?: string
|
||||
icons?: React.ComponentType<React.SVGProps<SVGSVGElement>>[]
|
||||
}
|
||||
|
||||
export const InstallDropdown = memo(({ items }: { items: DropdownItem[] }) => {
|
||||
return (
|
||||
<DropdownMenuContent align="end">
|
||||
{items.map((item, index) => {
|
||||
const className = "cursor-pointer flex items-center gap-1.5"
|
||||
return item.url ? (
|
||||
<DropdownMenuItem key={index} asChild>
|
||||
<a href={item.url} className={className} target="_blank" rel="noopener noreferrer">
|
||||
{item.text}{" "}
|
||||
{item.icons?.map((Icon, iconIndex) => (
|
||||
<Icon key={iconIndex} className="size-4" />
|
||||
))}
|
||||
</a>
|
||||
</DropdownMenuItem>
|
||||
) : (
|
||||
<DropdownMenuItem key={index} onClick={item.onClick} className={className}>
|
||||
{item.text}{" "}
|
||||
{item.icons?.map((Icon, iconIndex) => (
|
||||
<Icon key={iconIndex} className="size-4" />
|
||||
))}
|
||||
</DropdownMenuItem>
|
||||
)
|
||||
})}
|
||||
</DropdownMenuContent>
|
||||
)
|
||||
})
|
@@ -7,7 +7,7 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com
|
||||
import { useStore } from "@nanostores/react"
|
||||
import { $router } from "@/components/router.tsx"
|
||||
import { getPagePath, redirectPage } from "@nanostores/router"
|
||||
import { BellIcon, FileSlidersIcon, SettingsIcon } from "lucide-react"
|
||||
import { BellIcon, FileSlidersIcon, FingerprintIcon, SettingsIcon } from "lucide-react"
|
||||
import { $userSettings, pb } from "@/lib/stores.ts"
|
||||
import { toast } from "@/components/ui/use-toast.ts"
|
||||
import { UserSettings } from "@/types.js"
|
||||
@@ -15,6 +15,7 @@ import General from "./general.tsx"
|
||||
import Notifications from "./notifications.tsx"
|
||||
import ConfigYaml from "./config-yaml.tsx"
|
||||
import { useLingui } from "@lingui/react/macro"
|
||||
import Fingerprints from "./tokens-fingerprints.tsx"
|
||||
|
||||
export async function saveSettings(newSettings: Partial<UserSettings>) {
|
||||
try {
|
||||
@@ -58,6 +59,12 @@ export default function SettingsLayout() {
|
||||
href: getPagePath($router, "settings", { name: "notifications" }),
|
||||
icon: BellIcon,
|
||||
},
|
||||
{
|
||||
title: t`Tokens & Fingerprints`,
|
||||
href: getPagePath($router, "settings", { name: "tokens" }),
|
||||
icon: FingerprintIcon,
|
||||
// admin: true,
|
||||
},
|
||||
{
|
||||
title: t`YAML Config`,
|
||||
href: getPagePath($router, "settings", { name: "config" }),
|
||||
@@ -77,7 +84,7 @@ export default function SettingsLayout() {
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<Card className="pt-5 px-4 pb-8 sm:pt-6 sm:px-7">
|
||||
<Card className="pt-5 px-4 pb-8 min-h-96 sm:pt-6 sm:px-7">
|
||||
<CardHeader className="p-0">
|
||||
<CardTitle className="mb-1">
|
||||
<Trans>Settings</Trans>
|
||||
@@ -89,10 +96,10 @@ export default function SettingsLayout() {
|
||||
<CardContent className="p-0">
|
||||
<Separator className="hidden md:block my-5" />
|
||||
<div className="flex flex-col gap-3.5 md:flex-row md:gap-5 lg:gap-10">
|
||||
<aside className="md:w-48 w-full">
|
||||
<aside className="md:max-w-44 min-w-40">
|
||||
<SidebarNav items={sidebarNavItems} />
|
||||
</aside>
|
||||
<div className="flex-1">
|
||||
<div className="flex-1 min-w-0">
|
||||
{/* @ts-ignore */}
|
||||
<SettingsContent name={page?.params?.name ?? "general"} />
|
||||
</div>
|
||||
@@ -112,5 +119,7 @@ function SettingsContent({ name }: { name: string }) {
|
||||
return <Notifications userSettings={userSettings} />
|
||||
case "config":
|
||||
return <ConfigYaml />
|
||||
case "tokens":
|
||||
return <Fingerprints />
|
||||
}
|
||||
}
|
||||
|
@@ -31,9 +31,9 @@ export function SidebarNav({ className, items, ...props }: SidebarNavProps) {
|
||||
if (item.admin && !isAdmin()) return null
|
||||
return (
|
||||
<SelectItem key={item.href} value={item.href}>
|
||||
<span className="flex items-center gap-2">
|
||||
<span className="flex items-center gap-2 truncate">
|
||||
{item.icon && <item.icon className="h-4 w-4" />}
|
||||
{item.title}
|
||||
<span className="truncate">{item.title}</span>
|
||||
</span>
|
||||
</SelectItem>
|
||||
)
|
||||
@@ -55,13 +55,12 @@ export function SidebarNav({ className, items, ...props }: SidebarNavProps) {
|
||||
href={item.href}
|
||||
className={cn(
|
||||
buttonVariants({ variant: "ghost" }),
|
||||
"flex items-center gap-3",
|
||||
page?.path === item.href ? "bg-muted hover:bg-muted" : "hover:bg-muted/50",
|
||||
"justify-start"
|
||||
"flex items-center gap-3 justify-start truncate",
|
||||
page?.path === item.href ? "bg-muted hover:bg-muted" : "hover:bg-muted/50"
|
||||
)}
|
||||
>
|
||||
{item.icon && <item.icon className="h-4 w-4" />}
|
||||
{item.title}
|
||||
{item.icon && <item.icon className="h-4 w-4 shrink-0" />}
|
||||
<span className="truncate">{item.title}</span>
|
||||
</Link>
|
||||
)
|
||||
})}
|
||||
|
@@ -0,0 +1,352 @@
|
||||
import { Trans } from "@lingui/react/macro"
|
||||
import { t } from "@lingui/core/macro"
|
||||
import { $publicKey, pb } from "@/lib/stores"
|
||||
import { memo, useEffect, useMemo, useState } from "react"
|
||||
import { Table, TableCell, TableHead, TableBody, TableRow, TableHeader } from "@/components/ui/table"
|
||||
import { FingerprintRecord } from "@/types"
|
||||
import {
|
||||
CopyIcon,
|
||||
FingerprintIcon,
|
||||
KeyIcon,
|
||||
MoreHorizontalIcon,
|
||||
RotateCwIcon,
|
||||
ServerIcon,
|
||||
Trash2Icon,
|
||||
} from "lucide-react"
|
||||
import { toast } from "@/components/ui/use-toast"
|
||||
import { cn, copyToClipboard, generateToken, getHubURL, isReadOnlyUser, tokenMap } from "@/lib/utils"
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Separator } from "@/components/ui/separator"
|
||||
import { Switch } from "@/components/ui/switch"
|
||||
import {
|
||||
copyDockerCompose,
|
||||
copyDockerRun,
|
||||
copyLinuxCommand,
|
||||
copyWindowsCommand,
|
||||
DropdownItem,
|
||||
InstallDropdown,
|
||||
} from "@/components/install-dropdowns"
|
||||
import { AppleIcon, DockerIcon, TuxIcon, WindowsIcon } from "@/components/ui/icons"
|
||||
|
||||
const pbFingerprintOptions = {
|
||||
expand: "system",
|
||||
fields: "id,fingerprint,token,system,expand.system.name",
|
||||
}
|
||||
|
||||
const SettingsFingerprintsPage = memo(() => {
|
||||
const [fingerprints, setFingerprints] = useState<FingerprintRecord[]>([])
|
||||
|
||||
// Get fingerprint records on mount
|
||||
useEffect(() => {
|
||||
pb.collection("fingerprints")
|
||||
.getFullList(pbFingerprintOptions)
|
||||
// @ts-ignore
|
||||
.then(setFingerprints)
|
||||
}, [])
|
||||
|
||||
// Subscribe to fingerprint updates
|
||||
useEffect(() => {
|
||||
let unsubscribe: (() => void) | undefined
|
||||
;(async () => {
|
||||
// subscribe to fingerprint updates
|
||||
unsubscribe = await pb.collection("fingerprints").subscribe(
|
||||
"*",
|
||||
(res) => {
|
||||
setFingerprints((currentFingerprints) => {
|
||||
if (res.action === "create") {
|
||||
return [...currentFingerprints, res.record as FingerprintRecord]
|
||||
}
|
||||
if (res.action === "update") {
|
||||
return currentFingerprints.map((fingerprint) => {
|
||||
if (fingerprint.id === res.record.id) {
|
||||
return { ...fingerprint, ...res.record } as FingerprintRecord
|
||||
}
|
||||
return fingerprint
|
||||
})
|
||||
}
|
||||
if (res.action === "delete") {
|
||||
return currentFingerprints.filter((fingerprint) => fingerprint.id !== res.record.id)
|
||||
}
|
||||
return currentFingerprints
|
||||
})
|
||||
},
|
||||
pbFingerprintOptions
|
||||
)
|
||||
})()
|
||||
// unsubscribe on unmount
|
||||
return () => unsubscribe?.()
|
||||
}, [])
|
||||
|
||||
// Update token map whenever fingerprints change
|
||||
useEffect(() => {
|
||||
for (const fingerprint of fingerprints) {
|
||||
tokenMap.set(fingerprint.system, fingerprint.token)
|
||||
}
|
||||
}, [fingerprints])
|
||||
|
||||
return (
|
||||
<>
|
||||
<SectionIntro />
|
||||
<Separator className="my-4" />
|
||||
<SectionUniversalToken />
|
||||
<Separator className="my-4" />
|
||||
<SectionTable fingerprints={fingerprints} />
|
||||
</>
|
||||
)
|
||||
})
|
||||
|
||||
const SectionIntro = memo(() => {
|
||||
return (
|
||||
<div>
|
||||
<h3 className="text-xl font-medium mb-2">
|
||||
<Trans>Tokens & Fingerprints</Trans>
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground leading-relaxed">
|
||||
<Trans>Tokens and fingerprints are used to authenticate WebSocket connections to the hub.</Trans>
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground leading-relaxed mt-1.5">
|
||||
<Trans>
|
||||
Tokens allow agents to connect and register. Fingerprints are stable identifiers unique to each system, set on
|
||||
first connection.
|
||||
</Trans>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const SectionUniversalToken = memo(() => {
|
||||
const [token, setToken] = useState("")
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [checked, setChecked] = useState(false)
|
||||
|
||||
async function updateToken(enable: number = -1) {
|
||||
// enable: 0 for disable, 1 for enable, -1 (unset) for get current state
|
||||
const data = await pb.send(`/api/beszel/universal-token`, {
|
||||
query: {
|
||||
token,
|
||||
enable,
|
||||
},
|
||||
})
|
||||
setToken(data.token)
|
||||
setChecked(data.active)
|
||||
setIsLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
updateToken()
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h3 className="text-lg font-medium mb-2">
|
||||
<Trans>Universal token</Trans>
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground leading-relaxed">
|
||||
<Trans>
|
||||
When enabled, this token allows agents to self-register without prior system creation. Expires after one hour
|
||||
or on hub restart.
|
||||
</Trans>
|
||||
</p>
|
||||
<div className="min-h-16 overflow-auto max-w-full inline-flex items-center gap-5 mt-3 border py-2 pl-5 pr-4 rounded-md">
|
||||
{!isLoading && (
|
||||
<>
|
||||
<Switch
|
||||
defaultChecked={checked}
|
||||
onCheckedChange={(checked) => {
|
||||
updateToken(checked ? 1 : 0)
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
className={cn(
|
||||
"text-sm text-primary opacity-60 transition-opacity",
|
||||
checked ? "opacity-100" : "select-none"
|
||||
)}
|
||||
>
|
||||
{token}
|
||||
</span>
|
||||
<ActionsButtonUniversalToken token={token} checked={checked} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const ActionsButtonUniversalToken = memo(({ token, checked }: { token: string; checked: boolean }) => {
|
||||
const publicKey = $publicKey.get()
|
||||
const port = "45876"
|
||||
|
||||
const dropdownItems: DropdownItem[] = [
|
||||
{
|
||||
text: "Copy Docker Compose",
|
||||
onClick: () => copyDockerCompose(port, publicKey, token),
|
||||
icons: [DockerIcon],
|
||||
},
|
||||
{
|
||||
text: "Copy Docker Run",
|
||||
onClick: () => copyDockerRun(port, publicKey, token),
|
||||
icons: [DockerIcon],
|
||||
},
|
||||
{
|
||||
text: "Copy Linux Command",
|
||||
onClick: () => copyLinuxCommand(port, publicKey, token),
|
||||
icons: [TuxIcon],
|
||||
},
|
||||
{
|
||||
text: "Copy Brew Command",
|
||||
onClick: () => copyLinuxCommand(port, publicKey, token, true),
|
||||
icons: [TuxIcon, AppleIcon],
|
||||
},
|
||||
{
|
||||
text: "Copy Windows Command",
|
||||
onClick: () => copyWindowsCommand(port, publicKey, token),
|
||||
icons: [WindowsIcon],
|
||||
},
|
||||
]
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
disabled={!checked}
|
||||
className={cn("transition-opacity", !checked && "opacity-50")}
|
||||
>
|
||||
<span className="sr-only">
|
||||
<Trans>Open menu</Trans>
|
||||
</span>
|
||||
<MoreHorizontalIcon className="w-5" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<InstallDropdown items={dropdownItems} />
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const SectionTable = memo(({ fingerprints = [] }: { fingerprints: FingerprintRecord[] }) => {
|
||||
const isReadOnly = isReadOnlyUser()
|
||||
const headerCols = useMemo(
|
||||
() => [
|
||||
{
|
||||
label: "System",
|
||||
Icon: ServerIcon,
|
||||
w: "11em",
|
||||
},
|
||||
{
|
||||
label: "Token",
|
||||
Icon: KeyIcon,
|
||||
w: "20em",
|
||||
},
|
||||
{
|
||||
label: "Fingerprint",
|
||||
Icon: FingerprintIcon,
|
||||
w: "20em",
|
||||
},
|
||||
],
|
||||
[]
|
||||
)
|
||||
return (
|
||||
<div className="rounded-md border overflow-hidden w-full mt-4">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
{headerCols.map((col) => (
|
||||
<TableHead key={col.label} style={{ minWidth: col.w }}>
|
||||
<span className="flex items-center gap-2">
|
||||
<col.Icon className="size-4" />
|
||||
{col.label}
|
||||
</span>
|
||||
</TableHead>
|
||||
))}
|
||||
{!isReadOnly && (
|
||||
<TableHead className="w-0">
|
||||
<span className="sr-only">
|
||||
<Trans>Actions</Trans>
|
||||
</span>
|
||||
</TableHead>
|
||||
)}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody className="whitespace-pre">
|
||||
{fingerprints.map((fingerprint, i) => (
|
||||
<TableRow key={i}>
|
||||
<TableCell className="font-medium ps-5 py-2.5">{fingerprint.expand.system.name}</TableCell>
|
||||
<TableCell className="font-mono text-[0.95em] py-2.5">{fingerprint.token}</TableCell>
|
||||
<TableCell className="font-mono text-[0.95em] py-2.5">{fingerprint.fingerprint}</TableCell>
|
||||
{!isReadOnly && (
|
||||
<TableCell className="py-2.5 px-4 xl:px-2">
|
||||
<ActionsButtonTable fingerprint={fingerprint} />
|
||||
</TableCell>
|
||||
)}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
async function updateFingerprint(fingerprint: FingerprintRecord, rotateToken = false) {
|
||||
try {
|
||||
await pb.collection("fingerprints").update(fingerprint.id, {
|
||||
fingerprint: "",
|
||||
token: rotateToken ? generateToken() : fingerprint.token,
|
||||
})
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
title: t`Error`,
|
||||
description: error.message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const ActionsButtonTable = memo(({ fingerprint }: { fingerprint: FingerprintRecord }) => {
|
||||
const envVar = `HUB_URL=${getHubURL()}\nTOKEN=${fingerprint.token}`
|
||||
const copyEnv = () => copyToClipboard(envVar)
|
||||
const copyYaml = () => copyToClipboard(envVar.replaceAll("=", ": "))
|
||||
|
||||
return (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size={"icon"} data-nolink>
|
||||
<span className="sr-only">
|
||||
<Trans>Open menu</Trans>
|
||||
</span>
|
||||
<MoreHorizontalIcon className="w-5" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem onClick={copyYaml}>
|
||||
<CopyIcon className="me-2.5 size-4" />
|
||||
<Trans>Copy YAML</Trans>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={copyEnv}>
|
||||
<CopyIcon className="me-2.5 size-4" />
|
||||
<Trans context="Environment variables">Copy env</Trans>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem onSelect={() => updateFingerprint(fingerprint, true)}>
|
||||
<RotateCwIcon className="me-2.5 size-4" />
|
||||
<Trans>Rotate token</Trans>
|
||||
</DropdownMenuItem>
|
||||
{fingerprint.fingerprint && (
|
||||
<DropdownMenuItem onSelect={() => updateFingerprint(fingerprint)}>
|
||||
<Trash2Icon className="me-2.5 size-4" />
|
||||
<Trans>Delete fingerprint</Trans>
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
})
|
||||
|
||||
export default SettingsFingerprintsPage
|
38
beszel/site/src/components/ui/input-copy.tsx
Normal file
38
beszel/site/src/components/ui/input-copy.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { copyToClipboard } from "@/lib/utils"
|
||||
import { Input } from "./input"
|
||||
import { Trans } from "@lingui/react/macro"
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "./tooltip"
|
||||
import { CopyIcon } from "lucide-react"
|
||||
import { Button } from "./button"
|
||||
|
||||
export function InputCopy({ value, id, name }: { value: string; id: string; name: string }) {
|
||||
return (
|
||||
<div className="relative">
|
||||
<Input readOnly id={id} name={name} value={value} required></Input>
|
||||
<div
|
||||
className={
|
||||
"h-6 w-24 bg-gradient-to-r rtl:bg-gradient-to-l from-transparent to-background to-65% absolute top-2 end-1 pointer-events-none"
|
||||
}
|
||||
></div>
|
||||
<TooltipProvider delayDuration={100} disableHoverableContent>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant={"link"}
|
||||
className="absolute end-0 top-0"
|
||||
onClick={() => copyToClipboard(value)}
|
||||
>
|
||||
<CopyIcon className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p>
|
||||
<Trans>Click to copy</Trans>
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
)
|
||||
}
|
@@ -1,9 +1,9 @@
|
||||
import { t } from "@lingui/core/macro";
|
||||
import { t } from "@lingui/core/macro"
|
||||
import { toast } from "@/components/ui/use-toast"
|
||||
import { type ClassValue, clsx } from "clsx"
|
||||
import { twMerge } from "tailwind-merge"
|
||||
import { $alerts, $copyContent, $systems, $userSettings, pb } from "./stores"
|
||||
import { AlertInfo, AlertRecord, ChartTimeData, ChartTimes, SystemRecord } from "@/types"
|
||||
import { AlertInfo, AlertRecord, ChartTimeData, ChartTimes, FingerprintRecord, SystemRecord } from "@/types"
|
||||
import { RecordModel, RecordSubscription } from "pocketbase"
|
||||
import { WritableAtom } from "nanostores"
|
||||
import { timeDay, timeHour } from "d3-time"
|
||||
@@ -17,13 +17,9 @@ export function cn(...inputs: ClassValue[]) {
|
||||
}
|
||||
|
||||
/** Adds event listener to node and returns function that removes the listener */
|
||||
export function listen<T extends Event = Event>(
|
||||
node: Node,
|
||||
event: string,
|
||||
handler: (event: T) => void
|
||||
) {
|
||||
node.addEventListener(event, handler as EventListener)
|
||||
return () => node.removeEventListener(event, handler as EventListener)
|
||||
export function listen<T extends Event = Event>(node: Node, event: string, handler: (event: T) => void) {
|
||||
node.addEventListener(event, handler as EventListener)
|
||||
return () => node.removeEventListener(event, handler as EventListener)
|
||||
}
|
||||
|
||||
export async function copyToClipboard(content: string) {
|
||||
@@ -355,3 +351,12 @@ export const alertInfo: Record<string, AlertInfo> = {
|
||||
* const hostname = getHostDisplayValue(system) // hostname will be "beszel.sock"
|
||||
*/
|
||||
export const getHostDisplayValue = (system: SystemRecord): string => system.host.slice(system.host.lastIndexOf("/") + 1)
|
||||
|
||||
/** Generate a random token for the agent */
|
||||
export const generateToken = () => crypto?.randomUUID() ?? (performance.now() * Math.random()).toString(16)
|
||||
|
||||
/** Get the hub URL from the global BESZEL object */
|
||||
export const getHubURL = () => BESZEL?.HUB_URL || window.location.origin
|
||||
|
||||
/** Map of system IDs to their corresponding tokens (used to avoid fetching in add-system dialog) */
|
||||
export const tokenMap = new Map<SystemRecord["id"], FingerprintRecord["token"]>()
|
||||
|
13
beszel/site/src/types.d.ts
vendored
13
beszel/site/src/types.d.ts
vendored
@@ -6,6 +6,19 @@ declare global {
|
||||
var BESZEL: {
|
||||
BASE_PATH: string
|
||||
HUB_VERSION: string
|
||||
HUB_URL: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface FingerprintRecord extends RecordModel {
|
||||
id: string
|
||||
system: string
|
||||
fingerprint: string
|
||||
token: string
|
||||
expand: {
|
||||
system: {
|
||||
name: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -15,7 +15,7 @@ export default defineConfig({
|
||||
name: "replace version in index.html during dev",
|
||||
apply: "serve",
|
||||
transformIndexHtml(html) {
|
||||
return html.replace("{{V}}", version)
|
||||
return html.replace("{{V}}", version).replace("{{HUB_URL}}", "")
|
||||
},
|
||||
},
|
||||
],
|
||||
|
@@ -1,6 +1,10 @@
|
||||
package beszel
|
||||
|
||||
import "github.com/blang/semver"
|
||||
|
||||
const (
|
||||
Version = "0.11.1"
|
||||
Version = "0.12.0-beta1"
|
||||
AppName = "beszel"
|
||||
)
|
||||
|
||||
var MinVersionCbor = semver.MustParse("0.12.0-beta1")
|
||||
|
Reference in New Issue
Block a user