1use std::time::{Duration, SystemTime};
2
3use axum::extract::{Query, Request, State};
4use axum::middleware::Next;
5use axum::response::IntoResponse;
6use axum_extra::typed_header::TypedHeader;
7use headers::{authorization, HeaderMapExt};
8use http::{request, HeaderValue, StatusCode};
9use serde::{Deserialize, Serialize};
10use spacetimedb::auth::identity::SpacetimeIdentityClaims;
11use spacetimedb::auth::identity::{JwtError, JwtErrorKind};
12use spacetimedb::auth::token_validation::{
13 new_validator, DefaultValidator, TokenSigner, TokenValidationError, TokenValidator,
14};
15use spacetimedb::auth::JwtKeys;
16use spacetimedb::energy::EnergyQuanta;
17use spacetimedb::identity::Identity;
18use uuid::Uuid;
19
20use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
21
22#[derive(Clone, Deserialize)]
28pub struct SpacetimeCreds {
29 token: String,
30}
31
32pub const LOCALHOST: &str = "localhost";
33
34impl SpacetimeCreds {
35 pub fn token(&self) -> &str {
37 &self.token
38 }
39
40 pub fn from_signed_token(token: String) -> Self {
41 Self { token }
42 }
43
44 pub fn to_header_value(&self) -> HeaderValue {
45 let mut val = HeaderValue::try_from(["Bearer ", self.token()].concat()).unwrap();
46 val.set_sensitive(true);
47 val
48 }
49
50 fn from_request_parts(parts: &request::Parts) -> Result<Option<Self>, headers::Error> {
52 let header = parts
53 .headers
54 .typed_try_get::<headers::Authorization<authorization::Bearer>>()?;
55 if let Some(headers::Authorization(bearer)) = header {
56 let token = bearer.token().to_owned();
57 return Ok(Some(SpacetimeCreds { token }));
58 }
59 if let Ok(Query(creds)) = Query::<Self>::try_from_uri(&parts.uri) {
60 return Ok(Some(creds));
61 }
62 Ok(None)
63 }
64}
65
66#[derive(Clone)]
71pub struct SpacetimeAuth {
72 pub creds: SpacetimeCreds,
73 pub identity: Identity,
74 pub subject: String,
75 pub issuer: String,
76}
77
78use jsonwebtoken;
79
80pub struct TokenClaims {
81 pub issuer: String,
82 pub subject: String,
83 pub audience: Vec<String>,
84}
85
86impl From<SpacetimeAuth> for TokenClaims {
87 fn from(claims: SpacetimeAuth) -> Self {
88 Self {
89 issuer: claims.issuer,
90 subject: claims.subject,
91 audience: Vec::new(),
93 }
94 }
95}
96
97impl TokenClaims {
98 pub fn new(issuer: String, subject: String) -> Self {
99 Self {
100 issuer,
101 subject,
102 audience: Vec::new(),
103 }
104 }
105
106 pub fn id(&self) -> Identity {
108 Identity::from_claims(&self.issuer, &self.subject)
109 }
110
111 pub fn encode_and_sign_with_expiry(
112 &self,
113 signer: &impl TokenSigner,
114 expiry: Option<Duration>,
115 ) -> Result<String, JwtError> {
116 let iat = SystemTime::now();
117 let exp = expiry.map(|dur| iat + dur);
118 let claims = SpacetimeIdentityClaims {
119 identity: self.id(),
120 subject: self.subject.clone(),
121 issuer: self.issuer.clone(),
122 audience: self.audience.clone(),
123 iat,
124 exp,
125 };
126 signer.sign(&claims)
127 }
128
129 pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<String, JwtError> {
130 self.encode_and_sign_with_expiry(signer, None)
131 }
132}
133
134impl SpacetimeAuth {
135 pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result<Self> {
137 let subject = Uuid::new_v4().to_string();
139 let claims = TokenClaims {
140 issuer: ctx.jwt_auth_provider().local_issuer().to_owned(),
141 subject: subject.clone(),
142 audience: vec!["spacetimedb".to_string()],
144 };
145
146 let identity = claims.id();
147 let creds = {
148 let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?;
149 SpacetimeCreds::from_signed_token(token)
150 };
151
152 Ok(Self {
153 creds,
154 identity,
155 subject,
156 issuer: ctx.jwt_auth_provider().local_issuer().to_string(),
157 })
158 }
159
160 pub fn into_headers(self) -> (TypedHeader<SpacetimeIdentity>, TypedHeader<SpacetimeIdentityToken>) {
162 (
163 TypedHeader(SpacetimeIdentity(self.identity)),
164 TypedHeader(SpacetimeIdentityToken(self.creds)),
165 )
166 }
167
168 pub fn re_sign_with_expiry(&self, signer: &impl TokenSigner, expiry: Duration) -> Result<String, JwtError> {
172 TokenClaims::from(self.clone()).encode_and_sign_with_expiry(signer, Some(expiry))
173 }
174}
175
176pub trait JwtAuthProvider: Sync + Send + TokenSigner {
178 type TV: TokenValidator + Send + Sync;
179 fn validator(&self) -> &Self::TV;
181
182 fn local_issuer(&self) -> &str;
184
185 fn public_key_bytes(&self) -> &[u8];
189}
190
191pub struct JwtKeyAuthProvider<TV: TokenValidator + Send + Sync> {
192 keys: JwtKeys,
193 local_issuer: String,
194 validator: TV,
195}
196
197pub type DefaultJwtAuthProvider = JwtKeyAuthProvider<DefaultValidator>;
198
199pub fn default_auth_environment(keys: JwtKeys, local_issuer: String) -> JwtKeyAuthProvider<DefaultValidator> {
201 let validator = new_validator(keys.public.clone(), local_issuer.clone());
202 JwtKeyAuthProvider::new(keys, local_issuer, validator)
203}
204
205impl<TV: TokenValidator + Send + Sync> JwtKeyAuthProvider<TV> {
206 fn new(keys: JwtKeys, local_issuer: String, validator: TV) -> Self {
207 Self {
208 keys,
209 local_issuer,
210 validator,
211 }
212 }
213}
214
215impl<TV: TokenValidator + Send + Sync> TokenSigner for JwtKeyAuthProvider<TV> {
216 fn sign<T: Serialize>(&self, claims: &T) -> Result<String, JwtError> {
217 let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
218 jsonwebtoken::encode(&header, &claims, &self.keys.private)
219 }
220}
221
222impl<TV: TokenValidator + Send + Sync> JwtAuthProvider for JwtKeyAuthProvider<TV> {
223 type TV = TV;
224
225 fn local_issuer(&self) -> &str {
226 &self.local_issuer
227 }
228
229 fn public_key_bytes(&self) -> &[u8] {
230 &self.keys.public_pem
231 }
232
233 fn validator(&self) -> &Self::TV {
234 &self.validator
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::auth::TokenClaims;
241 use anyhow::Ok;
242 use spacetimedb::auth::{token_validation::TokenValidator, JwtKeys};
243
244 #[tokio::test]
246 async fn decode_encoded_token() -> Result<(), anyhow::Error> {
247 let kp = JwtKeys::generate()?;
248
249 let claims = TokenClaims {
250 issuer: "localhost".to_string(),
251 subject: "test-subject".to_string(),
252 audience: vec!["spacetimedb".to_string()],
253 };
254 let id = claims.id();
255 let token = claims.encode_and_sign(&kp.private)?;
256 let decoded = kp.public.validate_token(&token).await?;
257
258 assert_eq!(decoded.identity, id);
259 Ok(())
260 }
261}
262
263pub struct SpacetimeAuthHeader {
264 auth: Option<SpacetimeAuth>,
265}
266
267#[async_trait::async_trait]
268impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthHeader {
269 type Rejection = AuthorizationRejection;
270 async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
271 let Some(creds) = SpacetimeCreds::from_request_parts(parts)? else {
272 return Ok(Self { auth: None });
273 };
274
275 let claims = state
276 .jwt_auth_provider()
277 .validator()
278 .validate_token(&creds.token)
279 .await
280 .map_err(AuthorizationRejection::Custom)?;
281
282 let auth = SpacetimeAuth {
283 creds,
284 identity: claims.identity,
285 subject: claims.subject,
286 issuer: claims.issuer,
287 };
288 Ok(Self { auth: Some(auth) })
289 }
290}
291
292#[derive(Debug, derive_more::From)]
294pub enum AuthorizationRejection {
295 Jwt(JwtError),
296 Header(headers::Error),
297 Custom(TokenValidationError),
298 Required,
299}
300
301impl IntoResponse for AuthorizationRejection {
302 fn into_response(self) -> axum::response::Response {
303 const ROTATED: (StatusCode, &str) = (
305 StatusCode::UNAUTHORIZED,
306 "Authorization failed: token not signed by this instance",
307 );
308 const INVALID: (StatusCode, &str) = (StatusCode::BAD_REQUEST, "Authorization is invalid: malformed token");
310 const REQUIRED: (StatusCode, &str) = (StatusCode::UNAUTHORIZED, "Authorization required");
312
313 log::trace!("Authorization rejection: {self:?}");
314
315 match self {
316 AuthorizationRejection::Jwt(e) if *e.kind() == JwtErrorKind::InvalidSignature => ROTATED.into_response(),
317 AuthorizationRejection::Jwt(_) | AuthorizationRejection::Header(_) => INVALID.into_response(),
318 AuthorizationRejection::Custom(msg) => (StatusCode::UNAUTHORIZED, format!("{msg:?}")).into_response(),
319 AuthorizationRejection::Required => REQUIRED.into_response(),
320 }
321 }
322}
323
324impl SpacetimeAuthHeader {
325 pub fn get(self) -> Option<SpacetimeAuth> {
326 self.auth
327 }
328
329 pub async fn get_or_create(
332 self,
333 ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized),
334 ) -> axum::response::Result<SpacetimeAuth> {
335 match self.auth {
336 Some(auth) => Ok(auth),
337 None => SpacetimeAuth::alloc(ctx).await,
338 }
339 }
340}
341
342pub struct SpacetimeAuthRequired(pub SpacetimeAuth);
343
344#[async_trait::async_trait]
345impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for SpacetimeAuthRequired {
346 type Rejection = AuthorizationRejection;
347 async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
348 let auth = SpacetimeAuthHeader::from_request_parts(parts, state).await?;
349 let auth = auth.get().ok_or(AuthorizationRejection::Required)?;
350 Ok(SpacetimeAuthRequired(auth))
351 }
352}
353
354pub struct SpacetimeIdentity(pub Identity);
355impl headers::Header for SpacetimeIdentity {
356 fn name() -> &'static http::HeaderName {
357 static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-identity");
358 &NAME
359 }
360
361 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
362 unimplemented!()
363 }
364
365 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
366 values.extend([self.0.to_hex().as_str().try_into().unwrap()])
367 }
368}
369
370pub struct SpacetimeIdentityToken(pub SpacetimeCreds);
371impl headers::Header for SpacetimeIdentityToken {
372 fn name() -> &'static http::HeaderName {
373 static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-identity-token");
374 &NAME
375 }
376
377 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
378 unimplemented!()
379 }
380
381 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
382 values.extend([self.0.token().try_into().unwrap()])
383 }
384}
385
386pub struct SpacetimeEnergyUsed(pub EnergyQuanta);
387impl headers::Header for SpacetimeEnergyUsed {
388 fn name() -> &'static http::HeaderName {
389 static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-energy-used");
390 &NAME
391 }
392
393 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
394 unimplemented!()
395 }
396
397 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
398 let mut buf = itoa::Buffer::new();
399 let value = buf.format(self.0.get());
400 values.extend([value.try_into().unwrap()]);
401 }
402}
403
404pub struct SpacetimeExecutionDurationMicros(pub Duration);
405impl headers::Header for SpacetimeExecutionDurationMicros {
406 fn name() -> &'static http::HeaderName {
407 static NAME: http::HeaderName = http::HeaderName::from_static("spacetime-execution-duration-micros");
408 &NAME
409 }
410
411 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(_values: &mut I) -> Result<Self, headers::Error> {
412 unimplemented!()
413 }
414
415 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
416 values.extend([(self.0.as_micros() as u64).into()])
417 }
418}
419
420pub async fn anon_auth_middleware<S: ControlStateDelegate + NodeDelegate>(
421 State(worker_ctx): State<S>,
422 auth: SpacetimeAuthHeader,
423 mut req: Request,
424 next: Next,
425) -> axum::response::Result<impl IntoResponse> {
426 let auth = auth.get_or_create(&worker_ctx).await?;
427 req.extensions_mut().insert(auth.clone());
428 let resp = next.run(req).await;
429 Ok((auth.into_headers(), resp))
430}