walker_common/sender/
mod.rs

1//! Send data off to a remote API
2
3pub mod provider;
4
5mod error;
6
7pub use error::*;
8
9use crate::sender::provider::{TokenInjector, TokenProvider};
10use anyhow::Context;
11use reqwest::{header, IntoUrl, Method, RequestBuilder};
12use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
13
14#[derive(Clone)]
15pub struct HttpSender {
16    client: reqwest::Client,
17    provider: Arc<dyn TokenProvider>,
18    query_parameters: HashMap<String, String>,
19}
20
21/// Options for the [`HttpSender`].
22#[non_exhaustive]
23#[derive(Clone, Debug, Default)]
24pub struct HttpSenderOptions {
25    pub connect_timeout: Option<Duration>,
26    pub timeout: Option<Duration>,
27    pub additional_root_certificates: Vec<PathBuf>,
28    pub tls_insecure: bool,
29    pub query_parameters: HashMap<String, String>,
30}
31
32impl HttpSenderOptions {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    pub fn connect_timeout(mut self, connect_timeout: impl Into<Option<Duration>>) -> Self {
38        self.connect_timeout = connect_timeout.into();
39        self
40    }
41
42    pub fn timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
43        self.connect_timeout = timeout.into();
44        self
45    }
46
47    pub fn query_parameters<I>(mut self, query_parameters: I) -> Self
48    where
49        I: IntoIterator<Item = (String, String)>,
50    {
51        self.query_parameters = HashMap::from_iter(query_parameters);
52        self
53    }
54
55    pub fn add_query_parameters<I>(
56        mut self,
57        key: impl Into<String>,
58        value: impl Into<String>,
59    ) -> Self
60    where
61        I: IntoIterator<Item = (String, String)>,
62    {
63        self.query_parameters.insert(key.into(), value.into());
64        self
65    }
66
67    pub fn extend_query_parameters<I>(mut self, query_parameters: I) -> Self
68    where
69        I: IntoIterator<Item = (String, String)>,
70    {
71        self.query_parameters.extend(query_parameters);
72        self
73    }
74
75    pub fn additional_root_certificates<I>(mut self, additional_root_certificates: I) -> Self
76    where
77        I: IntoIterator<Item = PathBuf>,
78    {
79        self.additional_root_certificates = Vec::from_iter(additional_root_certificates);
80        self
81    }
82
83    pub fn add_additional_root_certificate(
84        mut self,
85        additional_root_certificate: impl Into<PathBuf>,
86    ) -> Self {
87        self.additional_root_certificates
88            .push(additional_root_certificate.into());
89        self
90    }
91
92    pub fn extend_additional_root_certificate<I>(mut self, additional_root_certificates: I) -> Self
93    where
94        I: IntoIterator<Item = PathBuf>,
95    {
96        self.additional_root_certificates
97            .extend(additional_root_certificates);
98        self
99    }
100
101    pub fn tls_insecure(mut self, tls_insecure: bool) -> Self {
102        self.tls_insecure = tls_insecure;
103        self
104    }
105}
106
107const USER_AGENT: &str = concat!("CSAF-Walker/", env!("CARGO_PKG_VERSION"));
108
109impl HttpSender {
110    pub async fn new<P>(provider: P, options: HttpSenderOptions) -> Result<Self, anyhow::Error>
111    where
112        P: TokenProvider + 'static,
113    {
114        let mut headers = header::HeaderMap::new();
115        headers.insert("User-Agent", header::HeaderValue::from_static(USER_AGENT));
116
117        let mut client = reqwest::ClientBuilder::new().default_headers(headers);
118
119        if let Some(connect_timeout) = options.connect_timeout {
120            client = client.connect_timeout(connect_timeout);
121        }
122
123        if let Some(timeout) = options.timeout {
124            client = client.timeout(timeout);
125        }
126
127        for cert in options.additional_root_certificates {
128            let cert = std::fs::read(&cert)
129                .with_context(|| format!("Reading certificate: {}", cert.display()))?;
130            let cert = reqwest::tls::Certificate::from_pem(&cert)?;
131            client = client.add_root_certificate(cert);
132        }
133
134        if options.tls_insecure {
135            log::warn!("Disabling TLS validation");
136            client = client
137                .danger_accept_invalid_hostnames(true)
138                .danger_accept_invalid_certs(true);
139        }
140
141        Ok(Self {
142            client: client.build()?,
143            provider: Arc::new(provider),
144            query_parameters: options.query_parameters,
145        })
146    }
147
148    /// build a new request, injecting the token
149    pub async fn request<U: IntoUrl>(
150        &self,
151        method: Method,
152        url: U,
153    ) -> Result<RequestBuilder, Error> {
154        self.client
155            .request(method, url)
156            .query(
157                &self
158                    .query_parameters
159                    .iter()
160                    .map(|(key, value)| (key.clone(), value.clone()))
161                    .collect::<Vec<(String, String)>>(),
162            )
163            .inject_token(&self.provider)
164            .await
165    }
166}