1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use core::fmt;
4use derive_more::{Deref, DerefMut};
5use regex::Regex;
6use serde::{
7 de::{self, DeserializeOwned},
8 Deserialize, Deserializer, Serialize,
9};
10use serde_json::{Map, Value};
11#[cfg(feature = "experimentation")]
12use strum::IntoEnumIterator;
13use superposition_derives::IsEmpty;
14
15#[cfg(feature = "experimentation")]
16use crate::database::models::experimentation::ExperimentStatusType;
17use crate::IsEmpty;
18
19pub trait CustomQuery: Sized {
20 type Inner: DeserializeOwned;
21
22 fn regex_pattern() -> &'static str;
23 fn capture_group() -> &'static str;
24
25 fn query_regex() -> Regex {
26 Regex::new(Self::regex_pattern()).unwrap()
27 }
28
29 fn extract_query(query_string: &str) -> Result<Self, String> {
30 let query_map =
31 serde_urlencoded::from_str::<HashMap<String, String>>(query_string).map_err(
32 |e| format!("Failed to parse query string: {query_string}. Error: {e}"),
33 )?;
34 let filtered_query = Self::filter_and_transform_query(query_map);
35 let inner =
36 serde_urlencoded::from_str::<Self::Inner>(&filtered_query).map_err(|e| {
37 format!("Failed to parse query string: {query_string}. Error: {e}")
38 })?;
39 Ok(Self::new(inner))
40 }
41
42 fn filter_and_transform_query(query_map: HashMap<String, String>) -> String {
43 let regex = Self::query_regex();
44 query_map
45 .into_iter()
46 .filter_map(|(k, v)| {
47 Self::extract_key(®ex, &k).map(|pk| format!("{pk}={v}"))
48 })
49 .collect::<Vec<String>>()
50 .join("&")
51 }
52
53 fn extract_key(regex: &Regex, key: &str) -> Option<String> {
54 regex
55 .captures(key)
56 .and_then(|captures| captures.name(Self::capture_group()))
57 .map(|m| m.as_str().to_owned())
58 }
59
60 fn extract_non_empty(query_string: &str) -> Self
61 where
62 Self::Inner: Default + IsEmpty,
63 {
64 let res = Self::extract_query(query_string)
65 .ok()
66 .map(|value| value.into_inner())
67 .filter(|value| !value.is_empty())
68 .unwrap_or_default();
69
70 Self::new(res)
71 }
72
73 fn new(inner: Self::Inner) -> Self;
74 fn into_inner(self) -> Self::Inner;
75}
76
77#[derive(Debug, Clone, PartialEq, Deref, DerefMut)]
79pub struct DimensionQuery<T: DeserializeOwned>(pub T);
80
81impl<T> CustomQuery for DimensionQuery<T>
82where
83 T: DeserializeOwned,
84{
85 type Inner = T;
86
87 fn regex_pattern() -> &'static str {
88 r"dimension\[(?<query_name>.*)\]"
89 }
90
91 fn capture_group() -> &'static str {
92 "query_name"
93 }
94
95 fn into_inner(self) -> T {
96 self.0
97 }
98
99 fn new(inner: Self::Inner) -> Self {
100 Self(inner)
101 }
102}
103
104#[cfg(feature = "server")]
105impl<T> actix_web::FromRequest for DimensionQuery<T>
106where
107 T: DeserializeOwned,
108{
109 type Error = actix_web::Error;
110 type Future = std::future::Ready<Result<Self, Self::Error>>;
111
112 fn from_request(
113 req: &actix_web::HttpRequest,
114 _: &mut actix_web::dev::Payload,
115 ) -> Self::Future {
116 use std::future::ready;
117 ready(
118 Self::extract_query(req.query_string())
119 .map_err(actix_web::error::ErrorBadRequest),
120 )
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct Query<T: DeserializeOwned>(pub T);
127
128impl<T> CustomQuery for Query<T>
129where
130 T: DeserializeOwned,
131{
132 type Inner = T;
133
134 fn regex_pattern() -> &'static str {
135 r"^(?<query_name>[^\[\]]+)$"
136 }
137
138 fn capture_group() -> &'static str {
139 "query_name"
140 }
141
142 fn into_inner(self) -> T {
143 self.0
144 }
145
146 fn new(inner: Self::Inner) -> Self {
147 Self(inner)
148 }
149}
150
151#[cfg(feature = "server")]
152impl<T> actix_web::FromRequest for Query<T>
153where
154 T: DeserializeOwned,
155{
156 type Error = actix_web::Error;
157 type Future = std::future::Ready<Result<Self, Self::Error>>;
158
159 fn from_request(
160 req: &actix_web::HttpRequest,
161 _: &mut actix_web::dev::Payload,
162 ) -> Self::Future {
163 use std::future::ready;
164 ready(
165 Self::extract_query(req.query_string())
166 .map_err(actix_web::error::ErrorBadRequest),
167 )
168 }
169}
170
171#[derive(Deserialize, Deref, DerefMut, Clone, PartialEq, Default)]
173#[cfg_attr(test, derive(Debug))]
174#[serde(from = "HashMap<String,String>")]
175pub struct QueryMap(Map<String, Value>);
176
177impl IsEmpty for QueryMap {
178 fn is_empty(&self) -> bool {
179 self.0.is_empty()
180 }
181}
182
183impl From<Map<String, Value>> for QueryMap {
184 fn from(value: Map<String, Value>) -> Self {
185 Self(value)
186 }
187}
188
189impl From<HashMap<String, String>> for QueryMap {
190 fn from(value: HashMap<String, String>) -> Self {
191 let value = value
192 .into_iter()
193 .map(|(key, value)| (key, value.parse().unwrap_or(Value::String(value))))
194 .collect::<Map<_, _>>();
195
196 Self(value)
197 }
198}
199
200impl Display for DimensionQuery<QueryMap> {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 let parts = self
203 .clone()
204 .into_inner()
205 .iter()
206 .map(|(key, value)| format!("dimension[{key}]={value}"))
207 .collect::<Vec<_>>();
208
209 write!(f, "{}", parts.join("&"))
210 }
211}
212
213impl From<Map<String, Value>> for DimensionQuery<QueryMap> {
214 fn from(value: Map<String, Value>) -> Self {
215 Self(QueryMap::from(value))
216 }
217}
218
219impl Default for DimensionQuery<QueryMap> {
220 fn default() -> Self {
221 Self::from(Map::new())
222 }
223}
224
225#[derive(Debug, Clone, PartialEq, IsEmpty)]
226pub struct PaginationParams {
227 pub count: Option<i64>,
228 pub page: Option<i64>,
229 pub all: Option<bool>,
230}
231
232impl PaginationParams {
233 pub fn all_entries() -> Self {
234 Self {
235 count: None,
236 page: None,
237 all: Some(true),
238 }
239 }
240
241 pub fn reset_page(&mut self) {
242 self.page = if let Some(true) = self.all {
243 None
244 } else {
245 Some(1)
246 };
247 }
248}
249
250impl Default for PaginationParams {
251 fn default() -> Self {
252 Self {
253 count: Some(10),
254 page: Some(1),
255 all: None,
256 }
257 }
258}
259
260impl Display for PaginationParams {
261 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262 let mut parts = vec![];
263
264 if let Some(page) = self.page {
265 parts.push(format!("page={}", page));
266 }
267
268 if let Some(count) = self.count {
269 parts.push(format!("count={}", count));
270 }
271
272 if let Some(all) = self.all {
273 parts.push(format!("all={}", all));
274 }
275
276 write!(f, "{}", parts.join("&"))
277 }
278}
279
280impl<'de> Deserialize<'de> for PaginationParams {
281 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
282 where
283 D: Deserializer<'de>,
284 {
285 #[derive(Deserialize)]
286 struct Helper {
287 count: Option<i64>,
288 page: Option<i64>,
289 all: Option<bool>,
290 }
291
292 let helper = Helper::deserialize(deserializer)?;
293
294 if helper.all == Some(true) && (helper.count.is_some() || helper.page.is_some()) {
295 return Err(de::Error::custom("When 'all' is true, 'count' and 'page' parameters should not be provided"));
296 }
297
298 if let Some(count) = helper.count {
299 if count <= 0 {
300 return Err(de::Error::custom("Count should be greater than 0."));
301 }
302 }
303
304 if let Some(page) = helper.page {
305 if page <= 0 {
306 return Err(de::Error::custom("Page should be greater than 0."));
307 }
308 }
309
310 Ok(Self {
311 count: helper.count,
312 page: helper.page,
313 all: helper.all,
314 })
315 }
316}
317
318#[derive(Debug, Clone, Deref, PartialEq, Default)]
319#[deref(forward)]
320pub struct CommaSeparatedQParams<T: Display + FromStr>(pub Vec<T>);
321
322impl<T: Display + FromStr> Display for CommaSeparatedQParams<T> {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 let str_arr = self.0.iter().map(|s| s.to_string()).collect::<Vec<_>>();
325 write!(f, "{}", str_arr.join(","))
326 }
327}
328
329impl<'de, T: Display + FromStr> Deserialize<'de> for CommaSeparatedQParams<T> {
330 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
331 where
332 D: Deserializer<'de>,
333 {
334 let items = String::deserialize(deserializer)?
335 .split(',')
336 .map(|item| item.trim().to_string())
337 .map(|s| T::from_str(&s))
338 .collect::<Result<Vec<_>, _>>()
339 .map_err(|_| {
340 serde::de::Error::custom(String::from("Error in converting type"))
341 })?;
342 Ok(Self(items))
343 }
344}
345
346impl<T: Display + FromStr> Serialize for CommaSeparatedQParams<T> {
347 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
348 where
349 S: serde::Serializer,
350 {
351 serializer.serialize_str(&self.to_string())
352 }
353}
354
355pub type CommaSeparatedStringQParams = CommaSeparatedQParams<String>;
356
357#[cfg(feature = "experimentation")]
358impl Default for CommaSeparatedQParams<ExperimentStatusType> {
359 fn default() -> Self {
360 Self(ExperimentStatusType::iter().collect())
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use std::collections::HashMap;
367
368 use serde_json::{from_value, json};
369
370 use crate::custom_query::QueryMap;
371
372 #[test]
373 fn querymap_from_hashmap() {
374 let hashmap: HashMap<String, String> = from_value(json!({
375 "key1": "123",
376 "key2": "\"123\"",
377 "key3": "test",
378 "key4": "true",
379 "key5": "\"true\"",
380 "key6": "null",
381 "key7": "\"null\""
382 }))
383 .unwrap();
384
385 let map = json!({
386 "key1": 123,
387 "key2": "123",
388 "key3": "test",
389 "key4": true,
390 "key5": "true",
391 "key6": null,
392 "key7": "null"
393 })
394 .as_object()
395 .unwrap()
396 .clone();
397
398 assert_eq!(QueryMap::from(hashmap), QueryMap(map));
399 }
400}