diff --git a/beszel/internal/entities/system/system.go b/beszel/internal/entities/system/system.go index 2d71cf6..316377d 100644 --- a/beszel/internal/entities/system/system.go +++ b/beszel/internal/entities/system/system.go @@ -1,5 +1,7 @@ package system +// TODO: this is confusing, make common package with common/types common/helpers etc + import ( "beszel/internal/entities/container" "time" diff --git a/beszel/internal/hub/config.go b/beszel/internal/hub/config.go index c59ff5f..90904a8 100644 --- a/beszel/internal/hub/config.go +++ b/beszel/internal/hub/config.go @@ -22,7 +22,7 @@ type Config struct { type SystemConfig struct { Name string `yaml:"name"` Host string `yaml:"host"` - Port uint16 `yaml:"port"` + Port uint16 `yaml:"port,omitempty"` Users []string `yaml:"users"` } diff --git a/beszel/internal/hub/systems/systems.go b/beszel/internal/hub/systems/systems.go new file mode 100644 index 0000000..879dab1 --- /dev/null +++ b/beszel/internal/hub/systems/systems.go @@ -0,0 +1,434 @@ +package systems + +import ( + "beszel/internal/entities/system" + "context" + "fmt" + "net" + "strings" + "time" + + "github.com/goccy/go-json" + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tools/store" + "golang.org/x/crypto/ssh" +) + +const ( + up string = "up" + down string = "down" + paused string = "paused" + pending string = "pending" + + interval int = 60_000 + + sessionTimeout = 4 * time.Second +) + +type SystemManager struct { + hub hubLike + systems *store.Store[string, *System] + sshConfig *ssh.ClientConfig +} + +type System struct { + Id string `db:"id"` + Host string `db:"host"` + Port string `db:"port"` + Status string `db:"status"` + manager *SystemManager + client *ssh.Client + data *system.CombinedData + ctx context.Context + cancel context.CancelFunc +} + +type hubLike interface { + core.App + GetSSHKey() ([]byte, error) + HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error + HandleStatusAlerts(status string, systemRecord *core.Record) error +} + +func NewSystemManager(hub hubLike) *SystemManager { + return &SystemManager{ + systems: store.New(map[string]*System{}), + hub: hub, + } +} + +// Initialize initializes the system manager. +// It binds the event hooks and starts updating existing systems. +func (sm *SystemManager) Initialize() error { + sm.bindEventHooks() + // ssh setup + key, err := sm.hub.GetSSHKey() + if err != nil { + return err + } + if err := sm.createSSHClientConfig(key); err != nil { + return err + } + // start updating existing systems + var systems []*System + err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems) + if err != nil || len(systems) == 0 { + return err + } + go func() { + // time between initial system updates + delta := interval / max(1, len(systems)) + delta = min(delta, 2_000) + sleepTime := time.Duration(delta) * time.Millisecond + for _, system := range systems { + time.Sleep(sleepTime) + _ = sm.AddSystem(system) + } + }() + return nil +} + +func (sm *SystemManager) bindEventHooks() { + sm.hub.OnRecordCreate("systems").BindFunc(sm.onRecordCreate) + sm.hub.OnRecordAfterCreateSuccess("systems").BindFunc(sm.onRecordAfterCreateSuccess) + sm.hub.OnRecordUpdate("systems").BindFunc(sm.onRecordUpdate) + sm.hub.OnRecordAfterUpdateSuccess("systems").BindFunc(sm.onRecordAfterUpdateSuccess) + sm.hub.OnRecordAfterDeleteSuccess("systems").BindFunc(sm.onRecordAfterDeleteSuccess) +} + +// Runs before the record is committed to the database +func (sm *SystemManager) onRecordCreate(e *core.RecordEvent) error { + e.Record.Set("info", system.Info{}) + e.Record.Set("status", pending) + return e.Next() +} + +// Runs after the record is committed to the database +func (sm *SystemManager) onRecordAfterCreateSuccess(e *core.RecordEvent) error { + if err := sm.AddRecord(e.Record); err != nil { + sm.hub.Logger().Error("Error adding record", "err", err) + } + return e.Next() +} + +// Runs before the record is updated +func (sm *SystemManager) onRecordUpdate(e *core.RecordEvent) error { + if e.Record.GetString("status") == paused { + e.Record.Set("info", system.Info{}) + } + return e.Next() +} + +// Runs after the record is updated +func (sm *SystemManager) onRecordAfterUpdateSuccess(e *core.RecordEvent) error { + newStatus := e.Record.GetString("status") + switch newStatus { + case paused: + sm.RemoveSystem(e.Record.Id) + return e.Next() + case pending: + if err := sm.AddRecord(e.Record); err != nil { + sm.hub.Logger().Error("Error adding record", "err", err) + } + return e.Next() + } + system, ok := sm.systems.GetOk(e.Record.Id) + if !ok { + return sm.AddRecord(e.Record) + } + prevStatus := system.Status + system.Status = newStatus + // system alerts if system is up + if system.Status == up { + if err := sm.hub.HandleSystemAlerts(e.Record, system.data); err != nil { + sm.hub.Logger().Error("Error handling system alerts", "err", err) + } + } + if (system.Status == down && prevStatus == up) || (system.Status == up && prevStatus == down) { + if err := sm.hub.HandleStatusAlerts(system.Status, e.Record); err != nil { + sm.hub.Logger().Error("Error handling status alerts", "err", err) + } + } + return e.Next() +} + +// Runs after the record is deleted +func (sm *SystemManager) onRecordAfterDeleteSuccess(e *core.RecordEvent) error { + sm.RemoveSystem(e.Record.Id) + return e.Next() +} + +// AddSystem adds a system to the manager +func (sm *SystemManager) AddSystem(sys *System) error { + if sm.systems.Has(sys.Id) { + return fmt.Errorf("system exists") + } + if sys.Id == "" || sys.Host == "" { + return fmt.Errorf("system is missing required fields") + } + sys.manager = sm + sys.ctx, sys.cancel = context.WithCancel(context.Background()) + sys.data = &system.CombinedData{} + sm.systems.Set(sys.Id, sys) + go sys.StartUpdater() + return nil +} + +// RemoveSystem removes a system from the manager +func (sm *SystemManager) RemoveSystem(systemID string) error { + system, ok := sm.systems.GetOk(systemID) + if !ok { + return fmt.Errorf("system not found") + } + // cancel the context to signal stop + if system.cancel != nil { + system.cancel() + } + system.resetSSHClient() + sm.systems.Remove(systemID) + return nil +} + +// AddRecord adds a record to the system manager. +// It first removes any existing system with the same ID, then creates a new System +// instance from the record data and adds it to the manager. +// This function is typically called when a new system is created or when an existing +// system's status changes to pending. +func (sm *SystemManager) AddRecord(record *core.Record) (err error) { + _ = sm.RemoveSystem(record.Id) + system := &System{ + Id: record.Id, + Status: record.GetString("status"), + Host: record.GetString("host"), + Port: record.GetString("port"), + } + return sm.AddSystem(system) +} + +// StartUpdater starts the system updater. +// It first fetches the data from the agent then updates the records. +// If the data is not found or the system is down, it sets the system down. +func (sys *System) StartUpdater() { + if sys.data == nil { + sys.data = &system.CombinedData{} + } + if err := sys.update(); err != nil { + _ = sys.setDown(err) + } + + c := time.Tick(time.Duration(interval) * time.Millisecond) + + for { + select { + case <-sys.ctx.Done(): + return + case <-c: + err := sys.update() + if err != nil { + _ = sys.setDown(err) + } + } + } +} + +// update updates the system data and records. +// It first fetches the data from the agent then updates the records. +func (sys *System) update() error { + _, err := sys.fetchDataFromAgent() + if err == nil { + _, err = sys.createRecords() + } + return err +} + +// createRecords updates the system record and adds system_stats and container_stats records +func (sys *System) createRecords() (*core.Record, error) { + systemRecord, err := sys.getRecord() + if err != nil { + return nil, err + } + hub := sys.manager.hub + systemRecord.Set("status", up) + systemRecord.Set("info", sys.data.Info) + if err := hub.SaveNoValidate(systemRecord); err != nil { + return nil, err + } + // add system_stats and container_stats records + systemStats, err := hub.FindCachedCollectionByNameOrId("system_stats") + if err != nil { + return nil, err + } + systemStatsRecord := core.NewRecord(systemStats) + systemStatsRecord.Set("system", systemRecord.Id) + systemStatsRecord.Set("stats", sys.data.Stats) + systemStatsRecord.Set("type", "1m") + if err := hub.SaveNoValidate(systemStatsRecord); err != nil { + return nil, err + } + // add new container_stats record + if len(sys.data.Containers) > 0 { + containerStats, err := hub.FindCachedCollectionByNameOrId("container_stats") + if err != nil { + return nil, err + } + containerStatsRecord := core.NewRecord(containerStats) + containerStatsRecord.Set("system", systemRecord.Id) + containerStatsRecord.Set("stats", sys.data.Containers) + containerStatsRecord.Set("type", "1m") + if err := hub.SaveNoValidate(containerStatsRecord); err != nil { + return nil, err + } + } + return systemRecord, nil +} + +// getRecord retrieves the system record from the database. +// If the record is not found or the system is paused, it removes the system from the manager. +func (sys *System) getRecord() (*core.Record, error) { + record, err := sys.manager.hub.FindRecordById("systems", sys.Id) + if err != nil || record == nil { + _ = sys.manager.RemoveSystem(sys.Id) + return nil, err + } + return record, nil +} + +// setDown marks a system as down in the database. +// It takes the original error that caused the system to go down and returns any error +// encountered during the process of updating the system status. +func (sys *System) setDown(OriginalError error) error { + if sys.Status == down { + return nil + } + record, err := sys.getRecord() + if err != nil { + return err + } + sys.manager.hub.Logger().Error("System down", "system", record.GetString("name"), "err", OriginalError) + record.Set("status", down) + err = sys.manager.hub.SaveNoValidate(record) + if err != nil { + return err + } + return nil +} + +// fetchDataFromAgent fetches the data from the agent. +// It first creates a new SSH client if it doesn't exist or the system is down. +// Then it creates a new SSH session and fetches the data from the agent. +// If the data is not found or the system is down, it sets the system down. +func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) { + maxRetries := 1 + for attempt := 0; attempt <= maxRetries; attempt++ { + if sys.client == nil || sys.Status == down { + if err := sys.createSSHClient(); err != nil { + return nil, err + } + } + + session, err := sys.createSessionWithTimeout(4 * time.Second) + if err != nil { + if attempt >= maxRetries { + return nil, err + } + sys.manager.hub.Logger().Warn("Session closed. Retrying...", "host", sys.Host, "port", sys.Port, "err", err) + sys.resetSSHClient() + continue + } + defer session.Close() + + stdout, err := session.StdoutPipe() + if err != nil { + return nil, err + } + if err := session.Shell(); err != nil { + return nil, err + } + + // this is initialized in startUpdater, should never be nil + *sys.data = system.CombinedData{} + if err := json.NewDecoder(stdout).Decode(sys.data); err != nil { + return nil, err + } + // wait for the session to complete + if err := session.Wait(); err != nil { + return nil, err + } + return sys.data, nil + } + + // this should never be reached due to the return in the loop + return nil, fmt.Errorf("failed to fetch data") +} + +func (sm *SystemManager) createSSHClientConfig(key []byte) error { + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return err + } + sm.sshConfig = &ssh.ClientConfig{ + User: "u", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: sessionTimeout, + } + return nil +} + +// createSSHClient creates a new SSH client for the system +func (s *System) createSSHClient() error { + network := "tcp" + host := s.Host + if strings.HasPrefix(host, "/") { + network = "unix" + } else { + host = net.JoinHostPort(host, s.Port) + } + var err error + s.client, err = ssh.Dial(network, host, s.manager.sshConfig) + if err != nil { + return err + } + return nil +} + +// createSessionWithTimeout creates a new SSH session with a timeout to avoid hanging +// in case of network issues +func (sys *System) createSessionWithTimeout(timeout time.Duration) (*ssh.Session, error) { + if sys.client == nil { + return nil, fmt.Errorf("client not initialized") + } + + ctx, cancel := context.WithTimeout(sys.ctx, timeout) + defer cancel() + + sessionChan := make(chan *ssh.Session, 1) + errChan := make(chan error, 1) + + go func() { + if session, err := sys.client.NewSession(); err != nil { + errChan <- err + } else { + sessionChan <- session + } + }() + + select { + case session := <-sessionChan: + return session, nil + case err := <-errChan: + return nil, err + case <-ctx.Done(): + return nil, fmt.Errorf("timeout") + } +} + +// resetSSHClient closes the SSH connection and resets the client to nil +func (sys *System) resetSSHClient() { + if sys.client != nil { + sys.client.Close() + } + sys.client = nil +} diff --git a/beszel/internal/hub/systems/systems_test.go b/beszel/internal/hub/systems/systems_test.go new file mode 100644 index 0000000..aa4d803 --- /dev/null +++ b/beszel/internal/hub/systems/systems_test.go @@ -0,0 +1,440 @@ +//go:build testing +// +build testing + +package systems_test + +import ( + "beszel/internal/entities/container" + "beszel/internal/entities/system" + "beszel/internal/hub/systems" + "beszel/internal/tests" + "fmt" + "sync" + "testing" + "time" + + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createTestSystem creates a test system record with a unique host name +// and returns the created record and any error +func createTestSystem(t *testing.T, hub *tests.TestHub, options map[string]any) (*core.Record, error) { + collection, err := hub.FindCachedCollectionByNameOrId("systems") + if err != nil { + return nil, err + } + + // get user record + var firstUser *core.Record + users, err := hub.FindAllRecords("users", dbx.NewExp("id != ''")) + if err != nil { + t.Fatal(err) + } + if len(users) > 0 { + firstUser = users[0] + } + // Generate a unique host name to ensure we're adding a new system + uniqueHost := fmt.Sprintf("test-host-%d.example.com", time.Now().UnixNano()) + + // Create the record + record := core.NewRecord(collection) + record.Set("name", uniqueHost) + record.Set("host", uniqueHost) + record.Set("port", "45876") + record.Set("status", "pending") + record.Set("users", []string{firstUser.Id}) + + // Apply any custom options + for key, value := range options { + record.Set(key, value) + } + + // Save the record to the database + err = hub.Save(record) + if err != nil { + return nil, err + } + + return record, nil +} + +func TestSystemManagerIntegration(t *testing.T) { + // Create a test hub + hub, err := tests.NewTestHub() + if err != nil { + t.Fatal(err) + } + defer hub.Cleanup() + + // Create independent system manager + sm := systems.NewSystemManager(hub) + assert.NotNil(t, sm) + + // Test initialization + sm.Initialize() + + // Test collection existence. todo: move to hub package tests + t.Run("CollectionExistence", func(t *testing.T) { + // Verify that required collections exist + systems, err := hub.FindCachedCollectionByNameOrId("systems") + require.NoError(t, err) + assert.NotNil(t, systems) + + systemStats, err := hub.FindCachedCollectionByNameOrId("system_stats") + require.NoError(t, err) + assert.NotNil(t, systemStats) + + containerStats, err := hub.FindCachedCollectionByNameOrId("container_stats") + require.NoError(t, err) + assert.NotNil(t, containerStats) + }) + + // Test adding a system record + t.Run("AddRecord", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + // Get the count before adding the system + countBefore := sm.GetSystemCount() + + // record should be pending on create + hub.OnRecordCreate("systems").BindFunc(func(e *core.RecordEvent) error { + record := e.Record + if record.GetString("name") == "welcometoarcoampm" { + assert.Equal(t, "pending", e.Record.GetString("status"), "System status should be 'pending'") + wg.Done() + } + return e.Next() + }) + + // record should be down on update + hub.OnRecordAfterUpdateSuccess("systems").BindFunc(func(e *core.RecordEvent) error { + record := e.Record + if record.GetString("name") == "welcometoarcoampm" { + assert.Equal(t, "down", e.Record.GetString("status"), "System status should be 'pending'") + wg.Done() + } + return e.Next() + }) + // Create a test system with the first user assigned + record, err := createTestSystem(t, hub, map[string]any{ + "name": "welcometoarcoampm", + "host": "localhost", + "port": "33914", + }) + require.NoError(t, err) + + wg.Wait() + + // system should be down if grabbed from the store + assert.Equal(t, "down", sm.GetSystemStatusFromStore(record.Id), "System status should be 'down'") + + // Check that the system count increased + countAfter := sm.GetSystemCount() + assert.Equal(t, countBefore+1, countAfter, "System count should increase after adding a system via event hook") + + // Verify the system was added by checking if it exists + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") + + // Verify the system host and port + host, port := sm.GetSystemHostPort(record.Id) + assert.Equal(t, record.Get("host"), host, "System host should match") + assert.Equal(t, record.Get("port"), port, "System port should match") + + // Verify the system is in the list of all system IDs + ids := sm.GetAllSystemIDs() + assert.Contains(t, ids, record.Id, "System ID should be in the list of all system IDs") + + // Verify the system was added by checking if removing it works + err = sm.RemoveSystem(record.Id) + assert.NoError(t, err, "System should exist and be removable") + + // Verify the system no longer exists + assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal") + + // Verify the system is not in the list of all system IDs + newIds := sm.GetAllSystemIDs() + assert.NotContains(t, newIds, record.Id, "System ID should not be in the list of all system IDs after removal") + + }) + + t.Run("RemoveSystem", func(t *testing.T) { + // Get the count before adding the system + countBefore := sm.GetSystemCount() + + // Create a test system record + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Verify the system count increased + countAfterAdd := sm.GetSystemCount() + assert.Equal(t, countBefore+1, countAfterAdd, "System count should increase after adding a system via event hook") + + // Verify the system exists + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") + + // Remove the system + err = sm.RemoveSystem(record.Id) + assert.NoError(t, err) + + // Check that the system count decreased + countAfterRemove := sm.GetSystemCount() + assert.Equal(t, countAfterAdd-1, countAfterRemove, "System count should decrease after removing a system") + + // Verify the system no longer exists + assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after removal") + + // Verify the system is not in the list of all system IDs + ids := sm.GetAllSystemIDs() + assert.NotContains(t, ids, record.Id, "System ID should not be in the list of all system IDs after removal") + + // Verify the system status is empty + status := sm.GetSystemStatusFromStore(record.Id) + assert.Equal(t, "", status, "System status should be empty after removal") + + // Try to remove it again - should return an error since it's already removed + err = sm.RemoveSystem(record.Id) + assert.Error(t, err) + }) + + t.Run("NewRecordPending", func(t *testing.T) { + // Create a test system + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Add the record to the system manager + err = sm.AddRecord(record) + require.NoError(t, err) + + // Test filtering records by status - should be "pending" now + filter := "status = 'pending'" + pendingSystems, err := hub.FindRecordsByFilter("systems", filter, "-created", 0, 0, nil) + require.NoError(t, err) + assert.GreaterOrEqual(t, len(pendingSystems), 1) + }) + + t.Run("SystemStatusUpdate", func(t *testing.T) { + // Create a test system record + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Add the record to the system manager + err = sm.AddRecord(record) + require.NoError(t, err) + + // Test status changes + initialStatus := sm.GetSystemStatusFromStore(record.Id) + + // Set a new status + sm.SetSystemStatusInDB(record.Id, "up") + + // Verify status was updated + newStatus := sm.GetSystemStatusFromStore(record.Id) + assert.Equal(t, "up", newStatus, "System status should be updated to 'up'") + assert.NotEqual(t, initialStatus, newStatus, "Status should have changed") + + // Verify the database was updated + updatedRecord, err := hub.FindRecordById("systems", record.Id) + require.NoError(t, err) + assert.Equal(t, "up", updatedRecord.Get("status"), "Database status should match") + }) + + t.Run("HandleSystemData", func(t *testing.T) { + // Create a test system record + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Create test system data + testData := &system.CombinedData{ + Info: system.Info{ + Hostname: "data-test.example.com", + KernelVersion: "5.15.0-generic", + Cores: 4, + Threads: 8, + CpuModel: "Test CPU", + Uptime: 3600, + Cpu: 25.5, + MemPct: 40.2, + DiskPct: 60.0, + Bandwidth: 100.0, + AgentVersion: "1.0.0", + }, + Stats: system.Stats{ + Cpu: 25.5, + Mem: 16384.0, + MemUsed: 6553.6, + MemPct: 40.0, + DiskTotal: 1024000.0, + DiskUsed: 614400.0, + DiskPct: 60.0, + NetworkSent: 1024.0, + NetworkRecv: 2048.0, + }, + Containers: []*container.Stats{}, + } + + // Test handling system data. todo: move to hub/alerts package tests + err = hub.HandleSystemAlerts(record, testData) + assert.NoError(t, err) + }) + + t.Run("ErrorHandling", func(t *testing.T) { + // Try to add a non-existent record + nonExistentId := "non_existent_id" + err := sm.RemoveSystem(nonExistentId) + assert.Error(t, err) + + // Try to add a system with invalid host + system := &systems.System{ + Host: "", + } + err = sm.AddSystem(system) + assert.Error(t, err) + }) + + t.Run("DeleteRecord", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + runs := 0 + + hub.OnRecordUpdate("systems").BindFunc(func(e *core.RecordEvent) error { + runs++ + record := e.Record + if record.GetString("name") == "deadflagblues" { + if runs == 1 { + assert.Equal(t, "up", e.Record.GetString("status"), "System status should be 'up'") + wg.Done() + } else if runs == 2 { + assert.Equal(t, "paused", e.Record.GetString("status"), "System status should be 'paused'") + wg.Done() + } + } + return e.Next() + }) + + // Create a test system record + record, err := createTestSystem(t, hub, map[string]any{ + "name": "deadflagblues", + }) + require.NoError(t, err) + + // Verify the system exists + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") + + // set the status manually to up + sm.SetSystemStatusInDB(record.Id, "up") + + // verify the status is up + assert.Equal(t, "up", sm.GetSystemStatusFromStore(record.Id), "System status should be 'up'") + + // Set the status to "paused" which should cause it to be deleted from the store + sm.SetSystemStatusInDB(record.Id, "paused") + + wg.Wait() + + // Verify the system no longer exists + assert.False(t, sm.HasSystem(record.Id), "System should not exist in the store after deletion") + }) + + t.Run("ConcurrentOperations", func(t *testing.T) { + // Create a test system + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Run concurrent operations + const goroutines = 5 + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := range goroutines { + go func(i int) { + defer wg.Done() + + // Alternate between different operations + switch i % 3 { + case 0: + status := fmt.Sprintf("status-%d", i) + sm.SetSystemStatusInDB(record.Id, status) + case 1: + _ = sm.GetSystemStatusFromStore(record.Id) + case 2: + _, _ = sm.GetSystemHostPort(record.Id) + } + }(i) + } + + wg.Wait() + + // Verify system still exists and is in a valid state + assert.True(t, sm.HasSystem(record.Id), "System should still exist after concurrent operations") + status := sm.GetSystemStatusFromStore(record.Id) + assert.NotEmpty(t, status, "System should have a status after concurrent operations") + }) + + t.Run("ContextCancellation", func(t *testing.T) { + // Create a test system record + record, err := createTestSystem(t, hub, map[string]any{}) + require.NoError(t, err) + + // Verify the system exists in the store + assert.True(t, sm.HasSystem(record.Id), "System should exist in the store") + + // Store the original context and cancel function + originalCtx, originalCancel, err := sm.GetSystemContextFromStore(record.Id) + assert.NoError(t, err) + + // Ensure the context is not nil + assert.NotNil(t, originalCtx, "System context should not be nil") + assert.NotNil(t, originalCancel, "System cancel function should not be nil") + + // Cancel the context + originalCancel() + + // Wait a short time for cancellation to propagate + time.Sleep(10 * time.Millisecond) + + // Verify the context is done + select { + case <-originalCtx.Done(): + // Context was properly cancelled + default: + t.Fatal("Context was not cancelled") + } + + // Verify the system is still in the store (cancellation shouldn't remove it) + assert.True(t, sm.HasSystem(record.Id), "System should still exist after context cancellation") + + // Explicitly remove the system + err = sm.RemoveSystem(record.Id) + assert.NoError(t, err, "RemoveSystem should succeed") + + // Verify the system is removed + assert.False(t, sm.HasSystem(record.Id), "System should be removed after RemoveSystem") + + // Try to remove it again - should return an error + err = sm.RemoveSystem(record.Id) + assert.Error(t, err, "RemoveSystem should fail for non-existent system") + + // Add the system back + err = sm.AddRecord(record) + require.NoError(t, err, "AddRecord should succeed") + + // Verify the system is back in the store + assert.True(t, sm.HasSystem(record.Id), "System should exist after re-adding") + + // Verify a new context was created + newCtx, newCancel, err := sm.GetSystemContextFromStore(record.Id) + assert.NoError(t, err) + assert.NotNil(t, newCtx, "New system context should not be nil") + assert.NotNil(t, newCancel, "New system cancel function should not be nil") + assert.NotEqual(t, originalCtx, newCtx, "New context should be different from original") + + // Clean up + err = sm.RemoveSystem(record.Id) + assert.NoError(t, err) + }) +} diff --git a/beszel/internal/hub/systems/systems_test_helpers.go b/beszel/internal/hub/systems/systems_test_helpers.go new file mode 100644 index 0000000..9822343 --- /dev/null +++ b/beszel/internal/hub/systems/systems_test_helpers.go @@ -0,0 +1,117 @@ +//go:build testing +// +build testing + +package systems + +import ( + entities "beszel/internal/entities/system" + "context" + "fmt" +) + +// GetSystemCount returns the number of systems in the store +func (sm *SystemManager) GetSystemCount() int { + return sm.systems.Length() +} + +// HasSystem checks if a system with the given ID exists in the store +func (sm *SystemManager) HasSystem(systemID string) bool { + return sm.systems.Has(systemID) +} + +// GetSystemStatusFromStore returns the status of a system with the given ID +// Returns an empty string if the system doesn't exist +func (sm *SystemManager) GetSystemStatusFromStore(systemID string) string { + sys, ok := sm.systems.GetOk(systemID) + if !ok { + return "" + } + return sys.Status +} + +// GetSystemContextFromStore returns the context and cancel function for a system +func (sm *SystemManager) GetSystemContextFromStore(systemID string) (context.Context, context.CancelFunc, error) { + sys, ok := sm.systems.GetOk(systemID) + if !ok { + return nil, nil, fmt.Errorf("no system") + } + return sys.ctx, sys.cancel, nil +} + +// GetSystemFromStore returns a store from the system +func (sm *SystemManager) GetSystemFromStore(systemID string) (*System, error) { + sys, ok := sm.systems.GetOk(systemID) + if !ok { + return nil, fmt.Errorf("no system") + } + return sys, nil +} + +// GetAllSystemIDs returns a slice of all system IDs in the store +func (sm *SystemManager) GetAllSystemIDs() []string { + data := sm.systems.GetAll() + ids := make([]string, 0, len(data)) + for id := range data { + ids = append(ids, id) + } + return ids +} + +// GetSystemData returns the combined data for a system with the given ID +// Returns nil if the system doesn't exist +// This method is intended for testing +func (sm *SystemManager) GetSystemData(systemID string) *entities.CombinedData { + sys, ok := sm.systems.GetOk(systemID) + if !ok { + return nil + } + return sys.data +} + +// GetSystemHostPort returns the host and port for a system with the given ID +// Returns empty strings if the system doesn't exist +func (sm *SystemManager) GetSystemHostPort(systemID string) (string, string) { + sys, ok := sm.systems.GetOk(systemID) + if !ok { + return "", "" + } + return sys.Host, sys.Port +} + +// DisableAutoUpdater disables the automatic updater for a system +// This is intended for testing +// Returns false if the system doesn't exist +// func (sm *SystemManager) DisableAutoUpdater(systemID string) bool { +// sys, ok := sm.systems.GetOk(systemID) +// if !ok { +// return false +// } +// if sys.cancel != nil { +// sys.cancel() +// sys.cancel = nil +// } +// return true +// } + +// SetSystemStatusInDB sets the status of a system directly and updates the database record +// This is intended for testing +// Returns false if the system doesn't exist +func (sm *SystemManager) SetSystemStatusInDB(systemID string, status string) bool { + if !sm.HasSystem(systemID) { + return false + } + + // Update the database record + record, err := sm.hub.FindRecordById("systems", systemID) + if err != nil { + return false + } + + record.Set("status", status) + err = sm.hub.Save(record) + if err != nil { + return false + } + + return true +} diff --git a/beszel/internal/tests/hub.go b/beszel/internal/tests/hub.go new file mode 100644 index 0000000..9ac3556 --- /dev/null +++ b/beszel/internal/tests/hub.go @@ -0,0 +1,58 @@ +// Package tests provides helpers for testing the application. +package tests + +import ( + "beszel/internal/hub" + + "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/tests" + + _ "github.com/pocketbase/pocketbase/migrations" +) + +// TestHub is a wrapper hub instance used for testing. +type TestHub struct { + core.App + *tests.TestApp + *hub.Hub +} + +// NewTestHub creates and initializes a test application instance. +// +// It is the caller's responsibility to call app.Cleanup() when the app is no longer needed. +func NewTestHub(optTestDataDir ...string) (*TestHub, error) { + var testDataDir string + if len(optTestDataDir) > 0 { + testDataDir = optTestDataDir[0] + } + + return NewTestHubWithConfig(core.BaseAppConfig{ + DataDir: testDataDir, + EncryptionEnv: "pb_test_env", + }) +} + +// NewTestHubWithConfig creates and initializes a test application instance +// from the provided config. +// +// If config.DataDir is not set it fallbacks to the default internal test data directory. +// +// config.DataDir is cloned for each new test application instance. +// +// It is the caller's responsibility to call app.Cleanup() when the app is no longer needed. +func NewTestHubWithConfig(config core.BaseAppConfig) (*TestHub, error) { + testApp, err := tests.NewTestAppWithConfig(config) + if err != nil { + return nil, err + } + + hub := hub.NewHub(testApp) + + t := &TestHub{ + App: testApp, + TestApp: testApp, + Hub: hub, + } + + return t, nil +}