1use crate::error::{Error, Result};
4use ambient_id::Detector;
5use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone)]
10pub struct IdentityToken {
11 raw: String,
13 claims: TokenClaims,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TokenClaims {
20 pub iss: String,
22 pub sub: String,
24 #[serde(default)]
26 pub aud: Audience,
27 pub exp: u64,
29 #[serde(default)]
31 pub iat: u64,
32 #[serde(default)]
34 pub email: Option<String>,
35 #[serde(default)]
37 pub email_verified: Option<bool>,
38 #[serde(default)]
40 pub federated_claims: Option<FederatedClaims>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
45#[serde(untagged)]
46pub enum Audience {
47 #[default]
48 None,
49 Single(String),
50 Multiple(Vec<String>),
51}
52
53impl Audience {
54 pub fn contains(&self, value: &str) -> bool {
56 match self {
57 Audience::None => false,
58 Audience::Single(s) => s == value,
59 Audience::Multiple(v) => v.iter().any(|s| s == value),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FederatedClaims {
67 #[serde(default)]
69 pub connector_id: Option<String>,
70 #[serde(default)]
72 pub user_id: Option<String>,
73}
74
75impl IdentityToken {
76 pub fn from_jwt(token: &str) -> Result<Self> {
78 let parts: Vec<&str> = token.split('.').collect();
80 if parts.len() != 3 {
81 return Err(Error::Token("invalid JWT format".to_string()));
82 }
83
84 let payload = URL_SAFE_NO_PAD
86 .decode(parts[1])
87 .map_err(|e| Error::Token(format!("failed to decode payload: {}", e)))?;
88
89 let claims: TokenClaims = serde_json::from_slice(&payload)
91 .map_err(|e| Error::Token(format!("failed to parse claims: {}", e)))?;
92
93 Ok(Self {
94 raw: token.to_string(),
95 claims,
96 })
97 }
98
99 pub async fn detect_ambient() -> Result<Option<Self>> {
104 match Detector::new().detect("sigstore").await {
105 Ok(Some(token)) => Self::from_jwt(token.reveal()).map(Some),
106 Ok(None) => Ok(None),
107 Err(e) => Err(Error::Token(format!(
108 "failed to detect ambient credentials: {}",
109 e
110 ))),
111 }
112 }
113
114 pub fn new(token: impl Into<String>) -> Self {
116 let raw = token.into();
117 let claims = Self::parse_claims(&raw).unwrap_or_else(|_| TokenClaims {
119 iss: String::new(),
120 sub: String::new(),
121 aud: Audience::None,
122 exp: 0,
123 iat: 0,
124 email: None,
125 email_verified: None,
126 federated_claims: None,
127 });
128 Self { raw, claims }
129 }
130
131 fn parse_claims(token: &str) -> Result<TokenClaims> {
132 let parts: Vec<&str> = token.split('.').collect();
133 if parts.len() != 3 {
134 return Err(Error::Token("invalid JWT format".to_string()));
135 }
136 let payload = URL_SAFE_NO_PAD
137 .decode(parts[1])
138 .map_err(|e| Error::Token(format!("failed to decode payload: {}", e)))?;
139 serde_json::from_slice(&payload)
140 .map_err(|e| Error::Token(format!("failed to parse claims: {}", e)))
141 }
142
143 pub fn raw(&self) -> &str {
145 &self.raw
146 }
147
148 pub fn token(&self) -> &str {
150 &self.raw
151 }
152
153 pub fn issuer(&self) -> &str {
155 &self.claims.iss
156 }
157
158 pub fn subject(&self) -> &str {
160 &self.claims.sub
161 }
162
163 pub fn email(&self) -> Option<&str> {
165 self.claims.email.as_deref()
166 }
167
168 pub fn email_verified(&self) -> bool {
170 self.claims.email_verified.unwrap_or(false)
171 }
172
173 pub fn expiration(&self) -> u64 {
175 self.claims.exp
176 }
177
178 pub fn is_expired(&self) -> bool {
180 let now = std::time::SystemTime::now()
181 .duration_since(std::time::UNIX_EPOCH)
182 .unwrap_or_default()
183 .as_secs();
184 self.claims.exp < now
185 }
186
187 pub fn claims(&self) -> &TokenClaims {
189 &self.claims
190 }
191
192 pub fn identity(&self) -> &str {
194 self.claims.email.as_deref().unwrap_or(&self.claims.sub)
195 }
196}
197
198pub mod issuers {
200 pub const SIGSTORE_OAUTH: &str = "https://oauth2.sigstore.dev/auth";
202 pub const GITHUB_ACTIONS: &str = "https://token.actions.githubusercontent.com";
204 pub const GOOGLE: &str = "https://accounts.google.com";
206 pub const MICROSOFT: &str = "https://login.microsoftonline.com";
208 pub const GITLAB: &str = "https://gitlab.com";
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_audience_contains() {
218 let single = Audience::Single("test".to_string());
219 assert!(single.contains("test"));
220 assert!(!single.contains("other"));
221
222 let multiple = Audience::Multiple(vec!["a".to_string(), "b".to_string()]);
223 assert!(multiple.contains("a"));
224 assert!(multiple.contains("b"));
225 assert!(!multiple.contains("c"));
226 }
227
228 #[test]
229 fn test_parse_jwt() {
230 let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
233 let payload = URL_SAFE_NO_PAD.encode(
235 r#"{"iss":"https://test.com","sub":"user123","exp":9999999999,"email":"test@example.com"}"#,
236 );
237 let signature = "signature";
238 let jwt = format!("{}.{}.{}", header, payload, signature);
239
240 let token = IdentityToken::from_jwt(&jwt).unwrap();
241 assert_eq!(token.issuer(), "https://test.com");
242 assert_eq!(token.subject(), "user123");
243 assert_eq!(token.email(), Some("test@example.com"));
244 assert!(!token.is_expired());
245 }
246}