wae_authentication/jwt/
service.rs1use crate::jwt::{AccessTokenClaims, JwtAlgorithm, JwtClaims, JwtConfig, JwtError, JwtResult, RefreshTokenClaims};
4use chrono::Utc;
5use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
6use serde::{Serialize, de::DeserializeOwned};
7
8#[derive(Debug, Clone)]
10pub struct JwtService {
11 config: JwtConfig,
12 encoding_key: EncodingKey,
13 decoding_key: DecodingKey,
14}
15
16impl JwtService {
17 pub fn new(config: JwtConfig) -> JwtResult<Self> {
22 let (encoding_key, decoding_key) = Self::create_keys(&config)?;
23 Ok(Self { config, encoding_key, decoding_key })
24 }
25
26 fn create_keys(config: &JwtConfig) -> JwtResult<(EncodingKey, DecodingKey)> {
27 match config.algorithm {
28 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
29 let encoding_key = EncodingKey::from_secret(config.secret.as_bytes());
30 let decoding_key = DecodingKey::from_secret(config.secret.as_bytes());
31 Ok((encoding_key, decoding_key))
32 }
33 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 | JwtAlgorithm::ES256 | JwtAlgorithm::ES384 => {
34 let encoding_key =
35 EncodingKey::from_rsa_pem(config.secret.as_bytes()).map_err(|e| JwtError::KeyError(e.to_string()))?;
36
37 let public_key = config
38 .public_key
39 .as_ref()
40 .ok_or_else(|| JwtError::KeyError("public key is required for asymmetric algorithms".into()))?;
41
42 let decoding_key =
43 DecodingKey::from_rsa_pem(public_key.as_bytes()).map_err(|e| JwtError::KeyError(e.to_string()))?;
44
45 Ok((encoding_key, decoding_key))
46 }
47 }
48 }
49
50 pub fn generate_token<T: Serialize>(&self, claims: &T) -> JwtResult<String> {
55 let header = Header::new(self.config.algorithm.into());
56 encode(&header, claims, &self.encoding_key).map_err(Into::into)
57 }
58
59 pub fn verify_token<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
64 let mut validation = Validation::new(self.config.algorithm.into());
65
66 if self.config.validate_issuer {
67 if let Some(ref issuer) = self.config.issuer {
68 validation.set_issuer(&[issuer]);
69 }
70 }
71
72 if self.config.validate_audience {
73 if let Some(ref audience) = self.config.audience {
74 validation.set_audience(&[audience]);
75 }
76 }
77
78 validation.leeway = self.config.leeway_seconds as u64;
79
80 let token_data = decode::<T>(token, &self.decoding_key, &validation)?;
81 Ok(token_data.claims)
82 }
83
84 pub fn generate_access_token(&self, user_id: &str, claims: AccessTokenClaims) -> JwtResult<String> {
90 let mut jwt_claims = JwtClaims::new(user_id, self.config.access_token_expires_in());
91
92 if let Some(ref issuer) = self.config.issuer {
93 jwt_claims = jwt_claims.with_issuer(issuer.clone());
94 }
95
96 if let Some(ref audience) = self.config.audience {
97 jwt_claims = jwt_claims.with_audience(audience.clone());
98 }
99
100 jwt_claims.custom.insert("user_id".to_string(), serde_json::to_value(&claims.user_id).unwrap());
101
102 if let Some(ref username) = claims.username {
103 jwt_claims.custom.insert("username".to_string(), serde_json::to_value(username).unwrap());
104 }
105
106 if !claims.roles.is_empty() {
107 jwt_claims.custom.insert("roles".to_string(), serde_json::to_value(&claims.roles).unwrap());
108 }
109
110 if !claims.permissions.is_empty() {
111 jwt_claims.custom.insert("permissions".to_string(), serde_json::to_value(&claims.permissions).unwrap());
112 }
113
114 if let Some(ref session_id) = claims.session_id {
115 jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
116 }
117
118 self.generate_token(&jwt_claims)
119 }
120
121 pub fn verify_access_token(&self, token: &str) -> JwtResult<JwtClaims> {
126 self.verify_token(token)
127 }
128
129 pub fn generate_refresh_token(&self, user_id: &str, session_id: &str, version: u32) -> JwtResult<String> {
136 let mut jwt_claims = JwtClaims::new(user_id, self.config.refresh_token_expires_in());
137
138 if let Some(ref issuer) = self.config.issuer {
139 jwt_claims = jwt_claims.with_issuer(issuer.clone());
140 }
141
142 jwt_claims.custom.insert("session_id".to_string(), serde_json::to_value(session_id).unwrap());
143 jwt_claims.custom.insert("version".to_string(), serde_json::to_value(version).unwrap());
144 jwt_claims.custom.insert("token_type".to_string(), serde_json::to_value("refresh").unwrap());
145
146 self.generate_token(&jwt_claims)
147 }
148
149 pub fn verify_refresh_token(&self, token: &str) -> JwtResult<RefreshTokenClaims> {
154 let claims: JwtClaims = self.verify_token(token)?;
155
156 let token_type: Option<String> = claims.custom.get("token_type").and_then(|v| serde_json::from_value(v.clone()).ok());
157
158 if token_type.as_deref() != Some("refresh") {
159 return Err(JwtError::InvalidToken("not a refresh token".into()));
160 }
161
162 let session_id: String = claims
163 .custom
164 .get("session_id")
165 .and_then(|v| serde_json::from_value(v.clone()).ok())
166 .ok_or_else(|| JwtError::MissingClaim("session_id".into()))?;
167
168 let version: u32 = claims.custom.get("version").and_then(|v| serde_json::from_value(v.clone()).ok()).unwrap_or(0);
169
170 Ok(RefreshTokenClaims { user_id: claims.sub, session_id, version })
171 }
172
173 pub fn decode_unchecked<T: DeserializeOwned>(&self, token: &str) -> JwtResult<T> {
178 let mut validation = Validation::new(self.config.algorithm.into());
179 validation.insecure_disable_signature_validation();
180
181 let token_data = decode::<T>(token, &self.decoding_key, &validation)?;
182 Ok(token_data.claims)
183 }
184
185 pub fn get_remaining_ttl(&self, token: &str) -> JwtResult<i64> {
190 let claims: JwtClaims = self.verify_token(token)?;
191 let now = Utc::now().timestamp();
192 let remaining = claims.exp - now;
193 Ok(remaining.max(0))
194 }
195
196 pub fn is_token_expiring_soon(&self, token: &str, threshold_seconds: i64) -> JwtResult<bool> {
202 let remaining = self.get_remaining_ttl(token)?;
203 Ok(remaining < threshold_seconds)
204 }
205
206 pub fn config(&self) -> &JwtConfig {
208 &self.config
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct TokenPair {
215 pub access_token: String,
217
218 pub refresh_token: String,
220
221 pub token_type: String,
223
224 pub expires_in: i64,
226
227 pub refresh_expires_in: i64,
229}
230
231impl JwtService {
232 pub fn generate_token_pair(
239 &self,
240 user_id: &str,
241 access_claims: AccessTokenClaims,
242 session_id: &str,
243 ) -> JwtResult<TokenPair> {
244 let access_token = self.generate_access_token(user_id, access_claims)?;
245 let refresh_token = self.generate_refresh_token(user_id, session_id, 0)?;
246
247 Ok(TokenPair {
248 access_token,
249 refresh_token,
250 token_type: "Bearer".to_string(),
251 expires_in: self.config.access_token_expires_in(),
252 refresh_expires_in: self.config.refresh_token_expires_in(),
253 })
254 }
255}
256
257pub fn default_jwt_service() -> JwtResult<JwtService> {
259 JwtService::new(JwtConfig::default())
260}
261
262pub fn jwt_service_with_secret(secret: impl Into<String>) -> JwtResult<JwtService> {
264 JwtService::new(JwtConfig::new(secret))
265}