spacetimedb_auth/
identity.rs1pub use jsonwebtoken::errors::Error as JwtError;
2pub use jsonwebtoken::errors::ErrorKind as JwtErrorKind;
3pub use jsonwebtoken::{DecodingKey, EncodingKey};
4use serde::Deserializer;
5use serde::{Deserialize, Serialize};
6use spacetimedb_lib::Identity;
7use std::time::SystemTime;
8
9#[serde_with::serde_as]
11#[derive(Debug, Serialize, Deserialize)]
12pub struct SpacetimeIdentityClaims {
13 #[serde(rename = "hex_identity")]
14 pub identity: Identity,
15 #[serde(rename = "sub")]
16 pub subject: String,
17 #[serde(rename = "iss")]
18 pub issuer: String,
19 #[serde(rename = "aud")]
20 pub audience: Vec<String>,
21
22 #[serde_as(as = "serde_with::TimestampSeconds")]
24 pub iat: SystemTime,
25 #[serde_as(as = "Option<serde_with::TimestampSeconds>")]
26 pub exp: Option<SystemTime>,
27}
28
29fn deserialize_audience<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
30where
31 D: Deserializer<'de>,
32{
33 #[derive(Deserialize)]
35 #[serde(untagged)]
36 enum Audience {
37 Single(String),
38 Multiple(Vec<String>),
39 }
40
41 let audience = Audience::deserialize(deserializer)?;
43
44 Ok(match audience {
46 Audience::Single(s) => vec![s],
47 Audience::Multiple(v) => v,
48 })
49}
50
51#[serde_with::serde_as]
54#[derive(Debug, Serialize, Deserialize)]
55pub struct IncomingClaims {
56 #[serde(rename = "hex_identity")]
57 pub identity: Option<Identity>,
58 #[serde(rename = "sub")]
59 pub subject: String,
60 #[serde(rename = "iss")]
61 pub issuer: String,
62 #[serde(rename = "aud", default, deserialize_with = "deserialize_audience")]
63 pub audience: Vec<String>,
64
65 #[serde_as(as = "serde_with::TimestampSeconds")]
67 pub iat: SystemTime,
68 #[serde_as(as = "Option<serde_with::TimestampSeconds>")]
69 pub exp: Option<SystemTime>,
70}
71
72impl TryInto<SpacetimeIdentityClaims> for IncomingClaims {
73 type Error = anyhow::Error;
74
75 fn try_into(self) -> anyhow::Result<SpacetimeIdentityClaims> {
76 if self.issuer.len() > 128 {
78 return Err(anyhow::anyhow!("Issuer too long: {:?}", self.issuer));
79 }
80 if self.subject.len() > 128 {
81 return Err(anyhow::anyhow!("Subject too long: {:?}", self.subject));
82 }
83 if self.issuer.is_empty() {
85 return Err(anyhow::anyhow!("Issuer empty"));
86 }
87 if self.subject.is_empty() {
88 return Err(anyhow::anyhow!("Subject empty"));
89 }
90
91 let computed_identity = Identity::from_claims(&self.issuer, &self.subject);
92 if let Some(token_identity) = self.identity {
94 if token_identity != computed_identity {
95 return Err(anyhow::anyhow!(
96 "Identity mismatch: token identity {:?} does not match computed identity {:?}",
97 token_identity,
98 computed_identity,
99 ));
100 }
101 }
102
103 Ok(SpacetimeIdentityClaims {
104 identity: computed_identity,
105 subject: self.subject,
106 issuer: self.issuer,
107 audience: self.audience,
108 iat: self.iat,
109 exp: self.exp,
110 })
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use serde_json::json;
118 use std::time::UNIX_EPOCH;
119
120 #[test]
121 fn test_deserialize_audience_single_string() {
122 let json_data = json!({
123 "sub": "123",
124 "iss": "example.com",
125 "aud": "audience1",
126 "iat": 1693425600,
127 "exp": 1693512000
128 });
129
130 let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
131
132 assert_eq!(claims.audience, vec!["audience1"]);
133 assert_eq!(claims.subject, "123");
134 assert_eq!(claims.issuer, "example.com");
135 assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
136 assert_eq!(
137 claims.exp,
138 Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
139 );
140 }
141
142 #[test]
143 fn test_deserialize_audience_multiple_strings() {
144 let json_data = json!({
145 "sub": "123",
146 "iss": "example.com",
147 "aud": ["audience1", "audience2"],
148 "iat": 1693425600,
149 "exp": 1693512000
150 });
151
152 let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
153
154 assert_eq!(claims.audience, vec!["audience1", "audience2"]);
155 assert_eq!(claims.subject, "123");
156 assert_eq!(claims.issuer, "example.com");
157 assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
158 assert_eq!(
159 claims.exp,
160 Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
161 );
162 }
163
164 #[test]
165 fn test_deserialize_audience_missing_field() {
166 let json_data = json!({
167 "sub": "123",
168 "iss": "example.com",
169 "iat": 1693425600,
170 "exp": 1693512000
171 });
172
173 let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
174
175 assert!(claims.audience.is_empty()); assert_eq!(claims.subject, "123");
177 assert_eq!(claims.issuer, "example.com");
178 assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
179 assert_eq!(
180 claims.exp,
181 Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
182 );
183 }
184}