1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use derive_more::{Deref, DerefMut};
4use regex::Regex;
5use serde::{
6 de::{self, DeserializeOwned},
7 Deserialize, Deserializer,
8};
9use serde_json::{Map, Value};
10#[cfg(feature = "experimentation")]
11use strum::IntoEnumIterator;
12use superposition_derives::{IsEmpty, QueryParam};
13
14#[cfg(feature = "experimentation")]
15use crate::database::models::experimentation::ExperimentStatusType;
16use crate::IsEmpty;
17
18pub trait QueryParam {
19 fn to_query_param(&self) -> String;
20}
21
22pub trait CustomQuery: Sized {
23 type Inner: DeserializeOwned;
24
25 fn regex_pattern() -> &'static str;
26 fn capture_group() -> &'static str;
27
28 fn query_regex() -> Regex {
29 Regex::new(Self::regex_pattern()).unwrap()
30 }
31
32 fn extract_query(query_string: &str) -> Result<Self, String> {
33 let query_list =
34 serde_urlencoded::from_str::<Vec<(String, String)>>(query_string).map_err(
35 |e| format!("Failed to parse query string: {query_string}. Error: {e}"),
36 )?;
37 let mut query_map = HashMap::new();
38 for (key, value) in query_list {
39 let new_val = query_map
40 .get(&key)
41 .map_or_else(|| value.clone(), |v| format!("{v},{value}"));
42 query_map.insert(key, new_val);
43 }
44 let filtered_query = Self::filter_and_transform_query(query_map);
45 let inner =
46 serde_urlencoded::from_str::<Self::Inner>(&filtered_query).map_err(|e| {
47 format!("Failed to parse query string: {query_string}. Error: {e}")
48 })?;
49 Ok(Self::new(inner))
50 }
51
52 fn filter_and_transform_query(query_map: HashMap<String, String>) -> String {
53 let regex = Self::query_regex();
54 query_map
55 .into_iter()
56 .filter_map(|(k, v)| {
57 Self::extract_key(®ex, &k).map(|pk| format!("{pk}={v}"))
58 })
59 .collect::<Vec<String>>()
60 .join("&")
61 }
62
63 fn extract_key(regex: &Regex, key: &str) -> Option<String> {
64 regex
65 .captures(key)
66 .and_then(|captures| captures.name(Self::capture_group()))
67 .map(|m| m.as_str().to_owned())
68 }
69
70 fn extract_non_empty(query_string: &str) -> Self
71 where
72 Self::Inner: Default + IsEmpty,
73 {
74 let res = Self::extract_query(query_string)
75 .ok()
76 .map(|value| value.into_inner())
77 .filter(|value| !value.is_empty())
78 .unwrap_or_default();
79
80 Self::new(res)
81 }
82
83 fn new(inner: Self::Inner) -> Self;
84 fn into_inner(self) -> Self::Inner;
85}
86
87#[derive(Debug, Clone, PartialEq, Deref, DerefMut)]
89pub struct DimensionQuery<T: DeserializeOwned>(pub T);
90
91impl<T> CustomQuery for DimensionQuery<T>
92where
93 T: DeserializeOwned,
94{
95 type Inner = T;
96
97 fn regex_pattern() -> &'static str {
98 r"dimension\[(?<query_name>.*)\]"
99 }
100
101 fn capture_group() -> &'static str {
102 "query_name"
103 }
104
105 fn into_inner(self) -> T {
106 self.0
107 }
108
109 fn new(inner: Self::Inner) -> Self {
110 Self(inner)
111 }
112}
113
114#[cfg(all(feature = "server", feature = "result"))]
115impl<T> actix_web::FromRequest for DimensionQuery<T>
116where
117 T: DeserializeOwned,
118{
119 type Error = actix_web::Error;
120 type Future = std::future::Ready<Result<Self, Self::Error>>;
121
122 fn from_request(
123 req: &actix_web::HttpRequest,
124 _: &mut actix_web::dev::Payload,
125 ) -> Self::Future {
126 use std::future::ready;
127 ready(
128 Self::extract_query(req.query_string())
129 .map_err(crate::result::AppError::BadArgument)
130 .map_err(Into::into),
131 )
132 }
133}
134
135#[derive(Debug, Clone, Deref, DerefMut)]
137pub struct Query<T: DeserializeOwned>(pub T);
138
139impl<T> CustomQuery for Query<T>
140where
141 T: DeserializeOwned,
142{
143 type Inner = T;
144
145 fn regex_pattern() -> &'static str {
146 r"^(?<query_name>[^\[\]]+)$"
147 }
148
149 fn capture_group() -> &'static str {
150 "query_name"
151 }
152
153 fn into_inner(self) -> T {
154 self.0
155 }
156
157 fn new(inner: Self::Inner) -> Self {
158 Self(inner)
159 }
160}
161
162#[cfg(all(feature = "server", feature = "result"))]
163impl<T> actix_web::FromRequest for Query<T>
164where
165 T: DeserializeOwned,
166{
167 type Error = actix_web::Error;
168 type Future = std::future::Ready<Result<Self, Self::Error>>;
169
170 fn from_request(
171 req: &actix_web::HttpRequest,
172 _: &mut actix_web::dev::Payload,
173 ) -> Self::Future {
174 use std::future::ready;
175 ready(
176 Self::extract_query(req.query_string())
177 .map_err(crate::result::AppError::BadArgument)
178 .map_err(Into::into),
179 )
180 }
181}
182
183#[derive(Deserialize, Deref, DerefMut, Clone, PartialEq, Default)]
185#[cfg_attr(test, derive(Debug))]
186#[serde(from = "HashMap<String,String>")]
187pub struct QueryMap(Map<String, Value>);
188
189impl IsEmpty for QueryMap {
190 fn is_empty(&self) -> bool {
191 self.0.is_empty()
192 }
193}
194
195impl IntoIterator for QueryMap {
196 type Item = (String, Value);
197 type IntoIter = <Map<String, Value> as IntoIterator>::IntoIter;
198
199 fn into_iter(self) -> Self::IntoIter {
200 self.0.into_iter()
201 }
202}
203
204impl From<Map<String, Value>> for QueryMap {
205 fn from(value: Map<String, Value>) -> Self {
206 Self(value)
207 }
208}
209
210impl From<HashMap<String, String>> for QueryMap {
211 fn from(value: HashMap<String, String>) -> Self {
212 let value = value
213 .into_iter()
214 .map(|(key, value)| (key, value.parse().unwrap_or(Value::String(value))))
215 .collect::<Map<_, _>>();
216
217 Self(value)
218 }
219}
220
221impl QueryParam for DimensionQuery<QueryMap> {
222 fn to_query_param(&self) -> String {
223 let parts = self
224 .clone()
225 .into_inner()
226 .iter()
227 .map(|(key, value)| format!("dimension[{key}]={value}"))
228 .collect::<Vec<_>>();
229
230 parts.join("&")
231 }
232}
233
234impl From<Map<String, Value>> for DimensionQuery<QueryMap> {
235 fn from(value: Map<String, Value>) -> Self {
236 Self(QueryMap::from(value))
237 }
238}
239
240impl Default for DimensionQuery<QueryMap> {
241 fn default() -> Self {
242 Self::from(Map::new())
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, IsEmpty, QueryParam)]
247pub struct PaginationParams {
248 pub count: Option<i64>,
249 pub page: Option<i64>,
250 pub all: Option<bool>,
251}
252
253impl PaginationParams {
254 pub fn all_entries() -> Self {
255 Self {
256 count: None,
257 page: None,
258 all: Some(true),
259 }
260 }
261
262 pub fn reset_page(&mut self) {
263 self.page = if let Some(true) = self.all {
264 None
265 } else {
266 Some(1)
267 };
268 }
269}
270
271impl Default for PaginationParams {
272 fn default() -> Self {
273 Self {
274 count: Some(10),
275 page: Some(1),
276 all: None,
277 }
278 }
279}
280
281impl<'de> Deserialize<'de> for PaginationParams {
282 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
283 where
284 D: Deserializer<'de>,
285 {
286 #[derive(Deserialize)]
287 struct Helper {
288 count: Option<i64>,
289 page: Option<i64>,
290 all: Option<bool>,
291 }
292
293 let helper = Helper::deserialize(deserializer)?;
294
295 if helper.all == Some(true) && (helper.count.is_some() || helper.page.is_some()) {
296 return Err(de::Error::custom("When 'all' is true, 'count' and 'page' parameters should not be provided"));
297 }
298
299 if let Some(count) = helper.count {
300 if count <= 0 {
301 return Err(de::Error::custom("Count should be greater than 0."));
302 }
303 }
304
305 if let Some(page) = helper.page {
306 if page <= 0 {
307 return Err(de::Error::custom("Page should be greater than 0."));
308 }
309 }
310
311 Ok(Self {
312 count: helper.count,
313 page: helper.page,
314 all: helper.all,
315 })
316 }
317}
318
319#[derive(Debug, Clone, Deref, PartialEq, Default)]
320#[deref(forward)]
321pub struct CommaSeparatedQParams<T: Display + FromStr>(pub Vec<T>);
322
323impl<T: Display + FromStr> Display for CommaSeparatedQParams<T> {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 let str_arr = self.0.iter().map(|s| s.to_string()).collect::<Vec<_>>();
326 write!(f, "{}", str_arr.join(","))
327 }
328}
329
330impl<'de, T: Display + FromStr> Deserialize<'de> for CommaSeparatedQParams<T> {
331 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
332 where
333 D: Deserializer<'de>,
334 {
335 let items = String::deserialize(deserializer)?
336 .split(',')
337 .map(|item| item.trim().to_string())
338 .map(|s| T::from_str(&s))
339 .collect::<Result<Vec<_>, _>>()
340 .map_err(|_| {
341 serde::de::Error::custom(String::from("Error in converting type"))
342 })?;
343 Ok(Self(items))
344 }
345}
346
347pub type CommaSeparatedStringQParams = CommaSeparatedQParams<String>;
348
349#[cfg(feature = "experimentation")]
350impl Default for CommaSeparatedQParams<ExperimentStatusType> {
351 fn default() -> Self {
352 Self(ExperimentStatusType::iter().collect())
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use std::collections::HashMap;
359
360 use serde_json::{from_value, json};
361
362 use crate::custom_query::QueryMap;
363
364 #[test]
365 fn querymap_from_hashmap() {
366 let hashmap: HashMap<String, String> = from_value(json!({
367 "key1": "123",
368 "key2": "\"123\"",
369 "key3": "test",
370 "key4": "true",
371 "key5": "\"true\"",
372 "key6": "null",
373 "key7": "\"null\""
374 }))
375 .unwrap();
376
377 let map = json!({
378 "key1": 123,
379 "key2": "123",
380 "key3": "test",
381 "key4": true,
382 "key5": "true",
383 "key6": null,
384 "key7": "null"
385 })
386 .as_object()
387 .unwrap()
388 .clone();
389
390 assert_eq!(QueryMap::from(hashmap), QueryMap(map));
391 }
392}