wae_authentication/jwt/
service.rs1use crate::jwt::{
4 AccessTokenClaims, JwtAlgorithm, JwtClaims, JwtConfig, JwtHeader, RefreshTokenClaims, decode_jwt, encode_jwt,
5};
6use chrono::Utc;
7use serde::{Serialize, de::DeserializeOwned};
8use wae_types::WaeError;
9
10pub type JwtResult<T> = Result<T, WaeError>;
12
13#[derive(Debug, Clone)]
15pub struct JwtService {
16 config: JwtConfig,
17}
18
19impl JwtService {
20 pub fn new(config: JwtConfig) -> JwtResult<Self> {
25 Ok(Self { config })
26 }
27
28 fn algorithm_name(&self) -> &'static str {
30 match self.config.algorithm {
31 JwtAlgorithm::HS256 => "HS256",
32 JwtAlgorithm::HS384 => "HS384",
33 JwtAlgorithm::HS512 => "HS512",
34 _ => unimplemented!("Only HS256, HS384, HS512 are supported in custom implementation"),
35 }
36 }
37
38 pub fn generate_token<T: Serialize>(&self, claims: &T) -> JwtResult<String> {
43 let header = JwtHeader::new(self.algorithm_name());
44 let secret = self.config.secret.as_bytes();
45 encode_jwt(&header, claims, secret).map_err(|e| e.into())
46 }
47
48 pub fn verify_token<T: DeserializeOwned + 'static>(&self, token: &str) -> JwtResult<T> {
53 let secret = self.config.secret.as_bytes();
54 let claims: T = decode_jwt(token, secret, true)?;
55
56 let now = Utc::now().timestamp();
57 let leeway = self.config.leeway_seconds;
58
59 if let Some(validated_claims) = (&claims as &dyn std::any::Any).downcast_ref::<JwtClaims>() {
60 if validated_claims.exp + leeway < now {
61 return Err(WaeError::token_expired());
62 }
63 if let Some(nbf) = validated_claims.nbf {
64 if nbf - leeway > now {
65 return Err(WaeError::token_not_valid_yet());
66 }
67 }
68 if self.config.validate_issuer {
69 if let Some(ref issuer) = self.config.issuer {
70 if validated_claims.iss.as_ref() != Some(issuer) {
71 return Err(WaeError::invalid_claim("issuer"));
72 }
73 }
74 }
75 if self.config.validate_audience {
76 if let Some(ref audience) = self.config.audience {
77 if validated_claims.aud.as_ref() != Some(audience) {
78 return Err(WaeError::invalid_audience());
79 }
80 }
81 }
82 }
83
84 Ok(claims)
85 }
86
87 pub fn generate_access_token(&self, user_id: &str, claims: AccessTokenClaims) -> JwtResult<String> {
93 let mut jwt_claims = JwtClaims::new(user_id, self.config.access_token_expires_in());
94
95 if let Some(ref issuer) = self.config.issuer {
96 jwt_claims = jwt_claims.with_issuer(issuer.clone());
97 }
98
99 if let Some(ref audience) = self.config.audience {
100 jwt_claims = jwt_claims.with_audience(audience.clone());
101 }
102
103 jwt_claims.custom.insert("user_id".to_string(), serde_json::to_value(&claims.user_id).unwrap());
104
105 if let Some(ref username) = claims.username {
106 jwt_claims.custom.insert("username".to_string(), serde_json::to_value(username).unwrap());
107 }
108
109 if !claims.roles.is_empty() {
110 jwt_claims.custom.insert("roles".to_string(), serde_json::to_value(&claims.roles).unwrap());
111 }
112
113 if !claims.permissions.is_empty() {
114 jwt_claims.custom.insert("permissions".to_string(), serde_json::to_value(&claims.permissions).unwrap());
115 }
116
117 if let Some(ref session_id) = claims.session_id {
118 jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
119 }
120
121 self.generate_token(&jwt_claims)
122 }
123
124 pub fn verify_access_token(&self, token: &str) -> JwtResult<JwtClaims> {
129 self.verify_token(token)
130 }
131
132 pub fn generate_refresh_token(&self, user_id: &str, session_id: &str, version: u32) -> JwtResult<String> {
139 let mut jwt_claims = JwtClaims::new(user_id, self.config.refresh_token_expires_in());
140
141 if let Some(ref issuer) = self.config.issuer {
142 jwt_claims = jwt_claims.with_issuer(issuer.clone());
143 }
144
145 jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
146 jwt_claims.custom.insert("version".to_string(), serde_json::to_value(version).unwrap());
147 jwt_claims.custom.insert("token_type".to_string(), serde_json::to_value("refresh").unwrap());
148
149 self.generate_token(&jwt_claims)
150 }
151
152 pub fn verify_refresh_token(&self, token: &str) -> JwtResult<RefreshTokenClaims> {
157 let claims: JwtClaims = self.verify_token(token)?;
158
159 let token_type: Option<String> = claims.custom.get("token_type").and_then(|v| serde_json::from_value(v.clone()).ok());
160
161 if token_type.as_deref() != Some("refresh") {
162 return Err(WaeError::invalid_token("not a refresh token"));
163 }
164
165 let session_id: String = claims
166 .custom
167 .get("session_id")
168 .and_then(|v| serde_json::from_value(v.clone()).ok())
169 .ok_or_else(|| WaeError::missing_claim("session_id"))?;
170
171 let version: u32 = claims.custom.get("version").and_then(|v| serde_json::from_value(v.clone()).ok()).unwrap_or(0);
172
173 Ok(RefreshTokenClaims { user_id: claims.sub, session_id, version })
174 }
175
176 pub fn decode_unchecked<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
181 let secret = self.config.secret.as_bytes();
182 decode_jwt(token, secret, false).map_err(|e| e.into())
183 }
184
185 pub fn get_remaining_ttl(&self, token: &str) -> JwtResult<i64> {
190 let claims: JwtClaims = self.verify_token(token)?;
191 let now = Utc::now().timestamp();
192 let remaining = claims.exp - now;
193 Ok(remaining.max(0))
194 }
195
196 pub fn is_token_expiring_soon(&self, token: &str, threshold_seconds: i64) -> JwtResult<bool> {
202 let remaining = self.get_remaining_ttl(token)?;
203 Ok(remaining < threshold_seconds)
204 }
205
206 pub fn config(&self) -> &JwtConfig {
208 &self.config
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct TokenPair {
215 pub access_token: String,
217
218 pub refresh_token: String,
220
221 pub token_type: String,
223
224 pub expires_in: i64,
226
227 pub refresh_expires_in: i64,
229}
230
231impl JwtService {
232 pub fn generate_token_pair(
239 &self,
240 user_id: &str,
241 access_claims: AccessTokenClaims,
242 session_id: &str,
243 ) -> JwtResult<TokenPair> {
244 let access_token = self.generate_access_token(user_id, access_claims)?;
245 let refresh_token = self.generate_refresh_token(user_id, session_id, 0)?;
246
247 Ok(TokenPair {
248 access_token,
249 refresh_token,
250 token_type: "Bearer".to_string(),
251 expires_in: self.config.access_token_expires_in(),
252 refresh_expires_in: self.config.refresh_token_expires_in(),
253 })
254 }
255
256 pub fn rotate_token_pair(&self, refresh_token: &str, new_access_claims: Option<AccessTokenClaims>) -> JwtResult<TokenPair> {
264 let refresh_claims = self.verify_refresh_token(refresh_token)?;
265
266 let new_version = refresh_claims.version + 1;
267 let access_claims = new_access_claims.unwrap_or_else(|| {
268 AccessTokenClaims::new(refresh_claims.user_id.clone()).with_session_id(refresh_claims.session_id.clone())
269 });
270
271 let access_token = self.generate_access_token(&refresh_claims.user_id, access_claims)?;
272 let new_refresh_token =
273 self.generate_refresh_token(&refresh_claims.user_id, &refresh_claims.session_id, new_version)?;
274
275 Ok(TokenPair {
276 access_token,
277 refresh_token: new_refresh_token,
278 token_type: "Bearer".to_string(),
279 expires_in: self.config.access_token_expires_in(),
280 refresh_expires_in: self.config.refresh_token_expires_in(),
281 })
282 }
283}
284
285pub fn default_jwt_service() -> JwtResult<JwtService> {
287 JwtService::new(JwtConfig::default())
288}
289
290pub fn jwt_service_with_secret(secret: impl Into<String>) -> JwtResult<JwtService> {
292 JwtService::new(JwtConfig::new(secret))
293}