1use std::marker::PhantomData;
2
3use serde::de::DeserializeOwned;
4
5use supabase_client_core::SupabaseResponse;
6
7use crate::backend::QueryBackend;
8use crate::filter::Filterable;
9use crate::modifier::Modifiable;
10use crate::sql::{FilterCondition, ParamStore, SqlParts};
11
12pub struct UpdateBuilder<T> {
15 pub(crate) backend: QueryBackend,
16 pub(crate) parts: SqlParts,
17 pub(crate) params: ParamStore,
18 pub(crate) _marker: PhantomData<T>,
19}
20
21impl<T> Filterable for UpdateBuilder<T> {
22 fn filters_mut(&mut self) -> &mut Vec<FilterCondition> {
23 &mut self.parts.filters
24 }
25 fn params_mut(&mut self) -> &mut ParamStore {
26 &mut self.params
27 }
28}
29
30impl<T> Modifiable for UpdateBuilder<T> {
31 fn parts_mut(&mut self) -> &mut SqlParts {
32 &mut self.parts
33 }
34}
35
36impl<T> UpdateBuilder<T> {
37 pub fn schema(mut self, schema: &str) -> Self {
41 self.parts.schema_override = Some(schema.to_string());
42 self
43 }
44
45 pub fn select(mut self) -> Self {
47 self.parts.returning = Some("*".to_string());
48 self
49 }
50
51 pub fn select_columns(mut self, columns: &str) -> Self {
53 if columns == "*" || columns.is_empty() {
54 self.parts.returning = Some("*".to_string());
55 } else {
56 let quoted = columns
57 .split(',')
58 .map(|c| {
59 let c = c.trim();
60 if c.contains('(') || c.contains('*') || c.contains('"') {
61 c.to_string()
62 } else {
63 format!("\"{}\"", c)
64 }
65 })
66 .collect::<Vec<_>>()
67 .join(", ");
68 self.parts.returning = Some(quoted);
69 }
70 self
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use crate::backend::QueryBackend;
78 use crate::sql::{ParamStore, SqlOperation, SqlParam, SqlParts};
79 use serde_json::Value as JsonValue;
80 use std::marker::PhantomData;
81 use std::sync::Arc;
82 use wiremock::matchers::{method, path};
83 use wiremock::{Mock, MockServer, ResponseTemplate};
84
85 fn make_update_builder() -> UpdateBuilder<JsonValue> {
86 let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
87 let mut params = ParamStore::new();
88 let idx = params.push(SqlParam::Text("Bob".to_string()));
89 parts.set_clauses.push(("name".to_string(), idx));
90 UpdateBuilder {
91 backend: QueryBackend::Rest {
92 http: reqwest::Client::new(),
93 base_url: Arc::from("http://localhost"),
94 api_key: Arc::from("test-key"),
95 schema: "public".to_string(),
96 },
97 parts,
98 params,
99 _marker: PhantomData,
100 }
101 }
102
103 #[test]
106 fn test_schema_sets_override() {
107 let builder = make_update_builder().schema("custom");
108 assert_eq!(builder.parts.schema_override.as_deref(), Some("custom"));
109 }
110
111 #[test]
112 fn test_select_sets_returning_star() {
113 let builder = make_update_builder().select();
114 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
115 }
116
117 #[test]
118 fn test_select_columns_star() {
119 let builder = make_update_builder().select_columns("*");
120 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
121 }
122
123 #[test]
124 fn test_select_columns_empty() {
125 let builder = make_update_builder().select_columns("");
126 assert_eq!(builder.parts.returning.as_deref(), Some("*"));
127 }
128
129 #[test]
130 fn test_select_columns_specific() {
131 let builder = make_update_builder().select_columns("id, name");
132 assert_eq!(builder.parts.returning.as_deref(), Some("\"id\", \"name\""));
133 }
134
135 #[test]
136 fn test_select_columns_complex_expression() {
137 let builder = make_update_builder().select_columns("count(*)");
138 assert_eq!(builder.parts.returning.as_deref(), Some("count(*)"));
139 }
140
141 #[tokio::test]
144 async fn test_execute_update_success() {
145 let mock_server = MockServer::start().await;
146 Mock::given(method("PATCH"))
147 .and(path("/rest/v1/users"))
148 .respond_with(
149 ResponseTemplate::new(200)
150 .set_body_json(serde_json::json!([{"id": 1, "name": "Bob"}])),
151 )
152 .mount(&mock_server)
153 .await;
154
155 let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
156 let mut params = ParamStore::new();
157 let idx = params.push(SqlParam::Text("Bob".to_string()));
158 parts.set_clauses.push(("name".to_string(), idx));
159 parts.returning = Some("*".to_string());
160
161 let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
162 backend: QueryBackend::Rest {
163 http: reqwest::Client::new(),
164 base_url: Arc::from(mock_server.uri().as_str()),
165 api_key: Arc::from("test-key"),
166 schema: "public".to_string(),
167 },
168 parts,
169 params,
170 _marker: PhantomData,
171 };
172
173 let resp = builder.execute().await;
174 assert!(resp.is_ok());
175 assert_eq!(resp.data.len(), 1);
176 assert_eq!(resp.data[0]["name"], "Bob");
177 }
178
179 #[tokio::test]
180 async fn test_execute_update_error() {
181 let mock_server = MockServer::start().await;
182 Mock::given(method("PATCH"))
183 .and(path("/rest/v1/users"))
184 .respond_with(
185 ResponseTemplate::new(400)
186 .set_body_json(serde_json::json!({
187 "message": "Column not found",
188 "code": "42703"
189 })),
190 )
191 .mount(&mock_server)
192 .await;
193
194 let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
195 let mut params = ParamStore::new();
196 let idx = params.push(SqlParam::Text("Bob".to_string()));
197 parts.set_clauses.push(("nonexistent".to_string(), idx));
198
199 let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
200 backend: QueryBackend::Rest {
201 http: reqwest::Client::new(),
202 base_url: Arc::from(mock_server.uri().as_str()),
203 api_key: Arc::from("test-key"),
204 schema: "public".to_string(),
205 },
206 parts,
207 params,
208 _marker: PhantomData,
209 };
210
211 let resp = builder.execute().await;
212 assert!(resp.is_err());
213 match resp.error.as_ref().unwrap() {
214 supabase_client_core::SupabaseError::PostgRest { status, message, .. } => {
215 assert_eq!(*status, 400);
216 assert_eq!(message, "Column not found");
217 }
218 other => panic!("Expected PostgRest error, got {:?}", other),
219 }
220 }
221
222 #[tokio::test]
223 async fn test_execute_update_no_returning() {
224 let mock_server = MockServer::start().await;
225 Mock::given(method("PATCH"))
226 .and(path("/rest/v1/users"))
227 .respond_with(ResponseTemplate::new(204))
228 .mount(&mock_server)
229 .await;
230
231 let mut parts = SqlParts::new(SqlOperation::Update, "public", "users");
232 let mut params = ParamStore::new();
233 let idx = params.push(SqlParam::Text("Bob".to_string()));
234 parts.set_clauses.push(("name".to_string(), idx));
235
236 let builder: UpdateBuilder<JsonValue> = UpdateBuilder {
237 backend: QueryBackend::Rest {
238 http: reqwest::Client::new(),
239 base_url: Arc::from(mock_server.uri().as_str()),
240 api_key: Arc::from("test-key"),
241 schema: "public".to_string(),
242 },
243 parts,
244 params,
245 _marker: PhantomData,
246 };
247
248 let resp = builder.execute().await;
249 assert!(resp.is_ok());
250 assert!(resp.data.is_empty());
251 }
252}
253
254#[cfg(not(feature = "direct-sql"))]
256impl<T> UpdateBuilder<T>
257where
258 T: DeserializeOwned + Send,
259{
260 pub async fn execute(self) -> SupabaseResponse<T> {
262 let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
263 let (url, headers, body) = match crate::postgrest::build_postgrest_update(
264 base_url, &self.parts, &self.params,
265 ) {
266 Ok(r) => r,
267 Err(e) => return SupabaseResponse::error(
268 supabase_client_core::SupabaseError::QueryBuilder(e),
269 ),
270 };
271 crate::postgrest_execute::execute_rest(
272 http, reqwest::Method::PATCH, &url, headers, Some(body), api_key, schema, &self.parts,
273 ).await
274 }
275}
276
277#[cfg(feature = "direct-sql")]
279impl<T> UpdateBuilder<T>
280where
281 T: DeserializeOwned + Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
282{
283 pub async fn execute(self) -> SupabaseResponse<T> {
285 match &self.backend {
286 QueryBackend::Rest { http, base_url, api_key, schema } => {
287 let (url, headers, body) = match crate::postgrest::build_postgrest_update(
288 base_url, &self.parts, &self.params,
289 ) {
290 Ok(r) => r,
291 Err(e) => return SupabaseResponse::error(
292 supabase_client_core::SupabaseError::QueryBuilder(e),
293 ),
294 };
295 crate::postgrest_execute::execute_rest(
296 http, reqwest::Method::PATCH, &url, headers, Some(body), api_key, schema, &self.parts,
297 ).await
298 }
299 QueryBackend::DirectSql { pool } => {
300 crate::execute::execute_typed::<T>(pool, &self.parts, &self.params).await
301 }
302 }
303 }
304}