shield_oauth/
provider.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use async_trait::async_trait;
use shield::{
    Provider, ProviderError, Response, Session, ShieldError, SignInCallbackRequest, SignInRequest,
    SignOutRequest, Subprovider, User,
};

use crate::{storage::OauthStorage, subprovider::OauthSubprovider};

pub const OAUTH_PROVIDER_ID: &str = "oauth";

pub struct OauthProvider<U: User> {
    subproviders: Vec<OauthSubprovider>,
    storage: Box<dyn OauthStorage<U>>,
}

impl<U: User> OauthProvider<U> {
    pub fn new<S: OauthStorage<U> + 'static>(storage: S) -> Self {
        Self {
            subproviders: vec![],
            storage: Box::new(storage),
        }
    }

    pub fn with_subproviders<I: IntoIterator<Item = OauthSubprovider>>(
        mut self,
        subproviders: I,
    ) -> Self {
        self.subproviders = subproviders.into_iter().collect();
        self
    }

    async fn oauth_subprovider_by_id(
        &self,
        subprovider_id: &str,
    ) -> Result<OauthSubprovider, ShieldError> {
        if let Some(subprovider) = self
            .subproviders
            .iter()
            .find(|subprovider| subprovider.id == subprovider_id)
        {
            return Ok(subprovider.clone());
        }

        if let Some(subprovider) = self.storage.oauth_subprovider_by_id(subprovider_id).await? {
            return Ok(subprovider);
        }

        Err(ProviderError::SubproviderNotFound(subprovider_id.to_owned()).into())
    }
}

#[async_trait]
impl<U: User> Provider for OauthProvider<U> {
    fn id(&self) -> String {
        OAUTH_PROVIDER_ID.to_owned()
    }

    async fn subproviders(&self) -> Result<Vec<Box<dyn Subprovider>>, ShieldError> {
        let subproviders = self
            .subproviders
            .iter()
            .cloned()
            .chain(self.storage.oauth_subproviders().await?);

        Ok(subproviders
            .map(|subprovider| Box::new(subprovider) as Box<dyn Subprovider>)
            .collect())
    }

    async fn subprovider_by_id(
        &self,
        subprovider_id: &str,
    ) -> Result<Option<Box<dyn Subprovider>>, ShieldError> {
        self.oauth_subprovider_by_id(subprovider_id)
            .await
            .map(|subprovider| Some(Box::new(subprovider) as Box<dyn Subprovider>))
    }

    async fn sign_in(
        &self,
        request: SignInRequest,
        _session: Session,
    ) -> Result<Response, ShieldError> {
        let _subprovider = match request.subprovider_id {
            Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
            None => return Err(ProviderError::SubproviderMissing.into()),
        };

        todo!("oauth sign in")
    }

    async fn sign_in_callback(
        &self,
        request: SignInCallbackRequest,
        _session: Session,
    ) -> Result<Response, ShieldError> {
        let _subprovider = match request.subprovider_id {
            Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
            None => return Err(ProviderError::SubproviderMissing.into()),
        };

        todo!("oauth sign in callback")
    }

    async fn sign_out(
        &self,
        request: SignOutRequest,
        _session: Session,
    ) -> Result<Response, ShieldError> {
        let _subprovider = match request.subprovider_id {
            Some(subprovider_id) => self.oauth_subprovider_by_id(&subprovider_id).await?,
            None => return Err(ProviderError::SubproviderMissing.into()),
        };

        todo!("oauth sign out")
    }
}