shield_oidc/
method.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use shield::{Action, Method, ShieldError, User, erased_method};
5
6use crate::{
7    actions::{OidcSignInAction, OidcSignInCallbackAction, OidcSignOutAction},
8    options::OidcOptions,
9    provider::OidcProvider,
10    session::OidcSession,
11    storage::OidcStorage,
12};
13
14pub const OIDC_METHOD_ID: &str = "oidc";
15
16pub struct OidcMethod<U: User> {
17    options: OidcOptions,
18    providers: Vec<OidcProvider>,
19    storage: Arc<dyn OidcStorage<U>>,
20}
21
22impl<U: User> OidcMethod<U> {
23    pub fn new<S: OidcStorage<U> + 'static>(storage: S) -> Self {
24        Self {
25            options: OidcOptions::default(),
26            providers: vec![],
27            storage: Arc::new(storage),
28        }
29    }
30
31    pub fn with_options(mut self, options: OidcOptions) -> Self {
32        self.options = options;
33        self
34    }
35
36    pub fn with_providers<I: IntoIterator<Item = OidcProvider>>(mut self, providers: I) -> Self {
37        self.providers = providers.into_iter().collect();
38        self
39    }
40
41    async fn oidc_provider_by_id_or_slug(
42        &self,
43        provider_id: &str,
44    ) -> Result<Option<OidcProvider>, ShieldError> {
45        if let Some(provider) = self
46            .providers
47            .iter()
48            .find(|provider| provider.id == provider_id)
49        {
50            return Ok(Some(provider.clone()));
51        }
52
53        if let Some(provider) = self
54            .storage
55            .oidc_provider_by_id_or_slug(provider_id)
56            .await?
57        {
58            return Ok(Some(provider));
59        }
60
61        Ok(None)
62    }
63}
64
65#[async_trait]
66impl<U: User + 'static> Method for OidcMethod<U> {
67    type Provider = OidcProvider;
68    type Session = OidcSession;
69
70    fn id(&self) -> String {
71        OIDC_METHOD_ID.to_owned()
72    }
73
74    fn actions(&self) -> Vec<Box<dyn Action<Self::Provider, Self::Session>>> {
75        vec![
76            Box::new(OidcSignInAction),
77            Box::new(OidcSignInCallbackAction::new(
78                self.options.clone(),
79                self.storage.clone(),
80            )),
81            Box::new(OidcSignOutAction),
82        ]
83    }
84
85    async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
86        Ok(self
87            .providers
88            .iter()
89            .cloned()
90            .chain(self.storage.oidc_providers().await?)
91            .collect())
92    }
93
94    async fn provider_by_id(
95        &self,
96        provider_id: Option<&str>,
97    ) -> Result<Option<Self::Provider>, ShieldError> {
98        if let Some(provider_id) = provider_id {
99            self.oidc_provider_by_id_or_slug(provider_id).await
100        } else {
101            Ok(None)
102        }
103    }
104}
105
106erased_method!(OidcMethod, <U: User>);