Refactor SSH configuration and key management

- Restrict to specific key exchanges / MACs / ciphers.
- Refactored GetSSHKey method to return an ssh.Signer instead of byte array.
- Added common package.

Co-authored-by: nhas <jordanatararimu@gmail.com>
This commit is contained in:
henrygd
2025-05-07 20:03:21 -04:00
parent c0a6153a43
commit 63af81666b
4 changed files with 74 additions and 64 deletions

View File

@@ -1,6 +1,7 @@
package agent
import (
"beszel/internal/common"
"encoding/json"
"fmt"
"log/slog"
@@ -19,8 +20,6 @@ type ServerOptions struct {
}
func (a *Agent) StartServer(opts ServerOptions) error {
ssh.Handle(a.handleSession)
slog.Info("Starting SSH server", "addr", opts.Addr, "network", opts.Network)
if opts.Network == "unix" {
@@ -37,17 +36,40 @@ func (a *Agent) StartServer(opts ServerOptions) error {
}
defer ln.Close()
// Start SSH server on the listener
return ssh.Serve(ln, nil, ssh.NoPty(),
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
// base config (limit to allowed algorithms)
config := &gossh.ServerConfig{}
config.KeyExchanges = common.DefaultKeyExchanges
config.MACs = common.DefaultMACs
config.Ciphers = common.DefaultCiphers
// set default handler
ssh.Handle(a.handleSession)
server := ssh.Server{
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return config
},
// check public key(s)
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
for _, pubKey := range opts.Keys {
if ssh.KeysEqual(key, pubKey) {
return true
}
}
return false
}),
)
},
// disable pty
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return false
},
// log failed connections
ConnectionFailedCallback: func(conn net.Conn, err error) {
slog.Warn("Failed connection attempt", "addr", conn.RemoteAddr().String(), "err", err)
},
}
// Start SSH server on the listener
return server.Serve(ln)
}
func (a *Agent) handleSession(s ssh.Session) {
@@ -56,6 +78,7 @@ func (a *Agent) handleSession(s ssh.Session) {
if err := json.NewEncoder(s).Encode(stats); err != nil {
slog.Error("Error encoding stats", "err", err, "stats", stats)
s.Exit(1)
return
}
s.Exit(0)
}

View File

@@ -0,0 +1,7 @@
package common
var (
DefaultKeyExchanges = []string{"curve25519-sha256"}
DefaultMACs = []string{"hmac-sha2-256-etm@openssh.com"}
DefaultCiphers = []string{"chacha20-poly1305@openssh.com"}
)

View File

@@ -10,11 +10,13 @@ import (
"beszel/site"
"crypto/ed25519"
"encoding/pem"
"fmt"
"io/fs"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"strings"
"github.com/pocketbase/pocketbase"
@@ -239,73 +241,47 @@ func (h *Hub) registerApiRoutes(se *core.ServeEvent) error {
return nil
}
// generates key pair if it doesn't exist and returns private key bytes
func (h *Hub) GetSSHKey() ([]byte, error) {
dataDir := h.DataDir()
// generates key pair if it doesn't exist and returns signer
func (h *Hub) GetSSHKey() (ssh.Signer, error) {
privateKeyPath := path.Join(h.DataDir(), "id_ed25519")
// check if the key pair already exists
existingKey, err := os.ReadFile(dataDir + "/id_ed25519")
existingKey, err := os.ReadFile(privateKeyPath)
if err == nil {
if pubKey, err := os.ReadFile(h.DataDir() + "/id_ed25519.pub"); err == nil {
h.pubKey = strings.TrimSuffix(string(pubKey), "\n")
private, err := ssh.ParsePrivateKey(existingKey)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %s", err)
}
// return existing private key
return existingKey, nil
pubKeyBytes := ssh.MarshalAuthorizedKey(private.PublicKey())
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
return private, nil
}
// Generate the Ed25519 key pair
pubKey, privKey, err := ed25519.GenerateKey(nil)
if err != nil {
// h.Logger().Error("Error generating key pair:", "err", err.Error())
return nil, err
}
// Get the private key in OpenSSH format
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil {
// h.Logger().Error("Error marshaling private key:", "err", err.Error())
return nil, err
}
// Save the private key to a file
privateFile, err := os.Create(dataDir + "/id_ed25519")
if err != nil {
// h.Logger().Error("Error creating private key file:", "err", err.Error())
return nil, err
}
defer privateFile.Close()
if err := pem.Encode(privateFile, privKeyBytes); err != nil {
// h.Logger().Error("Error writing private key to file:", "err", err.Error())
return nil, err
}
// Generate the public key in OpenSSH format
publicKey, err := ssh.NewPublicKey(pubKey)
privKeyPem, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil {
return nil, err
}
pubKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
if err := os.WriteFile(privateKeyPath, pem.EncodeToMemory(privKeyPem), 0600); err != nil {
return nil, fmt.Errorf("failed to write private key to %q: err: %w", privateKeyPath, err)
}
// These are fine to ignore the errors on, as we've literally just created a crypto.PublicKey | crypto.Signer
sshPubKey, _ := ssh.NewPublicKey(pubKey)
sshPrivate, _ := ssh.NewSignerFromSigner(privKey)
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
h.pubKey = strings.TrimSuffix(string(pubKeyBytes), "\n")
// Save the public key to a file
publicFile, err := os.Create(dataDir + "/id_ed25519.pub")
if err != nil {
return nil, err
}
defer publicFile.Close()
if _, err := publicFile.Write(pubKeyBytes); err != nil {
return nil, err
}
h.Logger().Info("ed25519 SSH key pair generated successfully.")
h.Logger().Info("Private key saved to: " + dataDir + "/id_ed25519")
h.Logger().Info("Public key saved to: " + dataDir + "/id_ed25519.pub")
h.Logger().Info("Saved to: " + privateKeyPath)
existingKey, err = os.ReadFile(dataDir + "/id_ed25519")
if err == nil {
return existingKey, nil
}
return nil, err
return sshPrivate, err
}

View File

@@ -1,6 +1,7 @@
package systems
import (
"beszel/internal/common"
"beszel/internal/entities/system"
"context"
"fmt"
@@ -45,7 +46,7 @@ type System struct {
type hubLike interface {
core.App
GetSSHKey() ([]byte, error)
GetSSHKey() (ssh.Signer, error)
HandleSystemAlerts(systemRecord *core.Record, data *system.CombinedData) error
HandleStatusAlerts(status string, systemRecord *core.Record) error
}
@@ -62,13 +63,10 @@ func NewSystemManager(hub hubLike) *SystemManager {
func (sm *SystemManager) Initialize() error {
sm.bindEventHooks()
// ssh setup
key, err := sm.hub.GetSSHKey()
err := sm.createSSHClientConfig()
if err != nil {
return err
}
if err := sm.createSSHClientConfig(key); err != nil {
return err
}
// start updating existing systems
var systems []*System
err = sm.hub.DB().NewQuery("SELECT id, host, port, status FROM systems WHERE status != 'paused'").All(&systems)
@@ -362,15 +360,21 @@ func (sys *System) fetchDataFromAgent() (*system.CombinedData, error) {
return nil, fmt.Errorf("failed to fetch data")
}
func (sm *SystemManager) createSSHClientConfig(key []byte) error {
signer, err := ssh.ParsePrivateKey(key)
// createSSHClientConfig initializes the ssh config for the system manager
func (sm *SystemManager) createSSHClientConfig() error {
privateKey, err := sm.hub.GetSSHKey()
if err != nil {
return err
}
sm.sshConfig = &ssh.ClientConfig{
User: "u",
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
ssh.PublicKeys(privateKey),
},
Config: ssh.Config{
Ciphers: common.DefaultCiphers,
KeyExchanges: common.DefaultKeyExchanges,
MACs: common.DefaultMACs,
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: sessionTimeout,