tiny_oidc_rp/
client.rs

1// SPDX-License-Identifier: MIT
2use crate::error::AuthenticationFailedError;
3use crate::{Error, IdToken, Provider};
4
5/// OpenID connect `response_mode` parameter.
6///
7/// See: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html
8#[derive(Clone, Debug)]
9pub enum OidcResponseMode {
10    /// Default for "code" flow.
11    /// Authentication code is returned by HTTP GET with query parameter.
12    Query,
13    /// Alternate mode.
14    /// Authentication code is returned by HTTP POST with form body.
15    ///
16    /// `form_post` mode lowers the risk of authentication code disclosure
17    /// by `Referer` HTTP header or HTTP server log,
18    /// but consider that SameSite session cookie will not be POST with this mode.
19    FormPost,
20    /// For single page Web app,
21    /// Authentication code is returned by HTTP GET with fragment
22    /// and will not be sent to server directly.
23    Fragment,
24}
25
26// response_mode as &str
27impl std::ops::Deref for OidcResponseMode {
28    type Target = str;
29    fn deref(&self) -> &str {
30        match self {
31            Self::Query => "query",
32            Self::FormPost => "form_post",
33            Self::Fragment => "fragment",
34        }
35    }
36}
37
38/// OpenID connect `prompt` parameter.
39#[derive(Clone, Debug)]
40pub enum OidcPrompt {
41    NoPrompt, // `prompt=none`, renamed to avoid confusion with Option::None
42    Login,
43    Consent,
44    SelectAccount,
45}
46
47// prompt as &str
48impl std::ops::Deref for OidcPrompt {
49    type Target = str;
50    fn deref(&self) -> &str {
51        match self {
52            Self::NoPrompt => "none",
53            Self::Login => "login",
54            Self::Consent => "consent",
55            Self::SelectAccount => "select_account",
56        }
57    }
58}
59
60/// OpenID Connect relying party client
61#[derive(Clone, Debug)]
62pub struct Client<P: Provider> {
63    client_id: String,
64    client_secret: String,
65    redirect_uri: String,
66    response_mode: OidcResponseMode,
67    provider: P,
68}
69
70impl<P: Provider> Client<P> {
71    /// Create authn URL with query parameter
72    ///
73    /// If you request the user to force re-login, set prompt=Some(Login)
74    pub fn auth_url(&self, session: &Session, prompt: Option<OidcPrompt>) -> url::Url {
75        // append queries to authorize endpoint
76        let mut authurl = self.provider.authorization_endpoint();
77        authurl
78            .query_pairs_mut()
79            .append_pair("scope", "openid profile email")
80            .append_pair("response_type", "code")
81            .append_pair("client_id", &self.client_id)
82            .append_pair("nonce", &session.nonce())
83            .append_pair("state", &session.state())
84            .append_pair("response_mode", &self.response_mode)
85            .append_pair("redirect_uri", &self.redirect_uri)
86            .append_pair("code_challenge_method", "S256")
87            .append_pair("code_challenge", &session.pkce_challenge());
88
89        if let Some(prompt) = prompt {
90            authurl.query_pairs_mut().append_pair("prompt", &prompt);
91        }
92
93        authurl
94    }
95
96    /// Authenticate user with `state`, `code`
97    ///
98    /// `state`, `code` are retrived from HTTP query parameters or form body.
99    /// `session` is retrived from HTTP cookie.
100    ///
101    /// If you need decoding extra claims in ID token,
102    /// specify your own Deserialized type as T.
103    /// Otherwise, set T as ()
104    pub async fn authenticate<T>(
105        &self,
106        state: &str,
107        code: &str,
108        session: &Session,
109    ) -> Result<IdToken<T>, Error>
110    where
111        T: serde::de::DeserializeOwned,
112    {
113        // Check state mismatch (possible CSRF)
114        if state != session.state() {
115            log::warn!("state mismatch");
116            return Err(Error::BadRequest);
117        }
118
119        // Prepare token endpoint request
120        let code_verifier = session.pkce_verifier();
121        let params = vec![
122            ("grant_type", "authorization_code"),
123            ("code", code),
124            ("client_id", &self.client_id),
125            ("client_secret", &self.client_secret),
126            ("redirect_uri", &self.redirect_uri),
127            ("code_verifier", &code_verifier),
128        ];
129
130        // Send POST request to token endpoint
131        let response = reqwest::Client::new()
132            .post(self.provider.token_endpoint().clone())
133            .form(&params)
134            .send()
135            .await?;
136
137        if let Err(err) = response.error_for_status_ref() {
138            // Error, log body
139            let err_body = response.text().await?;
140            log::warn!("Token endpoint returns error {}", err_body);
141
142            Err(err.into())
143        } else {
144            // Ok, decode body as JSON
145            let token_response = response.json::<OidcTokenEndpointResponse>().await?;
146            log::debug!("Token endpoint returns {:?}", token_response);
147
148            // Decode ID Token string.
149            //   Skip JWS signature validation here,
150            //   because code flow can trust issuer by TLS server certificate validation
151            let id_token = IdToken::<T>::decode_without_jws_validation(&token_response.id_token)?;
152
153            self.validate_claims(&id_token, session)?;
154            Ok(id_token)
155        }
156    }
157
158    /// Validate ID token claims
159    /// See also [OpenID connect spec 3.1.3.7. ID Token Validation](https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
160    fn validate_claims<T>(
161        &self,
162        id_token: &IdToken<T>,
163        session: &Session,
164    ) -> Result<(), AuthenticationFailedError> {
165        use std::time::SystemTime;
166
167        if !self.provider.validate_iss(&id_token.iss) {
168            log::info!("Invalid iss {}", id_token.iss);
169            return Err(AuthenticationFailedError::ClaimValidationError);
170        }
171
172        if id_token.aud != self.client_id {
173            log::info!("Invalid aud {}", id_token.aud);
174            return Err(AuthenticationFailedError::ClaimValidationError);
175        }
176
177        if &id_token.nonce != &session.nonce() {
178            log::info!("Invalid nonce {}", id_token.nonce);
179            return Err(AuthenticationFailedError::ClaimValidationError);
180        }
181
182        let now = SystemTime::now()
183            .duration_since(SystemTime::UNIX_EPOCH)
184            .map_or(0, |t| t.as_secs());
185        if id_token.iat > now + 60 || now > id_token.exp {
186            // token expired
187            log::info!(
188                "Invalid iat {} or exp {} : now = {}",
189                id_token.iat,
190                id_token.exp,
191                now
192            );
193            return Err(AuthenticationFailedError::ClaimValidationError);
194        }
195
196        Ok(())
197    }
198}
199
200/// Setup Client
201pub struct ClientBuilder<P: Provider> {
202    client_id: Option<String>,
203    client_secret: Option<String>,
204    redirect_uri: Option<String>,
205    response_mode: OidcResponseMode,
206    provider: P,
207}
208
209impl<P: Provider> ClientBuilder<P> {
210    /// Client builder from OpenID connect Provider
211    pub(crate) fn from_provider(provider: P) -> Self {
212        Self {
213            client_id: None,
214            client_secret: None,
215            redirect_uri: None,
216            response_mode: OidcResponseMode::Query,
217            provider,
218        }
219    }
220
221    /// Build OpenID connect Client
222    pub fn build(self) -> Option<Client<P>> {
223        match self {
224            Self {
225                client_id: Some(client_id),
226                client_secret: Some(client_secret),
227                redirect_uri: Some(redirect_uri),
228                response_mode,
229                provider,
230            } => Some(Client {
231                client_id,
232                client_secret,
233                redirect_uri,
234                response_mode,
235                provider,
236            }),
237            _ => {
238                // Some elements are not initialized.
239                None
240            }
241        }
242    }
243
244    /// Client ID
245    pub fn client_id(self, client_id: &str) -> Self {
246        let mut builder = self;
247        builder.client_id = Some(client_id.to_string());
248        builder
249    }
250
251    /// Client secret
252    pub fn client_secret(self, client_secret: &str) -> Self {
253        let mut builder = self;
254        builder.client_secret = Some(client_secret.to_string());
255        builder
256    }
257
258    /// Redirect URI
259    pub fn redirect_uri(self, redirect_uri: &str) -> Self {
260        let mut builder = self;
261        builder.redirect_uri = Some(redirect_uri.to_string());
262        builder
263    }
264
265    /// Response mode
266    pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
267        let mut builder = self;
268        builder.response_mode = response_mode;
269        builder
270    }
271}
272
273/// OpenID connect login session
274pub struct Session {
275    // 0..36=key, 36..72=state, 72..108=nonce, 108..144=pkce_verifier
276    rand_bytes: [u8; 144],
277}
278
279impl Session {
280    /// Start new OpenID connect session
281    pub fn new_session() -> Result<Session, crate::Error> {
282        // Make random bytes
283        let mut rand_bytes = [0u8; 144];
284        getrandom::fill(&mut rand_bytes).map_err(|e| {
285            log::error!("getrandom() failed with {:?}", e);
286            crate::Error::InternalError
287        })?;
288        Ok(Session { rand_bytes })
289    }
290
291    /// Serialize session and returns (key, value) pair.
292    /// Implementer should store `key` in browser session cookie or local storage,
293    /// and store `(key,value)` pair in server side database.
294    /// Both `key` and `value` is URL safe string
295    pub fn save_session(&self) -> (String, String) {
296        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
297        return (self.key(), URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..]));
298    }
299
300    /// Deserialize session saved by `save_session()`
301    /// Implementer should get session key from cookie,
302    /// and load session_value from server side database.
303    pub fn load_session(
304        session_key: &str,
305        session_value: &str,
306    ) -> Result<Self, base64::DecodeSliceError> {
307        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
308        let mut rand_bytes = [0u8; 144];
309
310        // Decode key & value
311        URL_SAFE_NO_PAD.decode_slice(session_key, &mut rand_bytes[..36])?;
312        URL_SAFE_NO_PAD.decode_slice(session_value, &mut rand_bytes[36..])?;
313
314        Ok(Self { rand_bytes })
315    }
316
317    /// Base64Url(key) -> 48 chars
318    pub fn key(&self) -> String {
319        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
320        URL_SAFE_NO_PAD.encode(&self.rand_bytes[..36])
321    }
322
323    /// Base64Url(state) -> 48 chars
324    fn state(&self) -> String {
325        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
326        URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..72])
327    }
328
329    /// Base64Url(nonce) -> 48 chars
330    fn nonce(&self) -> String {
331        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
332        URL_SAFE_NO_PAD.encode(&self.rand_bytes[72..108])
333    }
334
335    /// PKCE code_challenge in Base64 string
336    fn pkce_challenge(&self) -> String {
337        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
338        use sha2::{Digest, Sha256};
339
340        // PKCE code_challenge=Base64Url(SHA256(pkce_verifier))
341        let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
342
343        URL_SAFE_NO_PAD.encode(&challenge_byte)
344    }
345
346    /// PKCE code_verifier in Base64 string
347    fn pkce_verifier(&self) -> String {
348        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
349        // code_verifier = Base64Url(pkce_verifier)
350        URL_SAFE_NO_PAD.encode(&self.rand_bytes[108..144])
351    }
352}
353
354/// Response body JSON from token endpoint
355#[derive(Debug, serde::Deserialize)]
356struct OidcTokenEndpointResponse {
357    // access_token: Option<String>,
358    id_token: String,
359}