xreq_lib/
req.rs

1use std::{collections::HashMap, path::Path, str::FromStr};
2
3use anyhow::Result;
4use http::{header::HeaderName, HeaderMap, HeaderValue, Method};
5use reqwest::{Client, Response};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tokio::fs;
9use url::Url;
10
11use crate::{KeyVal, KeyValType};
12
13const USER_AGENT: &str = "Requester/0.1.0";
14
15#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
16pub struct RequestConfig {
17    #[serde(flatten)]
18    ctxs: HashMap<String, RequestContext>,
19}
20
21#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
22pub struct RequestContext {
23    #[serde(
24        with = "http_serde::method",
25        skip_serializing_if = "is_default",
26        default
27    )]
28    pub method: Method,
29    pub url: Url,
30    #[serde(skip_serializing_if = "is_empty_value", default = "default_params")]
31    pub params: Value,
32    #[serde(skip_serializing_if = "HeaderMap::is_empty", default)]
33    #[serde(with = "http_serde::header_map")]
34    pub headers: HeaderMap,
35    #[serde(skip_serializing_if = "Option::is_none", default)]
36    pub body: Option<Value>,
37    #[serde(skip_serializing_if = "Option::is_none", default)]
38    pub user_agent: Option<String>,
39}
40
41fn is_default<T: Default + PartialEq>(t: &T) -> bool {
42    t == &T::default()
43}
44
45fn is_empty_value(v: &Value) -> bool {
46    v.is_null() || (v.is_object() && v.as_object().unwrap().is_empty())
47}
48
49fn default_params() -> Value {
50    serde_json::json!({})
51}
52
53impl RequestConfig {
54    pub fn new_with_profile(profile: String, ctx: RequestContext) -> Self {
55        let mut ctxs = HashMap::new();
56        ctxs.insert(profile, ctx);
57        Self { ctxs }
58    }
59
60    pub async fn try_load(path: impl AsRef<Path>) -> Result<Self> {
61        let file = fs::read_to_string(path).await?;
62        let config: Self = serde_yaml::from_str(&file)?;
63        for (profile, ctx) in config.ctxs.iter() {
64            if !ctx.params.is_object() {
65                return Err(anyhow::anyhow!(
66                    "params must be an object in profile: {}",
67                    profile
68                ));
69            }
70        }
71        Ok(config)
72    }
73
74    pub fn get(&self, profile: &str) -> Result<&RequestContext> {
75        self.ctxs.get(profile).ok_or_else(|| {
76            anyhow::anyhow!(
77                "profile {} not found. Available profiles: {:?}.",
78                profile,
79                self.ctxs.keys()
80            )
81        })
82    }
83
84    pub async fn send(&self, profile: &str) -> Result<Response> {
85        let ctx = self.get(profile)?;
86
87        ctx.send().await
88    }
89}
90
91impl RequestContext {
92    pub fn update(&mut self, values: &[KeyVal]) -> Result<()> {
93        for v in values {
94            match v.kv_type {
95                KeyValType::Query => {
96                    self.params[&v.key] = serde_json::Value::String(v.val.to_owned());
97                }
98                KeyValType::Header => {
99                    self.headers.insert(
100                        HeaderName::from_str(&v.key)?,
101                        HeaderValue::from_str(&v.val)?,
102                    );
103                }
104                KeyValType::Body => {
105                    if let Some(body) = self.body.as_mut() {
106                        body[&v.key] = serde_json::Value::String(v.val.to_owned())
107                    }
108                }
109            }
110        }
111
112        Ok(())
113    }
114
115    pub async fn send(&self) -> Result<Response> {
116        let mut url = self.url.clone();
117        let user_agent = self
118            .user_agent
119            .clone()
120            .unwrap_or_else(|| USER_AGENT.to_string());
121        match url.scheme() {
122            "http" | "https" => {
123                let qs = serde_qs::to_string(&self.params)?;
124                if !qs.is_empty() {
125                    url.set_query(Some(&qs));
126                }
127                let client = Client::builder().user_agent(user_agent).build()?;
128
129                let mut builder = client
130                    .request(self.method.clone(), url)
131                    .headers(self.headers.clone());
132
133                if let Some(body) = &self.body {
134                    match self.headers.get(http::header::CONTENT_TYPE) {
135                        Some(content_type) => {
136                            if content_type.to_str().unwrap().contains("application/json") {
137                                builder = builder.json(body);
138                            } else {
139                                return Err(anyhow::anyhow!(
140                                    "unsupported content-type: {:?}",
141                                    content_type
142                                ));
143                            }
144                        }
145                        None => {
146                            // TODO (tchen): here we just assume the content-type is json
147                            builder = builder.json(body)
148                        }
149                    }
150                    builder = builder.body(serde_json::to_string(body)?);
151                }
152
153                let res = builder.send().await?;
154
155                Ok(res)
156            }
157            _ => Err(anyhow::anyhow!("unsupported scheme")),
158        }
159    }
160}
161
162impl FromStr for RequestContext {
163    type Err = anyhow::Error;
164
165    fn from_str(url: &str) -> std::result::Result<Self, Self::Err> {
166        let mut url = Url::parse(url)?;
167        let qs = url.query_pairs();
168        let mut params = serde_json::Value::Object(Default::default());
169        for (k, v) in qs {
170            let v = serde_json::Value::String(v.to_string());
171            match params.get_mut(&*k) {
172                Some(val) => {
173                    if val.is_string() {
174                        params[&*k] = serde_json::Value::Array(vec![val.clone(), v]);
175                    } else if val.is_array() {
176                        val.as_array_mut().unwrap().push(v);
177                    } else {
178                        panic!("unexpected value: {:?}", val);
179                    }
180                }
181                None => {
182                    params[&*k] = v;
183                }
184            }
185        }
186
187        url.set_query(None);
188        Ok(RequestContext {
189            method: Method::GET,
190            url,
191            params,
192            headers: HeaderMap::new(),
193            body: None,
194            user_agent: None,
195        })
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[tokio::test]
204    async fn send_request_should_work() {
205        let config = RequestConfig::try_load("fixtures/req.yml").await.unwrap();
206        let result = config.send("rust").await.unwrap();
207        assert_eq!(result.status(), 200);
208    }
209}