mirror of
https://github.com/fankes/beszel.git
synced 2025-10-19 01:39:34 +08:00
rename /src to /internal (sorry i'll fix the prs)
This commit is contained in:
321
internal/hub/agent_connect.go
Normal file
321
internal/hub/agent_connect.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/henrygd/beszel/src/common"
|
||||
"github.com/henrygd/beszel/src/hub/expirymap"
|
||||
"github.com/henrygd/beszel/src/hub/ws"
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/lxzan/gws"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
)
|
||||
|
||||
// agentConnectRequest holds information related to an agent's connection attempt.
|
||||
type agentConnectRequest struct {
|
||||
hub *Hub
|
||||
req *http.Request
|
||||
res http.ResponseWriter
|
||||
token string
|
||||
agentSemVer semver.Version
|
||||
// isUniversalToken is true if the token is a universal token.
|
||||
isUniversalToken bool
|
||||
// userId is the user ID associated with the universal token.
|
||||
userId string
|
||||
}
|
||||
|
||||
// universalTokenMap stores active universal tokens and their associated user IDs.
|
||||
var universalTokenMap tokenMap
|
||||
|
||||
type tokenMap struct {
|
||||
store *expirymap.ExpiryMap[string]
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// getMap returns the expirymap, creating it if necessary.
|
||||
func (tm *tokenMap) GetMap() *expirymap.ExpiryMap[string] {
|
||||
tm.once.Do(func() {
|
||||
tm.store = expirymap.New[string](time.Hour)
|
||||
})
|
||||
return tm.store
|
||||
}
|
||||
|
||||
// handleAgentConnect is the HTTP handler for an agent's connection request.
|
||||
func (h *Hub) handleAgentConnect(e *core.RequestEvent) error {
|
||||
agentRequest := agentConnectRequest{req: e.Request, res: e.Response, hub: h}
|
||||
_ = agentRequest.agentConnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
// agentConnect validates agent credentials and upgrades the connection to a WebSocket.
|
||||
func (acr *agentConnectRequest) agentConnect() (err error) {
|
||||
var agentVersion string
|
||||
|
||||
acr.token, agentVersion, err = acr.validateAgentHeaders(acr.req.Header)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusBadRequest, "")
|
||||
}
|
||||
|
||||
// Check if token is an active universal token
|
||||
acr.userId, acr.isUniversalToken = universalTokenMap.GetMap().GetOk(acr.token)
|
||||
|
||||
// Find matching fingerprint records for this token
|
||||
fpRecords := getFingerprintRecordsByToken(acr.token, acr.hub)
|
||||
if len(fpRecords) == 0 && !acr.isUniversalToken {
|
||||
// Invalid token - no records found and not a universal token
|
||||
return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid token")
|
||||
}
|
||||
|
||||
// Validate agent version
|
||||
acr.agentSemVer, err = semver.Parse(agentVersion)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusUnauthorized, "Invalid agent version")
|
||||
}
|
||||
|
||||
// Upgrade connection to WebSocket
|
||||
conn, err := ws.GetUpgrader().Upgrade(acr.res, acr.req)
|
||||
if err != nil {
|
||||
return acr.sendResponseError(acr.res, http.StatusInternalServerError, "WebSocket upgrade failed")
|
||||
}
|
||||
|
||||
go acr.verifyWsConn(conn, fpRecords)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyWsConn verifies the WebSocket connection using the agent's fingerprint and
|
||||
// SSH key signature, then adds the system to the system manager.
|
||||
func (acr *agentConnectRequest) verifyWsConn(conn *gws.Conn, fpRecords []ws.FingerprintRecord) (err error) {
|
||||
wsConn := ws.NewWsConnection(conn)
|
||||
|
||||
// must set wsConn in connection store before the read loop
|
||||
conn.Session().Store("wsConn", wsConn)
|
||||
|
||||
// make sure connection is closed if there is an error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
wsConn.Close([]byte(err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
go conn.ReadLoop()
|
||||
|
||||
signer, err := acr.hub.GetSSHKey("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
agentFingerprint, err := wsConn.GetFingerprint(acr.token, signer, acr.isUniversalToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Find or create the appropriate system for this token and fingerprint
|
||||
fpRecord, err := acr.findOrCreateSystemForToken(fpRecords, agentFingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return acr.hub.sm.AddWebSocketSystem(fpRecord.SystemId, acr.agentSemVer, wsConn)
|
||||
}
|
||||
|
||||
// validateAgentHeaders extracts and validates the token and agent version from HTTP headers.
|
||||
func (acr *agentConnectRequest) validateAgentHeaders(headers http.Header) (string, string, error) {
|
||||
token := headers.Get("X-Token")
|
||||
agentVersion := headers.Get("X-Beszel")
|
||||
|
||||
if agentVersion == "" || token == "" || len(token) > 64 {
|
||||
return "", "", errors.New("")
|
||||
}
|
||||
return token, agentVersion, nil
|
||||
}
|
||||
|
||||
// sendResponseError writes an HTTP error response.
|
||||
func (acr *agentConnectRequest) sendResponseError(res http.ResponseWriter, code int, message string) error {
|
||||
res.WriteHeader(code)
|
||||
if message != "" {
|
||||
res.Write([]byte(message))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFingerprintRecordsByToken retrieves all fingerprint records associated with a given token.
|
||||
func getFingerprintRecordsByToken(token string, h *Hub) []ws.FingerprintRecord {
|
||||
var records []ws.FingerprintRecord
|
||||
// All will populate empty slice even on error
|
||||
_ = h.DB().NewQuery("SELECT id, system, fingerprint, token FROM fingerprints WHERE token = {:token}").
|
||||
Bind(dbx.Params{
|
||||
"token": token,
|
||||
}).
|
||||
All(&records)
|
||||
return records
|
||||
}
|
||||
|
||||
// findOrCreateSystemForToken finds an existing system matching the token and fingerprint,
|
||||
// or creates a new one for a universal token.
|
||||
func (acr *agentConnectRequest) findOrCreateSystemForToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// No records - only valid for active universal tokens
|
||||
if len(fpRecords) == 0 {
|
||||
return acr.handleNoRecords(agentFingerprint)
|
||||
}
|
||||
|
||||
// Single record - handle as regular token
|
||||
if len(fpRecords) == 1 && !acr.isUniversalToken {
|
||||
return acr.handleSingleRecord(fpRecords[0], agentFingerprint)
|
||||
}
|
||||
|
||||
// Multiple records or universal token - look for matching fingerprint
|
||||
return acr.handleMultipleRecordsOrUniversalToken(fpRecords, agentFingerprint)
|
||||
}
|
||||
|
||||
// handleNoRecords handles the case where no fingerprint records are found for a token.
|
||||
// A new system is created if the token is a valid universal token.
|
||||
func (acr *agentConnectRequest) handleNoRecords(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
var fpRecord ws.FingerprintRecord
|
||||
|
||||
if !acr.isUniversalToken || acr.userId == "" {
|
||||
return fpRecord, errors.New("no matching fingerprints")
|
||||
}
|
||||
|
||||
return acr.createNewSystemForUniversalToken(agentFingerprint)
|
||||
}
|
||||
|
||||
// handleSingleRecord handles the case with a single fingerprint record. It validates
|
||||
// the agent's fingerprint against the stored one, or sets it on first connect.
|
||||
func (acr *agentConnectRequest) handleSingleRecord(fpRecord ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// If no current fingerprint, update with new fingerprint (first time connecting)
|
||||
if fpRecord.Fingerprint == "" {
|
||||
if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
// Update the record with the fingerprint that was set
|
||||
fpRecord.Fingerprint = agentFingerprint.Fingerprint
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// Abort if fingerprint exists but doesn't match (different machine)
|
||||
if fpRecord.Fingerprint != agentFingerprint.Fingerprint {
|
||||
return fpRecord, errors.New("fingerprint mismatch")
|
||||
}
|
||||
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// handleMultipleRecordsOrUniversalToken finds a matching fingerprint from multiple records.
|
||||
// If no match is found and the token is a universal token, a new system is created.
|
||||
func (acr *agentConnectRequest) handleMultipleRecordsOrUniversalToken(fpRecords []ws.FingerprintRecord, agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
// Return existing record with matching fingerprint if found
|
||||
for i := range fpRecords {
|
||||
if fpRecords[i].Fingerprint == agentFingerprint.Fingerprint {
|
||||
return fpRecords[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
// No matching fingerprint record found, but it's
|
||||
// an active universal token so create a new system
|
||||
if acr.isUniversalToken {
|
||||
return acr.createNewSystemForUniversalToken(agentFingerprint)
|
||||
}
|
||||
|
||||
return ws.FingerprintRecord{}, errors.New("fingerprint mismatch")
|
||||
}
|
||||
|
||||
// createNewSystemForUniversalToken creates a new system and fingerprint record for a universal token.
|
||||
func (acr *agentConnectRequest) createNewSystemForUniversalToken(agentFingerprint common.FingerprintResponse) (ws.FingerprintRecord, error) {
|
||||
var fpRecord ws.FingerprintRecord
|
||||
if !acr.isUniversalToken || acr.userId == "" {
|
||||
return fpRecord, errors.New("invalid token")
|
||||
}
|
||||
|
||||
fpRecord.Token = acr.token
|
||||
|
||||
systemId, err := acr.createSystem(agentFingerprint)
|
||||
if err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
fpRecord.SystemId = systemId
|
||||
|
||||
// Set the fingerprint for the new system
|
||||
if err := acr.hub.SetFingerprint(&fpRecord, agentFingerprint.Fingerprint); err != nil {
|
||||
return fpRecord, err
|
||||
}
|
||||
|
||||
// Update the record with the fingerprint that was set
|
||||
fpRecord.Fingerprint = agentFingerprint.Fingerprint
|
||||
|
||||
return fpRecord, nil
|
||||
}
|
||||
|
||||
// createSystem creates a new system record in the database using details from the agent.
|
||||
func (acr *agentConnectRequest) createSystem(agentFingerprint common.FingerprintResponse) (recordId string, err error) {
|
||||
systemsCollection, err := acr.hub.FindCachedCollectionByNameOrId("systems")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
remoteAddr := getRealIP(acr.req)
|
||||
// separate port from address
|
||||
if agentFingerprint.Hostname == "" {
|
||||
agentFingerprint.Hostname = remoteAddr
|
||||
}
|
||||
if agentFingerprint.Port == "" {
|
||||
agentFingerprint.Port = "45876"
|
||||
}
|
||||
// create new record
|
||||
systemRecord := core.NewRecord(systemsCollection)
|
||||
systemRecord.Set("name", agentFingerprint.Hostname)
|
||||
systemRecord.Set("host", remoteAddr)
|
||||
systemRecord.Set("port", agentFingerprint.Port)
|
||||
systemRecord.Set("users", []string{acr.userId})
|
||||
|
||||
return systemRecord.Id, acr.hub.Save(systemRecord)
|
||||
}
|
||||
|
||||
// SetFingerprint creates or updates a fingerprint record in the database.
|
||||
func (h *Hub) SetFingerprint(fpRecord *ws.FingerprintRecord, fingerprint string) (err error) {
|
||||
// // can't use raw query here because it doesn't trigger SSE
|
||||
var record *core.Record
|
||||
switch fpRecord.Id {
|
||||
case "":
|
||||
// create new record for universal token
|
||||
collection, _ := h.FindCachedCollectionByNameOrId("fingerprints")
|
||||
record = core.NewRecord(collection)
|
||||
record.Set("system", fpRecord.SystemId)
|
||||
default:
|
||||
record, err = h.FindRecordById("fingerprints", fpRecord.Id)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record.Set("token", fpRecord.Token)
|
||||
record.Set("fingerprint", fingerprint)
|
||||
return h.SaveNoValidate(record)
|
||||
}
|
||||
|
||||
// getRealIP extracts the client's real IP address from request headers,
|
||||
// checking common proxy headers before falling back to the remote address.
|
||||
func getRealIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// X-Forwarded-For can contain a comma-separated list: "client_ip, proxy1, proxy2"
|
||||
// Take the first one
|
||||
ips := strings.Split(ip, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
// Fallback to RemoteAddr
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
Reference in New Issue
Block a user