rustauth_oauth/oauth2/
request.rs1use std::collections::BTreeMap;
2
3use base64::engine::general_purpose::STANDARD;
4use base64::Engine;
5use url::form_urlencoded::{byte_serialize, Serializer};
6
7use super::error::OAuthError;
8use super::http::OAuthHttpClient;
9use super::tokens::{get_primary_client_id, ProviderOptions};
10
11pub(crate) const PROTECTED_OAUTH_PARAMS: &[&str] = &[
18 "state",
19 "response_type",
20 "redirect_uri",
21 "code",
22 "code_verifier",
23 "code_challenge",
24 "code_challenge_method",
25 "grant_type",
26 "refresh_token",
27 "client_id",
28 "client_secret",
29 "client_key",
30 "client_assertion",
31 "client_assertion_type",
32];
33
34pub(crate) fn is_protected_oauth_param(key: &str) -> bool {
37 PROTECTED_OAUTH_PARAMS.contains(&key)
38}
39
40#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
41pub enum ClientAuthentication {
42 #[default]
43 Post,
44 Basic,
45}
46
47#[derive(Debug, Clone, Default, PartialEq, Eq)]
48pub struct OAuthFormRequest {
49 pub body: Vec<(String, String)>,
50 pub headers: BTreeMap<String, String>,
51}
52
53impl OAuthFormRequest {
54 pub fn new() -> Self {
55 Self {
56 body: Vec::new(),
57 headers: BTreeMap::from([
58 (
59 "content-type".to_owned(),
60 "application/x-www-form-urlencoded".to_owned(),
61 ),
62 ("accept".to_owned(), "application/json".to_owned()),
63 ]),
64 }
65 }
66
67 pub(crate) fn push_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
68 self.body.push((key.into(), value.into()));
69 }
70
71 pub(crate) fn set_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
72 let key = key.into();
73 self.body.retain(|(existing, _)| existing != &key);
74 self.body.push((key, value.into()));
75 }
76
77 pub fn has_body(&self, key: &str) -> bool {
78 self.body.iter().any(|(existing, _)| existing == key)
79 }
80
81 pub fn form_value(&self, key: &str) -> Option<&str> {
82 self.body
83 .iter()
84 .find(|(existing, _)| existing == key)
85 .map(|(_, value)| value.as_str())
86 }
87
88 pub fn form_values(&self, key: &str) -> Vec<&str> {
89 self.body
90 .iter()
91 .filter(|(existing, _)| existing == key)
92 .map(|(_, value)| value.as_str())
93 .collect()
94 }
95
96 pub fn header(&self, key: &str) -> Option<&str> {
97 self.headers
98 .get(&key.to_ascii_lowercase())
99 .map(String::as_str)
100 }
101
102 pub(crate) fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
103 self.headers
104 .insert(key.into().to_ascii_lowercase(), value.into());
105 }
106
107 pub fn to_form_urlencoded(&self) -> String {
108 let mut serializer = Serializer::new(String::new());
109 for (key, value) in &self.body {
110 serializer.append_pair(key, value);
111 }
112 serializer.finish()
113 }
114}
115
116pub(crate) fn apply_client_authentication(
117 request: &mut OAuthFormRequest,
118 options: &ProviderOptions,
119 authentication: ClientAuthentication,
120 require_secret: bool,
121) -> Result<(), OAuthError> {
122 let primary_client_id = get_primary_client_id(&options.client_id);
123 let client_secret = non_empty_secret(options);
124
125 match authentication {
126 ClientAuthentication::Basic => {
127 let client_id = primary_client_id.ok_or_else(|| {
128 OAuthError::InvalidClientAuthentication(
129 "HTTP Basic authentication requires client_id".to_owned(),
130 )
131 })?;
132 let client_secret = if require_secret {
133 client_secret.ok_or(OAuthError::MissingOption("client_secret"))?
134 } else {
135 client_secret.unwrap_or("")
136 };
137 let credentials = STANDARD.encode(format!(
142 "{}:{}",
143 form_encode_credential(client_id),
144 form_encode_credential(client_secret)
145 ));
146 request.set_header("authorization", format!("Basic {credentials}"));
147 }
148 ClientAuthentication::Post => {
149 if let Some(client_id) = primary_client_id {
150 request.set_body("client_id", client_id);
151 }
152 if let Some(client_secret) = client_secret {
153 request.set_body("client_secret", client_secret);
154 } else if require_secret {
155 return Err(OAuthError::MissingOption("client_secret"));
156 }
157 }
158 }
159
160 Ok(())
161}
162
163fn non_empty_secret(options: &ProviderOptions) -> Option<&str> {
164 options.client_secret_str()
165}
166
167fn form_encode_credential(value: &str) -> String {
171 byte_serialize(value.as_bytes()).collect()
172}
173
174pub(crate) async fn post_form_with_client(
175 token_endpoint: &str,
176 request: OAuthFormRequest,
177 client: &OAuthHttpClient,
178) -> Result<serde_json::Value, OAuthError> {
179 client.post_form(token_endpoint, request).await
180}