1use crate::Codec;
17use crate::errors::{JwtError, JwtOperation};
18use crate::jwt::JwtClaims;
19use crate::jwt::validation_result::JwtValidationResult;
20
21use std::sync::Arc;
22
23use tracing::{debug, warn};
24use webgates_core::accounts::Account;
25use webgates_core::authz::access_hierarchy::AccessHierarchy;
26
27#[derive(Debug, Clone)]
36pub struct JwtValidationService<C> {
37 codec: Arc<C>,
38 expected_issuer: String,
39}
40
41pub trait JwtClaimsVerifier<T>: Clone {
48 fn verify_token(&self, token_value: &str) -> std::result::Result<T, JwtError>;
50}
51
52impl<C> JwtValidationService<C> {
53 pub fn new(codec: Arc<C>, expected_issuer: &str) -> Self {
58 Self {
59 codec,
60 expected_issuer: expected_issuer.to_owned(),
61 }
62 }
63}
64
65impl<C, R, G> JwtValidationService<C>
66where
67 C: Codec<Payload = JwtClaims<Account<R, G>>>,
68 R: AccessHierarchy + Eq,
69 G: Eq + Clone,
70{
71 pub fn validate_token(&self, token_value: &str) -> JwtValidationResult<Account<R, G>> {
82 let jwt = match self.codec.decode(token_value.as_bytes()) {
83 Ok(jwt) => jwt,
84 Err(error) => {
85 debug!(error = %error, "JWT token decoding failed");
86 return JwtValidationResult::InvalidToken;
87 }
88 };
89
90 debug!(
91 account_id = %jwt.custom_claims.account_id,
92 issuer = %jwt.registered_claims.issuer,
93 "JWT token decoded successfully"
94 );
95
96 if !jwt.has_issuer(&self.expected_issuer) {
97 warn!(
98 expected_issuer = %self.expected_issuer,
99 actual_issuer = %jwt.registered_claims.issuer,
100 account_id = %jwt.custom_claims.account_id,
101 "JWT issuer validation failed"
102 );
103 return JwtValidationResult::InvalidIssuer {
104 expected: self.expected_issuer.clone(),
105 actual: jwt.registered_claims.issuer,
106 };
107 }
108
109 JwtValidationResult::Valid(jwt)
110 }
111}
112
113impl<C, R, G> JwtClaimsVerifier<JwtClaims<Account<R, G>>> for JwtValidationService<C>
114where
115 C: Codec<Payload = JwtClaims<Account<R, G>>> + Clone,
116 R: AccessHierarchy + Eq,
117 G: Eq + Clone,
118{
119 fn verify_token(
120 &self,
121 token_value: &str,
122 ) -> std::result::Result<JwtClaims<Account<R, G>>, JwtError> {
123 match self.validate_token(token_value) {
124 JwtValidationResult::Valid(jwt) => Ok(jwt),
125 JwtValidationResult::InvalidToken => Err(JwtError::processing(
126 JwtOperation::Validate,
127 "token verification failed",
128 )),
129 JwtValidationResult::InvalidIssuer { expected, actual } => Err(JwtError::processing(
130 JwtOperation::Validate,
131 format!("token issuer mismatch: expected `{expected}`, got `{actual}`"),
132 )),
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::errors::{JwtError, JwtOperation};
141 use crate::jwt::RegisteredClaims;
142
143 use std::sync::Arc;
144
145 use uuid::Uuid;
146 use webgates_core::groups::Group;
147 use webgates_core::permissions::Permissions;
148 use webgates_core::roles::Role;
149
150 #[derive(Clone)]
151 struct MockCodec {
152 should_fail_decode: bool,
153 mock_issuer: String,
154 }
155
156 impl MockCodec {
157 fn new() -> Self {
158 Self {
159 should_fail_decode: false,
160 mock_issuer: "test-issuer".to_string(),
161 }
162 }
163
164 fn with_decode_failure() -> Self {
165 Self {
166 should_fail_decode: true,
167 mock_issuer: String::new(),
168 }
169 }
170
171 fn with_different_issuer() -> Self {
172 Self {
173 should_fail_decode: false,
174 mock_issuer: "different-issuer".to_string(),
175 }
176 }
177 }
178
179 impl Codec for MockCodec {
180 type Payload = JwtClaims<Account<Role, Group>>;
181
182 fn decode(&self, _encoded_value: &[u8]) -> crate::Result<Self::Payload> {
183 if self.should_fail_decode {
184 return Err(crate::Error::Jwt(JwtError::processing(
185 JwtOperation::Decode,
186 "Mock decode failure",
187 )));
188 }
189
190 let account = Account {
191 account_id: Uuid::now_v7(),
192 user_id: "test_user".to_string(),
193 roles: vec![Role::User],
194 groups: vec![Group::new("engineering")],
195 permissions: Permissions::new(),
196 };
197
198 let registered_claims = RegisteredClaims {
199 issuer: self.mock_issuer.clone(),
200 subject: Some("test".to_string()),
201 audience: None,
202 expiration_time: 9_999_999_999,
203 not_before_time: None,
204 issued_at_time: 1_000_000_000,
205 jwt_id: None,
206 session_id: None,
207 };
208
209 Ok(JwtClaims {
210 custom_claims: account,
211 registered_claims,
212 })
213 }
214
215 fn encode(&self, _payload: &Self::Payload) -> crate::Result<Vec<u8>> {
216 Ok(Vec::new())
217 }
218 }
219
220 #[test]
221 fn validation_service_returns_valid_result_for_matching_issuer() {
222 let codec = Arc::new(MockCodec::new());
223 let service = JwtValidationService::new(codec, "test-issuer");
224
225 let result = service.validate_token("valid-token");
226
227 match result {
228 JwtValidationResult::Valid(jwt) => {
229 assert_eq!(jwt.custom_claims.user_id, "test_user");
230 assert_eq!(jwt.registered_claims.issuer, "test-issuer");
231 }
232 other => panic!("Expected valid token result, got {other:?}"),
233 }
234 }
235
236 #[test]
237 fn validation_service_returns_invalid_token_when_decoding_fails() {
238 let codec = Arc::new(MockCodec::with_decode_failure());
239 let service = JwtValidationService::new(codec, "test-issuer");
240
241 let result = service.validate_token("invalid-token");
242
243 assert!(matches!(result, JwtValidationResult::InvalidToken));
244 }
245
246 #[test]
247 fn validation_service_returns_invalid_issuer_when_issuer_differs() {
248 let codec = Arc::new(MockCodec::with_different_issuer());
249 let service = JwtValidationService::new(codec, "expected-issuer");
250
251 let result = service.validate_token("valid-token");
252
253 match result {
254 JwtValidationResult::InvalidIssuer { expected, actual } => {
255 assert_eq!(expected, "expected-issuer");
256 assert_eq!(actual, "different-issuer");
257 }
258 other => panic!("Expected invalid issuer result, got {other:?}"),
259 }
260 }
261}