Skip to main content

rustauth_oauth/oauth2/
http.rs

1use reqwest::{Client, Response};
2use serde_json::Value;
3use std::sync::OnceLock;
4use std::time::Duration;
5
6use super::error::{oauth_error_description, OAuthError};
7use super::request::OAuthFormRequest;
8use super::ssrf::{ssrf_guarded_client_builder, url_host_is_blocked_ip};
9
10const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
11const DEFAULT_USER_AGENT: &str = concat!("rustauth-oauth/", env!("CARGO_PKG_VERSION"));
12const SENSITIVE_OAUTH_FIELDS: &[&str] = &[
13    "access_token",
14    "refresh_token",
15    "id_token",
16    "client_secret",
17    "client_assertion",
18    "subject_token",
19    "device_code",
20    "code",
21    "token",
22    "authorization",
23];
24
25#[derive(Debug, Clone)]
26pub struct OAuthHttpClient {
27    client: Client,
28    /// When `false`, requests whose URL host is a literal private/internal IP
29    /// are rejected at the request boundary, closing the SSRF gap that the
30    /// custom DNS resolver cannot see (reqwest does not resolve literal IPs).
31    allow_private_ips: bool,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct OAuthHttpClientConfig {
36    pub timeout: Duration,
37    pub user_agent: Option<String>,
38    /// When `false` (the default), the client blocks requests that resolve to
39    /// private, loopback, or otherwise non-public IP addresses to mitigate
40    /// SSRF. Set to `true` only for deployments that intentionally talk to
41    /// internal identity providers.
42    pub allow_private_ips: bool,
43}
44
45impl Default for OAuthHttpClientConfig {
46    fn default() -> Self {
47        Self {
48            timeout: DEFAULT_TIMEOUT,
49            user_agent: Some(DEFAULT_USER_AGENT.to_owned()),
50            allow_private_ips: false,
51        }
52    }
53}
54
55impl OAuthHttpClient {
56    /// Wraps a caller-supplied `reqwest::Client`.
57    ///
58    /// Injected clients are treated as explicitly permissive (no
59    /// request-boundary IP guard) because the caller owns the client's SSRF
60    /// policy; this keeps custom clients usable for tests and intentionally
61    /// internal deployments. Use [`OAuthHttpClient::from_config`] to obtain a
62    /// guarded client.
63    pub fn new(client: Client) -> Self {
64        Self {
65            client,
66            allow_private_ips: true,
67        }
68    }
69
70    /// Returns the underlying `reqwest::Client`.
71    ///
72    /// Useful for callers that must issue requests outside the OAuth form-post
73    /// helpers (for example OIDC discovery, JWKS, or userinfo fetches) while
74    /// sharing the same SSRF guard, timeout, and connection pool.
75    pub fn reqwest_client(&self) -> &Client {
76        &self.client
77    }
78
79    pub fn default_client() -> Result<Self, OAuthError> {
80        Self::from_config(OAuthHttpClientConfig::default())
81    }
82
83    pub fn from_config(config: OAuthHttpClientConfig) -> Result<Self, OAuthError> {
84        if config.timeout.is_zero() {
85            return Err(OAuthError::InvalidConfiguration(
86                "HTTP timeout must be greater than zero".to_owned(),
87            ));
88        }
89        let mut builder = if config.allow_private_ips {
90            Client::builder()
91        } else {
92            ssrf_guarded_client_builder()
93        }
94        .timeout(config.timeout);
95        if let Some(user_agent) = config.user_agent {
96            builder = builder.user_agent(user_agent);
97        }
98        builder
99            .build()
100            .map(|client| Self {
101                client,
102                allow_private_ips: config.allow_private_ips,
103            })
104            .map_err(Into::into)
105    }
106
107    /// Rejects request URLs whose host is a literal blocked IP unless this
108    /// client is explicitly permissive. `reqwest` connects to literal-IP URLs
109    /// without consulting the SSRF DNS guard, so this closes that gap.
110    fn ensure_request_url_allowed(&self, url: &str) -> Result<(), OAuthError> {
111        if !self.allow_private_ips && url_host_is_blocked_ip(url) {
112            return Err(OAuthError::BlockedRequestUrl);
113        }
114        Ok(())
115    }
116
117    pub async fn get_bytes(&self, url: &str) -> Result<Vec<u8>, OAuthError> {
118        self.get_bytes_with_headers(url, &[]).await
119    }
120
121    pub async fn get_bytes_with_headers(
122        &self,
123        url: &str,
124        headers: &[(&str, &str)],
125    ) -> Result<Vec<u8>, OAuthError> {
126        self.ensure_request_url_allowed(url)?;
127        let mut builder = self.client.get(url).header("accept", "application/json");
128        for (key, value) in headers {
129            builder = builder.header(*key, *value);
130        }
131        let response = builder.send().await?;
132        response_bytes(response).await
133    }
134
135    pub async fn post_form(
136        &self,
137        token_endpoint: &str,
138        request: OAuthFormRequest,
139    ) -> Result<Value, OAuthError> {
140        self.ensure_request_url_allowed(token_endpoint)?;
141        let mut builder = self.client.post(token_endpoint);
142        for (key, value) in &request.headers {
143            builder = builder.header(key, value);
144        }
145        let response = builder.body(request.to_form_urlencoded()).send().await?;
146        response_json(response).await
147    }
148}
149
150pub fn default_http_client() -> Result<OAuthHttpClient, OAuthError> {
151    static CLIENT: OnceLock<Result<OAuthHttpClient, String>> = OnceLock::new();
152
153    CLIENT
154        .get_or_init(|| OAuthHttpClient::default_client().map_err(|error| error.to_string()))
155        .clone()
156        .map_err(OAuthError::InvalidConfiguration)
157}
158
159async fn response_bytes(response: Response) -> Result<Vec<u8>, OAuthError> {
160    let status = response.status();
161    let bytes = response.bytes().await?;
162    if status.is_success() {
163        return Ok(bytes.to_vec());
164    }
165    Err(http_status_error(status.as_u16(), &bytes))
166}
167
168async fn response_json(response: Response) -> Result<Value, OAuthError> {
169    let status = response.status();
170    let bytes = response.bytes().await?;
171    let value = serde_json::from_slice::<Value>(&bytes);
172    if status.is_success() {
173        return value.map_err(|error| OAuthError::InvalidResponse(error.to_string()));
174    }
175    if let Ok(value) = value {
176        if let Some(error) = value.get("error").and_then(Value::as_str) {
177            return Err(OAuthError::ErrorResponse {
178                error: error.to_owned(),
179                description: oauth_error_description(redact_error_description(
180                    value.get("error_description").and_then(Value::as_str),
181                )),
182                uri: value
183                    .get("error_uri")
184                    .and_then(Value::as_str)
185                    .map(str::to_owned),
186            });
187        }
188    }
189    Err(http_status_error(status.as_u16(), &bytes))
190}
191
192fn http_status_error(status: u16, body: &[u8]) -> OAuthError {
193    OAuthError::HttpStatus {
194        status,
195        body: redact_body(&String::from_utf8_lossy(body)),
196    }
197}
198
199fn redact_body(body: &str) -> String {
200    if let Ok(mut value) = serde_json::from_str::<Value>(body) {
201        redact_json_value(&mut value);
202        return value.to_string();
203    }
204
205    let lower = body.to_ascii_lowercase();
206    if SENSITIVE_OAUTH_FIELDS.iter().any(|key| lower.contains(key))
207        || lower.contains("bearer ")
208        || lower.contains("basic ")
209    {
210        return "<redacted OAuth response body>".to_owned();
211    }
212    body.to_owned()
213}
214
215fn redact_json_value(value: &mut Value) {
216    match value {
217        Value::Object(object) => {
218            for (key, value) in object {
219                if SENSITIVE_OAUTH_FIELDS
220                    .iter()
221                    .any(|sensitive| key.eq_ignore_ascii_case(sensitive))
222                {
223                    *value = Value::String("<redacted>".to_owned());
224                } else {
225                    redact_json_value(value);
226                }
227            }
228        }
229        Value::Array(values) => {
230            for value in values {
231                redact_json_value(value);
232            }
233        }
234        _ => {}
235    }
236}
237
238fn redact_error_description(description: Option<&str>) -> Option<String> {
239    let description = description?;
240    let lower = description.to_ascii_lowercase();
241    if [
242        "access_token",
243        "refresh_token",
244        "id_token",
245        "client_secret",
246        "client_assertion",
247        "subject_token",
248        "device_code",
249        "authorization",
250        "bearer ",
251        "basic ",
252    ]
253    .iter()
254    .any(|needle| lower.contains(needle))
255    {
256        return Some("<redacted error_description>".to_owned());
257    }
258    Some(description.to_owned())
259}