1use crate::error::AuthenticationFailedError;
3use crate::{Error, IdToken, Provider};
4
5#[derive(Clone, Debug)]
9pub enum OidcResponseMode {
10 Query,
13 FormPost,
20 Fragment,
24}
25
26impl std::ops::Deref for OidcResponseMode {
28 type Target = str;
29 fn deref(&self) -> &str {
30 match self {
31 Self::Query => "query",
32 Self::FormPost => "form_post",
33 Self::Fragment => "fragment",
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
40pub enum OidcPrompt {
41 NoPrompt, Login,
43 Consent,
44 SelectAccount,
45}
46
47impl std::ops::Deref for OidcPrompt {
49 type Target = str;
50 fn deref(&self) -> &str {
51 match self {
52 Self::NoPrompt => "none",
53 Self::Login => "login",
54 Self::Consent => "consent",
55 Self::SelectAccount => "select_account",
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct Client<P: Provider> {
63 client_id: String,
64 client_secret: String,
65 redirect_uri: String,
66 response_mode: OidcResponseMode,
67 provider: P,
68}
69
70impl<P: Provider> Client<P> {
71 pub fn auth_url(&self, session: &Session, prompt: Option<OidcPrompt>) -> url::Url {
75 let mut authurl = self.provider.authorization_endpoint();
77 authurl
78 .query_pairs_mut()
79 .append_pair("scope", "openid profile email")
80 .append_pair("response_type", "code")
81 .append_pair("client_id", &self.client_id)
82 .append_pair("nonce", &session.nonce())
83 .append_pair("state", &session.state())
84 .append_pair("response_mode", &self.response_mode)
85 .append_pair("redirect_uri", &self.redirect_uri)
86 .append_pair("code_challenge_method", "S256")
87 .append_pair("code_challenge", &session.pkce_challenge());
88
89 if let Some(prompt) = prompt {
90 authurl.query_pairs_mut().append_pair("prompt", &prompt);
91 }
92
93 authurl
94 }
95
96 pub async fn authenticate<T>(
105 &self,
106 state: &str,
107 code: &str,
108 session: &Session,
109 ) -> Result<IdToken<T>, Error>
110 where
111 T: serde::de::DeserializeOwned,
112 {
113 if state != session.state() {
115 log::warn!("state mismatch");
116 return Err(Error::BadRequest);
117 }
118
119 let code_verifier = session.pkce_verifier();
121 let params = vec![
122 ("grant_type", "authorization_code"),
123 ("code", code),
124 ("client_id", &self.client_id),
125 ("client_secret", &self.client_secret),
126 ("redirect_uri", &self.redirect_uri),
127 ("code_verifier", &code_verifier),
128 ];
129
130 let response = reqwest::Client::new()
132 .post(self.provider.token_endpoint().clone())
133 .form(¶ms)
134 .send()
135 .await?;
136
137 if let Err(err) = response.error_for_status_ref() {
138 let err_body = response.text().await?;
140 log::warn!("Token endpoint returns error {}", err_body);
141
142 Err(err.into())
143 } else {
144 let token_response = response.json::<OidcTokenEndpointResponse>().await?;
146 log::debug!("Token endpoint returns {:?}", token_response);
147
148 let id_token = IdToken::<T>::decode_without_jws_validation(&token_response.id_token)?;
152
153 self.validate_claims(&id_token, session)?;
154 Ok(id_token)
155 }
156 }
157
158 fn validate_claims<T>(
161 &self,
162 id_token: &IdToken<T>,
163 session: &Session,
164 ) -> Result<(), AuthenticationFailedError> {
165 use std::time::SystemTime;
166
167 if !self.provider.validate_iss(&id_token.iss) {
168 log::info!("Invalid iss {}", id_token.iss);
169 return Err(AuthenticationFailedError::ClaimValidationError);
170 }
171
172 if id_token.aud != self.client_id {
173 log::info!("Invalid aud {}", id_token.aud);
174 return Err(AuthenticationFailedError::ClaimValidationError);
175 }
176
177 if &id_token.nonce != &session.nonce() {
178 log::info!("Invalid nonce {}", id_token.nonce);
179 return Err(AuthenticationFailedError::ClaimValidationError);
180 }
181
182 let now = SystemTime::now()
183 .duration_since(SystemTime::UNIX_EPOCH)
184 .map_or(0, |t| t.as_secs());
185 if id_token.iat > now + 60 || now > id_token.exp {
186 log::info!(
188 "Invalid iat {} or exp {} : now = {}",
189 id_token.iat,
190 id_token.exp,
191 now
192 );
193 return Err(AuthenticationFailedError::ClaimValidationError);
194 }
195
196 Ok(())
197 }
198}
199
200pub struct ClientBuilder<P: Provider> {
202 client_id: Option<String>,
203 client_secret: Option<String>,
204 redirect_uri: Option<String>,
205 response_mode: OidcResponseMode,
206 provider: P,
207}
208
209impl<P: Provider> ClientBuilder<P> {
210 pub(crate) fn from_provider(provider: P) -> Self {
212 Self {
213 client_id: None,
214 client_secret: None,
215 redirect_uri: None,
216 response_mode: OidcResponseMode::Query,
217 provider,
218 }
219 }
220
221 pub fn build(self) -> Option<Client<P>> {
223 match self {
224 Self {
225 client_id: Some(client_id),
226 client_secret: Some(client_secret),
227 redirect_uri: Some(redirect_uri),
228 response_mode,
229 provider,
230 } => Some(Client {
231 client_id,
232 client_secret,
233 redirect_uri,
234 response_mode,
235 provider,
236 }),
237 _ => {
238 None
240 }
241 }
242 }
243
244 pub fn client_id(self, client_id: &str) -> Self {
246 let mut builder = self;
247 builder.client_id = Some(client_id.to_string());
248 builder
249 }
250
251 pub fn client_secret(self, client_secret: &str) -> Self {
253 let mut builder = self;
254 builder.client_secret = Some(client_secret.to_string());
255 builder
256 }
257
258 pub fn redirect_uri(self, redirect_uri: &str) -> Self {
260 let mut builder = self;
261 builder.redirect_uri = Some(redirect_uri.to_string());
262 builder
263 }
264
265 pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
267 let mut builder = self;
268 builder.response_mode = response_mode;
269 builder
270 }
271}
272
273pub struct Session {
275 rand_bytes: [u8; 144],
277}
278
279impl Session {
280 pub fn new_session() -> Result<Session, crate::Error> {
282 let mut rand_bytes = [0u8; 144];
284 getrandom::fill(&mut rand_bytes).map_err(|e| {
285 log::error!("getrandom() failed with {:?}", e);
286 crate::Error::InternalError
287 })?;
288 Ok(Session { rand_bytes })
289 }
290
291 pub fn save_session(&self) -> (String, String) {
296 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
297 return (self.key(), URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..]));
298 }
299
300 pub fn load_session(
304 session_key: &str,
305 session_value: &str,
306 ) -> Result<Self, base64::DecodeSliceError> {
307 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
308 let mut rand_bytes = [0u8; 144];
309
310 URL_SAFE_NO_PAD.decode_slice(session_key, &mut rand_bytes[..36])?;
312 URL_SAFE_NO_PAD.decode_slice(session_value, &mut rand_bytes[36..])?;
313
314 Ok(Self { rand_bytes })
315 }
316
317 pub fn key(&self) -> String {
319 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
320 URL_SAFE_NO_PAD.encode(&self.rand_bytes[..36])
321 }
322
323 fn state(&self) -> String {
325 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
326 URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..72])
327 }
328
329 fn nonce(&self) -> String {
331 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
332 URL_SAFE_NO_PAD.encode(&self.rand_bytes[72..108])
333 }
334
335 fn pkce_challenge(&self) -> String {
337 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
338 use sha2::{Digest, Sha256};
339
340 let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
342
343 URL_SAFE_NO_PAD.encode(&challenge_byte)
344 }
345
346 fn pkce_verifier(&self) -> String {
348 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
349 URL_SAFE_NO_PAD.encode(&self.rand_bytes[108..144])
351 }
352}
353
354#[derive(Debug, serde::Deserialize)]
356struct OidcTokenEndpointResponse {
357 id_token: String,
359}