diff --git a/beszel/internal/agent/client.go b/beszel/internal/agent/client.go index 6dbcf1d..601a7de 100644 --- a/beszel/internal/agent/client.go +++ b/beszel/internal/agent/client.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "os" "path" "strings" "time" @@ -53,9 +54,9 @@ func newWebSocketClient(agent *Agent) (client *WebSocketClient, err error) { return nil, errors.New("invalid hub URL") } // get registration token - client.token, _ = GetEnv("TOKEN") - if client.token == "" { - return nil, errors.New("TOKEN environment variable not set") + client.token, err = getToken() + if err != nil { + return nil, err } client.agent = agent @@ -65,6 +66,27 @@ func newWebSocketClient(agent *Agent) (client *WebSocketClient, err error) { return client, nil } +// getToken returns the token for the WebSocket client. +// It first checks the TOKEN environment variable, then the TOKEN_FILE environment variable. +// If neither is set, it returns an error. +func getToken() (string, error) { + // get token from env var + token, _ := GetEnv("TOKEN") + if token != "" { + return token, nil + } + // get token from file + tokenFile, _ := GetEnv("TOKEN_FILE") + if tokenFile == "" { + return "", errors.New("must set TOKEN or TOKEN_FILE") + } + tokenBytes, err := os.ReadFile(tokenFile) + if err != nil { + return "", err + } + return string(tokenBytes), nil +} + // getOptions returns the WebSocket client options, creating them if necessary. // It configures the connection URL, TLS settings, and authentication headers. func (client *WebSocketClient) getOptions() *gws.ClientOption { diff --git a/beszel/internal/agent/client_test.go b/beszel/internal/agent/client_test.go index cd90038..bb5e253 100644 --- a/beszel/internal/agent/client_test.go +++ b/beszel/internal/agent/client_test.go @@ -61,7 +61,7 @@ func TestNewWebSocketClient(t *testing.T) { hubURL: "http://localhost:8080", token: "", expectError: true, - errorMsg: "TOKEN environment variable not set", + errorMsg: "must set TOKEN or TOKEN_FILE", }, } @@ -389,3 +389,150 @@ func TestWebSocketClient_ConnectRateLimit(t *testing.T) { err = client.Connect() assert.Error(t, err, "Connection should fail but not hang") } + +// TestGetToken tests the getToken function with various scenarios +func TestGetToken(t *testing.T) { + unsetEnvVars := func() { + os.Unsetenv("BESZEL_AGENT_TOKEN") + os.Unsetenv("TOKEN") + os.Unsetenv("BESZEL_AGENT_TOKEN_FILE") + os.Unsetenv("TOKEN_FILE") + } + + t.Run("token from TOKEN environment variable", func(t *testing.T) { + unsetEnvVars() + + // Set TOKEN env var + expectedToken := "test-token-from-env" + os.Setenv("TOKEN", expectedToken) + defer os.Unsetenv("TOKEN") + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, expectedToken, token) + }) + + t.Run("token from BESZEL_AGENT_TOKEN environment variable", func(t *testing.T) { + unsetEnvVars() + + // Set BESZEL_AGENT_TOKEN env var (should take precedence) + expectedToken := "test-token-from-beszel-env" + os.Setenv("BESZEL_AGENT_TOKEN", expectedToken) + defer os.Unsetenv("BESZEL_AGENT_TOKEN") + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, expectedToken, token) + }) + + t.Run("token from TOKEN_FILE", func(t *testing.T) { + unsetEnvVars() + + // Create a temporary token file + expectedToken := "test-token-from-file" + tokenFile, err := os.CreateTemp("", "token-test-*.txt") + require.NoError(t, err) + defer os.Remove(tokenFile.Name()) + + _, err = tokenFile.WriteString(expectedToken) + require.NoError(t, err) + tokenFile.Close() + + // Set TOKEN_FILE env var + os.Setenv("TOKEN_FILE", tokenFile.Name()) + defer os.Unsetenv("TOKEN_FILE") + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, expectedToken, token) + }) + + t.Run("token from BESZEL_AGENT_TOKEN_FILE", func(t *testing.T) { + unsetEnvVars() + + // Create a temporary token file + expectedToken := "test-token-from-beszel-file" + tokenFile, err := os.CreateTemp("", "token-test-*.txt") + require.NoError(t, err) + defer os.Remove(tokenFile.Name()) + + _, err = tokenFile.WriteString(expectedToken) + require.NoError(t, err) + tokenFile.Close() + + // Set BESZEL_AGENT_TOKEN_FILE env var (should take precedence) + os.Setenv("BESZEL_AGENT_TOKEN_FILE", tokenFile.Name()) + defer os.Unsetenv("BESZEL_AGENT_TOKEN_FILE") + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, expectedToken, token) + }) + + t.Run("TOKEN takes precedence over TOKEN_FILE", func(t *testing.T) { + unsetEnvVars() + + // Create a temporary token file + fileToken := "token-from-file" + tokenFile, err := os.CreateTemp("", "token-test-*.txt") + require.NoError(t, err) + defer os.Remove(tokenFile.Name()) + + _, err = tokenFile.WriteString(fileToken) + require.NoError(t, err) + tokenFile.Close() + + // Set both TOKEN and TOKEN_FILE + envToken := "token-from-env" + os.Setenv("TOKEN", envToken) + os.Setenv("TOKEN_FILE", tokenFile.Name()) + defer func() { + os.Unsetenv("TOKEN") + os.Unsetenv("TOKEN_FILE") + }() + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, envToken, token, "TOKEN should take precedence over TOKEN_FILE") + }) + + t.Run("error when neither TOKEN nor TOKEN_FILE is set", func(t *testing.T) { + unsetEnvVars() + + token, err := getToken() + assert.Error(t, err) + assert.Equal(t, "", token) + assert.Contains(t, err.Error(), "must set TOKEN or TOKEN_FILE") + }) + + t.Run("error when TOKEN_FILE points to non-existent file", func(t *testing.T) { + unsetEnvVars() + + // Set TOKEN_FILE to a non-existent file + os.Setenv("TOKEN_FILE", "/non/existent/file.txt") + defer os.Unsetenv("TOKEN_FILE") + + token, err := getToken() + assert.Error(t, err) + assert.Equal(t, "", token) + assert.Contains(t, err.Error(), "no such file or directory") + }) + + t.Run("handles empty token file", func(t *testing.T) { + unsetEnvVars() + + // Create an empty token file + tokenFile, err := os.CreateTemp("", "token-test-*.txt") + require.NoError(t, err) + defer os.Remove(tokenFile.Name()) + tokenFile.Close() + + // Set TOKEN_FILE env var + os.Setenv("TOKEN_FILE", tokenFile.Name()) + defer os.Unsetenv("TOKEN_FILE") + + token, err := getToken() + assert.NoError(t, err) + assert.Equal(t, "", token, "Empty file should return empty string") + }) +}