walker_common/sender/
mod.rs

1//! Send data off to a remote API
2
3pub mod provider;
4
5mod error;
6pub use error::*;
7
8use crate::{
9    USER_AGENT,
10    sender::provider::{TokenInjector, TokenProvider},
11};
12use anyhow::Context;
13use reqwest::{IntoUrl, Method, RequestBuilder, header};
14use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
15
16#[derive(Clone)]
17pub struct HttpSender {
18    client: reqwest::Client,
19    provider: Arc<dyn TokenProvider>,
20    query_parameters: HashMap<String, String>,
21}
22
23/// Options for the [`HttpSender`].
24#[non_exhaustive]
25#[derive(Clone, Debug, Default)]
26pub struct HttpSenderOptions {
27    pub connect_timeout: Option<Duration>,
28    pub timeout: Option<Duration>,
29    pub additional_root_certificates: Vec<PathBuf>,
30    pub tls_insecure: bool,
31    pub query_parameters: HashMap<String, String>,
32}
33
34impl HttpSenderOptions {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    pub fn connect_timeout(mut self, connect_timeout: impl Into<Option<Duration>>) -> Self {
40        self.connect_timeout = connect_timeout.into();
41        self
42    }
43
44    pub fn timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
45        self.connect_timeout = timeout.into();
46        self
47    }
48
49    pub fn query_parameters<I>(mut self, query_parameters: I) -> Self
50    where
51        I: IntoIterator<Item = (String, String)>,
52    {
53        self.query_parameters = HashMap::from_iter(query_parameters);
54        self
55    }
56
57    pub fn add_query_parameters<I>(
58        mut self,
59        key: impl Into<String>,
60        value: impl Into<String>,
61    ) -> Self
62    where
63        I: IntoIterator<Item = (String, String)>,
64    {
65        self.query_parameters.insert(key.into(), value.into());
66        self
67    }
68
69    pub fn extend_query_parameters<I>(mut self, query_parameters: I) -> Self
70    where
71        I: IntoIterator<Item = (String, String)>,
72    {
73        self.query_parameters.extend(query_parameters);
74        self
75    }
76
77    pub fn additional_root_certificates<I>(mut self, additional_root_certificates: I) -> Self
78    where
79        I: IntoIterator<Item = PathBuf>,
80    {
81        self.additional_root_certificates = Vec::from_iter(additional_root_certificates);
82        self
83    }
84
85    pub fn add_additional_root_certificate(
86        mut self,
87        additional_root_certificate: impl Into<PathBuf>,
88    ) -> Self {
89        self.additional_root_certificates
90            .push(additional_root_certificate.into());
91        self
92    }
93
94    pub fn extend_additional_root_certificate<I>(mut self, additional_root_certificates: I) -> Self
95    where
96        I: IntoIterator<Item = PathBuf>,
97    {
98        self.additional_root_certificates
99            .extend(additional_root_certificates);
100        self
101    }
102
103    pub fn tls_insecure(mut self, tls_insecure: bool) -> Self {
104        self.tls_insecure = tls_insecure;
105        self
106    }
107}
108
109impl HttpSender {
110    pub async fn new<P>(provider: P, options: HttpSenderOptions) -> Result<Self, anyhow::Error>
111    where
112        P: TokenProvider + 'static,
113    {
114        let mut headers = header::HeaderMap::new();
115        headers.insert("User-Agent", header::HeaderValue::from_static(USER_AGENT));
116
117        let mut client = reqwest::ClientBuilder::new().default_headers(headers);
118
119        if let Some(connect_timeout) = options.connect_timeout {
120            client = client.connect_timeout(connect_timeout);
121        }
122
123        if let Some(timeout) = options.timeout {
124            client = client.timeout(timeout);
125        }
126
127        for cert in options.additional_root_certificates {
128            let cert = std::fs::read(&cert)
129                .with_context(|| format!("Reading certificate: {}", cert.display()))?;
130            let cert = reqwest::tls::Certificate::from_pem(&cert)?;
131            client = client.add_root_certificate(cert);
132        }
133
134        if options.tls_insecure {
135            log::warn!("Disabling TLS validation");
136            client = client
137                .danger_accept_invalid_hostnames(true)
138                .danger_accept_invalid_certs(true);
139        }
140
141        Ok(Self {
142            client: client.build()?,
143            provider: Arc::new(provider),
144            query_parameters: options.query_parameters,
145        })
146    }
147
148    /// build a new request, injecting the token
149    pub async fn request<U: IntoUrl>(
150        &self,
151        method: Method,
152        url: U,
153    ) -> Result<RequestBuilder, Error> {
154        self.client
155            .request(method, url)
156            .query(
157                &self
158                    .query_parameters
159                    .iter()
160                    .map(|(key, value)| (key.clone(), value.clone()))
161                    .collect::<Vec<(String, String)>>(),
162            )
163            .inject_token(&self.provider)
164            .await
165    }
166}