Skip to main content

supabase_rust/
db.rs

1use reqwest::{Error, Method, Response};
2use serde::Serialize;
3
4use crate::Supabase;
5
6/// Error type for database operations.
7#[derive(Debug)]
8pub enum DbError {
9    /// Failed to serialize data to JSON.
10    Serialization(serde_json::Error),
11    /// HTTP request failed.
12    Request(Error),
13}
14
15impl std::fmt::Display for DbError {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            Self::Serialization(e) => write!(f, "serialization error: {e}"),
19            Self::Request(e) => write!(f, "request error: {e}"),
20        }
21    }
22}
23
24impl std::error::Error for DbError {
25    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
26        match self {
27            Self::Serialization(e) => Some(e),
28            Self::Request(e) => Some(e),
29        }
30    }
31}
32
33impl From<serde_json::Error> for DbError {
34    fn from(err: serde_json::Error) -> Self {
35        Self::Serialization(err)
36    }
37}
38
39impl From<Error> for DbError {
40    fn from(err: Error) -> Self {
41        Self::Request(err)
42    }
43}
44
45/// Query builder for PostgREST database operations.
46///
47/// Provides a fluent API for constructing and executing database queries.
48pub struct QueryBuilder<'a> {
49    client: &'a Supabase,
50    table: String,
51    query_params: Vec<(String, String)>,
52    method: Method,
53    body: Option<String>,
54}
55
56impl<'a> QueryBuilder<'a> {
57    /// Creates a new QueryBuilder for the specified table.
58    pub(crate) fn new(client: &'a Supabase, table: impl Into<String>) -> Self {
59        Self {
60            client,
61            table: table.into(),
62            query_params: Vec::new(),
63            method: Method::GET,
64            body: None,
65        }
66    }
67
68    /// Specifies which columns to select.
69    ///
70    /// Pass `"*"` to select all columns, or a comma-separated list of column names.
71    pub fn select(mut self, columns: impl Into<String>) -> Self {
72        self.query_params.push(("select".into(), columns.into()));
73        self.method = Method::GET;
74        self
75    }
76
77    /// Prepares an insert operation with the provided data.
78    ///
79    /// Data will be serialized to JSON. Call `execute()` to run the query.
80    pub fn insert<T: Serialize>(mut self, data: &T) -> Result<Self, serde_json::Error> {
81        self.method = Method::POST;
82        self.body = Some(serde_json::to_string(data)?);
83        Ok(self)
84    }
85
86    /// Prepares an update operation with the provided data.
87    ///
88    /// Should be combined with filter methods to target specific rows.
89    pub fn update<T: Serialize>(mut self, data: &T) -> Result<Self, serde_json::Error> {
90        self.method = Method::PATCH;
91        self.body = Some(serde_json::to_string(data)?);
92        Ok(self)
93    }
94
95    /// Prepares a delete operation.
96    ///
97    /// Should be combined with filter methods to target specific rows.
98    pub fn delete(mut self) -> Self {
99        self.method = Method::DELETE;
100        self
101    }
102
103    /// Filter: column equals value (`col=eq.val`).
104    pub fn eq(self, column: impl Into<String>, value: impl Into<String>) -> Self {
105        self.add_filter(column, "eq", value)
106    }
107
108    /// Filter: column not equals value (`col=neq.val`).
109    pub fn neq(self, column: impl Into<String>, value: impl Into<String>) -> Self {
110        self.add_filter(column, "neq", value)
111    }
112
113    /// Filter: column greater than value (`col=gt.val`).
114    pub fn gt(self, column: impl Into<String>, value: impl Into<String>) -> Self {
115        self.add_filter(column, "gt", value)
116    }
117
118    /// Filter: column greater than or equal to value (`col=gte.val`).
119    pub fn gte(self, column: impl Into<String>, value: impl Into<String>) -> Self {
120        self.add_filter(column, "gte", value)
121    }
122
123    /// Filter: column less than value (`col=lt.val`).
124    pub fn lt(self, column: impl Into<String>, value: impl Into<String>) -> Self {
125        self.add_filter(column, "lt", value)
126    }
127
128    /// Filter: column less than or equal to value (`col=lte.val`).
129    pub fn lte(self, column: impl Into<String>, value: impl Into<String>) -> Self {
130        self.add_filter(column, "lte", value)
131    }
132
133    /// Filter: column matches pattern (`col=like.pattern`).
134    ///
135    /// Use `*` as wildcard character.
136    pub fn like(self, column: impl Into<String>, pattern: impl Into<String>) -> Self {
137        self.add_filter(column, "like", pattern)
138    }
139
140    /// Filter: column matches pattern case-insensitively (`col=ilike.pattern`).
141    ///
142    /// Use `*` as wildcard character.
143    pub fn ilike(self, column: impl Into<String>, pattern: impl Into<String>) -> Self {
144        self.add_filter(column, "ilike", pattern)
145    }
146
147    /// Filter: column value is in the provided list (`col=in.(v1,v2,...)`).
148    pub fn in_<I, S>(mut self, column: impl Into<String>, values: I) -> Self
149    where
150        I: IntoIterator<Item = S>,
151        S: AsRef<str>,
152    {
153        let values_str: Vec<_> = values.into_iter().map(|s| s.as_ref().to_string()).collect();
154        self.query_params
155            .push((column.into(), format!("in.({})", values_str.join(","))));
156        self
157    }
158
159    /// Filter: column is null (`col=is.null`).
160    pub fn is_null(mut self, column: impl Into<String>) -> Self {
161        self.query_params.push((column.into(), "is.null".into()));
162        self
163    }
164
165    /// Filter: column is not null (`col=not.is.null`).
166    pub fn not_null(mut self, column: impl Into<String>) -> Self {
167        self.query_params.push((column.into(), "not.is.null".into()));
168        self
169    }
170
171    /// Orders results by the specified column.
172    ///
173    /// Use `"column"` for ascending or `"column.desc"` for descending.
174    pub fn order(mut self, column: impl Into<String>) -> Self {
175        self.query_params.push(("order".into(), column.into()));
176        self
177    }
178
179    /// Limits the number of rows returned.
180    pub fn limit(mut self, count: usize) -> Self {
181        self.query_params.push(("limit".into(), count.to_string()));
182        self
183    }
184
185    /// Offsets the results by the specified number of rows.
186    pub fn offset(mut self, count: usize) -> Self {
187        self.query_params.push(("offset".into(), count.to_string()));
188        self
189    }
190
191    /// Executes the query and returns the response.
192    pub async fn execute(self) -> Result<Response, Error> {
193        let url = format!("{}/rest/v1/{}", self.client.url, self.table);
194
195        let mut request = self
196            .client
197            .client
198            .request(self.method, &url)
199            .header("apikey", &self.client.api_key)
200            .header("Content-Type", "application/json");
201
202        if let Some(ref token) = self.client.bearer_token {
203            request = request.bearer_auth(token);
204        }
205
206        if !self.query_params.is_empty() {
207            request = request.query(&self.query_params);
208        }
209
210        if let Some(body) = self.body {
211            request = request.body(body);
212        }
213
214        request.send().await
215    }
216
217    fn add_filter(
218        mut self,
219        column: impl Into<String>,
220        op: &str,
221        value: impl Into<String>,
222    ) -> Self {
223        self.query_params
224            .push((column.into(), format!("{op}.{}", value.into())));
225        self
226    }
227}
228
229impl Supabase {
230    /// Creates a QueryBuilder for the specified table.
231    ///
232    /// This is the entry point for all database operations.
233    ///
234    /// # Examples
235    ///
236    /// ```ignore
237    /// // Select all from users
238    /// client.from("users").select("*").execute().await?;
239    ///
240    /// // Select with filter
241    /// client.from("users").select("id,name").eq("status", "active").execute().await?;
242    ///
243    /// // Insert
244    /// client.from("users").insert(&user_data)?.execute().await?;
245    ///
246    /// // Update
247    /// client.from("users").update(&updates)?.eq("id", "123").execute().await?;
248    ///
249    /// // Delete
250    /// client.from("users").delete().eq("id", "123").execute().await?;
251    /// ```
252    pub fn from(&self, table: impl Into<String>) -> QueryBuilder<'_> {
253        QueryBuilder::new(self, table)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use serde::{Deserialize, Serialize};
261
262    fn client() -> Supabase {
263        Supabase::new(None, None, None)
264    }
265
266    #[derive(Debug, Serialize, Deserialize)]
267    struct TestItem {
268        name: String,
269        value: i32,
270    }
271
272    #[tokio::test]
273    async fn test_select() {
274        let client = client();
275
276        let result = client.from("test_items").select("*").execute().await;
277
278        match result {
279            Ok(resp) => {
280                let status = resp.status();
281                assert!(status.is_success() || status.as_u16() == 401);
282            }
283            Err(e) => {
284                println!("Test skipped due to network error: {e}");
285            }
286        }
287    }
288
289    #[tokio::test]
290    async fn test_select_columns() {
291        let client = client();
292
293        let result = client.from("test_items").select("id,name").execute().await;
294
295        match result {
296            Ok(resp) => {
297                let status = resp.status();
298                assert!(status.is_success() || status.as_u16() == 401);
299            }
300            Err(e) => {
301                println!("Test skipped due to network error: {e}");
302            }
303        }
304    }
305
306    #[tokio::test]
307    async fn test_select_with_filter() {
308        let client = client();
309
310        let result = client
311            .from("test_items")
312            .select("*")
313            .eq("name", "test")
314            .execute()
315            .await;
316
317        match result {
318            Ok(resp) => {
319                let status = resp.status();
320                assert!(status.is_success() || status.as_u16() == 401);
321            }
322            Err(e) => {
323                println!("Test skipped due to network error: {e}");
324            }
325        }
326    }
327
328    #[tokio::test]
329    async fn test_insert() {
330        let client = client();
331
332        let item = TestItem {
333            name: "test_item".into(),
334            value: 42,
335        };
336
337        let result = client
338            .from("test_items")
339            .insert(&item)
340            .expect("serialization should succeed")
341            .execute()
342            .await;
343
344        match result {
345            Ok(resp) => {
346                let status = resp.status();
347                assert!(status.is_success() || status.as_u16() == 401);
348            }
349            Err(e) => {
350                println!("Test skipped due to network error: {e}");
351            }
352        }
353    }
354
355    #[tokio::test]
356    async fn test_update() {
357        let client = client();
358
359        let updates = serde_json::json!({ "value": 100 });
360
361        let result = client
362            .from("test_items")
363            .update(&updates)
364            .expect("serialization should succeed")
365            .eq("name", "test_item")
366            .execute()
367            .await;
368
369        match result {
370            Ok(resp) => {
371                let status = resp.status();
372                assert!(status.is_success() || status.as_u16() == 401);
373            }
374            Err(e) => {
375                println!("Test skipped due to network error: {e}");
376            }
377        }
378    }
379
380    #[tokio::test]
381    async fn test_delete() {
382        let client = client();
383
384        let result = client
385            .from("test_items")
386            .delete()
387            .eq("name", "test_item")
388            .execute()
389            .await;
390
391        match result {
392            Ok(resp) => {
393                let status = resp.status();
394                assert!(status.is_success() || status.as_u16() == 401);
395            }
396            Err(e) => {
397                println!("Test skipped due to network error: {e}");
398            }
399        }
400    }
401
402    #[tokio::test]
403    async fn test_select_with_order_and_limit() {
404        let client = client();
405
406        let result = client
407            .from("test_items")
408            .select("*")
409            .order("id.desc")
410            .limit(10)
411            .execute()
412            .await;
413
414        match result {
415            Ok(resp) => {
416                let status = resp.status();
417                assert!(status.is_success() || status.as_u16() == 401);
418            }
419            Err(e) => {
420                println!("Test skipped due to network error: {e}");
421            }
422        }
423    }
424
425    #[tokio::test]
426    async fn test_select_with_multiple_filters() {
427        let client = client();
428
429        let result = client
430            .from("test_items")
431            .select("*")
432            .gte("value", "10")
433            .lte("value", "100")
434            .execute()
435            .await;
436
437        match result {
438            Ok(resp) => {
439                let status = resp.status();
440                assert!(status.is_success() || status.as_u16() == 401);
441            }
442            Err(e) => {
443                println!("Test skipped due to network error: {e}");
444            }
445        }
446    }
447
448    #[tokio::test]
449    async fn test_in_filter() {
450        let client = client();
451
452        let result = client
453            .from("test_items")
454            .select("*")
455            .in_("id", ["1", "2", "3"])
456            .execute()
457            .await;
458
459        match result {
460            Ok(resp) => {
461                let status = resp.status();
462                assert!(status.is_success() || status.as_u16() == 401);
463            }
464            Err(e) => {
465                println!("Test skipped due to network error: {e}");
466            }
467        }
468    }
469
470    #[test]
471    fn test_db_error_display() {
472        // Verify error types display correctly
473        let json_err = serde_json::from_str::<i32>("invalid").unwrap_err();
474        let db_err = DbError::Serialization(json_err);
475        assert!(format!("{db_err}").contains("serialization error"));
476    }
477}