Refactor SSH configuration and key management

- Restrict to specific key exchanges / MACs / ciphers.
- Refactored GetSSHKey method to return an ssh.Signer instead of byte array.
- Added common package.

Co-authored-by: nhas <jordanatararimu@gmail.com>
This commit is contained in:
henrygd
2025-05-07 20:03:21 -04:00
parent c0a6153a43
commit 63af81666b
4 changed files with 74 additions and 64 deletions

View File

@@ -1,6 +1,7 @@
package agent package agent
import ( import (
"beszel/internal/common"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -19,8 +20,6 @@ type ServerOptions struct {
} }
func (a *Agent) StartServer(opts ServerOptions) error { func (a *Agent) StartServer(opts ServerOptions) error {
ssh.Handle(a.handleSession)
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network) slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
if opts.Network == "unix" { if opts.Network == "unix" {
@@ -37,17 +36,40 @@ func (a *Agent) StartServer(opts ServerOptions) error {
} }
defer ln.Close() defer ln.Close()
// Start SSH server on the listener // base config (limit to allowed algorithms)
return ssh.Serve(ln, nil, ssh.NoPty(), config := &gossh.ServerConfig{}
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { config.KeyExchanges = common.DefaultKeyExchanges
config.MACs = common.DefaultMACs
config.Ciphers = common.DefaultCiphers
// set default handler
ssh.Handle(a.handleSession)
server := ssh.Server{
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return config
},
// check public key(s)
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
for _, pubKey := range opts.Keys { for _, pubKey := range opts.Keys {
if ssh.KeysEqual(key, pubKey) { if ssh.KeysEqual(key, pubKey) {
return true return true
} }
} }
return false return false
}), },
) // disable pty
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return false
},
// log failed connections
ConnectionFailedCallback: func(conn net.Conn, err error) {
slog.Warn("Failed connection attempt", "addr", conn.RemoteAddr().String(), "err", err)
},
}
// Start SSH server on the listener
return server.Serve(ln)
} }
func (a *Agent) handleSession(s ssh.Session) { func (a *Agent) handleSession(s ssh.Session) {
@@ -56,6 +78,7 @@ func (a *Agent) handleSession(s ssh.Session) {
if err := json.NewEncoder(s).Encode(stats); err != nil { if err := json.NewEncoder(s).Encode(stats); err != nil {
slog.Error("Error encoding stats", "err", err, "stats", stats) slog.Error("Error encoding stats", "err", err, "stats", stats)
s.Exit(1) s.Exit(1)
return
} }
s.Exit(0) s.Exit(0)
} }

View File

@@ -0,0 +1,7 @@
package common
var (
DefaultKeyExchanges = []string{"curve25519-sha256"}
DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"}
DefaultCiphers = []string{"chacha20-poly1305@openssh.com"}
)

View File

@@ -10,11 +10,13 @@ import (
"beszel/site" "beszel/site"
"crypto/ed25519" "crypto/ed25519"
"encoding/pem" "encoding/pem"
"fmt"
"io/fs" "io/fs"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os" "os"
"path"
"strings" "strings"
"github.com/pocketbase/pocketbase" "github.com/pocketbase/pocketbase"
@@ -239,73 +241,47 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error {
return nil return nil
} }
// generates key pair if it doesn't exist and returns private key bytes // generates key pair if it doesn't exist and returns signer
func (h *Hub) GetSSHKey() ([]byte, error) { func (h *Hub) GetSSHKey() (ssh.Signer, error) {
dataDir := h.DataDir() privateKeyPath := path.Join(h.DataDir(), "id_ed25519")
// check if the key pair already exists // check if the key pair already exists
existingKey, err := os.ReadFile(dataDir + "/id_ed25519") existingKey, err := os.ReadFile(privateKeyPath)
if err == nil { if err == nil {
if pubKey, err := os.ReadFile(h.DataDir() + "/id_ed25519.pub"); err == nil { private, err := ssh.ParsePrivateKey(existingKey)
h.pubKey = strings.TrimSuffix(string(pubKey), "\n") if err != nil {
return nil, fmt.Errorf("failed to parse private key: %s", err)
} }
// return existing private key pubKeyBytes := ssh.MarshalAuthorizedKey(private.PublicKey())
return existingKey, nil h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
return private, nil
} }
// Generate the Ed25519 key pair // Generate the Ed25519 key pair
pubKey, privKey, err := ed25519.GenerateKey(nil) pubKey, privKey, err := ed25519.GenerateKey(nil)
if err != nil { if err != nil {
// h.Logger().Error("Error generating key pair:", "err", err.Error())
return nil, err return nil, err
} }
// Get the private key in OpenSSH format // Get the private key in OpenSSH format
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "") privKeyPem, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil {
// h.Logger().Error("Error marshaling private key:", "err", err.Error())
return nil, err
}
// Save the private key to a file
privateFile, err := os.Create(dataDir + "/id_ed25519")
if err != nil {
// h.Logger().Error("Error creating private key file:", "err", err.Error())
return nil, err
}
defer privateFile.Close()
if err := pem.Encode(privateFile, privKeyBytes); err != nil {
// h.Logger().Error("Error writing private key to file:", "err", err.Error())
return nil, err
}
// Generate the public key in OpenSSH format
publicKey, err := ssh.NewPublicKey(pubKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pubKeyBytes := ssh.MarshalAuthorizedKey(publicKey) if err := os.WriteFile(privateKeyPath, pem.EncodeToMemory(privKeyPem), 0600); err != nil {
return nil, fmt.Errorf("failed to write private key to %q: err: %w", privateKeyPath, err)
}
// These are fine to ignore the errors on, as we've literally just created a crypto.PublicKey | crypto.Signer
sshPubKey, _ := ssh.NewPublicKey(pubKey)
sshPrivate, _ := ssh.NewSignerFromSigner(privKey)
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n") h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
// Save the public key to a file
publicFile, err := os.Create(dataDir + "/id_ed25519.pub")
if err != nil {
return nil, err
}
defer publicFile.Close()
if _, err := publicFile.Write(pubKeyBytes); err != nil {
return nil, err
}
h.Logger().Info("ed25519 SSH key pair generated successfully.") h.Logger().Info("ed25519 SSH key pair generated successfully.")
h.Logger().Info("Private key saved to: " + dataDir + "/id_ed25519") h.Logger().Info("Saved to: " + privateKeyPath)
h.Logger().Info("Public key saved to: " + dataDir + "/id_ed25519.pub")
existingKey, err = os.ReadFile(dataDir + "/id_ed25519") return sshPrivate, err
if err == nil {
return existingKey, nil
}
return nil, err
} }

View File

@@ -1,6 +1,7 @@
package systems package systems
import ( import (
"beszel/internal/common"
"beszel/internal/entities/system" "beszel/internal/entities/system"
"context" "context"
"fmt" "fmt"
@@ -45,7 +46,7 @@ type System struct {
type hubLike interface { type hubLike interface {
core.App core.App
GetSSHKey() ([]byte, error) GetSSHKey() (ssh.Signer, error)
HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error
HandleStatusAlerts(status string, systemRecord *core.Record) error HandleStatusAlerts(status string, systemRecord *core.Record) error
} }
@@ -62,13 +63,10 @@ func NewSystemManager(hub hubLike) *SystemManager {
func (sm *SystemManager) Initialize() error { func (sm *SystemManager) Initialize() error {
sm.bindEventHooks() sm.bindEventHooks()
// ssh setup // ssh setup
key, err := sm.hub.GetSSHKey() err := sm.createSSHClientConfig()
if err != nil { if err != nil {
return err return err
} }
if err := sm.createSSHClientConfig(key); err != nil {
return err
}
// start updating existing systems // start updating existing systems
var systems []*System var systems []*System
err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems) err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems)
@@ -362,15 +360,21 @@ func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) {
return nil, fmt.Errorf("failed to fetch data") return nil, fmt.Errorf("failed to fetch data")
} }
func (sm *SystemManager) createSSHClientConfig(key []byte) error { // createSSHClientConfig initializes the ssh config for the system manager
signer, err := ssh.ParsePrivateKey(key) func (sm *SystemManager) createSSHClientConfig() error {
privateKey, err := sm.hub.GetSSHKey()
if err != nil { if err != nil {
return err return err
} }
sm.sshConfig = &ssh.ClientConfig{ sm.sshConfig = &ssh.ClientConfig{
User: "u", User: "u",
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer), ssh.PublicKeys(privateKey),
},
Config: ssh.Config{
Ciphers: common.DefaultCiphers,
KeyExchanges: common.DefaultKeyExchanges,
MACs: common.DefaultMACs,
}, },
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: sessionTimeout, Timeout: sessionTimeout,