diff --git a/internal/hub/hub.go b/internal/hub/hub.go index 0d93ea2..89fb4a9 100644 --- a/internal/hub/hub.go +++ b/internal/hub/hub.go @@ -69,6 +69,8 @@ func (h *Hub) StartHub() error { if err := config.SyncSystems(e); err != nil { return err } + // register middlewares + h.registerMiddlewares(e) // register api routes if err := h.registerApiRoutes(e); err != nil { return err @@ -171,6 +173,41 @@ func (h *Hub) registerCronJobs(_ *core.ServeEvent) error { return nil } +// custom middlewares +func (h *Hub) registerMiddlewares(se *core.ServeEvent) { + // authenticate with trusted header + if trustedHeader, _ := GetEnv("TRUSTED_AUTH_HEADER"); trustedHeader != "" { + se.Router.BindFunc(func(e *core.RequestEvent) error { + if e.Auth != nil { + return e.Next() + } + trustedEmail := e.Request.Header.Get(trustedHeader) + if trustedEmail == "" { + return e.Next() + } + isAuthRefresh := e.Request.URL.Path == "/api/collections/users/auth-refresh" && e.Request.Method == http.MethodPost + if !isAuthRefresh { + authRecord, err := e.App.FindAuthRecordByEmail("users", trustedEmail) + if err == nil { + e.Auth = authRecord + } + return e.Next() + } + // if auth refresh endpoint, find user record directly and generate token + user, err := e.App.FindFirstRecordByData("users", "email", trustedEmail) + if err != nil { + return e.Next() + } + e.Auth = user + // need to set the authorization header for the client sdk to pick up the token + if token, err := user.NewAuthToken(); err == nil { + e.Request.Header.Set("Authorization", token) + } + return e.Next() + }) + } +} + // custom api routes func (h *Hub) registerApiRoutes(se *core.ServeEvent) error { // auth protected routes diff --git a/internal/hub/hub_test.go b/internal/hub/hub_test.go index ea8febb..47b5300 100644 --- a/internal/hub/hub_test.go +++ b/internal/hub/hub_test.go @@ -711,3 +711,63 @@ func TestCreateUserEndpointAvailability(t *testing.T) { scenario.Test(t) }) } + +func TestTrustedHeaderMiddleware(t *testing.T) { + var hubs []*beszelTests.TestHub + + defer func() { + defer os.Unsetenv("TRUSTED_AUTH_HEADER") + for _, hub := range hubs { + hub.Cleanup() + } + }() + + os.Setenv("TRUSTED_AUTH_HEADER", "X-Beszel-Trusted") + + testAppFactory := func(t testing.TB) *pbTests.TestApp { + hub, _ := beszelTests.NewTestHub(t.TempDir()) + hubs = append(hubs, hub) + hub.StartHub() + return hub.TestApp + } + + scenarios := []beszelTests.ApiScenario{ + { + Name: "GET /getkey - without trusted header should fail", + Method: http.MethodGet, + URL: "/api/beszel/getkey", + ExpectedStatus: 401, + ExpectedContent: []string{"requires valid"}, + TestAppFactory: testAppFactory, + }, + { + Name: "GET /getkey - with trusted header should fail if no matching user", + Method: http.MethodGet, + URL: "/api/beszel/getkey", + Headers: map[string]string{ + "X-Beszel-Trusted": "user@test.com", + }, + ExpectedStatus: 401, + ExpectedContent: []string{"requires valid"}, + TestAppFactory: testAppFactory, + }, + { + Name: "GET /getkey - with trusted header should succeed", + Method: http.MethodGet, + URL: "/api/beszel/getkey", + Headers: map[string]string{ + "X-Beszel-Trusted": "user@test.com", + }, + ExpectedStatus: 200, + ExpectedContent: []string{"\"key\":", "\"v\":"}, + TestAppFactory: testAppFactory, + BeforeTestFunc: func(t testing.TB, app *pbTests.TestApp, e *core.ServeEvent) { + beszelTests.CreateUser(app, "user@test.com", "password123") + }, + }, + } + + for _, scenario := range scenarios { + scenario.Test(t) + } +} diff --git a/internal/site/src/main.tsx b/internal/site/src/main.tsx index 9b44411..70ab33e 100644 --- a/internal/site/src/main.tsx +++ b/internal/site/src/main.tsx @@ -74,6 +74,19 @@ const Layout = () => { document.documentElement.dir = direction }, [direction]) + // biome-ignore lint/correctness/useExhaustiveDependencies: only run on mount + useEffect(() => { + // refresh auth if not authenticated (required for trusted auth header) + if (!authenticated) { + pb.collection("users") + .authRefresh() + .then((res) => { + pb.authStore.save(res.token, res.record) + $authenticated.set(!!pb.authStore.isValid) + }) + } + }, []) + return ( {!authenticated ? (