Skip to main content

supabase_client_query/
rpc.rs

1use std::marker::PhantomData;
2
3use serde::de::DeserializeOwned;
4use serde_json::Value as JsonValue;
5
6use supabase_client_core::{Row, SupabaseError, SupabaseResponse};
7
8use crate::backend::QueryBackend;
9#[cfg(feature = "direct-sql")]
10use crate::sql::{ParamStore, SqlParam};
11use crate::sql::{SqlParts, SqlOperation, validate_identifier};
12
13/// Builder for RPC (function call) queries.
14pub struct RpcBuilder {
15    backend: QueryBackend,
16    schema: String,
17    function: String,
18    args: JsonValue,
19    rollback: bool,
20    #[cfg(feature = "direct-sql")]
21    params: ParamStore,
22    #[cfg(feature = "direct-sql")]
23    named_params: Vec<(String, usize)>,
24}
25
26impl RpcBuilder {
27    pub fn new(
28        backend: QueryBackend,
29        schema: String,
30        function: String,
31        args: JsonValue,
32    ) -> Result<Self, SupabaseError> {
33        validate_identifier(&function, "Function")?;
34
35        #[cfg(feature = "direct-sql")]
36        let (params, named_params) = {
37            let mut param_store = ParamStore::new();
38            let mut named = Vec::new();
39
40            if let JsonValue::Object(ref map) = args {
41                for (key, value) in map {
42                    validate_identifier(key, "Parameter")?;
43                    let sql_param = json_value_to_param(value.clone());
44                    let idx = param_store.push(sql_param);
45                    named.push((key.clone(), idx));
46                }
47            } else if !args.is_null() {
48                return Err(SupabaseError::query_builder(
49                    "RPC arguments must be a JSON object or null",
50                ));
51            }
52
53            (param_store, named)
54        };
55
56        #[cfg(not(feature = "direct-sql"))]
57        {
58            if let JsonValue::Object(ref map) = args {
59                for key in map.keys() {
60                    validate_identifier(key, "Parameter")?;
61                }
62            } else if !args.is_null() {
63                return Err(SupabaseError::query_builder(
64                    "RPC arguments must be a JSON object or null",
65                ));
66            }
67        }
68
69        Ok(Self {
70            backend,
71            schema,
72            function,
73            args,
74            rollback: false,
75            #[cfg(feature = "direct-sql")]
76            params,
77            #[cfg(feature = "direct-sql")]
78            named_params,
79        })
80    }
81
82    /// Run the RPC call inside a transaction that is rolled back.
83    /// Useful for testing or dry-run scenarios.
84    pub fn rollback(mut self) -> Self {
85        self.rollback = true;
86        self
87    }
88
89    #[cfg(feature = "direct-sql")]
90    fn build_sql(&self) -> Result<String, SupabaseError> {
91        validate_identifier(&self.schema, "Schema")?;
92        validate_identifier(&self.function, "Function")?;
93
94        if self.named_params.is_empty() {
95            Ok(format!(
96                "SELECT * FROM \"{}\".\"{}\"()",
97                self.schema, self.function
98            ))
99        } else {
100            let param_list: Vec<String> = self
101                .named_params
102                .iter()
103                .map(|(name, idx)| format!("\"{}\" := ${}", name, idx))
104                .collect();
105            Ok(format!(
106                "SELECT * FROM \"{}\".\"{}\"({})",
107                self.schema,
108                self.function,
109                param_list.join(", ")
110            ))
111        }
112    }
113}
114
115// REST-only mode: no sqlx needed
116#[cfg(not(feature = "direct-sql"))]
117impl RpcBuilder {
118    /// Execute the RPC call and return dynamic rows.
119    pub async fn execute(self) -> SupabaseResponse<Row> {
120        let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
121        let (url, headers, body) = crate::postgrest::build_postgrest_rpc(
122            base_url, &self.function, &self.args, self.rollback,
123        );
124        let parts = SqlParts::new(SqlOperation::Select, &self.schema, &self.function);
125        crate::postgrest_execute::execute_rest(
126            http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &parts,
127        ).await
128    }
129}
130
131// Direct-SQL mode: dispatch on backend variant
132#[cfg(feature = "direct-sql")]
133impl RpcBuilder {
134    /// Execute the RPC call and return dynamic rows.
135    pub async fn execute(self) -> SupabaseResponse<Row> {
136        match &self.backend {
137            QueryBackend::Rest { http, base_url, api_key, schema } => {
138                let (url, headers, body) = crate::postgrest::build_postgrest_rpc(
139                    base_url, &self.function, &self.args, self.rollback,
140                );
141                let parts = SqlParts::new(SqlOperation::Select, &self.schema, &self.function);
142                crate::postgrest_execute::execute_rest(
143                    http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &parts,
144                ).await
145            }
146            QueryBackend::DirectSql { pool } => {
147                let sql = match self.build_sql() {
148                    Ok(s) => s,
149                    Err(e) => return SupabaseResponse::error(e),
150                };
151
152                tracing::debug!(sql = %sql, "Executing RPC call");
153
154                let args = match crate::execute::bind_params(&self.params) {
155                    Ok(a) => a,
156                    Err(e) => return SupabaseResponse::error(e),
157                };
158
159                match sqlx::query_with(&sql, args).fetch_all(pool.as_ref()).await {
160                    Ok(rows) => {
161                        use sqlx::{Column, Row as PgRowTrait};
162                        let data: Vec<Row> = rows
163                            .iter()
164                            .map(|row| {
165                                let mut map = Row::new();
166                                for col in row.columns() {
167                                    let name = col.name();
168                                    if let Ok(v) = row.try_get::<JsonValue, _>(name) {
169                                        map.set(name, v);
170                                    } else if let Ok(v) = row.try_get::<String, _>(name) {
171                                        map.set(name, JsonValue::String(v));
172                                    } else if let Ok(v) = row.try_get::<i64, _>(name) {
173                                        map.set(name, JsonValue::Number(v.into()));
174                                    } else if let Ok(v) = row.try_get::<i32, _>(name) {
175                                        map.set(name, JsonValue::Number(v.into()));
176                                    } else if let Ok(v) = row.try_get::<f64, _>(name) {
177                                        if let Some(n) = serde_json::Number::from_f64(v) {
178                                            map.set(name, JsonValue::Number(n));
179                                        } else {
180                                            map.set(name, JsonValue::Null);
181                                        }
182                                    } else if let Ok(v) = row.try_get::<bool, _>(name) {
183                                        map.set(name, JsonValue::Bool(v));
184                                    } else {
185                                        map.set(name, JsonValue::Null);
186                                    }
187                                }
188                                map
189                            })
190                            .collect();
191                        SupabaseResponse::ok(data)
192                    }
193                    Err(e) => SupabaseResponse::error(SupabaseError::Database(e)),
194                }
195            }
196        }
197    }
198}
199
200/// Typed RPC builder that deserializes results into `T`.
201pub struct TypedRpcBuilder<T> {
202    backend: QueryBackend,
203    schema: String,
204    function: String,
205    args: JsonValue,
206    rollback: bool,
207    #[cfg(feature = "direct-sql")]
208    params: ParamStore,
209    #[cfg(feature = "direct-sql")]
210    named_params: Vec<(String, usize)>,
211    _marker: PhantomData<T>,
212}
213
214impl<T> TypedRpcBuilder<T>
215where
216    T: DeserializeOwned + Send,
217{
218    pub fn new(
219        backend: QueryBackend,
220        schema: String,
221        function: String,
222        args: JsonValue,
223    ) -> Result<Self, SupabaseError> {
224        validate_identifier(&function, "Function")?;
225
226        #[cfg(feature = "direct-sql")]
227        let (params, named_params) = {
228            let mut param_store = ParamStore::new();
229            let mut named = Vec::new();
230
231            if let JsonValue::Object(ref map) = args {
232                for (key, value) in map {
233                    validate_identifier(key, "Parameter")?;
234                    let sql_param = json_value_to_param(value.clone());
235                    let idx = param_store.push(sql_param);
236                    named.push((key.clone(), idx));
237                }
238            } else if !args.is_null() {
239                return Err(SupabaseError::query_builder(
240                    "RPC arguments must be a JSON object or null",
241                ));
242            }
243
244            (param_store, named)
245        };
246
247        #[cfg(not(feature = "direct-sql"))]
248        {
249            if let JsonValue::Object(ref map) = args {
250                for key in map.keys() {
251                    validate_identifier(key, "Parameter")?;
252                }
253            } else if !args.is_null() {
254                return Err(SupabaseError::query_builder(
255                    "RPC arguments must be a JSON object or null",
256                ));
257            }
258        }
259
260        Ok(Self {
261            backend,
262            schema,
263            function,
264            args,
265            rollback: false,
266            #[cfg(feature = "direct-sql")]
267            params,
268            #[cfg(feature = "direct-sql")]
269            named_params,
270            _marker: PhantomData,
271        })
272    }
273
274    /// Run the RPC call inside a transaction that is rolled back.
275    pub fn rollback(mut self) -> Self {
276        self.rollback = true;
277        self
278    }
279
280    #[cfg(feature = "direct-sql")]
281    fn build_sql(&self) -> Result<String, SupabaseError> {
282        validate_identifier(&self.schema, "Schema")?;
283        validate_identifier(&self.function, "Function")?;
284
285        if self.named_params.is_empty() {
286            Ok(format!(
287                "SELECT * FROM \"{}\".\"{}\"()",
288                self.schema, self.function
289            ))
290        } else {
291            let param_list: Vec<String> = self
292                .named_params
293                .iter()
294                .map(|(name, idx)| format!("\"{}\" := ${}", name, idx))
295                .collect();
296            Ok(format!(
297                "SELECT * FROM \"{}\".\"{}\"({})",
298                self.schema,
299                self.function,
300                param_list.join(", ")
301            ))
302        }
303    }
304}
305
306// REST-only mode: only DeserializeOwned + Send needed
307#[cfg(not(feature = "direct-sql"))]
308impl<T> TypedRpcBuilder<T>
309where
310    T: DeserializeOwned + Send,
311{
312    /// Execute the typed RPC call.
313    pub async fn execute(self) -> SupabaseResponse<T> {
314        let QueryBackend::Rest { ref http, ref base_url, ref api_key, ref schema } = self.backend;
315        let (url, headers, body) = crate::postgrest::build_postgrest_rpc(
316            base_url, &self.function, &self.args, self.rollback,
317        );
318        let parts = SqlParts::new(SqlOperation::Select, &self.schema, &self.function);
319        crate::postgrest_execute::execute_rest(
320            http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &parts,
321        ).await
322    }
323}
324
325// Direct-SQL mode: additional FromRow + Unpin bounds
326#[cfg(feature = "direct-sql")]
327impl<T> TypedRpcBuilder<T>
328where
329    T: DeserializeOwned + Send + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
330{
331    /// Execute the typed RPC call.
332    pub async fn execute(self) -> SupabaseResponse<T> {
333        match &self.backend {
334            QueryBackend::Rest { http, base_url, api_key, schema } => {
335                let (url, headers, body) = crate::postgrest::build_postgrest_rpc(
336                    base_url, &self.function, &self.args, self.rollback,
337                );
338                let parts = SqlParts::new(SqlOperation::Select, &self.schema, &self.function);
339                crate::postgrest_execute::execute_rest(
340                    http, reqwest::Method::POST, &url, headers, Some(body), api_key, schema, &parts,
341                ).await
342            }
343            QueryBackend::DirectSql { pool } => {
344                let sql = match self.build_sql() {
345                    Ok(s) => s,
346                    Err(e) => return SupabaseResponse::error(e),
347                };
348
349                tracing::debug!(sql = %sql, "Executing typed RPC call");
350
351                let args = match crate::execute::bind_params(&self.params) {
352                    Ok(a) => a,
353                    Err(e) => return SupabaseResponse::error(e),
354                };
355
356                match sqlx::query_as_with::<_, T, _>(&sql, args)
357                    .fetch_all(pool.as_ref())
358                    .await
359                {
360                    Ok(data) => SupabaseResponse::ok(data),
361                    Err(e) => SupabaseResponse::error(SupabaseError::Database(e)),
362                }
363            }
364        }
365    }
366}
367
368#[cfg(feature = "direct-sql")]
369fn json_value_to_param(value: JsonValue) -> SqlParam {
370    match value {
371        JsonValue::Null => SqlParam::Null,
372        JsonValue::Bool(b) => SqlParam::Bool(b),
373        JsonValue::Number(n) => {
374            if let Some(i) = n.as_i64() {
375                if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
376                    SqlParam::I32(i as i32)
377                } else {
378                    SqlParam::I64(i)
379                }
380            } else if let Some(f) = n.as_f64() {
381                SqlParam::F64(f)
382            } else {
383                SqlParam::Text(n.to_string())
384            }
385        }
386        JsonValue::String(s) => SqlParam::Text(s),
387        other => SqlParam::Json(other),
388    }
389}