1pub mod client;
3
4mod config;
5mod error;
6mod ops;
7mod state;
8
9pub use client::*;
10pub use error::*;
11pub use ops::*;
12pub use state::LoginState;
13
14pub(crate) use config::*;
15
16use crate::context::{Authentication, OAuth2Context, Reason};
17use gloo_storage::{SessionStorage, Storage};
18use gloo_timers::callback::Timeout;
19use gloo_utils::{history, window};
20use js_sys::Date;
21use log::error;
22use num_traits::cast::ToPrimitive;
23use reqwest::Url;
24use state::*;
25use std::{cmp::min, collections::HashMap, fmt::Debug, time::Duration};
26use tokio::sync::mpsc::{Receiver, Sender, channel};
27use wasm_bindgen::JsValue;
28use wasm_bindgen_futures::spawn_local;
29use yew::Callback;
30
31#[derive(Debug, Clone, Default)]
76#[non_exhaustive]
77pub struct LoginOptions {
78 pub query: HashMap<String, String>,
80
81 pub redirect_url: Option<Url>,
85
86 pub post_login_redirect_callback: Option<Callback<String>>,
94}
95
96impl LoginOptions {
97 pub fn new() -> Self {
98 LoginOptions::default()
99 }
100
101 pub fn with_query(mut self, query: impl IntoIterator<Item = (String, String)>) -> Self {
103 self.query = HashMap::from_iter(query);
104 self
105 }
106
107 pub fn extend_query(mut self, query: impl IntoIterator<Item = (String, String)>) -> Self {
109 self.query.extend(query);
110 self
111 }
112
113 pub fn add_query(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
115 self.query.insert(key.into(), value.into());
116 self
117 }
118
119 pub fn with_redirect_url(mut self, redirect_url: impl Into<Url>) -> Self {
121 self.redirect_url = Some(redirect_url.into());
122 self
123 }
124
125 pub fn with_redirect_callback(mut self, redirect_callback: Callback<String>) -> Self {
127 self.post_login_redirect_callback = Some(redirect_callback);
128 self
129 }
130
131 #[cfg(feature = "yew-nested-router")]
133 pub fn with_nested_router_redirect(mut self) -> Self {
134 let callback = Callback::from(|url: String| {
135 if yew_nested_router::History::push_state(JsValue::null(), &url).is_err() {
136 error!("Unable to redirect");
137 }
138 });
139
140 self.post_login_redirect_callback = Some(callback);
141 self
142 }
143}
144
145#[non_exhaustive]
149#[derive(Debug, Clone, Default, PartialEq, Eq)]
150pub struct LogoutOptions {
151 pub target: Option<Url>,
155}
156
157impl LogoutOptions {
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 pub fn with_target(mut self, target: impl Into<Url>) -> Self {
163 self.target = Some(target.into());
164 self
165 }
166}
167
168#[doc(hidden)]
169pub enum Msg<C>
170where
171 C: Client,
172{
173 Configure(AgentConfiguration<C>),
174 StartLogin(Option<LoginOptions>),
175 Logout(Option<LogoutOptions>),
176 Refresh,
177}
178
179#[derive(Clone, Debug)]
181pub struct Agent<C>
182where
183 C: Client,
184{
185 tx: Sender<Msg<C>>,
186}
187
188impl<C> Agent<C>
189where
190 C: Client,
191{
192 pub fn new<F>(state_callback: F) -> Self
193 where
194 F: Fn(OAuth2Context) + 'static,
195 {
196 let (tx, rx) = channel(128);
197
198 let inner = InnerAgent::new(tx.clone(), state_callback);
199 inner.spawn(rx);
200
201 Self { tx }
202 }
203}
204
205#[doc(hidden)]
206pub struct InnerAgent<C>
207where
208 C: Client,
209{
210 tx: Sender<Msg<C>>,
211 state_callback: Callback<OAuth2Context>,
212 config: Option<InnerConfig>,
213 client: Option<C>,
214 state: OAuth2Context,
215 session_state: Option<C::SessionState>,
216 timeout: Option<Timeout>,
217}
218
219#[doc(hidden)]
220#[derive(Clone, Debug)]
221pub struct InnerConfig {
222 scopes: Vec<String>,
223 grace_period: Duration,
224 max_expiration: Option<Duration>,
225 audience: Option<String>,
226 default_login_options: Option<LoginOptions>,
227 default_logout_options: Option<LogoutOptions>,
228}
229
230impl<C> InnerAgent<C>
231where
232 C: Client,
233{
234 pub fn new<F>(tx: Sender<Msg<C>>, state_callback: F) -> Self
235 where
236 F: Fn(OAuth2Context) + 'static,
237 {
238 Self {
239 tx,
240 state_callback: Callback::from(state_callback),
241 client: None,
242 config: None,
243 state: OAuth2Context::NotInitialized,
244 session_state: None,
245 timeout: None,
246 }
247 }
248
249 fn spawn(self, rx: Receiver<Msg<C>>) {
250 spawn_local(async move {
251 self.run(rx).await;
252 })
253 }
254
255 async fn run(mut self, mut rx: Receiver<Msg<C>>) {
256 loop {
257 match rx.recv().await {
258 Some(msg) => self.process(msg).await,
259 None => {
260 log::debug!("Agent channel closed");
261 break;
262 }
263 }
264 }
265 }
266
267 async fn process(&mut self, msg: Msg<C>) {
268 match msg {
269 Msg::Configure(config) => self.configure(config).await,
270 Msg::StartLogin(login) => {
271 if let Err(err) = self.start_login(login) {
272 log::info!("Failed to start login: {err}");
274 }
275 }
276 Msg::Logout(logout) => self.logout_opts(logout),
277 Msg::Refresh => self.refresh().await,
278 }
279 }
280
281 fn update_state(&mut self, state: OAuth2Context, session_state: Option<C::SessionState>) {
282 log::debug!("update state: {state:?}");
283
284 if let OAuth2Context::Authenticated(Authentication {
285 expires: Some(expires),
286 ..
287 }) = &state
288 {
289 let grace = self
290 .config
291 .as_ref()
292 .map(|c| c.grace_period)
293 .unwrap_or_default();
294
295 let mut expires = *expires;
296 if let Some(max) = self.config.as_ref().and_then(|cfg| cfg.max_expiration) {
297 expires = min(expires, max.as_secs());
299 }
300
301 let now = Date::now() / 1000f64;
303 let diff = expires as f64 - now - grace.as_secs_f64();
305
306 let tx = self.tx.clone();
307 if diff > 0f64 {
308 let millis = (diff * 1000f64).to_i32().unwrap_or(i32::MAX);
310 log::debug!("Starting timeout for: {millis}ms",);
311 self.timeout = Some(Timeout::new(millis as u32, move || {
312 let _ = tx.try_send(Msg::Refresh);
313 }));
314 } else {
315 let _ = tx.try_send(Msg::Refresh);
317 }
318 } else {
319 self.timeout = None;
320 }
321
322 self.notify_state(state.clone());
323
324 self.state = state;
325 self.session_state = session_state;
326 }
327
328 fn notify_state(&self, state: OAuth2Context) {
329 self.state_callback.emit(state);
330 }
331
332 async fn configured(&mut self, outcome: Result<(C, InnerConfig), OAuth2Error>) {
334 match outcome {
335 Ok((client, config)) => {
336 log::debug!("Client created");
337
338 self.client = Some(client);
339 self.config = Some(config);
340
341 if matches!(self.state, OAuth2Context::NotInitialized) {
342 let detected = self.detect_state().await;
343 log::debug!("Detected state: {detected:?}");
344 match detected {
345 Ok(true) => {
346 if let Err(e) = self.post_login_redirect() {
347 error!("Post-login redirect failed: {e}");
348 }
349 }
350 Ok(false) => {
351 self.update_state(
352 OAuth2Context::NotAuthenticated {
353 reason: Reason::NewSession,
354 },
355 None,
356 );
357 }
358 Err(err) => {
359 self.update_state(err.into(), None);
360 }
361 }
362 }
363 }
364 Err(err) => {
365 log::debug!("Failed to configure client: {err}");
366 if matches!(self.state, OAuth2Context::NotInitialized) {
367 self.update_state(err.into(), None);
368 }
369 }
370 }
371 }
372
373 async fn make_client(config: AgentConfiguration<C>) -> Result<(C, InnerConfig), OAuth2Error> {
374 let AgentConfiguration {
375 config,
376 scopes,
377 grace_period,
378 audience,
379 default_login_options,
380 default_logout_options,
381 max_expiration,
382 } = config;
383
384 let client = C::from_config(config).await?;
385
386 let inner = InnerConfig {
387 scopes,
388 grace_period,
389 audience,
390 default_login_options,
391 default_logout_options,
392 max_expiration,
393 };
394
395 Ok((client, inner))
396 }
397
398 async fn detect_state(&mut self) -> Result<bool, OAuth2Error> {
403 let client = self.client.as_ref().ok_or(OAuth2Error::NotInitialized)?;
404
405 let state = if let Some(state) = Self::find_query_state() {
406 state
407 } else {
408 return Ok(false);
410 };
411
412 log::debug!("Found state: {state:?}",);
413
414 if let Some(error) = state.error {
415 log::info!("Login error from server: {error}");
416
417 Self::cleanup_url();
419
420 return Err(OAuth2Error::LoginResult(error));
422 }
423
424 if let Some(code) = state.code {
425 Self::cleanup_url();
427
428 match state.state {
429 None => {
430 return Err(OAuth2Error::LoginResult(
431 "Missing state from server".to_string(),
432 ));
433 }
434 Some(state) => {
435 let stored_state = get_from_store(STORAGE_KEY_CSRF_TOKEN)?;
436
437 if state != stored_state {
438 return Err(OAuth2Error::LoginResult("State mismatch".to_string()));
439 }
440 }
441 }
442
443 let state: C::LoginState =
444 SessionStorage::get(STORAGE_KEY_LOGIN_STATE).map_err(|err| {
445 OAuth2Error::Storage(format!("Failed to load login state: {err}"))
446 })?;
447
448 log::debug!("Login state: {state:?}");
449
450 let redirect_url = get_from_store(STORAGE_KEY_REDIRECT_URL)?;
451 log::debug!("Redirect URL: {redirect_url}");
452 let redirect_url = Url::parse(&redirect_url).map_err(|err| {
453 OAuth2Error::LoginResult(format!("Failed to parse redirect URL: {err}"))
454 })?;
455
456 let client = client.clone().set_redirect_uri(redirect_url);
457
458 let result = client.exchange_code(code, state).await;
459 self.update_state_from_result(result);
460
461 Ok(true)
462 } else {
463 log::debug!("Neither an error nor a code. Continue without applying state.");
464 Ok(false)
465 }
466 }
467
468 fn post_login_redirect(&self) -> Result<(), OAuth2Error> {
469 let config = self.config.as_ref().ok_or(OAuth2Error::NotInitialized)?;
470 let Some(redirect_callback) = config
471 .default_login_options
472 .as_ref()
473 .and_then(|opts| opts.post_login_redirect_callback.clone())
474 else {
475 return Ok(());
476 };
477 let Some(url) = get_from_store_optional(STORAGE_KEY_POST_LOGIN_URL)? else {
478 return Ok(());
479 };
480 SessionStorage::delete(STORAGE_KEY_POST_LOGIN_URL);
481 redirect_callback.emit(url);
482
483 Ok(())
484 }
485
486 fn update_state_from_result(
487 &mut self,
488 result: Result<(OAuth2Context, C::SessionState), OAuth2Error>,
489 ) {
490 match result {
491 Ok((state, session_state)) => {
492 self.update_state(state, Some(session_state));
493 }
494 Err(err) => {
495 self.update_state(err.into(), None);
496 }
497 }
498 }
499
500 async fn refresh(&mut self) {
501 let (client, session_state) =
502 if let (Some(client), Some(session_state)) = (&self.client, &self.session_state) {
503 (client.clone(), session_state.clone())
504 } else {
505 self.update_state(
507 OAuth2Context::NotAuthenticated {
508 reason: Reason::Expired,
509 },
510 None,
511 );
512 return;
513 };
514
515 if let OAuth2Context::Authenticated(Authentication {
516 refresh_token: Some(refresh_token),
517 ..
518 }) = &self.state
519 {
520 log::debug!("Triggering refresh");
521
522 let result = client
523 .exchange_refresh_token(refresh_token.clone(), session_state)
524 .await;
525
526 if let Err(err) = &result {
527 log::warn!("Failed to refresh token: {err}");
528 }
529
530 self.update_state_from_result(result);
531 }
532 }
533
534 fn find_query_state() -> Option<State> {
536 if let Ok(url) = Self::current_url() {
537 let query: HashMap<_, _> = url.query_pairs().collect();
538
539 Some(State {
540 code: query.get("code").map(ToString::to_string),
541 state: query.get("state").map(ToString::to_string),
542 error: query.get("error").map(ToString::to_string),
543 })
544 } else {
545 None
546 }
547 }
548
549 fn current_url() -> Result<Url, String> {
550 let href = window().location().href().map_err(|err| {
551 err.as_string()
552 .unwrap_or_else(|| "unable to get current location".to_string())
553 })?;
554 Url::parse(&href).map_err(|err| err.to_string())
555 }
556
557 fn cleanup_url() {
558 if let Ok(mut url) = Self::current_url() {
559 url.set_query(None);
560 let state = history().state().unwrap_or(JsValue::NULL);
561 history()
562 .replace_state_with_url(&state, "", Some(url.as_str()))
563 .ok();
564 }
565 }
566
567 async fn configure(&mut self, config: AgentConfiguration<C>) {
568 self.configured(Self::make_client(config).await).await;
569 }
570
571 fn start_login(&mut self, options: Option<LoginOptions>) -> Result<(), OAuth2Error> {
572 let client = self.client.as_ref().ok_or(OAuth2Error::NotInitialized)?;
573 let config = self.config.as_ref().ok_or(OAuth2Error::NotInitialized)?;
574
575 let options =
576 options.unwrap_or_else(|| config.default_login_options.clone().unwrap_or_default());
577
578 let current_url = Self::current_url().map_err(OAuth2Error::StartLogin)?;
579
580 let redirect_url = options
582 .redirect_url
583 .or_else(|| {
584 config
585 .default_login_options
586 .as_ref()
587 .and_then(|opts| opts.redirect_url.clone())
588 })
589 .unwrap_or_else(|| current_url.clone());
590
591 if redirect_url != current_url {
592 SessionStorage::set(STORAGE_KEY_POST_LOGIN_URL, current_url)
593 .map_err(|err| OAuth2Error::StartLogin(err.to_string()))?;
594 }
595
596 let login_context = client.make_login_context(config, redirect_url.clone())?;
597
598 SessionStorage::set(STORAGE_KEY_CSRF_TOKEN, login_context.csrf_token)
599 .map_err(|err| OAuth2Error::StartLogin(err.to_string()))?;
600
601 SessionStorage::set(STORAGE_KEY_LOGIN_STATE, login_context.state)
602 .map_err(|err| OAuth2Error::StartLogin(err.to_string()))?;
603
604 SessionStorage::set(STORAGE_KEY_REDIRECT_URL, redirect_url)
605 .map_err(|err| OAuth2Error::StartLogin(err.to_string()))?;
606
607 let mut login_url = login_context.url;
608
609 login_url.query_pairs_mut().extend_pairs(options.query);
610
611 window()
614 .location()
615 .set_href(login_url.as_str())
616 .map_err(|err| {
617 OAuth2Error::StartLogin(
618 err.as_string()
619 .unwrap_or_else(|| "Unable to navigate to login page".to_string()),
620 )
621 })?;
622
623 Ok(())
624 }
625
626 fn logout_opts(&mut self, options: Option<LogoutOptions>) {
627 if let Some(client) = &self.client {
628 if let Some(session_state) = self.session_state.clone() {
629 log::debug!("Notify client of logout");
632 let options = options
633 .or_else(|| {
634 self.config
635 .as_ref()
636 .and_then(|config| config.default_logout_options.clone())
637 })
638 .unwrap_or_default();
639 client.logout(session_state, options);
640 }
641 }
642
643 self.update_state(
647 OAuth2Context::NotAuthenticated {
648 reason: Reason::Logout,
649 },
650 None,
651 );
652 }
653}
654
655impl<C> OAuth2Operations<C> for Agent<C>
656where
657 C: Client,
658{
659 fn configure(&self, config: AgentConfiguration<C>) -> Result<(), Error> {
660 self.tx
661 .try_send(Msg::Configure(config))
662 .map_err(|_| Error::NoAgent)
663 }
664
665 fn start_login(&self) -> Result<(), Error> {
666 self.tx
667 .try_send(Msg::StartLogin(None))
668 .map_err(|_| Error::NoAgent)
669 }
670
671 fn start_login_opts(&self, options: LoginOptions) -> Result<(), Error> {
672 self.tx
673 .try_send(Msg::StartLogin(Some(options)))
674 .map_err(|_| Error::NoAgent)
675 }
676
677 fn logout(&self) -> Result<(), Error> {
678 self.tx
679 .try_send(Msg::Logout(None))
680 .map_err(|_| Error::NoAgent)
681 }
682
683 fn logout_opts(&self, options: LogoutOptions) -> Result<(), Error> {
684 self.tx
685 .try_send(Msg::Logout(Some(options)))
686 .map_err(|_| Error::NoAgent)
687 }
688}