1use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery, WorkspaceId};
2
3use crate::auto_refresh::AutoRefresh;
4use crate::oidc_refresher::{OidcProvider, OidcRefresher};
5use crate::token_store::{NoStore, TokenStore};
6use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken};
7
8pub struct OidcFederationStrategy<P, S = NoStore> {
50 inner: AutoRefresh<OidcRefresher<P>, S>,
51 expected_workspace: WorkspaceId,
52}
53
54impl<P: OidcProvider> OidcFederationStrategy<P> {
55 pub fn new(
60 region: Region,
61 workspace_id: WorkspaceId,
62 oidc_provider: P,
63 ) -> Result<Self, AuthError> {
64 Self::builder(region, workspace_id, oidc_provider).build()
65 }
66
67 pub fn builder(
69 region: Region,
70 workspace_id: WorkspaceId,
71 oidc_provider: P,
72 ) -> OidcFederationStrategyBuilder<P> {
73 OidcFederationStrategyBuilder {
74 region,
75 workspace_id,
76 oidc_provider,
77 base_url_override: None,
78 token_store: NoStore,
79 }
80 }
81}
82
83impl<P: OidcProvider, S: TokenStore> AuthStrategy for &OidcFederationStrategy<P, S> {
84 async fn get_token(self) -> Result<ServiceToken, AuthError> {
85 let token: ServiceToken = self.inner.get_token().await?;
86 let token_workspace = *token.workspace_id()?;
87 if token_workspace != self.expected_workspace {
88 return Err(AuthError::WorkspaceMismatch {
89 expected_workspace: self.expected_workspace,
90 token_workspace,
91 });
92 }
93 Ok(token)
94 }
95}
96
97pub struct OidcFederationStrategyBuilder<P, S = NoStore> {
101 region: Region,
102 workspace_id: WorkspaceId,
103 oidc_provider: P,
104 base_url_override: Option<url::Url>,
105 token_store: S,
106}
107
108impl<P, S> OidcFederationStrategyBuilder<P, S> {
109 #[cfg(any(test, feature = "test-utils"))]
113 pub fn base_url(mut self, url: url::Url) -> Self {
114 self.base_url_override = Some(url);
115 self
116 }
117
118 pub fn with_token_store<T: TokenStore>(self, store: T) -> OidcFederationStrategyBuilder<P, T> {
131 OidcFederationStrategyBuilder {
132 region: self.region,
133 workspace_id: self.workspace_id,
134 oidc_provider: self.oidc_provider,
135 base_url_override: self.base_url_override,
136 token_store: store,
137 }
138 }
139}
140
141impl<P: OidcProvider, S: TokenStore> OidcFederationStrategyBuilder<P, S> {
142 pub fn build(self) -> Result<OidcFederationStrategy<P, S>, AuthError> {
147 let base_url = match self.base_url_override {
148 Some(url) => url,
149 None => crate::cts_base_url_from_env()?
150 .unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
151 };
152 let expected_workspace = self.workspace_id;
153 let refresher = OidcRefresher::new(
154 self.oidc_provider,
155 self.workspace_id,
156 ensure_trailing_slash(base_url),
157 );
158 Ok(OidcFederationStrategy {
159 inner: AutoRefresh::with_store(refresher, self.token_store),
160 expected_workspace,
161 })
162 }
163}
164
165#[cfg(test)]
166#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
167mod tests {
168 use std::sync::Arc;
169 use std::time::{SystemTime, UNIX_EPOCH};
170
171 use cts_common::Region;
172 use mocktail::prelude::*;
173
174 use super::*;
175 use crate::oidc_refresher::OidcProviderFn;
176 use crate::{InMemoryTokenStore, SecretToken, Token, TokenStore};
177
178 fn jwt_with_workspace(workspace: &str) -> String {
182 use jsonwebtoken::{encode, EncodingKey, Header};
183 let now = SystemTime::now()
184 .duration_since(UNIX_EPOCH)
185 .expect("system clock")
186 .as_secs();
187 let claims = serde_json::json!({
188 "iss": "https://cts.example.com/",
189 "sub": "CS|test-user",
190 "aud": "test-audience",
191 "iat": now,
192 "exp": now + 3600,
193 "workspace": workspace,
194 "scope": "",
195 });
196 encode(
197 &Header::default(),
198 &claims,
199 &EncodingKey::from_secret(b"test-secret"),
200 )
201 .expect("JWT encode")
202 }
203
204 async fn start_mock_server_returning_jwt(workspace: &str) -> MockServer {
207 let mut mocks = MockSet::new();
208 let jwt = jwt_with_workspace(workspace);
209 mocks.mock(move |when, then| {
210 when.post().path("/api/authorise");
211 then.json(serde_json::json!({ "accessToken": jwt, "expiry": 3600 }));
212 });
213 let server =
214 MockServer::new_http("oidc-federation-strategy-workspace-test").with_mocks(mocks);
215 server.start().await.expect("mock server start");
216 server
217 }
218
219 fn test_region() -> Region {
220 Region::aws("ap-southeast-2").expect("region parses")
221 }
222
223 fn provider() -> OidcProviderFn<impl Fn() -> std::future::Ready<Result<SecretToken, AuthError>>>
224 {
225 OidcProviderFn::new(|| {
226 std::future::ready(Ok(SecretToken::new("header.payload.signature".to_string())))
227 })
228 }
229
230 #[tokio::test]
233 async fn returns_token_when_workspace_matches() {
234 const WS: &str = "ZVATKW3VHMFG27DY";
235 let server = start_mock_server_returning_jwt(WS).await;
236
237 let strategy =
238 OidcFederationStrategy::builder(test_region(), WS.parse().unwrap(), provider())
239 .base_url(server.url(""))
240 .build()
241 .expect("builder");
242
243 let token = (&strategy).get_token().await.expect("get_token");
244 assert_eq!(
245 token.workspace_id().expect("workspace_id").as_str(),
246 WS,
247 "happy-path token should carry the expected workspace",
248 );
249 }
250
251 #[tokio::test]
257 async fn errors_when_token_workspace_differs() {
258 const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
259 const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
260 let server = start_mock_server_returning_jwt(TOKEN_WS).await;
261
262 let strategy = OidcFederationStrategy::builder(
263 test_region(),
264 EXPECTED_WS.parse().unwrap(),
265 provider(),
266 )
267 .base_url(server.url(""))
268 .build()
269 .expect("builder");
270
271 let err = (&strategy)
272 .get_token()
273 .await
274 .expect_err("expected mismatch");
275 match err {
276 AuthError::WorkspaceMismatch {
277 expected_workspace,
278 token_workspace,
279 } => {
280 assert_eq!(expected_workspace.as_str(), EXPECTED_WS);
281 assert_eq!(token_workspace.as_str(), TOKEN_WS);
282 }
283 other => panic!("expected WorkspaceMismatch, got {other:?}"),
284 }
285 }
286
287 #[tokio::test]
291 async fn errors_with_invalid_token_when_jwt_malformed() {
292 let mut mocks = MockSet::new();
293 mocks.mock(|when, then| {
294 when.post().path("/api/authorise");
295 then.json(serde_json::json!({ "accessToken": "not-a-jwt", "expiry": 3600 }));
296 });
297 let server =
298 MockServer::new_http("oidc-federation-strategy-malformed-test").with_mocks(mocks);
299 server.start().await.expect("mock server start");
300
301 let strategy = OidcFederationStrategy::builder(
302 test_region(),
303 "ZVATKW3VHMFG27DY".parse().unwrap(),
304 provider(),
305 )
306 .base_url(server.url(""))
307 .build()
308 .expect("builder");
309
310 let err = (&strategy)
311 .get_token()
312 .await
313 .expect_err("expected invalid-token error");
314 assert!(
315 matches!(err, AuthError::InvalidToken(_)),
316 "expected InvalidToken, got {err:?}",
317 );
318 }
319
320 #[tokio::test]
326 async fn rejects_stored_token_for_different_workspace() {
327 const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
328 const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
329
330 let mut mocks = MockSet::new();
331 mocks.mock(|when, then| {
332 when.post().path("/api/authorise");
333 then.internal_server_error()
334 .json(serde_json::json!({"error": "store must satisfy the request"}));
335 });
336 let server =
337 MockServer::new_http("oidc-federation-strategy-store-mismatch-test").with_mocks(mocks);
338 server.start().await.expect("mock server start");
339
340 let now = SystemTime::now()
341 .duration_since(UNIX_EPOCH)
342 .expect("system clock")
343 .as_secs();
344 let stored = Token {
345 access_token: SecretToken::new(jwt_with_workspace(TOKEN_WS)),
346 token_type: "Bearer".to_string(),
347 expires_at: now + 3600,
348 refresh_token: None,
349 region: None,
350 client_id: None,
351 device_instance_id: None,
352 };
353 let store = Arc::new(InMemoryTokenStore::new());
354 store.save(&stored).await;
355
356 let strategy = OidcFederationStrategy::builder(
357 test_region(),
358 EXPECTED_WS.parse().unwrap(),
359 provider(),
360 )
361 .base_url(server.url(""))
362 .with_token_store(Arc::clone(&store))
363 .build()
364 .expect("builder");
365
366 let err = (&strategy)
367 .get_token()
368 .await
369 .expect_err("expected mismatch from stored token");
370 assert!(
371 matches!(err, AuthError::WorkspaceMismatch { .. }),
372 "expected WorkspaceMismatch, got {err:?}",
373 );
374 }
375
376 #[tokio::test]
381 async fn errors_on_each_subsequent_get_token_call() {
382 const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
383 const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
384 let server = start_mock_server_returning_jwt(TOKEN_WS).await;
385
386 let strategy = OidcFederationStrategy::builder(
387 test_region(),
388 EXPECTED_WS.parse().unwrap(),
389 provider(),
390 )
391 .base_url(server.url(""))
392 .build()
393 .expect("builder");
394
395 for call in 1..=2 {
396 let err = match (&strategy).get_token().await {
397 Ok(_) => panic!("call {call}: expected Err, got Ok"),
398 Err(e) => e,
399 };
400 assert!(
401 matches!(err, AuthError::WorkspaceMismatch { .. }),
402 "call {call}: expected WorkspaceMismatch, got {err:?}",
403 );
404 }
405 }
406}