1use std::{collections::HashMap, path::Path, str::FromStr};
2
3use anyhow::Result;
4use http::{header::HeaderName, HeaderMap, HeaderValue, Method};
5use reqwest::{Client, Response};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tokio::fs;
9use url::Url;
10
11use crate::{KeyVal, KeyValType};
12
13const USER_AGENT: &str = "Requester/0.1.0";
14
15#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
16pub struct RequestConfig {
17 #[serde(flatten)]
18 ctxs: HashMap<String, RequestContext>,
19}
20
21#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
22pub struct RequestContext {
23 #[serde(
24 with = "http_serde::method",
25 skip_serializing_if = "is_default",
26 default
27 )]
28 pub method: Method,
29 pub url: Url,
30 #[serde(skip_serializing_if = "is_empty_value", default = "default_params")]
31 pub params: Value,
32 #[serde(skip_serializing_if = "HeaderMap::is_empty", default)]
33 #[serde(with = "http_serde::header_map")]
34 pub headers: HeaderMap,
35 #[serde(skip_serializing_if = "Option::is_none", default)]
36 pub body: Option<Value>,
37 #[serde(skip_serializing_if = "Option::is_none", default)]
38 pub user_agent: Option<String>,
39}
40
41fn is_default<T: Default + PartialEq>(t: &T) -> bool {
42 t == &T::default()
43}
44
45fn is_empty_value(v: &Value) -> bool {
46 v.is_null() || (v.is_object() && v.as_object().unwrap().is_empty())
47}
48
49fn default_params() -> Value {
50 serde_json::json!({})
51}
52
53impl RequestConfig {
54 pub fn new_with_profile(profile: String, ctx: RequestContext) -> Self {
55 let mut ctxs = HashMap::new();
56 ctxs.insert(profile, ctx);
57 Self { ctxs }
58 }
59
60 pub async fn try_load(path: impl AsRef<Path>) -> Result<Self> {
61 let file = fs::read_to_string(path).await?;
62 let config: Self = serde_yaml::from_str(&file)?;
63 for (profile, ctx) in config.ctxs.iter() {
64 if !ctx.params.is_object() {
65 return Err(anyhow::anyhow!(
66 "params must be an object in profile: {}",
67 profile
68 ));
69 }
70 }
71 Ok(config)
72 }
73
74 pub fn get(&self, profile: &str) -> Result<&RequestContext> {
75 self.ctxs.get(profile).ok_or_else(|| {
76 anyhow::anyhow!(
77 "profile {} not found. Available profiles: {:?}.",
78 profile,
79 self.ctxs.keys()
80 )
81 })
82 }
83
84 pub async fn send(&self, profile: &str) -> Result<Response> {
85 let ctx = self.get(profile)?;
86
87 ctx.send().await
88 }
89}
90
91impl RequestContext {
92 pub fn update(&mut self, values: &[KeyVal]) -> Result<()> {
93 for v in values {
94 match v.kv_type {
95 KeyValType::Query => {
96 self.params[&v.key] = serde_json::Value::String(v.val.to_owned());
97 }
98 KeyValType::Header => {
99 self.headers.insert(
100 HeaderName::from_str(&v.key)?,
101 HeaderValue::from_str(&v.val)?,
102 );
103 }
104 KeyValType::Body => {
105 if let Some(body) = self.body.as_mut() {
106 body[&v.key] = serde_json::Value::String(v.val.to_owned())
107 }
108 }
109 }
110 }
111
112 Ok(())
113 }
114
115 pub async fn send(&self) -> Result<Response> {
116 let mut url = self.url.clone();
117 let user_agent = self
118 .user_agent
119 .clone()
120 .unwrap_or_else(|| USER_AGENT.to_string());
121 match url.scheme() {
122 "http" | "https" => {
123 let qs = serde_qs::to_string(&self.params)?;
124 if !qs.is_empty() {
125 url.set_query(Some(&qs));
126 }
127 let client = Client::builder().user_agent(user_agent).build()?;
128
129 let mut builder = client
130 .request(self.method.clone(), url)
131 .headers(self.headers.clone());
132
133 if let Some(body) = &self.body {
134 match self.headers.get(http::header::CONTENT_TYPE) {
135 Some(content_type) => {
136 if content_type.to_str().unwrap().contains("application/json") {
137 builder = builder.json(body);
138 } else {
139 return Err(anyhow::anyhow!(
140 "unsupported content-type: {:?}",
141 content_type
142 ));
143 }
144 }
145 None => {
146 builder = builder.json(body)
148 }
149 }
150 builder = builder.body(serde_json::to_string(body)?);
151 }
152
153 let res = builder.send().await?;
154
155 Ok(res)
156 }
157 _ => Err(anyhow::anyhow!("unsupported scheme")),
158 }
159 }
160}
161
162impl FromStr for RequestContext {
163 type Err = anyhow::Error;
164
165 fn from_str(url: &str) -> std::result::Result<Self, Self::Err> {
166 let mut url = Url::parse(url)?;
167 let qs = url.query_pairs();
168 let mut params = serde_json::Value::Object(Default::default());
169 for (k, v) in qs {
170 let v = serde_json::Value::String(v.to_string());
171 match params.get_mut(&*k) {
172 Some(val) => {
173 if val.is_string() {
174 params[&*k] = serde_json::Value::Array(vec![val.clone(), v]);
175 } else if val.is_array() {
176 val.as_array_mut().unwrap().push(v);
177 } else {
178 panic!("unexpected value: {:?}", val);
179 }
180 }
181 None => {
182 params[&*k] = v;
183 }
184 }
185 }
186
187 url.set_query(None);
188 Ok(RequestContext {
189 method: Method::GET,
190 url,
191 params,
192 headers: HeaderMap::new(),
193 body: None,
194 user_agent: None,
195 })
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[tokio::test]
204 async fn send_request_should_work() {
205 let config = RequestConfig::try_load("fixtures/req.yml").await.unwrap();
206 let result = config.send("rust").await.unwrap();
207 assert_eq!(result.status(), 200);
208 }
209}