Skip to main content

superposition_types/
custom_query.rs

1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use derive_more::{Deref, DerefMut};
4use regex::Regex;
5use serde::{
6    de::{self, DeserializeOwned},
7    Deserialize, Deserializer,
8};
9use serde_json::{Map, Value};
10#[cfg(feature = "experimentation")]
11use strum::IntoEnumIterator;
12use superposition_derives::{IsEmpty, QueryParam};
13
14#[cfg(feature = "experimentation")]
15use crate::database::models::experimentation::ExperimentStatusType;
16use crate::IsEmpty;
17
18pub trait QueryParam {
19    fn to_query_param(&self) -> String;
20}
21
22pub trait CustomQuery: Sized {
23    type Inner: DeserializeOwned;
24
25    fn regex_pattern() -> &'static str;
26    fn capture_group() -> &'static str;
27
28    fn query_regex() -> Regex {
29        Regex::new(Self::regex_pattern()).unwrap()
30    }
31
32    fn extract_query(query_string: &str) -> Result<Self, String> {
33        let query_list =
34            serde_urlencoded::from_str::<Vec<(String, String)>>(query_string).map_err(
35                |e| format!("Failed to parse query string: {query_string}. Error: {e}"),
36            )?;
37        let mut query_map = HashMap::new();
38        for (key, value) in query_list {
39            let new_val = query_map
40                .get(&key)
41                .map_or_else(|| value.clone(), |v| format!("{v},{value}"));
42            query_map.insert(key, new_val);
43        }
44        let filtered_query = Self::filter_and_transform_query(query_map);
45        let inner =
46            serde_urlencoded::from_str::<Self::Inner>(&filtered_query).map_err(|e| {
47                format!("Failed to parse query string: {query_string}. Error: {e}")
48            })?;
49        Ok(Self::new(inner))
50    }
51
52    fn filter_and_transform_query(query_map: HashMap<String, String>) -> String {
53        let regex = Self::query_regex();
54        query_map
55            .into_iter()
56            .filter_map(|(k, v)| {
57                Self::extract_key(&regex, &k).map(|pk| format!("{pk}={v}"))
58            })
59            .collect::<Vec<String>>()
60            .join("&")
61    }
62
63    fn extract_key(regex: &Regex, key: &str) -> Option<String> {
64        regex
65            .captures(key)
66            .and_then(|captures| captures.name(Self::capture_group()))
67            .map(|m| m.as_str().to_owned())
68    }
69
70    fn extract_non_empty(query_string: &str) -> Self
71    where
72        Self::Inner: Default + IsEmpty,
73    {
74        let res = Self::extract_query(query_string)
75            .ok()
76            .map(|value| value.into_inner())
77            .filter(|value| !value.is_empty())
78            .unwrap_or_default();
79
80        Self::new(res)
81    }
82
83    fn new(inner: Self::Inner) -> Self;
84    fn into_inner(self) -> Self::Inner;
85}
86
87/// Provides struct to extract those query params from the request which are `wrapped` in `dimension[param_name]`
88#[derive(Debug, Clone, PartialEq, Deref, DerefMut)]
89pub struct DimensionQuery<T: DeserializeOwned>(pub T);
90
91impl<T> CustomQuery for DimensionQuery<T>
92where
93    T: DeserializeOwned,
94{
95    type Inner = T;
96
97    fn regex_pattern() -> &'static str {
98        r"dimension\[(?<query_name>.*)\]"
99    }
100
101    fn capture_group() -> &'static str {
102        "query_name"
103    }
104
105    fn into_inner(self) -> T {
106        self.0
107    }
108
109    fn new(inner: Self::Inner) -> Self {
110        Self(inner)
111    }
112}
113
114#[cfg(all(feature = "server", feature = "result"))]
115impl<T> actix_web::FromRequest for DimensionQuery<T>
116where
117    T: DeserializeOwned,
118{
119    type Error = actix_web::Error;
120    type Future = std::future::Ready<Result<Self, Self::Error>>;
121
122    fn from_request(
123        req: &actix_web::HttpRequest,
124        _: &mut actix_web::dev::Payload,
125    ) -> Self::Future {
126        use std::future::ready;
127        ready(
128            Self::extract_query(req.query_string())
129                .map_err(crate::result::AppError::BadArgument)
130                .map_err(Into::into),
131        )
132    }
133}
134
135/// Provides struct to extract those query params from the request which are `not wrapped` in contrusts like `platform[param_name]`
136#[derive(Debug, Clone, Deref, DerefMut)]
137pub struct Query<T: DeserializeOwned>(pub T);
138
139impl<T> CustomQuery for Query<T>
140where
141    T: DeserializeOwned,
142{
143    type Inner = T;
144
145    fn regex_pattern() -> &'static str {
146        r"^(?<query_name>[^\[\]]+)$"
147    }
148
149    fn capture_group() -> &'static str {
150        "query_name"
151    }
152
153    fn into_inner(self) -> T {
154        self.0
155    }
156
157    fn new(inner: Self::Inner) -> Self {
158        Self(inner)
159    }
160}
161
162#[cfg(all(feature = "server", feature = "result"))]
163impl<T> actix_web::FromRequest for Query<T>
164where
165    T: DeserializeOwned,
166{
167    type Error = actix_web::Error;
168    type Future = std::future::Ready<Result<Self, Self::Error>>;
169
170    fn from_request(
171        req: &actix_web::HttpRequest,
172        _: &mut actix_web::dev::Payload,
173    ) -> Self::Future {
174        use std::future::ready;
175        ready(
176            Self::extract_query(req.query_string())
177                .map_err(crate::result::AppError::BadArgument)
178                .map_err(Into::into),
179        )
180    }
181}
182
183/// Provides struct to `Deserialize` `HashMap<String, String>` as `serde_json::Map<String, serde_json::Value>`
184#[derive(Deserialize, Deref, DerefMut, Clone, PartialEq, Default)]
185#[cfg_attr(test, derive(Debug))]
186#[serde(from = "HashMap<String,String>")]
187pub struct QueryMap(Map<String, Value>);
188
189impl IsEmpty for QueryMap {
190    fn is_empty(&self) -> bool {
191        self.0.is_empty()
192    }
193}
194
195impl IntoIterator for QueryMap {
196    type Item = (String, Value);
197    type IntoIter = <Map<String, Value> as IntoIterator>::IntoIter;
198
199    fn into_iter(self) -> Self::IntoIter {
200        self.0.into_iter()
201    }
202}
203
204impl From<Map<String, Value>> for QueryMap {
205    fn from(value: Map<String, Value>) -> Self {
206        Self(value)
207    }
208}
209
210impl From<HashMap<String, String>> for QueryMap {
211    fn from(value: HashMap<String, String>) -> Self {
212        let value = value
213            .into_iter()
214            .map(|(key, value)| (key, value.parse().unwrap_or(Value::String(value))))
215            .collect::<Map<_, _>>();
216
217        Self(value)
218    }
219}
220
221impl QueryParam for DimensionQuery<QueryMap> {
222    fn to_query_param(&self) -> String {
223        let parts = self
224            .clone()
225            .into_inner()
226            .iter()
227            .map(|(key, value)| format!("dimension[{key}]={value}"))
228            .collect::<Vec<_>>();
229
230        parts.join("&")
231    }
232}
233
234impl From<Map<String, Value>> for DimensionQuery<QueryMap> {
235    fn from(value: Map<String, Value>) -> Self {
236        Self(QueryMap::from(value))
237    }
238}
239
240impl Default for DimensionQuery<QueryMap> {
241    fn default() -> Self {
242        Self::from(Map::new())
243    }
244}
245
246#[derive(Debug, Clone, PartialEq, IsEmpty, QueryParam)]
247pub struct PaginationParams {
248    pub count: Option<i64>,
249    pub page: Option<i64>,
250    pub all: Option<bool>,
251}
252
253impl PaginationParams {
254    pub fn all_entries() -> Self {
255        Self {
256            count: None,
257            page: None,
258            all: Some(true),
259        }
260    }
261
262    pub fn reset_page(&mut self) {
263        self.page = if let Some(true) = self.all {
264            None
265        } else {
266            Some(1)
267        };
268    }
269}
270
271impl Default for PaginationParams {
272    fn default() -> Self {
273        Self {
274            count: Some(10),
275            page: Some(1),
276            all: None,
277        }
278    }
279}
280
281impl<'de> Deserialize<'de> for PaginationParams {
282    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283    where
284        D: Deserializer<'de>,
285    {
286        #[derive(Deserialize)]
287        struct Helper {
288            count: Option<i64>,
289            page: Option<i64>,
290            all: Option<bool>,
291        }
292
293        let helper = Helper::deserialize(deserializer)?;
294
295        if helper.all == Some(true) && (helper.count.is_some() || helper.page.is_some()) {
296            return Err(de::Error::custom("When 'all' is true, 'count' and 'page' parameters should not be provided"));
297        }
298
299        if let Some(count) = helper.count {
300            if count <= 0 {
301                return Err(de::Error::custom("Count should be greater than 0."));
302            }
303        }
304
305        if let Some(page) = helper.page {
306            if page <= 0 {
307                return Err(de::Error::custom("Page should be greater than 0."));
308            }
309        }
310
311        Ok(Self {
312            count: helper.count,
313            page: helper.page,
314            all: helper.all,
315        })
316    }
317}
318
319#[derive(Debug, Clone, Deref, PartialEq, Default)]
320#[deref(forward)]
321pub struct CommaSeparatedQParams<T: Display + FromStr>(pub Vec<T>);
322
323impl<T: Display + FromStr> Display for CommaSeparatedQParams<T> {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        let str_arr = self.0.iter().map(|s| s.to_string()).collect::<Vec<_>>();
326        write!(f, "{}", str_arr.join(","))
327    }
328}
329
330impl<'de, T: Display + FromStr> Deserialize<'de> for CommaSeparatedQParams<T> {
331    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
332    where
333        D: Deserializer<'de>,
334    {
335        let items = String::deserialize(deserializer)?
336            .split(',')
337            .map(|item| item.trim().to_string())
338            .map(|s| T::from_str(&s))
339            .collect::<Result<Vec<_>, _>>()
340            .map_err(|_| {
341                serde::de::Error::custom(String::from("Error in converting type"))
342            })?;
343        Ok(Self(items))
344    }
345}
346
347pub type CommaSeparatedStringQParams = CommaSeparatedQParams<String>;
348
349#[cfg(feature = "experimentation")]
350impl Default for CommaSeparatedQParams<ExperimentStatusType> {
351    fn default() -> Self {
352        Self(ExperimentStatusType::iter().collect())
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use std::collections::HashMap;
359
360    use serde_json::{from_value, json};
361
362    use crate::custom_query::QueryMap;
363
364    #[test]
365    fn querymap_from_hashmap() {
366        let hashmap: HashMap<String, String> = from_value(json!({
367            "key1": "123",
368            "key2": "\"123\"",
369            "key3": "test",
370            "key4": "true",
371            "key5": "\"true\"",
372            "key6": "null",
373            "key7": "\"null\""
374        }))
375        .unwrap();
376
377        let map = json!({
378            "key1": 123,
379            "key2": "123",
380            "key3": "test",
381            "key4": true,
382            "key5": "true",
383            "key6": null,
384            "key7": "null"
385        })
386        .as_object()
387        .unwrap()
388        .clone();
389
390        assert_eq!(QueryMap::from(hashmap), QueryMap(map));
391    }
392}