tower_biscuit_auth/
lib.rs

1use arc_swap::ArcSwap;
2use std::{collections::BTreeMap, fmt, sync::Arc};
3
4use biscuit_auth::{error::Token, Authorizer, Biscuit, PublicKey};
5
6// mod http;
7
8#[derive(Clone)]
9pub struct BiscuitAuth<E> {
10    auth_info: Arc<ArcSwap<AuthConfig>>,
11    extractor: E,
12}
13
14impl<E> BiscuitAuth<E> {
15    pub fn new(auth_info: AuthConfig, extractor: E) -> Self {
16        Self {
17            auth_info: Arc::new(ArcSwap::from_pointee(auth_info)),
18            extractor,
19        }
20    }
21
22    pub fn update(&self, auth_info: AuthConfig) {
23        self.auth_info.store(Arc::new(auth_info))
24    }
25
26    pub fn to_auth_info(&self) -> AuthConfig {
27        self.auth_info.load().as_ref().clone()
28    }
29
30    pub fn check<R>(&self, request: &R, extractor: &E) -> Result<(), BiscuitAuthError>
31    where
32        E: AuthExtract<Request = R>,
33    {
34        let auth_info_guard: arc_swap::Guard<_, _> = self.auth_info.load();
35        let auth_info: &AuthConfig = &auth_info_guard;
36
37        // We play some weird error-handling games here so that we don't
38        // accidentally emit sensitive information in errors, which are
39        // likely to end up in logs somewhere.
40        let try_auth = || -> Result<(), BiscuitAuthError> {
41            let biscuit = auth_info.biscuit(&extractor.auth_token(request)?)?;
42
43            let mut authorizer = auth_info.authorizer.clone();
44            authorizer.add_token(&biscuit)?;
45
46            extractor.extract_context(request, &mut authorizer)?;
47
48            authorizer.authorize()?;
49
50            Ok(())
51        };
52
53        try_auth().map_err(|err| match auth_info.error_mode {
54            ErrorMode::Secure => BiscuitAuthError::Unknown,
55            ErrorMode::Verbose => err,
56        })
57    }
58}
59
60type TowerError = tower::BoxError;
61
62impl<Request, Extract> tower::filter::Predicate<Request> for BiscuitAuth<Extract>
63where
64    Extract: AuthExtract<Request = Request>,
65{
66    type Request = Request;
67
68    fn check(&mut self, request: Request) -> Result<Self::Request, TowerError> {
69        BiscuitAuth::check(self, &request, &self.extractor)?;
70        Ok(request)
71    }
72}
73
74#[derive(Debug)]
75pub enum BiscuitAuthError {
76    Unknown,
77    Other(tower::BoxError),
78    Failure(Token),
79}
80
81impl fmt::Display for BiscuitAuthError {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            BiscuitAuthError::Unknown => f.write_str("unauthorized"),
85            BiscuitAuthError::Other(err) => {
86                f.write_str("other error: ")?;
87                err.fmt(f)
88            }
89            BiscuitAuthError::Failure(err_token) => {
90                f.write_str("verification failure: ")?;
91                err_token.fmt(f)
92            }
93        }
94    }
95}
96
97impl std::error::Error for BiscuitAuthError {}
98
99impl From<Token> for BiscuitAuthError {
100    fn from(err: Token) -> Self {
101        Self::Failure(err)
102    }
103}
104
105#[derive(Clone, Copy, Debug)]
106pub enum ErrorMode {
107    Secure,
108    Verbose,
109}
110
111pub trait AuthExtract {
112    type Request;
113
114    /// How to find the
115    fn auth_token(&self, req: &Self::Request) -> Result<Vec<u8>, BiscuitAuthError>;
116
117    /// Use the information in the request to add any relevant infoformation
118    /// to the authorizer, such as if the request is a read or write request,
119    /// or the specific resource the request is trying to access.
120    fn extract_context(
121        &self,
122        _req: &Self::Request,
123        _authorizer: &mut Authorizer,
124    ) -> Result<(), Token> {
125        Ok(())
126    }
127}
128
129pub struct AuthContext {}
130
131#[derive(Clone, Debug)]
132pub struct RootKeys {
133    base: PublicKey,
134    by_id: BTreeMap<u32, PublicKey>,
135}
136
137impl RootKeys {
138    pub fn new(base: PublicKey) -> Self {
139        Self {
140            base,
141            by_id: BTreeMap::new(),
142        }
143    }
144}
145
146#[derive(Clone)]
147pub struct AuthConfig {
148    pub root_pubkeys: RootKeys,
149    pub authorizer: Authorizer<'static>,
150    pub error_mode: ErrorMode,
151}
152
153impl AuthConfig {
154    pub fn new(pubkey: RootKeys, authorizer: Authorizer<'static>, error_mode: ErrorMode) -> Self {
155        Self {
156            root_pubkeys: pubkey,
157            authorizer,
158            error_mode,
159        }
160    }
161
162    fn biscuit(&self, token: &[u8]) -> Result<Biscuit, Token> {
163        Biscuit::from(token, |id| self.pubkey_by_id(id))
164    }
165
166    fn pubkey_by_id(&self, id: Option<u32>) -> PublicKey {
167        match id {
168            None => self.root_pubkeys.base,
169            Some(id) => self
170                .root_pubkeys
171                .by_id
172                .get(&id)
173                .copied()
174                .unwrap_or(self.root_pubkeys.base),
175        }
176    }
177}
178
179impl fmt::Debug for AuthConfig {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        // TODO: Why does this warn here? Can we fix the warning?
182        #[allow(dead_code)]
183        #[derive(Debug)]
184        struct InnerAuthConfig<'a> {
185            root_keys: &'a RootKeys,
186            error_mode: &'a ErrorMode,
187        }
188
189        let inner = InnerAuthConfig {
190            root_keys: &self.root_pubkeys,
191            error_mode: &self.error_mode,
192        };
193
194        inner.fmt(f)
195    }
196}