Skip to main content

supabase_client_functions/
client.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use tracing::debug;
6use url::Url;
7
8use crate::error::{FunctionsApiErrorResponse, FunctionsError};
9use crate::types::*;
10
11/// HTTP client for Supabase Edge Functions.
12///
13/// Communicates with Edge Functions at `/functions/v1/{function_name}`.
14///
15/// # Example
16/// ```ignore
17/// use supabase_client_functions::{FunctionsClient, InvokeOptions};
18/// use serde_json::json;
19///
20/// let client = FunctionsClient::new("https://your-project.supabase.co", "your-anon-key")?;
21/// let response = client.invoke("hello", InvokeOptions::new()
22///     .body(json!({"name": "World"}))
23/// ).await?;
24/// let data: serde_json::Value = response.json()?;
25/// ```
26#[derive(Debug, Clone)]
27pub struct FunctionsClient {
28    http: reqwest::Client,
29    base_url: Url,
30    api_key: String,
31    /// Overridden auth token (if set via `set_auth`).
32    auth_override: Arc<RwLock<Option<String>>>,
33}
34
35impl FunctionsClient {
36    /// Create a new Edge Functions client.
37    ///
38    /// `supabase_url` is the project URL (e.g., `https://your-project.supabase.co`).
39    /// `api_key` is the Supabase anon or service_role key.
40    pub fn new(supabase_url: &str, api_key: &str) -> Result<Self, FunctionsError> {
41        let base = supabase_url.trim_end_matches('/');
42        let base_url = Url::parse(&format!("{}/functions/v1", base))?;
43
44        let mut default_headers = HeaderMap::new();
45        default_headers.insert(
46            "apikey",
47            HeaderValue::from_str(api_key)
48                .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid API key header: {}", e)))?,
49        );
50        default_headers.insert(
51            reqwest::header::AUTHORIZATION,
52            HeaderValue::from_str(&format!("Bearer {}", api_key))
53                .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid auth header: {}", e)))?,
54        );
55
56        let http = reqwest::Client::builder()
57            .default_headers(default_headers)
58            .build()
59            .map_err(FunctionsError::Http)?;
60
61        Ok(Self {
62            http,
63            base_url,
64            api_key: api_key.to_string(),
65            auth_override: Arc::new(RwLock::new(None)),
66        })
67    }
68
69    /// Get the base URL for the functions endpoint.
70    pub fn base_url(&self) -> &Url {
71        &self.base_url
72    }
73
74    /// Get the API key used by this client.
75    pub fn api_key(&self) -> &str {
76        &self.api_key
77    }
78
79    /// Update the default auth token for function invocations.
80    ///
81    /// Subsequent invocations will use `Bearer <token>` unless overridden per-request.
82    ///
83    /// Mirrors `supabase.functions.setAuth(token)`.
84    pub fn set_auth(&self, token: &str) {
85        let mut auth = self.auth_override.write().unwrap();
86        *auth = Some(token.to_string());
87    }
88
89    /// Invoke an Edge Function.
90    ///
91    /// # Arguments
92    /// * `function_name` - The name of the deployed function.
93    /// * `options` - Invocation options (body, method, headers, region, etc.).
94    ///
95    /// # Errors
96    /// * [`FunctionsError::RelayError`] if Supabase infrastructure returned an error (x-relay-error: true).
97    /// * [`FunctionsError::HttpError`] if the function returned a non-2xx status.
98    /// * [`FunctionsError::Http`] on network failure.
99    pub async fn invoke(
100        &self,
101        function_name: &str,
102        options: InvokeOptions,
103    ) -> Result<FunctionResponse, FunctionsError> {
104        let url = format!("{}/{}", self.base_url, function_name);
105        debug!(function = function_name, method = %options.method, "Invoking edge function");
106
107        // Build the request with the correct HTTP method
108        let mut request = match options.method {
109            HttpMethod::Get => self.http.get(&url),
110            HttpMethod::Post => self.http.post(&url),
111            HttpMethod::Put => self.http.put(&url),
112            HttpMethod::Patch => self.http.patch(&url),
113            HttpMethod::Delete => self.http.delete(&url),
114            HttpMethod::Options => self.http.request(reqwest::Method::OPTIONS, &url),
115            HttpMethod::Head => self.http.head(&url),
116        };
117
118        // Override Authorization: per-request first, then client-level set_auth, then default (from reqwest default headers)
119        if let Some(ref auth) = options.authorization {
120            request = request.header(
121                reqwest::header::AUTHORIZATION,
122                HeaderValue::from_str(auth)
123                    .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid authorization header: {}", e)))?,
124            );
125        } else if let Some(ref token) = *self.auth_override.read().unwrap() {
126            request = request.header(
127                reqwest::header::AUTHORIZATION,
128                HeaderValue::from_str(&format!("Bearer {}", token))
129                    .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid auth override header: {}", e)))?,
130            );
131        }
132
133        // Set region header if specified
134        if let Some(ref region) = options.region {
135            request = request.header("x-region", region.to_string());
136        }
137
138        // Add custom headers
139        for (key, value) in &options.headers {
140            let header_name = HeaderName::from_bytes(key.as_bytes())
141                .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid header name '{}': {}", key, e)))?;
142            let header_value = HeaderValue::from_str(value)
143                .map_err(|e| FunctionsError::InvalidConfig(format!("Invalid header value for '{}': {}", key, e)))?;
144            request = request.header(header_name, header_value);
145        }
146
147        // Set Content-Type and body
148        match options.body {
149            InvokeBody::Json(value) => {
150                let ct = options.content_type.as_deref().unwrap_or("application/json");
151                request = request
152                    .header(reqwest::header::CONTENT_TYPE, ct)
153                    .body(serde_json::to_vec(&value)?);
154            }
155            InvokeBody::Bytes(bytes) => {
156                let ct = options
157                    .content_type
158                    .as_deref()
159                    .unwrap_or("application/octet-stream");
160                request = request
161                    .header(reqwest::header::CONTENT_TYPE, ct)
162                    .body(bytes);
163            }
164            InvokeBody::Text(text) => {
165                let ct = options.content_type.as_deref().unwrap_or("text/plain");
166                request = request
167                    .header(reqwest::header::CONTENT_TYPE, ct)
168                    .body(text);
169            }
170            InvokeBody::None => {
171                if let Some(ct) = options.content_type {
172                    request = request.header(reqwest::header::CONTENT_TYPE, ct);
173                }
174            }
175        }
176
177        // Send the request
178        let response = request.send().await?;
179
180        // Collect response headers (lowercased keys)
181        let status = response.status().as_u16();
182        let is_relay_error = response
183            .headers()
184            .get("x-relay-error")
185            .and_then(|v| v.to_str().ok())
186            .map(|v| v == "true")
187            .unwrap_or(false);
188
189        let mut resp_headers = HashMap::new();
190        for (name, value) in response.headers() {
191            if let Ok(v) = value.to_str() {
192                resp_headers.insert(name.as_str().to_string(), v.to_string());
193            }
194        }
195
196        // Read response body
197        let body = response.bytes().await?.to_vec();
198
199        // Check for errors
200        if is_relay_error {
201            let message = parse_error_message(&body);
202            debug!(status, message = %message, "Relay error from edge function");
203            return Err(FunctionsError::RelayError { status, message });
204        }
205
206        if status >= 400 {
207            let message = parse_error_message(&body);
208            debug!(status, message = %message, "HTTP error from edge function");
209            return Err(FunctionsError::HttpError { status, message });
210        }
211
212        Ok(FunctionResponse::new(status, resp_headers, body))
213    }
214}
215
216/// Try to parse an error message from the response body (JSON first, then plain text).
217fn parse_error_message(body: &[u8]) -> String {
218    if let Ok(api_err) = serde_json::from_slice::<FunctionsApiErrorResponse>(body) {
219        return api_err.error_message();
220    }
221    String::from_utf8_lossy(body).into_owned()
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn client_new_ok() {
230        let client = FunctionsClient::new("https://example.supabase.co", "test-key");
231        assert!(client.is_ok());
232    }
233
234    #[test]
235    fn client_base_url() {
236        let client = FunctionsClient::new("https://example.supabase.co", "test-key").unwrap();
237        assert_eq!(client.base_url().path(), "/functions/v1");
238    }
239
240    #[test]
241    fn client_base_url_trailing_slash() {
242        let client = FunctionsClient::new("https://example.supabase.co/", "test-key").unwrap();
243        assert_eq!(client.base_url().path(), "/functions/v1");
244    }
245
246    #[test]
247    fn client_api_key() {
248        let client = FunctionsClient::new("https://example.supabase.co", "my-key").unwrap();
249        assert_eq!(client.api_key(), "my-key");
250    }
251
252    #[test]
253    fn parse_error_message_json() {
254        let body = br#"{"message":"Function not found"}"#;
255        assert_eq!(parse_error_message(body), "Function not found");
256    }
257
258    #[test]
259    fn parse_error_message_plain_text() {
260        let body = b"Something went wrong";
261        assert_eq!(parse_error_message(body), "Something went wrong");
262    }
263
264    #[test]
265    fn set_auth_updates_override() {
266        let client = FunctionsClient::new("https://example.supabase.co", "test-key").unwrap();
267        assert!(client.auth_override.read().unwrap().is_none());
268        client.set_auth("new-token");
269        assert_eq!(
270            client.auth_override.read().unwrap().as_deref(),
271            Some("new-token")
272        );
273    }
274
275    #[test]
276    fn set_auth_clone_shares_state() {
277        let client = FunctionsClient::new("https://example.supabase.co", "test-key").unwrap();
278        let clone = client.clone();
279        client.set_auth("shared-token");
280        assert_eq!(
281            clone.auth_override.read().unwrap().as_deref(),
282            Some("shared-token")
283        );
284    }
285
286    // ─── Wiremock Tests ──────────────────────────────────────
287
288    use wiremock::matchers::{body_string_contains, header, method, path};
289    use wiremock::{Mock, MockServer, ResponseTemplate};
290
291    /// Helper: create a FunctionsClient pointing at the given mock server.
292    fn mock_client(server: &MockServer) -> FunctionsClient {
293        FunctionsClient::new(&server.uri(), "test-anon-key").unwrap()
294    }
295
296    #[tokio::test]
297    async fn wiremock_invoke_json_body_success() {
298        let server = MockServer::start().await;
299        Mock::given(method("POST"))
300            .and(path("/functions/v1/hello"))
301            .respond_with(
302                ResponseTemplate::new(200)
303                    .set_body_json(serde_json::json!({"message": "ok"})),
304            )
305            .mount(&server)
306            .await;
307
308        let client = mock_client(&server);
309        let opts = InvokeOptions::new().body(serde_json::json!({"name": "World"}));
310        let resp = client.invoke("hello", opts).await.unwrap();
311        assert_eq!(resp.status(), 200);
312        let val: serde_json::Value = resp.json().unwrap();
313        assert_eq!(val["message"], "ok");
314    }
315
316    #[tokio::test]
317    async fn wiremock_invoke_relay_error() {
318        let server = MockServer::start().await;
319        Mock::given(method("POST"))
320            .and(path("/functions/v1/broken"))
321            .respond_with(
322                ResponseTemplate::new(500)
323                    .insert_header("x-relay-error", "true")
324                    .set_body_json(serde_json::json!({"message": "Function not found"})),
325            )
326            .mount(&server)
327            .await;
328
329        let client = mock_client(&server);
330        let err = client
331            .invoke("broken", InvokeOptions::new())
332            .await
333            .unwrap_err();
334        match err {
335            FunctionsError::RelayError { status, message } => {
336                assert_eq!(status, 500);
337                assert_eq!(message, "Function not found");
338            }
339            other => panic!("Expected RelayError, got: {:?}", other),
340        }
341    }
342
343    #[tokio::test]
344    async fn wiremock_invoke_http_4xx_error() {
345        let server = MockServer::start().await;
346        Mock::given(method("POST"))
347            .and(path("/functions/v1/missing"))
348            .respond_with(
349                ResponseTemplate::new(404)
350                    .set_body_json(serde_json::json!({"message": "Not Found"})),
351            )
352            .mount(&server)
353            .await;
354
355        let client = mock_client(&server);
356        let err = client
357            .invoke("missing", InvokeOptions::new())
358            .await
359            .unwrap_err();
360        match err {
361            FunctionsError::HttpError { status, message } => {
362                assert_eq!(status, 404);
363                assert_eq!(message, "Not Found");
364            }
365            other => panic!("Expected HttpError, got: {:?}", other),
366        }
367    }
368
369    #[tokio::test]
370    async fn wiremock_invoke_auth_override_header() {
371        let server = MockServer::start().await;
372        Mock::given(method("POST"))
373            .and(path("/functions/v1/secure"))
374            .and(header("authorization", "Bearer user-jwt-token"))
375            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"ok": true})))
376            .mount(&server)
377            .await;
378
379        let client = mock_client(&server);
380        let opts = InvokeOptions::new().authorization("Bearer user-jwt-token");
381        let resp = client.invoke("secure", opts).await.unwrap();
382        assert_eq!(resp.status(), 200);
383    }
384
385    #[tokio::test]
386    async fn wiremock_invoke_region_header() {
387        let server = MockServer::start().await;
388        Mock::given(method("POST"))
389            .and(path("/functions/v1/regional"))
390            .and(header("x-region", "us-east-1"))
391            .respond_with(ResponseTemplate::new(200))
392            .mount(&server)
393            .await;
394
395        let client = mock_client(&server);
396        let opts = InvokeOptions::new().region(FunctionRegion::UsEast1);
397        let resp = client.invoke("regional", opts).await.unwrap();
398        assert_eq!(resp.status(), 200);
399    }
400
401    #[tokio::test]
402    async fn wiremock_invoke_custom_headers() {
403        let server = MockServer::start().await;
404        Mock::given(method("POST"))
405            .and(path("/functions/v1/custom"))
406            .and(header("x-custom-one", "alpha"))
407            .and(header("x-custom-two", "beta"))
408            .respond_with(ResponseTemplate::new(200))
409            .mount(&server)
410            .await;
411
412        let client = mock_client(&server);
413        let opts = InvokeOptions::new()
414            .header("x-custom-one", "alpha")
415            .header("x-custom-two", "beta");
416        let resp = client.invoke("custom", opts).await.unwrap();
417        assert_eq!(resp.status(), 200);
418    }
419
420    #[tokio::test]
421    async fn wiremock_invoke_body_json() {
422        let server = MockServer::start().await;
423        Mock::given(method("POST"))
424            .and(path("/functions/v1/echo"))
425            .and(header("content-type", "application/json"))
426            .and(body_string_contains("\"key\""))
427            .respond_with(ResponseTemplate::new(200))
428            .mount(&server)
429            .await;
430
431        let client = mock_client(&server);
432        let opts = InvokeOptions::new().body(serde_json::json!({"key": "value"}));
433        let resp = client.invoke("echo", opts).await.unwrap();
434        assert_eq!(resp.status(), 200);
435    }
436
437    #[tokio::test]
438    async fn wiremock_invoke_body_bytes() {
439        let server = MockServer::start().await;
440        Mock::given(method("POST"))
441            .and(path("/functions/v1/upload"))
442            .and(header("content-type", "application/octet-stream"))
443            .respond_with(ResponseTemplate::new(200))
444            .mount(&server)
445            .await;
446
447        let client = mock_client(&server);
448        let opts = InvokeOptions::new().body_bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
449        let resp = client.invoke("upload", opts).await.unwrap();
450        assert_eq!(resp.status(), 200);
451    }
452
453    #[tokio::test]
454    async fn wiremock_invoke_body_text() {
455        let server = MockServer::start().await;
456        Mock::given(method("POST"))
457            .and(path("/functions/v1/text"))
458            .and(header("content-type", "text/plain"))
459            .and(body_string_contains("hello world"))
460            .respond_with(ResponseTemplate::new(200))
461            .mount(&server)
462            .await;
463
464        let client = mock_client(&server);
465        let opts = InvokeOptions::new().body_text("hello world");
466        let resp = client.invoke("text", opts).await.unwrap();
467        assert_eq!(resp.status(), 200);
468    }
469
470    #[tokio::test]
471    async fn wiremock_invoke_body_none() {
472        let server = MockServer::start().await;
473        Mock::given(method("POST"))
474            .and(path("/functions/v1/empty"))
475            .respond_with(ResponseTemplate::new(204))
476            .mount(&server)
477            .await;
478
479        let client = mock_client(&server);
480        let opts = InvokeOptions::new(); // body is None by default
481        let resp = client.invoke("empty", opts).await.unwrap();
482        assert_eq!(resp.status(), 204);
483    }
484
485    #[tokio::test]
486    async fn wiremock_invoke_method_get() {
487        let server = MockServer::start().await;
488        Mock::given(method("GET"))
489            .and(path("/functions/v1/data"))
490            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({"items": []})))
491            .mount(&server)
492            .await;
493
494        let client = mock_client(&server);
495        let opts = InvokeOptions::new().method(HttpMethod::Get);
496        let resp = client.invoke("data", opts).await.unwrap();
497        assert_eq!(resp.status(), 200);
498    }
499
500    #[tokio::test]
501    async fn wiremock_invoke_method_put() {
502        let server = MockServer::start().await;
503        Mock::given(method("PUT"))
504            .and(path("/functions/v1/update"))
505            .respond_with(ResponseTemplate::new(200))
506            .mount(&server)
507            .await;
508
509        let client = mock_client(&server);
510        let opts = InvokeOptions::new().method(HttpMethod::Put);
511        let resp = client.invoke("update", opts).await.unwrap();
512        assert_eq!(resp.status(), 200);
513    }
514
515    #[tokio::test]
516    async fn wiremock_invoke_method_delete() {
517        let server = MockServer::start().await;
518        Mock::given(method("DELETE"))
519            .and(path("/functions/v1/remove"))
520            .respond_with(ResponseTemplate::new(200))
521            .mount(&server)
522            .await;
523
524        let client = mock_client(&server);
525        let opts = InvokeOptions::new().method(HttpMethod::Delete);
526        let resp = client.invoke("remove", opts).await.unwrap();
527        assert_eq!(resp.status(), 200);
528    }
529
530    #[tokio::test]
531    async fn wiremock_invoke_method_patch() {
532        let server = MockServer::start().await;
533        Mock::given(method("PATCH"))
534            .and(path("/functions/v1/patch"))
535            .respond_with(ResponseTemplate::new(200))
536            .mount(&server)
537            .await;
538
539        let client = mock_client(&server);
540        let opts = InvokeOptions::new().method(HttpMethod::Patch);
541        let resp = client.invoke("patch", opts).await.unwrap();
542        assert_eq!(resp.status(), 200);
543    }
544}