refactor(hub): embed pocketbase fields in hub struct

This commit is contained in:
henrygd
2025-02-11 19:17:20 -05:00
parent 3376a97bea
commit 2ab2cc83de
6 changed files with 90 additions and 99 deletions

View File

@@ -1,29 +1,10 @@
package main package main
import ( import (
"beszel"
"beszel/internal/hub" "beszel/internal/hub"
_ "beszel/migrations" _ "beszel/migrations"
"github.com/pocketbase/pocketbase"
"github.com/spf13/cobra"
) )
func main() { func main() {
app := pocketbase.NewWithConfig(pocketbase.Config{ hub.NewHub().Run()
DefaultDataDir: beszel.AppName + "_data",
})
app.RootCmd.Version = beszel.Version
app.RootCmd.Use = beszel.AppName
app.RootCmd.Short = ""
// add update command
app.RootCmd.AddCommand(&cobra.Command{
Use: "update",
Short: "Update " + beszel.AppName + " to the latest version",
Run: hub.Update,
})
hub.NewHub(app).Run()
} }

View File

@@ -12,7 +12,6 @@ import (
"github.com/containrrr/shoutrrr" "github.com/containrrr/shoutrrr"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/mailer" "github.com/pocketbase/pocketbase/tools/mailer"
@@ -21,7 +20,7 @@ import (
) )
type AlertManager struct { type AlertManager struct {
app *pocketbase.PocketBase app core.App
} }
type AlertMessageData struct { type AlertMessageData struct {
@@ -61,7 +60,7 @@ type SystemAlertData struct {
descriptor string // override descriptor in notification body (for temp sensor, disk partition, etc) descriptor string // override descriptor in notification body (for temp sensor, disk partition, etc)
} }
func NewAlertManager(app *pocketbase.PocketBase) *AlertManager { func NewAlertManager(app core.App) *AlertManager {
return &AlertManager{ return &AlertManager{
app: app, app: app,
} }
@@ -167,7 +166,6 @@ func (am *AlertManager) HandleSystemAlerts(systemRecord *core.Record, systemInfo
)). )).
OrderBy("created"). OrderBy("created").
All(&systemStats) All(&systemStats)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -28,7 +28,7 @@ type SystemConfig struct {
// Syncs systems with the config.yml file // Syncs systems with the config.yml file
func (h *Hub) syncSystemsWithConfig() error { func (h *Hub) syncSystemsWithConfig() error {
configPath := filepath.Join(h.app.DataDir(), "config.yml") configPath := filepath.Join(h.DataDir(), "config.yml")
configData, err := os.ReadFile(configPath) configData, err := os.ReadFile(configPath)
if err != nil { if err != nil {
return nil return nil
@@ -49,7 +49,7 @@ func (h *Hub) syncSystemsWithConfig() error {
// Create a map of email to user ID // Create a map of email to user ID
userEmailToID := make(map[string]string) userEmailToID := make(map[string]string)
users, err := h.app.FindAllRecords("users", dbx.NewExp("id != ''")) users, err := h.FindAllRecords("users", dbx.NewExp("id != ''"))
if err != nil { if err != nil {
return err return err
} }
@@ -84,7 +84,7 @@ func (h *Hub) syncSystemsWithConfig() error {
} }
// Get existing systems // Get existing systems
existingSystems, err := h.app.FindAllRecords("systems", dbx.NewExp("id != ''")) existingSystems, err := h.FindAllRecords("systems", dbx.NewExp("id != ''"))
if err != nil { if err != nil {
return err return err
} }
@@ -104,13 +104,13 @@ func (h *Hub) syncSystemsWithConfig() error {
existingSystem.Set("name", sysConfig.Name) existingSystem.Set("name", sysConfig.Name)
existingSystem.Set("users", sysConfig.Users) existingSystem.Set("users", sysConfig.Users)
existingSystem.Set("port", sysConfig.Port) existingSystem.Set("port", sysConfig.Port)
if err := h.app.Save(existingSystem); err != nil { if err := h.Save(existingSystem); err != nil {
return err return err
} }
delete(existingSystemsMap, key) delete(existingSystemsMap, key)
} else { } else {
// Create new system // Create new system
systemsCollection, err := h.app.FindCollectionByNameOrId("systems") systemsCollection, err := h.FindCollectionByNameOrId("systems")
if err != nil { if err != nil {
return fmt.Errorf("failed to find systems collection: %v", err) return fmt.Errorf("failed to find systems collection: %v", err)
} }
@@ -121,7 +121,7 @@ func (h *Hub) syncSystemsWithConfig() error {
newSystem.Set("users", sysConfig.Users) newSystem.Set("users", sysConfig.Users)
newSystem.Set("info", system.Info{}) newSystem.Set("info", system.Info{})
newSystem.Set("status", "pending") newSystem.Set("status", "pending")
if err := h.app.Save(newSystem); err != nil { if err := h.Save(newSystem); err != nil {
return fmt.Errorf("failed to create new system: %v", err) return fmt.Errorf("failed to create new system: %v", err)
} }
} }
@@ -129,7 +129,7 @@ func (h *Hub) syncSystemsWithConfig() error {
// Delete systems not in config // Delete systems not in config
for _, system := range existingSystemsMap { for _, system := range existingSystemsMap {
if err := h.app.Delete(system); err != nil { if err := h.Delete(system); err != nil {
return err return err
} }
} }
@@ -141,7 +141,7 @@ func (h *Hub) syncSystemsWithConfig() error {
// Generates content for the config.yml file as a YAML string // Generates content for the config.yml file as a YAML string
func (h *Hub) generateConfigYAML() (string, error) { func (h *Hub) generateConfigYAML() (string, error) {
// Fetch all systems from the database // Fetch all systems from the database
systems, err := h.app.FindRecordsByFilter("systems", "id != ''", "name", -1, 0) systems, err := h.FindRecordsByFilter("systems", "id != ''", "name", -1, 0)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -194,7 +194,7 @@ func (h *Hub) generateConfigYAML() (string, error) {
// New helper function to get a map of user IDs to emails // New helper function to get a map of user IDs to emails
func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) { func (h *Hub) getUserEmailMap(userIDs []string) (map[string]string, error) {
users, err := h.app.FindRecordsByIds("users", userIDs) users, err := h.FindRecordsByIds("users", userIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -27,11 +27,12 @@ import (
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/plugins/migratecmd" "github.com/pocketbase/pocketbase/plugins/migratecmd"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Hub struct { type Hub struct {
app *pocketbase.PocketBase *pocketbase.PocketBase
sshClientConfig *ssh.ClientConfig sshClientConfig *ssh.ClientConfig
pubKey string pubKey string
am *alerts.AlertManager am *alerts.AlertManager
@@ -42,15 +43,28 @@ type Hub struct {
appURL string appURL string
} }
func NewHub(app *pocketbase.PocketBase) *Hub { // NewHub creates a new Hub instance with default configuration
hub := &Hub{ func NewHub() *Hub {
app: app, var hub Hub
am: alerts.NewAlertManager(app), hub.PocketBase = pocketbase.NewWithConfig(pocketbase.Config{
um: users.NewUserManager(app), DefaultDataDir: beszel.AppName + "_data",
rm: records.NewRecordManager(app), })
}
hub.RootCmd.Version = beszel.Version
hub.RootCmd.Use = beszel.AppName
hub.RootCmd.Short = ""
// add update command
hub.RootCmd.AddCommand(&cobra.Command{
Use: "update",
Short: "Update " + beszel.AppName + " to the latest version",
Run: Update,
})
hub.am = alerts.NewAlertManager(hub)
hub.um = users.NewUserManager(hub)
hub.rm = records.NewRecordManager(hub)
hub.appURL, _ = GetEnv("APP_URL") hub.appURL, _ = GetEnv("APP_URL")
return hub return &hub
} }
// GetEnv retrieves an environment variable with a "BESZEL_HUB_" prefix, or falls back to the unprefixed key. // GetEnv retrieves an environment variable with a "BESZEL_HUB_" prefix, or falls back to the unprefixed key.
@@ -67,21 +81,21 @@ func (h *Hub) Run() {
isGoRun := strings.HasPrefix(os.Args[0], os.TempDir()) isGoRun := strings.HasPrefix(os.Args[0], os.TempDir())
// enable auto creation of migration files when making collection changes in the Admin UI // enable auto creation of migration files when making collection changes in the Admin UI
migratecmd.MustRegister(h.app, h.app.RootCmd, migratecmd.Config{ migratecmd.MustRegister(h, h.RootCmd, migratecmd.Config{
// (the isGoRun check is to enable it only during development) // (the isGoRun check is to enable it only during development)
Automigrate: isGoRun, Automigrate: isGoRun,
Dir: "../../migrations", Dir: "../../migrations",
}) })
// initial setup // initial setup
h.app.OnServe().BindFunc(func(se *core.ServeEvent) error { h.OnServe().BindFunc(func(se *core.ServeEvent) error {
// create ssh client config // create ssh client config
err := h.createSSHClientConfig() err := h.createSSHClientConfig()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// set general settings // set general settings
settings := h.app.Settings() settings := h.Settings()
// batch requests (for global alerts) // batch requests (for global alerts)
settings.Batch.Enabled = true settings.Batch.Enabled = true
// set URL if BASE_URL env is set // set URL if BASE_URL env is set
@@ -89,7 +103,7 @@ func (h *Hub) Run() {
settings.Meta.AppURL = h.appURL settings.Meta.AppURL = h.appURL
} }
// set auth settings // set auth settings
usersCollection, err := h.app.FindCollectionByNameOrId("users") usersCollection, err := h.FindCollectionByNameOrId("users")
if err != nil { if err != nil {
return err return err
} }
@@ -108,7 +122,7 @@ func (h *Hub) Run() {
} else { } else {
usersCollection.CreateRule = nil usersCollection.CreateRule = nil
} }
if err := h.app.Save(usersCollection); err != nil { if err := h.Save(usersCollection); err != nil {
return err return err
} }
// sync systems with config // sync systems with config
@@ -117,7 +131,7 @@ func (h *Hub) Run() {
}) })
// serve web ui // serve web ui
h.app.OnServe().BindFunc(func(se *core.ServeEvent) error { h.OnServe().BindFunc(func(se *core.ServeEvent) error {
switch isGoRun { switch isGoRun {
case true: case true:
proxy := httputil.NewSingleHostReverseProxy(&url.URL{ proxy := httputil.NewSingleHostReverseProxy(&url.URL{
@@ -163,14 +177,14 @@ func (h *Hub) Run() {
}) })
// set up scheduled jobs / ticker for system updates // set up scheduled jobs / ticker for system updates
h.app.OnServe().BindFunc(func(se *core.ServeEvent) error { h.OnServe().BindFunc(func(se *core.ServeEvent) error {
// 15 second ticker for system updates // 15 second ticker for system updates
go h.startSystemUpdateTicker() go h.startSystemUpdateTicker()
// set up cron jobs // set up cron jobs
// delete old records once every hour // delete old records once every hour
h.app.Cron().MustAdd("delete old records", "8 * * * *", h.rm.DeleteOldRecords) h.Cron().MustAdd("delete old records", "8 * * * *", h.rm.DeleteOldRecords)
// create longer records every 10 minutes // create longer records every 10 minutes
h.app.Cron().MustAdd("create longer records", "*/10 * * * *", func() { h.Cron().MustAdd("create longer records", "*/10 * * * *", func() {
if systemStats, containerStats, err := h.getCollections(); err == nil { if systemStats, containerStats, err := h.getCollections(); err == nil {
h.rm.CreateLongerRecords([]*core.Collection{systemStats, containerStats}) h.rm.CreateLongerRecords([]*core.Collection{systemStats, containerStats})
} }
@@ -179,7 +193,7 @@ func (h *Hub) Run() {
}) })
// custom api routes // custom api routes
h.app.OnServe().BindFunc(func(se *core.ServeEvent) error { h.OnServe().BindFunc(func(se *core.ServeEvent) error {
// returns public key // returns public key
se.Router.GET("/api/beszel/getkey", func(e *core.RequestEvent) error { se.Router.GET("/api/beszel/getkey", func(e *core.RequestEvent) error {
info, _ := e.RequestInfo() info, _ := e.RequestInfo()
@@ -190,7 +204,7 @@ func (h *Hub) Run() {
}) })
// check if first time setup on login page // check if first time setup on login page
se.Router.GET("/api/beszel/first-run", func(e *core.RequestEvent) error { se.Router.GET("/api/beszel/first-run", func(e *core.RequestEvent) error {
total, err := h.app.CountRecords("users") total, err := h.CountRecords("users")
return e.JSON(http.StatusOK, map[string]bool{"firstRun": err == nil && total == 0}) return e.JSON(http.StatusOK, map[string]bool{"firstRun": err == nil && total == 0})
}) })
// send test notification // send test notification
@@ -198,31 +212,31 @@ func (h *Hub) Run() {
// API endpoint to get config.yml content // API endpoint to get config.yml content
se.Router.GET("/api/beszel/config-yaml", h.getYamlConfig) se.Router.GET("/api/beszel/config-yaml", h.getYamlConfig)
// create first user endpoint only needed if no users exist // create first user endpoint only needed if no users exist
if totalUsers, _ := h.app.CountRecords("users"); totalUsers == 0 { if totalUsers, _ := h.CountRecords("users"); totalUsers == 0 {
se.Router.POST("/api/beszel/create-user", h.um.CreateFirstUser) se.Router.POST("/api/beszel/create-user", h.um.CreateFirstUser)
} }
return se.Next() return se.Next()
}) })
// system creation defaults // system creation defaults
h.app.OnRecordCreate("systems").BindFunc(func(e *core.RecordEvent) error { h.OnRecordCreate("systems").BindFunc(func(e *core.RecordEvent) error {
e.Record.Set("info", system.Info{}) e.Record.Set("info", system.Info{})
e.Record.Set("status", "pending") e.Record.Set("status", "pending")
return e.Next() return e.Next()
}) })
// immediately create connection for new systems // immediately create connection for new systems
h.app.OnRecordAfterCreateSuccess("systems").BindFunc(func(e *core.RecordEvent) error { h.OnRecordAfterCreateSuccess("systems").BindFunc(func(e *core.RecordEvent) error {
go h.updateSystem(e.Record) go h.updateSystem(e.Record)
return e.Next() return e.Next()
}) })
// handle default values for user / user_settings creation // handle default values for user / user_settings creation
h.app.OnRecordCreate("users").BindFunc(h.um.InitializeUserRole) h.OnRecordCreate("users").BindFunc(h.um.InitializeUserRole)
h.app.OnRecordCreate("user_settings").BindFunc(h.um.InitializeUserSettings) h.OnRecordCreate("user_settings").BindFunc(h.um.InitializeUserSettings)
// empty info for systems that are paused // empty info for systems that are paused
h.app.OnRecordUpdate("systems").BindFunc(func(e *core.RecordEvent) error { h.OnRecordUpdate("systems").BindFunc(func(e *core.RecordEvent) error {
if e.Record.GetString("status") == "paused" { if e.Record.GetString("status") == "paused" {
e.Record.Set("info", system.Info{}) e.Record.Set("info", system.Info{})
} }
@@ -230,7 +244,7 @@ func (h *Hub) Run() {
}) })
// do things after a systems record is updated // do things after a systems record is updated
h.app.OnRecordAfterUpdateSuccess("systems").BindFunc(func(e *core.RecordEvent) error { h.OnRecordAfterUpdateSuccess("systems").BindFunc(func(e *core.RecordEvent) error {
newRecord := e.Record.Fresh() newRecord := e.Record.Fresh()
oldRecord := newRecord.Original() oldRecord := newRecord.Original()
newStatus := newRecord.GetString("status") newStatus := newRecord.GetString("status")
@@ -250,12 +264,12 @@ func (h *Hub) Run() {
}) })
// if system is deleted, close connection // if system is deleted, close connection
h.app.OnRecordAfterDeleteSuccess("systems").BindFunc(func(e *core.RecordEvent) error { h.OnRecordAfterDeleteSuccess("systems").BindFunc(func(e *core.RecordEvent) error {
h.deleteSystemConnection(e.Record) h.deleteSystemConnection(e.Record)
return e.Next() return e.Next()
}) })
if err := h.app.Start(); err != nil { if err := h.Start(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
@@ -268,7 +282,7 @@ func (h *Hub) startSystemUpdateTicker() {
} }
func (h *Hub) updateSystems() { func (h *Hub) updateSystems() {
records, err := h.app.FindRecordsByFilter( records, err := h.FindRecordsByFilter(
"2hz5ncl8tizk5nx", // systems collection "2hz5ncl8tizk5nx", // systems collection
"status != 'paused'", // filter "status != 'paused'", // filter
"updated", // sort "updated", // sort
@@ -277,7 +291,7 @@ func (h *Hub) updateSystems() {
) )
// log.Println("records", len(records)) // log.Println("records", len(records))
if err != nil || len(records) == 0 { if err != nil || len(records) == 0 {
// h.app.Logger().Error("Failed to query systems") // h.Logger().Error("Failed to query systems")
return return
} }
fiftySecondsAgo := time.Now().UTC().Add(-50 * time.Second) fiftySecondsAgo := time.Now().UTC().Add(-50 * time.Second)
@@ -302,52 +316,52 @@ func (h *Hub) updateSystem(record *core.Record) {
var err error var err error
// check if system connection exists // check if system connection exists
if existingClient, ok := h.app.Store().GetOk(record.Id); ok { if existingClient, ok := h.Store().GetOk(record.Id); ok {
client = existingClient.(*ssh.Client) client = existingClient.(*ssh.Client)
} else { } else {
// create system connection // create system connection
client, err = h.createSystemConnection(record) client, err = h.createSystemConnection(record)
if err != nil { if err != nil {
if record.GetString("status") != "down" { if record.GetString("status") != "down" {
h.app.Logger().Error("Failed to connect:", "err", err.Error(), "system", record.GetString("host"), "port", record.GetString("port")) h.Logger().Error("Failed to connect:", "err", err.Error(), "system", record.GetString("host"), "port", record.GetString("port"))
h.updateSystemStatus(record, "down") h.updateSystemStatus(record, "down")
} }
return return
} }
h.app.Store().Set(record.Id, client) h.Store().Set(record.Id, client)
} }
// get system stats from agent // get system stats from agent
var systemData system.CombinedData var systemData system.CombinedData
if err := h.requestJsonFromAgent(client, &systemData); err != nil { if err := h.requestJsonFromAgent(client, &systemData); err != nil {
if err.Error() == "bad client" { if err.Error() == "bad client" {
// if previous connection was closed, try again // if previous connection was closed, try again
h.app.Logger().Error("Existing SSH connection closed. Retrying...", "host", record.GetString("host"), "port", record.GetString("port")) h.Logger().Error("Existing SSH connection closed. Retrying...", "host", record.GetString("host"), "port", record.GetString("port"))
h.deleteSystemConnection(record) h.deleteSystemConnection(record)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
h.updateSystem(record) h.updateSystem(record)
return return
} }
h.app.Logger().Error("Failed to get system stats: ", "err", err.Error()) h.Logger().Error("Failed to get system stats: ", "err", err.Error())
h.updateSystemStatus(record, "down") h.updateSystemStatus(record, "down")
return return
} }
// update system record // update system record
record.Set("status", "up") record.Set("status", "up")
record.Set("info", systemData.Info) record.Set("info", systemData.Info)
if err := h.app.SaveNoValidate(record); err != nil { if err := h.SaveNoValidate(record); err != nil {
h.app.Logger().Error("Failed to update record: ", "err", err.Error()) h.Logger().Error("Failed to update record: ", "err", err.Error())
} }
// add system_stats and container_stats records // add system_stats and container_stats records
if systemStats, containerStats, err := h.getCollections(); err != nil { if systemStats, containerStats, err := h.getCollections(); err != nil {
h.app.Logger().Error("Failed to get collections: ", "err", err.Error()) h.Logger().Error("Failed to get collections: ", "err", err.Error())
} else { } else {
// add new system_stats record // add new system_stats record
systemStatsRecord := core.NewRecord(systemStats) systemStatsRecord := core.NewRecord(systemStats)
systemStatsRecord.Set("system", record.Id) systemStatsRecord.Set("system", record.Id)
systemStatsRecord.Set("stats", systemData.Stats) systemStatsRecord.Set("stats", systemData.Stats)
systemStatsRecord.Set("type", "1m") systemStatsRecord.Set("type", "1m")
if err := h.app.SaveNoValidate(systemStatsRecord); err != nil { if err := h.SaveNoValidate(systemStatsRecord); err != nil {
h.app.Logger().Error("Failed to save record: ", "err", err.Error()) h.Logger().Error("Failed to save record: ", "err", err.Error())
} }
// add new container_stats record // add new container_stats record
if len(systemData.Containers) > 0 { if len(systemData.Containers) > 0 {
@@ -355,29 +369,29 @@ func (h *Hub) updateSystem(record *core.Record) {
containerStatsRecord.Set("system", record.Id) containerStatsRecord.Set("system", record.Id)
containerStatsRecord.Set("stats", systemData.Containers) containerStatsRecord.Set("stats", systemData.Containers)
containerStatsRecord.Set("type", "1m") containerStatsRecord.Set("type", "1m")
if err := h.app.SaveNoValidate(containerStatsRecord); err != nil { if err := h.SaveNoValidate(containerStatsRecord); err != nil {
h.app.Logger().Error("Failed to save record: ", "err", err.Error()) h.Logger().Error("Failed to save record: ", "err", err.Error())
} }
} }
} }
// system info alerts // system info alerts
if err := h.am.HandleSystemAlerts(record, systemData.Info, systemData.Stats.Temperatures, systemData.Stats.ExtraFs); err != nil { if err := h.am.HandleSystemAlerts(record, systemData.Info, systemData.Stats.Temperatures, systemData.Stats.ExtraFs); err != nil {
h.app.Logger().Error("System alerts error", "err", err.Error()) h.Logger().Error("System alerts error", "err", err.Error())
} }
} }
// return system_stats and container_stats collections // return system_stats and container_stats collections
func (h *Hub) getCollections() (*core.Collection, *core.Collection, error) { func (h *Hub) getCollections() (*core.Collection, *core.Collection, error) {
if h.systemStats == nil { if h.systemStats == nil {
systemStats, err := h.app.FindCollectionByNameOrId("system_stats") systemStats, err := h.FindCollectionByNameOrId("system_stats")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
h.systemStats = systemStats h.systemStats = systemStats
} }
if h.containerStats == nil { if h.containerStats == nil {
containerStats, err := h.app.FindCollectionByNameOrId("container_stats") containerStats, err := h.FindCollectionByNameOrId("container_stats")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -390,19 +404,19 @@ func (h *Hub) getCollections() (*core.Collection, *core.Collection, error) {
func (h *Hub) updateSystemStatus(record *core.Record, status string) { func (h *Hub) updateSystemStatus(record *core.Record, status string) {
if record.Fresh().GetString("status") != status { if record.Fresh().GetString("status") != status {
record.Set("status", status) record.Set("status", status)
if err := h.app.SaveNoValidate(record); err != nil { if err := h.SaveNoValidate(record); err != nil {
h.app.Logger().Error("Failed to update record: ", "err", err.Error()) h.Logger().Error("Failed to update record: ", "err", err.Error())
} }
} }
} }
// delete system connection from map and close connection // delete system connection from map and close connection
func (h *Hub) deleteSystemConnection(record *core.Record) { func (h *Hub) deleteSystemConnection(record *core.Record) {
if client, ok := h.app.Store().GetOk(record.Id); ok { if client, ok := h.Store().GetOk(record.Id); ok {
if sshClient := client.(*ssh.Client); sshClient != nil { if sshClient := client.(*ssh.Client); sshClient != nil {
sshClient.Close() sshClient.Close()
} }
h.app.Store().Remove(record.Id) h.Store().Remove(record.Id)
} }
} }
@@ -417,7 +431,7 @@ func (h *Hub) createSystemConnection(record *core.Record) (*ssh.Client, error) {
func (h *Hub) createSSHClientConfig() error { func (h *Hub) createSSHClientConfig() error {
key, err := h.getSSHKey() key, err := h.getSSHKey()
if err != nil { if err != nil {
h.app.Logger().Error("Failed to get SSH key: ", "err", err.Error()) h.Logger().Error("Failed to get SSH key: ", "err", err.Error())
return err return err
} }
@@ -494,11 +508,11 @@ func newSessionWithTimeout(client *ssh.Client, timeout time.Duration) (*ssh.Sess
} }
func (h *Hub) getSSHKey() ([]byte, error) { func (h *Hub) getSSHKey() ([]byte, error) {
dataDir := h.app.DataDir() dataDir := h.DataDir()
// check if the key pair already exists // check if the key pair already exists
existingKey, err := os.ReadFile(dataDir + "/id_ed25519") existingKey, err := os.ReadFile(dataDir + "/id_ed25519")
if err == nil { if err == nil {
if pubKey, err := os.ReadFile(h.app.DataDir() + "/id_ed25519.pub"); err == nil { if pubKey, err := os.ReadFile(h.DataDir() + "/id_ed25519.pub"); err == nil {
h.pubKey = strings.TrimSuffix(string(pubKey), "\n") h.pubKey = strings.TrimSuffix(string(pubKey), "\n")
} }
// return existing private key // return existing private key
@@ -508,27 +522,27 @@ func (h *Hub) getSSHKey() ([]byte, error) {
// Generate the Ed25519 key pair // Generate the Ed25519 key pair
pubKey, privKey, err := ed25519.GenerateKey(nil) pubKey, privKey, err := ed25519.GenerateKey(nil)
if err != nil { if err != nil {
// h.app.Logger().Error("Error generating key pair:", "err", err.Error()) // h.Logger().Error("Error generating key pair:", "err", err.Error())
return nil, err return nil, err
} }
// Get the private key in OpenSSH format // Get the private key in OpenSSH format
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "") privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil { if err != nil {
// h.app.Logger().Error("Error marshaling private key:", "err", err.Error()) // h.Logger().Error("Error marshaling private key:", "err", err.Error())
return nil, err return nil, err
} }
// Save the private key to a file // Save the private key to a file
privateFile, err := os.Create(dataDir + "/id_ed25519") privateFile, err := os.Create(dataDir + "/id_ed25519")
if err != nil { if err != nil {
// h.app.Logger().Error("Error creating private key file:", "err", err.Error()) // h.Logger().Error("Error creating private key file:", "err", err.Error())
return nil, err return nil, err
} }
defer privateFile.Close() defer privateFile.Close()
if err := pem.Encode(privateFile, privKeyBytes); err != nil { if err := pem.Encode(privateFile, privKeyBytes); err != nil {
// h.app.Logger().Error("Error writing private key to file:", "err", err.Error()) // h.Logger().Error("Error writing private key to file:", "err", err.Error())
return nil, err return nil, err
} }
@@ -552,9 +566,9 @@ func (h *Hub) getSSHKey() ([]byte, error) {
return nil, err return nil, err
} }
h.app.Logger().Info("ed25519 SSH key pair generated successfully.") h.Logger().Info("ed25519 SSH key pair generated successfully.")
h.app.Logger().Info("Private key saved to: " + dataDir + "/id_ed25519") h.Logger().Info("Private key saved to: " + dataDir + "/id_ed25519")
h.app.Logger().Info("Public key saved to: " + dataDir + "/id_ed25519.pub") h.Logger().Info("Public key saved to: " + dataDir + "/id_ed25519.pub")
existingKey, err = os.ReadFile(dataDir + "/id_ed25519") existingKey, err = os.ReadFile(dataDir + "/id_ed25519")
if err == nil { if err == nil {

View File

@@ -10,13 +10,12 @@ import (
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/pocketbase/dbx" "github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/types" "github.com/pocketbase/pocketbase/tools/types"
) )
type RecordManager struct { type RecordManager struct {
app *pocketbase.PocketBase app core.App
} }
type LongerRecordData struct { type LongerRecordData struct {
@@ -35,7 +34,7 @@ type RecordStats []struct {
Stats []byte `db:"stats"` Stats []byte `db:"stats"`
} }
func NewRecordManager(app *pocketbase.PocketBase) *RecordManager { func NewRecordManager(app core.App) *RecordManager {
return &RecordManager{app} return &RecordManager{app}
} }

View File

@@ -6,12 +6,11 @@ import (
"log" "log"
"net/http" "net/http"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
) )
type UserManager struct { type UserManager struct {
app *pocketbase.PocketBase app core.App
} }
type UserSettings struct { type UserSettings struct {
@@ -21,7 +20,7 @@ type UserSettings struct {
// Language string `json:"lang"` // Language string `json:"lang"`
} }
func NewUserManager(app *pocketbase.PocketBase) *UserManager { func NewUserManager(app core.App) *UserManager {
return &UserManager{ return &UserManager{
app: app, app: app,
} }