1use reqwest::{
2 header::{HeaderMap, HeaderValue},
3 Client, Error, Method, Response,
4};
5use serde::{Deserialize, Serialize};
6
7pub mod headers_serde {
8 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
9 use serde::{
10 de::{Deserialize, Error},
11 ser::SerializeSeq,
12 };
13
14 pub fn serialize<S: serde::Serializer>(map: &HeaderMap, s: S) -> Result<S::Ok, S::Error> {
15 struct Bytes<'a>(&'a [u8]);
16
17 impl<'a> serde::Serialize for Bytes<'a> {
18 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
19 where
20 S: serde::Serializer,
21 {
22 s.serialize_bytes(self.0)
23 }
24 }
25
26 let mut seq = s.serialize_seq(Some(map.len()))?;
27 for (k, v) in map.iter() {
28 match v.to_str() {
29 Ok(s) => seq.serialize_element(&(k.as_str(), s))?,
30 Err(_) => seq.serialize_element(&(k.as_str(), &Bytes(v.as_bytes())))?,
31 }
32 }
33 seq.end()
34 }
35
36 pub struct Name(HeaderName);
37
38 pub struct NameVisitor;
39
40 impl<'de> serde::de::Visitor<'de> for NameVisitor {
41 type Value = Name;
42
43 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
44 formatter.write_str("string")
45 }
46
47 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
48 where
49 E: Error,
50 {
51 Ok(Name(HeaderName::from_bytes(v).map_err(E::custom)?))
52 }
53
54 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
55 where
56 E: Error,
57 {
58 self.visit_bytes(v.as_bytes())
59 }
60 }
61
62 impl<'de> Deserialize<'de> for Name {
63 fn deserialize<D>(d: D) -> Result<Self, D::Error>
64 where
65 D: serde::Deserializer<'de>,
66 {
67 d.deserialize_str(NameVisitor)
68 }
69 }
70
71 pub struct Value(HeaderValue);
72
73 pub struct ValueVisitor;
74
75 impl<'de> serde::de::Visitor<'de> for ValueVisitor {
76 type Value = Value;
77
78 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
79 formatter.write_str("string or bytes")
80 }
81
82 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
83 where
84 E: Error,
85 {
86 Ok(Value(HeaderValue::from_bytes(v).map_err(E::custom)?))
87 }
88
89 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
90 where
91 E: Error,
92 {
93 Ok(Value(HeaderValue::from_str(v).map_err(E::custom)?))
94 }
95 }
96
97 impl<'de> Deserialize<'de> for Value {
98 fn deserialize<D>(d: D) -> Result<Self, D::Error>
99 where
100 D: serde::Deserializer<'de>,
101 {
102 d.deserialize_any(ValueVisitor)
103 }
104 }
105
106 pub struct Visitor;
107
108 impl<'de> serde::de::Visitor<'de> for Visitor {
109 type Value = HeaderMap;
110
111 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
112 formatter.write_str("[(string, string|bytes)]")
113 }
114
115 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116 where
117 A: serde::de::SeqAccess<'de>,
118 {
119 let mut map = HeaderMap::with_capacity(seq.size_hint().unwrap_or(0));
120 while let Some((name, value)) = seq.next_element::<(Name, Value)>()? {
121 map.insert(name.0, value.0);
122 }
123 Ok(map)
124 }
125 }
126
127 pub fn deserialize<'de, D: serde::Deserializer<'de>>(d: D) -> Result<HeaderMap, D::Error> {
128 d.deserialize_seq(Visitor)
129 }
130}
131
132pub mod method_serde {
133 use reqwest::Method;
134 use serde::de::{Deserialize, Error};
135
136 pub fn serialize<S: serde::Serializer>(method: &Method, s: S) -> Result<S::Ok, S::Error> {
137 s.serialize_str(method.as_str())
138 }
139
140 pub fn deserialize<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Method, D::Error> {
141 String::deserialize(d)?.parse().map_err(D::Error::custom)
142 }
143}
144
145#[derive(Serialize, Deserialize, Debug, Clone)]
146pub struct Query {
147 #[serde(with = "method_serde")]
148 pub method: Method,
149 pub url: String,
150 pub schema: Option<String>,
151 pub queries: Vec<(String, String)>,
153 #[serde(with = "headers_serde")]
154 pub headers: HeaderMap,
155 pub body: Option<String>,
156 pub is_rpc: bool,
157}
158
159impl From<Builder> for Query {
160 fn from(
161 Builder {
162 method,
163 url,
164 schema,
165 queries,
166 headers,
167 body,
168 is_rpc,
169 ..
170 }: Builder,
171 ) -> Self {
172 Self {
173 method,
174 url,
175 schema,
176 queries,
177 headers,
178 body,
179 is_rpc,
180 }
181 }
182}
183
184#[derive(Clone, Debug)]
186pub struct Builder {
187 method: Method,
188 url: String,
189 schema: Option<String>,
190 pub(crate) queries: Vec<(String, String)>,
192 headers: HeaderMap,
193 body: Option<String>,
194 is_rpc: bool,
195 pub client: Client,
198}
199
200impl Builder {
202 pub fn from_query(
203 Query {
204 method,
205 url,
206 schema,
207 queries,
208 headers,
209 body,
210 is_rpc,
211 }: Query,
212 client: Client,
213 ) -> Self {
214 Self {
215 method,
216 url,
217 schema,
218 queries,
219 headers,
220 body,
221 is_rpc,
222 client,
223 }
224 }
225
226 pub fn new<T>(url: T, schema: Option<String>, headers: HeaderMap, client: Client) -> Self
228 where
229 T: Into<String>,
230 {
231 let url = url.into().trim_end_matches('/').to_string();
232
233 let mut builder = Builder {
234 method: Method::GET,
235 url,
236 schema,
237 queries: Vec::new(),
238 headers,
239 body: None,
240 is_rpc: false,
241 client,
242 };
243 builder
244 .headers
245 .insert("Accept", HeaderValue::from_static("application/json"));
246 builder
247 }
248
249 pub fn auth<T>(mut self, token: T) -> Self
262 where
263 T: AsRef<str>,
264 {
265 self.headers.insert(
266 "Authorization",
267 HeaderValue::from_str(&format!("Bearer {}", token.as_ref())).unwrap(),
268 );
269 self
270 }
271
272 pub fn select<T>(mut self, columns: T) -> Self
358 where
359 T: Into<String>,
360 {
361 self.queries.push(("select".to_string(), columns.into()));
362 self
363 }
364
365 pub fn order<T>(mut self, columns: T) -> Self
379 where
380 T: Into<String>,
381 {
382 self.queries.push(("order".to_string(), columns.into()));
383 self
384 }
385
386 pub fn order_with_options<T, U>(
400 mut self,
401 columns: T,
402 foreign_table: Option<U>,
403 ascending: bool,
404 nulls_first: bool,
405 ) -> Self
406 where
407 T: Into<String>,
408 U: Into<String>,
409 {
410 let mut key = "order".to_string();
411 if let Some(foreign_table) = foreign_table {
412 let foreign_table = foreign_table.into();
413 if !foreign_table.is_empty() {
414 key = format!("{}.order", foreign_table);
415 }
416 }
417
418 let mut ascending_string = "desc";
419 if ascending {
420 ascending_string = "asc";
421 }
422
423 let mut nulls_first_string = "nullslast";
424 if nulls_first {
425 nulls_first_string = "nullsfirst";
426 }
427
428 let existing_order = self.queries.iter().find(|(k, _)| k == &key);
429 match existing_order {
430 Some((_, v)) => {
431 let new_order = format!(
432 "{},{}.{}.{}",
433 v,
434 columns.into(),
435 ascending_string,
436 nulls_first_string
437 );
438 self.queries.push((key, new_order));
439 }
440 None => {
441 self.queries.push((
442 key,
443 format!(
444 "{}.{}.{}",
445 columns.into(),
446 ascending_string,
447 nulls_first_string
448 ),
449 ));
450 }
451 }
452 self
453 }
454
455 pub fn limit(mut self, count: usize) -> Self {
469 self.headers
470 .insert("Range-Unit", HeaderValue::from_static("items"));
471 self.headers.insert(
472 "Range",
473 HeaderValue::from_str(&format!("0-{}", count - 1)).unwrap(),
474 );
475 self
476 }
477
478 pub fn foreign_table_limit<T>(mut self, count: usize, foreign_table: T) -> Self
492 where
493 T: Into<String>,
494 {
495 self.queries
496 .push((format!("{}.limit", foreign_table.into()), count.to_string()));
497 self
498 }
499
500 pub fn range(mut self, low: usize, high: usize) -> Self {
515 self.headers
516 .insert("Range-Unit", HeaderValue::from_static("items"));
517 self.headers.insert(
518 "Range",
519 HeaderValue::from_str(&format!("{}-{}", low, high)).unwrap(),
520 );
521 self
522 }
523
524 fn count(mut self, method: &str) -> Self {
525 self.headers
526 .insert("Range-Unit", HeaderValue::from_static("items"));
527 self.headers
529 .insert("Range", HeaderValue::from_static("0-0"));
530 self.headers.insert(
531 "Prefer",
532 HeaderValue::from_str(&format!("count={}", method)).unwrap(),
533 );
534 self
535 }
536
537 pub fn exact_count(self) -> Self {
551 self.count("exact")
552 }
553
554 pub fn planned_count(self) -> Self {
569 self.count("planned")
570 }
571
572 pub fn estimated_count(self) -> Self {
587 self.count("estimated")
588 }
589
590 pub fn single(mut self) -> Self {
604 self.headers.insert(
605 "Accept",
606 HeaderValue::from_static("application/vnd.pgrst.object+json"),
607 );
608 self
609 }
610
611 pub fn insert<T>(mut self, body: T) -> Self
625 where
626 T: Into<String>,
627 {
628 self.method = Method::POST;
629 self.headers
630 .insert("Prefer", HeaderValue::from_static("return=representation"));
631 self.body = Some(body.into());
632 self
633 }
634
635 pub fn upsert<T>(mut self, body: T) -> Self
654 where
655 T: Into<String>,
656 {
657 self.method = Method::POST;
658 self.headers.insert(
659 "Prefer",
660 HeaderValue::from_static("return=representation,resolution=merge-duplicates"),
661 );
662 self.body = Some(body.into());
663 self
664 }
665
666 pub fn on_conflict<T>(mut self, columns: T) -> Self
690 where
691 T: Into<String>,
692 {
693 self.queries
694 .push(("on_conflict".to_string(), columns.into()));
695 self
696 }
697
698 pub fn update<T>(mut self, body: T) -> Self
712 where
713 T: Into<String>,
714 {
715 self.method = Method::PATCH;
716 self.headers
717 .insert("Prefer", HeaderValue::from_static("return=representation"));
718 self.body = Some(body.into());
719 self
720 }
721
722 pub fn delete(mut self) -> Self {
736 self.method = Method::DELETE;
737 self.headers
738 .insert("Prefer", HeaderValue::from_static("return=representation"));
739 self
740 }
741
742 pub fn rpc<T>(mut self, params: T) -> Self
745 where
746 T: Into<String>,
747 {
748 self.method = Method::POST;
749 self.body = Some(params.into());
750 self.is_rpc = true;
751 self
752 }
753
754 pub fn build(mut self) -> reqwest::RequestBuilder {
756 if let Some(schema) = self.schema {
757 let key = match self.method {
758 Method::GET | Method::HEAD => "Accept-Profile",
759 _ => "Content-Profile",
760 };
761 self.headers
762 .insert(key, HeaderValue::from_str(&schema).unwrap());
763 }
764 match self.method {
765 Method::GET | Method::HEAD => {}
766 _ => {
767 self.headers
768 .insert("Content-Type", HeaderValue::from_static("application/json"));
769 }
770 };
771 self.client
772 .request(self.method, self.url)
773 .headers(self.headers)
774 .query(&self.queries)
775 .body(self.body.unwrap_or_default())
776 }
777
778 pub async fn execute(self) -> Result<Response, Error> {
780 self.build().send().await
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787
788 const TABLE_URL: &str = "http://localhost:3000/table";
789 const RPC_URL: &str = "http://localhost:3000/rpc";
790
791 #[test]
792 fn only_accept_json() {
793 let client = Client::new();
794 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client);
795 assert_eq!(
796 builder.headers.get("Accept").unwrap(),
797 HeaderValue::from_static("application/json")
798 );
799 }
800
801 #[test]
802 fn auth_with_token() {
803 let client = Client::new();
804 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).auth("$Up3rS3crET");
805 assert_eq!(
806 builder.headers.get("Authorization").unwrap(),
807 HeaderValue::from_static("Bearer $Up3rS3crET")
808 );
809 }
810
811 #[test]
812 fn select_assert_query() {
813 let client = Client::new();
814 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).select("some_table");
815 assert_eq!(builder.method, Method::GET);
816 assert_eq!(
817 builder
818 .queries
819 .contains(&("select".to_string(), "some_table".to_string())),
820 true
821 );
822 }
823
824 #[test]
825 fn order_assert_query() {
826 let client = Client::new();
827 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).order("id");
828 assert_eq!(
829 builder
830 .queries
831 .contains(&("order".to_string(), "id".to_string())),
832 true
833 );
834 }
835
836 #[test]
837 fn order_with_options_assert_query() {
838 let client = Client::new();
839 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).order_with_options(
840 "name",
841 Some("cities"),
842 true,
843 false,
844 );
845 assert_eq!(
846 builder
847 .queries
848 .contains(&("cities.order".to_string(), "name.asc.nullslast".to_string())),
849 true
850 );
851 }
852
853 #[test]
854 fn limit_assert_range_header() {
855 let client = Client::new();
856 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).limit(20);
857 assert_eq!(
858 builder.headers.get("Range").unwrap(),
859 HeaderValue::from_static("0-19")
860 );
861 }
862
863 #[test]
864 fn foreign_table_limit_assert_query() {
865 let client = Client::new();
866 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client)
867 .foreign_table_limit(20, "some_table");
868 assert_eq!(
869 builder
870 .queries
871 .contains(&("some_table.limit".to_string(), "20".to_string())),
872 true
873 );
874 }
875
876 #[test]
877 fn range_assert_range_header() {
878 let client = Client::new();
879 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).range(10, 20);
880 assert_eq!(
881 builder.headers.get("Range").unwrap(),
882 HeaderValue::from_static("10-20")
883 );
884 }
885
886 #[test]
887 fn single_assert_accept_header() {
888 let client = Client::new();
889 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).single();
890 assert_eq!(
891 builder.headers.get("Accept").unwrap(),
892 HeaderValue::from_static("application/vnd.pgrst.object+json")
893 );
894 }
895
896 #[test]
897 fn upsert_assert_prefer_header() {
898 let client = Client::new();
899 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).upsert("ignored");
900 assert_eq!(
901 builder.headers.get("Prefer").unwrap(),
902 HeaderValue::from_static("return=representation,resolution=merge-duplicates")
903 );
904 }
905
906 #[test]
907 fn not_rpc_should_not_have_flag() {
908 let client = Client::new();
909 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).select("ignored");
910 assert_eq!(builder.is_rpc, false);
911 }
912
913 #[test]
914 fn rpc_should_have_body_and_flag() {
915 let client = Client::new();
916 let builder =
917 Builder::new(RPC_URL, None, HeaderMap::new(), client).rpc("{\"a\": 1, \"b\": 2}");
918 assert_eq!(builder.body.unwrap(), "{\"a\": 1, \"b\": 2}");
919 assert_eq!(builder.is_rpc, true);
920 }
921
922 #[test]
923 fn chain_filters() -> Result<(), Box<dyn std::error::Error>> {
924 let client = Client::new();
925 let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client)
926 .eq("username", "supabot")
927 .neq("message", "hello world")
928 .gte("channel_id", "1")
929 .select("*");
930
931 let queries = builder.queries;
932 assert_eq!(queries.len(), 4);
933 assert!(queries.contains(&("username".into(), "eq.supabot".into())));
934 assert!(queries.contains(&("message".into(), "neq.hello world".into())));
935 assert!(queries.contains(&("channel_id".into(), "gte.1".into())));
936
937 Ok(())
938 }
939}