Skip to main content

rustauth_oauth/oauth2/
request.rs

1use std::collections::BTreeMap;
2
3use base64::engine::general_purpose::STANDARD;
4use base64::Engine;
5use url::form_urlencoded::{byte_serialize, Serializer};
6
7use super::error::OAuthError;
8use super::http::OAuthHttpClient;
9use super::tokens::{get_primary_client_id, ProviderOptions};
10
11/// OAuth request parameters that carry validated security invariants of a flow
12/// (CSRF `state`, PKCE binding, redirect URI, grant type, refresh token, and
13/// client credential/authentication fields). Generic `additional_params`,
14/// `override_params`, and `extra_params` maps must never set or replace these,
15/// so a provider extension or caller-controlled value cannot blank, downgrade,
16/// or hijack an already-validated and authenticated request.
17pub(crate) const PROTECTED_OAUTH_PARAMS: &[&str] = &[
18    "state",
19    "response_type",
20    "redirect_uri",
21    "code",
22    "code_verifier",
23    "code_challenge",
24    "code_challenge_method",
25    "grant_type",
26    "refresh_token",
27    "client_id",
28    "client_secret",
29    "client_key",
30    "client_assertion",
31    "client_assertion_type",
32];
33
34/// Returns `true` when `key` is a security-critical OAuth parameter that the
35/// generic extension maps are not allowed to set or override.
36pub(crate) fn is_protected_oauth_param(key: &str) -> bool {
37    PROTECTED_OAUTH_PARAMS.contains(&key)
38}
39
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41pub enum ClientAuthentication {
42    #[default]
43    Post,
44    Basic,
45}
46
47#[derive(Debug, Clone, Default, PartialEq, Eq)]
48pub struct OAuthFormRequest {
49    pub body: Vec<(String, String)>,
50    pub headers: BTreeMap<String, String>,
51}
52
53impl OAuthFormRequest {
54    pub fn new() -> Self {
55        Self {
56            body: Vec::new(),
57            headers: BTreeMap::from([
58                (
59                    "content-type".to_owned(),
60                    "application/x-www-form-urlencoded".to_owned(),
61                ),
62                ("accept".to_owned(), "application/json".to_owned()),
63            ]),
64        }
65    }
66
67    pub(crate) fn push_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
68        self.body.push((key.into(), value.into()));
69    }
70
71    pub(crate) fn set_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
72        let key = key.into();
73        self.body.retain(|(existing, _)| existing != &key);
74        self.body.push((key, value.into()));
75    }
76
77    pub fn has_body(&self, key: &str) -> bool {
78        self.body.iter().any(|(existing, _)| existing == key)
79    }
80
81    pub fn form_value(&self, key: &str) -> Option<&str> {
82        self.body
83            .iter()
84            .find(|(existing, _)| existing == key)
85            .map(|(_, value)| value.as_str())
86    }
87
88    pub fn form_values(&self, key: &str) -> Vec<&str> {
89        self.body
90            .iter()
91            .filter(|(existing, _)| existing == key)
92            .map(|(_, value)| value.as_str())
93            .collect()
94    }
95
96    pub fn header(&self, key: &str) -> Option<&str> {
97        self.headers
98            .get(&key.to_ascii_lowercase())
99            .map(String::as_str)
100    }
101
102    pub(crate) fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
103        self.headers
104            .insert(key.into().to_ascii_lowercase(), value.into());
105    }
106
107    pub fn to_form_urlencoded(&self) -> String {
108        let mut serializer = Serializer::new(String::new());
109        for (key, value) in &self.body {
110            serializer.append_pair(key, value);
111        }
112        serializer.finish()
113    }
114}
115
116pub(crate) fn apply_client_authentication(
117    request: &mut OAuthFormRequest,
118    options: &ProviderOptions,
119    authentication: ClientAuthentication,
120    require_secret: bool,
121) -> Result<(), OAuthError> {
122    let primary_client_id = get_primary_client_id(&options.client_id);
123    let client_secret = non_empty_secret(options);
124
125    match authentication {
126        ClientAuthentication::Basic => {
127            let client_id = primary_client_id.ok_or_else(|| {
128                OAuthError::InvalidClientAuthentication(
129                    "HTTP Basic authentication requires client_id".to_owned(),
130                )
131            })?;
132            let client_secret = if require_secret {
133                client_secret.ok_or(OAuthError::MissingOption("client_secret"))?
134            } else {
135                client_secret.unwrap_or("")
136            };
137            // RFC 6749 §2.3.1: the client id and secret are each encoded with the
138            // `application/x-www-form-urlencoded` algorithm before being joined with
139            // `:` and Base64-encoded, so reserved characters (`:`, `+`, `=`, spaces,
140            // non-ASCII bytes, ...) cannot corrupt the decoded Basic credentials.
141            let credentials = STANDARD.encode(format!(
142                "{}:{}",
143                form_encode_credential(client_id),
144                form_encode_credential(client_secret)
145            ));
146            request.set_header("authorization", format!("Basic {credentials}"));
147        }
148        ClientAuthentication::Post => {
149            if let Some(client_id) = primary_client_id {
150                request.set_body("client_id", client_id);
151            }
152            if let Some(client_secret) = client_secret {
153                request.set_body("client_secret", client_secret);
154            } else if require_secret {
155                return Err(OAuthError::MissingOption("client_secret"));
156            }
157        }
158    }
159
160    Ok(())
161}
162
163fn non_empty_secret(options: &ProviderOptions) -> Option<&str> {
164    options.client_secret_str()
165}
166
167/// Encodes a Basic-auth credential component with `application/x-www-form-urlencoded`
168/// rules per RFC 6749 §2.3.1. Unreserved ASCII (including `-`, `_`, `.`, `*`) is left
169/// unchanged, preserving the wire format for simple credentials.
170fn form_encode_credential(value: &str) -> String {
171    byte_serialize(value.as_bytes()).collect()
172}
173
174pub(crate) async fn post_form_with_client(
175    token_endpoint: &str,
176    request: OAuthFormRequest,
177    client: &OAuthHttpClient,
178) -> Result<serde_json::Value, OAuthError> {
179    client.post_form(token_endpoint, request).await
180}