Skip to main content

supabase_client_query/
update.rs

1use std::marker::PhantomData;
2
3use serde::de::DeserializeOwned;
4
5use supabase_client_core::SupabaseResponse;
6
7use crate::backend::QueryBackend;
8use crate::filter::Filterable;
9use crate::modifier::Modifiable;
10use crate::sql::{FilterCondition, ParamStore, SqlParts};
11
12/// Builder for UPDATE queries. Implements Filterable and Modifiable.
13/// Call `.select()` to add RETURNING clause.
14pub struct UpdateBuilder<T> {
15    pub(crate) backend: QueryBackend,
16    pub(crate) parts: SqlParts,
17    pub(crate) params: ParamStore,
18    pub(crate) _marker: PhantomData<T>,
19}
20
21impl<T> Filterable for UpdateBuilder<T> {
22    fn filters_mut(&mut self) -> &mut Vec<FilterCondition> {
23        &mut self.parts.filters
24    }
25    fn params_mut(&mut self) -> &mut ParamStore {
26        &mut self.params
27    }
28}
29
30impl<T> Modifiable for UpdateBuilder<T> {
31    fn parts_mut(&mut self) -> &mut SqlParts {
32        &mut self.parts
33    }
34}
35
36impl<T> UpdateBuilder<T> {
37    /// Override the schema for this query.
38    ///
39    /// Generates `"schema"."table"` instead of the default schema.
40    pub fn schema(mut self, schema: &str) -> Self {
41        self.parts.schema_override = Some(schema.to_string());
42        self
43    }
44
45    /// Add RETURNING * clause.
46    pub fn select(mut self) -> Self {
47        self.parts.returning = Some("*".to_string());
48        self
49    }
50
51    /// Add RETURNING with specific columns.
52    pub fn select_columns(mut self, columns: &str) -> Self {
53        if columns == "*" || columns.is_empty() {
54            self.parts.returning = Some("*".to_string());
55        } else {
56            let quoted = columns
57                .split(',')
58                .map(|c| {
59                    let c = c.trim();
60                    if c.contains('(') || c.contains('*') || c.contains('"') {
61                        c.to_string()
62                    } else {
63                        format!("\"{}\"", c)
64                    }
65                })
66                .collect::<Vec<_>>()
67                .join(", ");
68            self.parts.returning = Some(quoted);
69        }
70        self
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::backend::QueryBackend;
78    use crate::sql::{ParamStore, SqlOperation, SqlParam, SqlParts};
79    use serde_json::Value as JsonValue;
80    use std::marker::PhantomData;
81    use std::sync::Arc;
82    use wiremock::matchers::{method, path};
83    use wiremock::{Mock, MockServer, ResponseTemplate};
84
85    fn make_update_builder() -> UpdateBuilder<JsonValue> {
86        let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
87        let mut params = ParamStore::new();
88        let idx = params.push(SqlParam::Text("Bob".to_string()));
89        parts.set_clauses.push(("name".to_string(), idx));
90        UpdateBuilder {
91            backend: QueryBackend::Rest {
92                http: reqwest::Client::new(),
93                base_url: Arc::from("http://localhost"),
94                api_key: Arc::from("test-key"),
95                schema: "public".to_string(),
96            },
97            parts,
98            params,
99            _marker: PhantomData,
100        }
101    }
102
103    // ---- Builder method tests ----
104
105    #[test]
106    fn test_schema_sets_override() {
107        let builder = make_update_builder().schema("custom");
108        assert_eq!(builder.parts.schema_override.as_deref(), Some("custom"));
109    }
110
111    #[test]
112    fn test_select_sets_returning_star() {
113        let builder = make_update_builder().select();
114        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
115    }
116
117    #[test]
118    fn test_select_columns_star() {
119        let builder = make_update_builder().select_columns("*");
120        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
121    }
122
123    #[test]
124    fn test_select_columns_empty() {
125        let builder = make_update_builder().select_columns("");
126        assert_eq!(builder.parts.returning.as_deref(), Some("*"));
127    }
128
129    #[test]
130    fn test_select_columns_specific() {
131        let builder = make_update_builder().select_columns("id, name");
132        assert_eq!(builder.parts.returning.as_deref(), Some("\"id\", \"name\""));
133    }
134
135    #[test]
136    fn test_select_columns_complex_expression() {
137        let builder = make_update_builder().select_columns("count(*)");
138        assert_eq!(builder.parts.returning.as_deref(), Some("count(*)"));
139    }
140
141    // ---- execute() via wiremock ----
142
143    #[tokio::test]
144    async fn test_execute_update_success() {
145        let mock_server = MockServer::start().await;
146        Mock::given(method("PATCH"))
147            .and(path("/rest/v1/users"))
148            .respond_with(
149                ResponseTemplate::new(200)
150                    .set_body_json(serde_json::json!([{"id": 1, "name": "Bob"}])),
151            )
152            .mount(&mock_server)
153            .await;
154
155        let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
156        let mut params = ParamStore::new();
157        let idx = params.push(SqlParam::Text("Bob".to_string()));
158        parts.set_clauses.push(("name".to_string(), idx));
159        parts.returning = Some("*".to_string());
160
161        let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
162            backend: QueryBackend::Rest {
163                http: reqwest::Client::new(),
164                base_url: Arc::from(mock_server.uri().as_str()),
165                api_key: Arc::from("test-key"),
166                schema: "public".to_string(),
167            },
168            parts,
169            params,
170            _marker: PhantomData,
171        };
172
173        let resp = builder.execute().await;
174        assert!(resp.is_ok());
175        assert_eq!(resp.data.len(), 1);
176        assert_eq!(resp.data[0]["name"], "Bob");
177    }
178
179    #[tokio::test]
180    async fn test_execute_update_error() {
181        let mock_server = MockServer::start().await;
182        Mock::given(method("PATCH"))
183            .and(path("/rest/v1/users"))
184            .respond_with(
185                ResponseTemplate::new(400)
186                    .set_body_json(serde_json::json!({
187                        "message": "Column not found",
188                        "code": "42703"
189                    })),
190            )
191            .mount(&mock_server)
192            .await;
193
194        let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
195        let mut params = ParamStore::new();
196        let idx = params.push(SqlParam::Text("Bob".to_string()));
197        parts.set_clauses.push(("nonexistent".to_string(), idx));
198
199        let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
200            backend: QueryBackend::Rest {
201                http: reqwest::Client::new(),
202                base_url: Arc::from(mock_server.uri().as_str()),
203                api_key: Arc::from("test-key"),
204                schema: "public".to_string(),
205            },
206            parts,
207            params,
208            _marker: PhantomData,
209        };
210
211        let resp = builder.execute().await;
212        assert!(resp.is_err());
213        match resp.error.as_ref().unwrap() {
214            supabase_client_core::SupabaseError::PostgRest { status, message, .. } => {
215                assert_eq!(*status, 400);
216                assert_eq!(message, "Column not found");
217            }
218            other => panic!("Expected PostgRest error, got {:?}", other),
219        }
220    }
221
222    #[tokio::test]
223    async fn test_execute_update_no_returning() {
224        let mock_server = MockServer::start().await;
225        Mock::given(method("PATCH"))
226            .and(path("/rest/v1/users"))
227            .respond_with(ResponseTemplate::new(204))
228            .mount(&mock_server)
229            .await;
230
231        let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
232        let mut params = ParamStore::new();
233        let idx = params.push(SqlParam::Text("Bob".to_string()));
234        parts.set_clauses.push(("name".to_string(), idx));
235
236        let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
237            backend: QueryBackend::Rest {
238                http: reqwest::Client::new(),
239                base_url: Arc::from(mock_server.uri().as_str()),
240                api_key: Arc::from("test-key"),
241                schema: "public".to_string(),
242            },
243            parts,
244            params,
245            _marker: PhantomData,
246        };
247
248        let resp = builder.execute().await;
249        assert!(resp.is_ok());
250        assert!(resp.data.is_empty());
251    }
252}
253
254// REST-only mode: only DeserializeOwned + Send needed
255#[cfg(not(feature = "direct-sql"))]
256impl<T> UpdateBuilder<T>
257where
258    T: DeserializeOwned + Send,
259{
260    /// Execute the UPDATE query.
261    pub async fn execute(self) -> SupabaseResponse<T> {
262        let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
263        let (url, headers, body) = match crate::postgrest::build_postgrest_update(
264            base_url, &self.parts, &self.params,
265        ) {
266            Ok(r) => r,
267            Err(e) => return SupabaseResponse::error(
268                supabase_client_core::SupabaseError::QueryBuilder(e),
269            ),
270        };
271        crate::postgrest_execute::execute_rest(
272            http, reqwest::Method::PATCH, &url, headers, Some(body), api_key, schema, &self.parts,
273        ).await
274    }
275}
276
277// Direct-SQL mode: additional FromRow + Unpin bounds
278#[cfg(feature = "direct-sql")]
279impl<T> UpdateBuilder<T>
280where
281    T: DeserializeOwned + Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
282{
283    /// Execute the UPDATE query.
284    pub async fn execute(self) -> SupabaseResponse<T> {
285        match &self.backend {
286            QueryBackend::Rest { http, base_url, api_key, schema } => {
287                let (url, headers, body) = match crate::postgrest::build_postgrest_update(
288                    base_url, &self.parts, &self.params,
289                ) {
290                    Ok(r) => r,
291                    Err(e) => return SupabaseResponse::error(
292                        supabase_client_core::SupabaseError::QueryBuilder(e),
293                    ),
294                };
295                crate::postgrest_execute::execute_rest(
296                    http, reqwest::Method::PATCH, &url, headers, Some(body), api_key, schema, &self.parts,
297                ).await
298            }
299            QueryBackend::DirectSql { pool } => {
300                crate::execute::execute_typed::<T>(pool, &self.parts, &self.params).await
301            }
302        }
303    }
304}