shield_oauth/
method.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use shield::{Action, Method, ShieldError, User, erased_method};
5
6use crate::{
7    actions::{OauthSignInAction, OauthSignInCallbackAction, OauthSignOutAction},
8    options::OauthOptions,
9    provider::OauthProvider,
10    session::OauthSession,
11    storage::OauthStorage,
12};
13
14pub const OAUTH_METHOD_ID: &str = "oauth";
15
16pub struct OauthMethod<U: User> {
17    options: OauthOptions,
18    providers: Vec<OauthProvider>,
19    storage: Arc<dyn OauthStorage<U>>,
20}
21
22impl<U: User> OauthMethod<U> {
23    pub fn new<S: OauthStorage<U> + 'static>(storage: S) -> Self {
24        Self {
25            options: OauthOptions::default(),
26            providers: vec![],
27            storage: Arc::new(storage),
28        }
29    }
30
31    pub fn with_options(mut self, options: OauthOptions) -> Self {
32        self.options = options;
33        self
34    }
35
36    pub fn with_providers<I: IntoIterator<Item = OauthProvider>>(mut self, providers: I) -> Self {
37        self.providers = providers.into_iter().collect();
38        self
39    }
40
41    async fn oauth_provider_by_id_or_slug(
42        &self,
43        provider_id: &str,
44    ) -> Result<Option<OauthProvider>, ShieldError> {
45        if let Some(provider) = self
46            .providers
47            .iter()
48            .find(|provider| provider.id == provider_id)
49        {
50            return Ok(Some(provider.clone()));
51        }
52
53        if let Some(provider) = self
54            .storage
55            .oauth_provider_by_id_or_slug(provider_id)
56            .await?
57        {
58            return Ok(Some(provider));
59        }
60
61        Ok(None)
62    }
63}
64
65#[async_trait]
66impl<U: User + 'static> Method for OauthMethod<U> {
67    type Provider = OauthProvider;
68    type Session = OauthSession;
69
70    fn id(&self) -> String {
71        OAUTH_METHOD_ID.to_owned()
72    }
73
74    fn actions(&self) -> Vec<Box<dyn Action<Self::Provider, Self::Session>>> {
75        vec![
76            Box::new(OauthSignInAction),
77            Box::new(OauthSignInCallbackAction::new(
78                self.options.clone(),
79                self.storage.clone(),
80            )),
81            Box::new(OauthSignOutAction),
82        ]
83    }
84
85    async fn providers(&self) -> Result<Vec<Self::Provider>, ShieldError> {
86        Ok(self
87            .providers
88            .iter()
89            .cloned()
90            .chain(self.storage.oauth_providers().await?)
91            .collect())
92    }
93
94    async fn provider_by_id(
95        &self,
96        provider_id: Option<&str>,
97    ) -> Result<Option<Self::Provider>, ShieldError> {
98        if let Some(provider_id) = provider_id {
99            self.oauth_provider_by_id_or_slug(provider_id).await
100        } else {
101            Ok(None)
102        }
103    }
104}
105
106erased_method!(OauthMethod, <U: User>);