superposition_types/
custom_query.rs

1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use core::fmt;
4use derive_more::{Deref, DerefMut};
5use regex::Regex;
6use serde::{
7    de::{self, DeserializeOwned},
8    Deserialize, Deserializer, Serialize,
9};
10use serde_json::{Map, Value};
11#[cfg(feature = "experimentation")]
12use strum::IntoEnumIterator;
13use superposition_derives::IsEmpty;
14
15#[cfg(feature = "experimentation")]
16use crate::database::models::experimentation::ExperimentStatusType;
17use crate::IsEmpty;
18
19pub trait CustomQuery: Sized {
20    type Inner: DeserializeOwned;
21
22    fn regex_pattern() -> &'static str;
23    fn capture_group() -> &'static str;
24
25    fn query_regex() -> Regex {
26        Regex::new(Self::regex_pattern()).unwrap()
27    }
28
29    fn extract_query(query_string: &str) -> Result<Self, String> {
30        let query_map =
31            serde_urlencoded::from_str::<HashMap<String, String>>(query_string).map_err(
32                |e| format!("Failed to parse query string: {query_string}. Error: {e}"),
33            )?;
34        let filtered_query = Self::filter_and_transform_query(query_map);
35        let inner =
36            serde_urlencoded::from_str::<Self::Inner>(&filtered_query).map_err(|e| {
37                format!("Failed to parse query string: {query_string}. Error: {e}")
38            })?;
39        Ok(Self::new(inner))
40    }
41
42    fn filter_and_transform_query(query_map: HashMap<String, String>) -> String {
43        let regex = Self::query_regex();
44        query_map
45            .into_iter()
46            .filter_map(|(k, v)| {
47                Self::extract_key(&regex, &k).map(|pk| format!("{pk}={v}"))
48            })
49            .collect::<Vec<String>>()
50            .join("&")
51    }
52
53    fn extract_key(regex: &Regex, key: &str) -> Option<String> {
54        regex
55            .captures(key)
56            .and_then(|captures| captures.name(Self::capture_group()))
57            .map(|m| m.as_str().to_owned())
58    }
59
60    fn extract_non_empty(query_string: &str) -> Self
61    where
62        Self::Inner: Default + IsEmpty,
63    {
64        let res = Self::extract_query(query_string)
65            .ok()
66            .map(|value| value.into_inner())
67            .filter(|value| !value.is_empty())
68            .unwrap_or_default();
69
70        Self::new(res)
71    }
72
73    fn new(inner: Self::Inner) -> Self;
74    fn into_inner(self) -> Self::Inner;
75}
76
77/// Provides struct to extract those query params from the request which are `wrapped` in `dimension[param_name]`
78#[derive(Debug, Clone, PartialEq, Deref, DerefMut)]
79pub struct DimensionQuery<T: DeserializeOwned>(pub T);
80
81impl<T> CustomQuery for DimensionQuery<T>
82where
83    T: DeserializeOwned,
84{
85    type Inner = T;
86
87    fn regex_pattern() -> &'static str {
88        r"dimension\[(?<query_name>.*)\]"
89    }
90
91    fn capture_group() -> &'static str {
92        "query_name"
93    }
94
95    fn into_inner(self) -> T {
96        self.0
97    }
98
99    fn new(inner: Self::Inner) -> Self {
100        Self(inner)
101    }
102}
103
104#[cfg(feature = "server")]
105impl<T> actix_web::FromRequest for DimensionQuery<T>
106where
107    T: DeserializeOwned,
108{
109    type Error = actix_web::Error;
110    type Future = std::future::Ready<Result<Self, Self::Error>>;
111
112    fn from_request(
113        req: &actix_web::HttpRequest,
114        _: &mut actix_web::dev::Payload,
115    ) -> Self::Future {
116        use std::future::ready;
117        ready(
118            Self::extract_query(req.query_string())
119                .map_err(actix_web::error::ErrorBadRequest),
120        )
121    }
122}
123
124/// Provides struct to extract those query params from the request which are `not wrapped` in contrusts like `platform[param_name]`
125#[derive(Debug, Clone)]
126pub struct Query<T: DeserializeOwned>(pub T);
127
128impl<T> CustomQuery for Query<T>
129where
130    T: DeserializeOwned,
131{
132    type Inner = T;
133
134    fn regex_pattern() -> &'static str {
135        r"^(?<query_name>[^\[\]]+)$"
136    }
137
138    fn capture_group() -> &'static str {
139        "query_name"
140    }
141
142    fn into_inner(self) -> T {
143        self.0
144    }
145
146    fn new(inner: Self::Inner) -> Self {
147        Self(inner)
148    }
149}
150
151#[cfg(feature = "server")]
152impl<T> actix_web::FromRequest for Query<T>
153where
154    T: DeserializeOwned,
155{
156    type Error = actix_web::Error;
157    type Future = std::future::Ready<Result<Self, Self::Error>>;
158
159    fn from_request(
160        req: &actix_web::HttpRequest,
161        _: &mut actix_web::dev::Payload,
162    ) -> Self::Future {
163        use std::future::ready;
164        ready(
165            Self::extract_query(req.query_string())
166                .map_err(actix_web::error::ErrorBadRequest),
167        )
168    }
169}
170
171/// Provides struct to `Deserialize` `HashMap<String, String>` as `serde_json::Map<String, serde_json::Value>`
172#[derive(Deserialize, Deref, DerefMut, Clone, PartialEq, Default)]
173#[cfg_attr(test, derive(Debug))]
174#[serde(from = "HashMap<String,String>")]
175pub struct QueryMap(Map<String, Value>);
176
177impl IsEmpty for QueryMap {
178    fn is_empty(&self) -> bool {
179        self.0.is_empty()
180    }
181}
182
183impl From<Map<String, Value>> for QueryMap {
184    fn from(value: Map<String, Value>) -> Self {
185        Self(value)
186    }
187}
188
189impl From<HashMap<String, String>> for QueryMap {
190    fn from(value: HashMap<String, String>) -> Self {
191        let value = value
192            .into_iter()
193            .map(|(key, value)| (key, value.parse().unwrap_or(Value::String(value))))
194            .collect::<Map<_, _>>();
195
196        Self(value)
197    }
198}
199
200impl Display for DimensionQuery<QueryMap> {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        let parts = self
203            .clone()
204            .into_inner()
205            .iter()
206            .map(|(key, value)| format!("dimension[{key}]={value}"))
207            .collect::<Vec<_>>();
208
209        write!(f, "{}", parts.join("&"))
210    }
211}
212
213impl From<Map<String, Value>> for DimensionQuery<QueryMap> {
214    fn from(value: Map<String, Value>) -> Self {
215        Self(QueryMap::from(value))
216    }
217}
218
219impl Default for DimensionQuery<QueryMap> {
220    fn default() -> Self {
221        Self::from(Map::new())
222    }
223}
224
225#[derive(Debug, Clone, PartialEq, IsEmpty)]
226pub struct PaginationParams {
227    pub count: Option<i64>,
228    pub page: Option<i64>,
229    pub all: Option<bool>,
230}
231
232impl PaginationParams {
233    pub fn all_entries() -> Self {
234        Self {
235            count: None,
236            page: None,
237            all: Some(true),
238        }
239    }
240
241    pub fn reset_page(&mut self) {
242        self.page = if let Some(true) = self.all {
243            None
244        } else {
245            Some(1)
246        };
247    }
248}
249
250impl Default for PaginationParams {
251    fn default() -> Self {
252        Self {
253            count: Some(10),
254            page: Some(1),
255            all: None,
256        }
257    }
258}
259
260impl Display for PaginationParams {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        let mut parts = vec![];
263
264        if let Some(page) = self.page {
265            parts.push(format!("page={}", page));
266        }
267
268        if let Some(count) = self.count {
269            parts.push(format!("count={}", count));
270        }
271
272        if let Some(all) = self.all {
273            parts.push(format!("all={}", all));
274        }
275
276        write!(f, "{}", parts.join("&"))
277    }
278}
279
280impl<'de> Deserialize<'de> for PaginationParams {
281    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
282    where
283        D: Deserializer<'de>,
284    {
285        #[derive(Deserialize)]
286        struct Helper {
287            count: Option<i64>,
288            page: Option<i64>,
289            all: Option<bool>,
290        }
291
292        let helper = Helper::deserialize(deserializer)?;
293
294        if helper.all == Some(true) && (helper.count.is_some() || helper.page.is_some()) {
295            return Err(de::Error::custom("When 'all' is true, 'count' and 'page' parameters should not be provided"));
296        }
297
298        if let Some(count) = helper.count {
299            if count <= 0 {
300                return Err(de::Error::custom("Count should be greater than 0."));
301            }
302        }
303
304        if let Some(page) = helper.page {
305            if page <= 0 {
306                return Err(de::Error::custom("Page should be greater than 0."));
307            }
308        }
309
310        Ok(Self {
311            count: helper.count,
312            page: helper.page,
313            all: helper.all,
314        })
315    }
316}
317
318#[derive(Debug, Clone, Deref, PartialEq, Default)]
319#[deref(forward)]
320pub struct CommaSeparatedQParams<T: Display + FromStr>(pub Vec<T>);
321
322impl<T: Display + FromStr> Display for CommaSeparatedQParams<T> {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        let str_arr = self.0.iter().map(|s| s.to_string()).collect::<Vec<_>>();
325        write!(f, "{}", str_arr.join(","))
326    }
327}
328
329impl<'de, T: Display + FromStr> Deserialize<'de> for CommaSeparatedQParams<T> {
330    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
331    where
332        D: Deserializer<'de>,
333    {
334        let items = String::deserialize(deserializer)?
335            .split(',')
336            .map(|item| item.trim().to_string())
337            .map(|s| T::from_str(&s))
338            .collect::<Result<Vec<_>, _>>()
339            .map_err(|_| {
340                serde::de::Error::custom(String::from("Error in converting type"))
341            })?;
342        Ok(Self(items))
343    }
344}
345
346impl<T: Display + FromStr> Serialize for CommaSeparatedQParams<T> {
347    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
348    where
349        S: serde::Serializer,
350    {
351        serializer.serialize_str(&self.to_string())
352    }
353}
354
355pub type CommaSeparatedStringQParams = CommaSeparatedQParams<String>;
356
357#[cfg(feature = "experimentation")]
358impl Default for CommaSeparatedQParams<ExperimentStatusType> {
359    fn default() -> Self {
360        Self(ExperimentStatusType::iter().collect())
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use std::collections::HashMap;
367
368    use serde_json::{from_value, json};
369
370    use crate::custom_query::QueryMap;
371
372    #[test]
373    fn querymap_from_hashmap() {
374        let hashmap: HashMap<String, String> = from_value(json!({
375            "key1": "123",
376            "key2": "\"123\"",
377            "key3": "test",
378            "key4": "true",
379            "key5": "\"true\"",
380            "key6": "null",
381            "key7": "\"null\""
382        }))
383        .unwrap();
384
385        let map = json!({
386            "key1": 123,
387            "key2": "123",
388            "key3": "test",
389            "key4": true,
390            "key5": "true",
391            "key6": null,
392            "key7": "null"
393        })
394        .as_object()
395        .unwrap()
396        .clone();
397
398        assert_eq!(QueryMap::from(hashmap), QueryMap(map));
399    }
400}