spacetimedb_client_api/
auth.rs

1use std::time::{Duration, SystemTime};
2
3use axum::extract::{Query, Request, State};
4use axum::middleware::Next;
5use axum::response::IntoResponse;
6use axum_extra::typed_header::TypedHeader;
7use headers::{authorization, HeaderMapExt};
8use http::{request, HeaderValue, StatusCode};
9use serde::{Deserialize, Serialize};
10use spacetimedb::auth::identity::SpacetimeIdentityClaims;
11use spacetimedb::auth::identity::{JwtError, JwtErrorKind};
12use spacetimedb::auth::token_validation::{
13    new_validator, DefaultValidator, TokenSigner, TokenValidationError, TokenValidator,
14};
15use spacetimedb::auth::JwtKeys;
16use spacetimedb::energy::EnergyQuanta;
17use spacetimedb::identity::Identity;
18use uuid::Uuid;
19
20use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
21
22/// Credentials for login for a spacetime identity, represented as a JWT.
23///
24/// This can be passed as a header `Authentication: Bearer $token` or as
25/// a query param `?token=$token`, with the former taking precedence over
26/// the latter.
27#[derive(Clone, Deserialize)]
28pub struct SpacetimeCreds {
29    token: String,
30}
31
32pub const LOCALHOST: &str = "localhost";
33
34impl SpacetimeCreds {
35    /// The JWT token representing these credentials.
36    pub fn token(&self) -> &str {
37        &self.token
38    }
39
40    pub fn from_signed_token(token: String) -> Self {
41        Self { token }
42    }
43
44    pub fn to_header_value(&self) -> HeaderValue {
45        let mut val = HeaderValue::try_from(["Bearer ", self.token()].concat()).unwrap();
46        val.set_sensitive(true);
47        val
48    }
49
50    /// Extract credentials from the headers or else query string of a request.
51    fn from_request_parts(parts: &request::Parts) -> Result<Option<Self>, headers::Error> {
52        let header = parts
53            .headers
54            .typed_try_get::<headers::Authorization<authorization::Bearer>>()?;
55        if let Some(headers::Authorization(bearer)) = header {
56            let token = bearer.token().to_owned();
57            return Ok(Some(SpacetimeCreds { token }));
58        }
59        if let Ok(Query(creds)) = Query::<Self>::try_from_uri(&parts.uri) {
60            return Ok(Some(creds));
61        }
62        Ok(None)
63    }
64}
65
66/// The auth information in a request.
67///
68/// This is inserted as an extension by [`auth_middleware`]; make sure that's applied if you're making expecting
69/// this to be present.
70#[derive(Clone)]
71pub struct SpacetimeAuth {
72    pub creds: SpacetimeCreds,
73    pub identity: Identity,
74    pub subject: String,
75    pub issuer: String,
76}
77
78use jsonwebtoken;
79
80pub struct TokenClaims {
81    pub issuer: String,
82    pub subject: String,
83    pub audience: Vec<String>,
84}
85
86impl From<SpacetimeAuth> for TokenClaims {
87    fn from(claims: SpacetimeAuth) -> Self {
88        Self {
89            issuer: claims.issuer,
90            subject: claims.subject,
91            // This will need to be changed when we care about audiencies.
92            audience: Vec::new(),
93        }
94    }
95}
96
97impl TokenClaims {
98    pub fn new(issuer: String, subject: String) -> Self {
99        Self {
100            issuer,
101            subject,
102            audience: Vec::new(),
103        }
104    }
105
106    // Compute the id from the issuer and subject.
107    pub fn id(&self) -> Identity {
108        Identity::from_claims(&self.issuer, &self.subject)
109    }
110
111    pub fn encode_and_sign_with_expiry(
112        &self,
113        signer: &impl TokenSigner,
114        expiry: Option<Duration>,
115    ) -> Result<String, JwtError> {
116        let iat = SystemTime::now();
117        let exp = expiry.map(|dur| iat + dur);
118        let claims = SpacetimeIdentityClaims {
119            identity: self.id(),
120            subject: self.subject.clone(),
121            issuer: self.issuer.clone(),
122            audience: self.audience.clone(),
123            iat,
124            exp,
125        };
126        signer.sign(&claims)
127    }
128
129    pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<String, JwtError> {
130        self.encode_and_sign_with_expiry(signer, None)
131    }
132}
133
134impl SpacetimeAuth {
135    /// Allocate a new identity, and mint a new token for it.
136    pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result<Self> {
137        // Generate claims with a random subject.
138        let subject = Uuid::new_v4().to_string();
139        let claims = TokenClaims {
140            issuer: ctx.jwt_auth_provider().local_issuer().to_owned(),
141            subject: subject.clone(),
142            // Placeholder audience.
143            audience: vec!["spacetimedb".to_string()],
144        };
145
146        let identity = claims.id();
147        let creds = {
148            let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?;
149            SpacetimeCreds::from_signed_token(token)
150        };
151
152        Ok(Self {
153            creds,
154            identity,
155            subject,
156            issuer: ctx.jwt_auth_provider().local_issuer().to_string(),
157        })
158    }
159
160    /// Get the auth credentials as headers to be returned from an endpoint.
161    pub fn into_headers(self) -> (TypedHeader<SpacetimeIdentity>, TypedHeader<SpacetimeIdentityToken>) {
162        (
163            TypedHeader(SpacetimeIdentity(self.identity)),
164            TypedHeader(SpacetimeIdentityToken(self.creds)),
165        )
166    }
167
168    // Sign a new token with the same claims and a new expiry.
169    // Note that this will not change the issuer, so the private_key might not match.
170    // We do this to create short-lived tokens that we will be able to verify.
171    pub fn re_sign_with_expiry(&self, signer: &impl TokenSigner, expiry: Duration) -> Result<String, JwtError> {
172        TokenClaims::from(self.clone()).encode_and_sign_with_expiry(signer, Some(expiry))
173    }
174}
175
176// JwtAuthProvider is used for signing and verifying JWT tokens.
177pub trait JwtAuthProvider: Sync + Send + TokenSigner {
178    type TV: TokenValidator + Send + Sync;
179    /// Used to validate incoming JWTs.
180    fn validator(&self) -> &Self::TV;
181
182    /// The issuer to use when signing JWTs.
183    fn local_issuer(&self) -> &str;
184
185    /// Return the public key used to verify JWTs, as the bytes of a PEM public key file.
186    ///
187    /// The `/identity/public-key` route calls this method to return the public key to callers.
188    fn public_key_bytes(&self) -> &[u8];
189}
190
191pub struct JwtKeyAuthProvider<TV: TokenValidator + Send + Sync> {
192    keys: JwtKeys,
193    local_issuer: String,
194    validator: TV,
195}
196
197pub type DefaultJwtAuthProvider = JwtKeyAuthProvider<DefaultValidator>;
198
199// Create a new AuthEnvironment using the default caching validator.
200pub fn default_auth_environment(keys: JwtKeys, local_issuer: String) -> JwtKeyAuthProvider<DefaultValidator> {
201    let validator = new_validator(keys.public.clone(), local_issuer.clone());
202    JwtKeyAuthProvider::new(keys, local_issuer, validator)
203}
204
205impl<TV: TokenValidator + Send + Sync> JwtKeyAuthProvider<TV> {
206    fn new(keys: JwtKeys, local_issuer: String, validator: TV) -> Self {
207        Self {
208            keys,
209            local_issuer,
210            validator,
211        }
212    }
213}
214
215impl<TV: TokenValidator + Send + Sync> TokenSigner for JwtKeyAuthProvider<TV> {
216    fn sign<T: Serialize>(&self, claims: &T) -> Result<String, JwtError> {
217        let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
218        jsonwebtoken::encode(&header, &claims, &self.keys.private)
219    }
220}
221
222impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV> {
223    type TV = TV;
224
225    fn local_issuer(&self) -> &str {
226        &self.local_issuer
227    }
228
229    fn public_key_bytes(&self) -> &[u8] {
230        &self.keys.public_pem
231    }
232
233    fn validator(&self) -> &Self::TV {
234        &self.validator
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::auth::TokenClaims;
241    use anyhow::Ok;
242    use spacetimedb::auth::{token_validation::TokenValidator, JwtKeys};
243
244    // Make sure that when we encode TokenClaims, we can decode to get the expected identity.
245    #[tokio::test]
246    async fn decode_encoded_token() -> Result<(), anyhow::Error> {
247        let kp = JwtKeys::generate()?;
248
249        let claims = TokenClaims {
250            issuer: "localhost".to_string(),
251            subject: "test-subject".to_string(),
252            audience: vec!["spacetimedb".to_string()],
253        };
254        let id = claims.id();
255        let token = claims.encode_and_sign(&kp.private)?;
256        let decoded = kp.public.validate_token(&token).await?;
257
258        assert_eq!(decoded.identity, id);
259        Ok(())
260    }
261}
262
263pub struct SpacetimeAuthHeader {
264    auth: Option<SpacetimeAuth>,
265}
266
267#[async_trait::async_trait]
268impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthHeader {
269    type Rejection = AuthorizationRejection;
270    async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
271        let Some(creds) = SpacetimeCreds::from_request_parts(parts)? else {
272            return Ok(Self { auth: None });
273        };
274
275        let claims = state
276            .jwt_auth_provider()
277            .validator()
278            .validate_token(&creds.token)
279            .await
280            .map_err(AuthorizationRejection::Custom)?;
281
282        let auth = SpacetimeAuth {
283            creds,
284            identity: claims.identity,
285            subject: claims.subject,
286            issuer: claims.issuer,
287        };
288        Ok(Self { auth: Some(auth) })
289    }
290}
291
292/// A response by the API signifying that an authorization was rejected with the `reason` for this.
293#[derive(Debug, derive_more::From)]
294pub enum AuthorizationRejection {
295    Jwt(JwtError),
296    Header(headers::Error),
297    Custom(TokenValidationError),
298    Required,
299}
300
301impl IntoResponse for AuthorizationRejection {
302    fn into_response(self) -> axum::response::Response {
303        // Most likely, the server key was rotated.
304        const ROTATED: (StatusCode, &str) = (
305            StatusCode::UNAUTHORIZED,
306            "Authorization failed: token not signed by this instance",
307        );
308        // The JWT is malformed, see SpacetimeCreds for specifics on the format.
309        const INVALID: (StatusCode, &str) = (StatusCode::BAD_REQUEST, "Authorization is invalid: malformed token");
310        // Sensible fallback if no auth header is present.
311        const REQUIRED: (StatusCode, &str) = (StatusCode::UNAUTHORIZED, "Authorization required");
312
313        log::trace!("Authorization rejection: {self:?}");
314
315        match self {
316            AuthorizationRejection::Jwt(e) if *e.kind() == JwtErrorKind::InvalidSignature => ROTATED.into_response(),
317            AuthorizationRejection::Jwt(_) | AuthorizationRejection::Header(_) => INVALID.into_response(),
318            AuthorizationRejection::Custom(msg) => (StatusCode::UNAUTHORIZED, format!("{msg:?}")).into_response(),
319            AuthorizationRejection::Required => REQUIRED.into_response(),
320        }
321    }
322}
323
324impl SpacetimeAuthHeader {
325    pub fn get(self) -> Option<SpacetimeAuth> {
326        self.auth
327    }
328
329    /// Given an authorization header we will try to get the identity and token from the auth header (as JWT).
330    /// If there is no JWT in the auth header we will create a new identity and token and return it.
331    pub async fn get_or_create(
332        self,
333        ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized),
334    ) -> axum::response::Result<SpacetimeAuth> {
335        match self.auth {
336            Some(auth) => Ok(auth),
337            None => SpacetimeAuth::alloc(ctx).await,
338        }
339    }
340}
341
342pub struct SpacetimeAuthRequired(pub SpacetimeAuth);
343
344#[async_trait::async_trait]
345impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthRequired {
346    type Rejection = AuthorizationRejection;
347    async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
348        let auth = SpacetimeAuthHeader::from_request_parts(parts, state).await?;
349        let auth = auth.get().ok_or(AuthorizationRejection::Required)?;
350        Ok(SpacetimeAuthRequired(auth))
351    }
352}
353
354pub struct SpacetimeIdentity(pub Identity);
355impl headers::Header for SpacetimeIdentity {
356    fn name() -> &'static http::HeaderName {
357        static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-identity");
358        &NAME
359    }
360
361    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
362        unimplemented!()
363    }
364
365    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
366        values.extend([self.0.to_hex().as_str().try_into().unwrap()])
367    }
368}
369
370pub struct SpacetimeIdentityToken(pub SpacetimeCreds);
371impl headers::Header for SpacetimeIdentityToken {
372    fn name() -> &'static http::HeaderName {
373        static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-identity-token");
374        &NAME
375    }
376
377    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
378        unimplemented!()
379    }
380
381    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
382        values.extend([self.0.token().try_into().unwrap()])
383    }
384}
385
386pub struct SpacetimeEnergyUsed(pub EnergyQuanta);
387impl headers::Header for SpacetimeEnergyUsed {
388    fn name() -> &'static http::HeaderName {
389        static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-energy-used");
390        &NAME
391    }
392
393    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
394        unimplemented!()
395    }
396
397    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
398        let mut buf = itoa::Buffer::new();
399        let value = buf.format(self.0.get());
400        values.extend([value.try_into().unwrap()]);
401    }
402}
403
404pub struct SpacetimeExecutionDurationMicros(pub Duration);
405impl headers::Header for SpacetimeExecutionDurationMicros {
406    fn name() -> &'static http::HeaderName {
407        static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-execution-duration-micros");
408        &NAME
409    }
410
411    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
412        unimplemented!()
413    }
414
415    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
416        values.extend([(self.0.as_micros() as u64).into()])
417    }
418}
419
420pub async fn anon_auth_middleware<S: ControlStateDelegate + NodeDelegate>(
421    State(worker_ctx): State<S>,
422    auth: SpacetimeAuthHeader,
423    mut req: Request,
424    next: Next,
425) -> axum::response::Result<impl IntoResponse> {
426    let auth = auth.get_or_create(&worker_ctx).await?;
427    req.extensions_mut().insert(auth.clone());
428    let resp = next.run(req).await;
429    Ok((auth.into_headers(), resp))
430}