Skip to main content

tapis_authenticator/
client.rs

1use crate::apis::{
2    admin_api, clients_api, configuration, health_check_api, metadata_api, profiles_api,
3    tokens_api, Error,
4};
5use crate::models;
6use http::header::{HeaderMap, HeaderValue};
7use reqwest::{Client, Request, Response};
8use reqwest_middleware::{ClientBuilder, Middleware, Next, Result as MiddlewareResult};
9use std::sync::Arc;
10use tapis_core::TokenProvider;
11
12tokio::task_local! {
13    /// Extra headers to inject into every request within a [`with_headers`] scope.
14    static EXTRA_HEADERS: HeaderMap;
15}
16
17/// Run an async call with additional HTTP headers injected into every request
18/// made within the future `f`. Headers are scoped to this task only, so
19/// concurrent calls with different headers are safe.
20pub async fn with_headers<F, T>(headers: HeaderMap, f: F) -> T
21where
22    F: std::future::Future<Output = T>,
23{
24    EXTRA_HEADERS.scope(headers, f).await
25}
26
27#[derive(Debug)]
28struct LoggingMiddleware;
29
30#[derive(Debug)]
31struct HeaderInjectionMiddleware;
32
33#[async_trait::async_trait]
34impl Middleware for LoggingMiddleware {
35    async fn handle(
36        &self,
37        req: Request,
38        extensions: &mut http::Extensions,
39        next: Next<'_>,
40    ) -> MiddlewareResult<Response> {
41        let method = req.method().clone();
42        let url = req.url().clone();
43        println!("Tapis SDK request: {} {}", method, url);
44        next.run(req, extensions).await
45    }
46}
47
48#[async_trait::async_trait]
49impl Middleware for HeaderInjectionMiddleware {
50    async fn handle(
51        &self,
52        mut req: Request,
53        extensions: &mut http::Extensions,
54        next: Next<'_>,
55    ) -> MiddlewareResult<Response> {
56        let _ = EXTRA_HEADERS.try_with(|headers| {
57            for (k, v) in headers {
58                req.headers_mut().insert(k, v.clone());
59            }
60        });
61        next.run(req, extensions).await
62    }
63}
64
65fn validate_tracking_id(tracking_id: &str) -> Result<(), String> {
66    if !tracking_id.is_ascii() {
67        return Err("X-Tapis-Tracking-ID must be an entirely ASCII string.".to_string());
68    }
69    if tracking_id.len() > 126 {
70        return Err("X-Tapis-Tracking-ID must be less than 126 characters.".to_string());
71    }
72    if tracking_id.matches('.').count() != 1 {
73        return Err("X-Tapis-Tracking-ID must contain exactly one '.' (format: <namespace>.<unique_identifier>).".to_string());
74    }
75    if tracking_id.starts_with('.') || tracking_id.ends_with('.') {
76        return Err("X-Tapis-Tracking-ID cannot start or end with '.'.".to_string());
77    }
78    let (namespace, unique_id) = tracking_id.split_once('.').unwrap();
79    if !namespace.chars().all(|c| c.is_alphanumeric() || c == '_') {
80        return Err("X-Tapis-Tracking-ID namespace must contain only alphanumeric characters and underscores.".to_string());
81    }
82    if !unique_id.chars().all(|c| c.is_alphanumeric() || c == '-') {
83        return Err("X-Tapis-Tracking-ID unique identifier must contain only alphanumeric characters and hyphens.".to_string());
84    }
85    Ok(())
86}
87
88#[derive(Debug)]
89struct TrackingIdMiddleware;
90
91#[async_trait::async_trait]
92impl Middleware for TrackingIdMiddleware {
93    async fn handle(
94        &self,
95        mut req: Request,
96        extensions: &mut http::Extensions,
97        next: Next<'_>,
98    ) -> MiddlewareResult<Response> {
99        let tracking_key = req
100            .headers()
101            .keys()
102            .find(|k| {
103                let s = k.as_str();
104                s.eq_ignore_ascii_case("x-tapis-tracking-id")
105                    || s.eq_ignore_ascii_case("x_tapis_tracking_id")
106            })
107            .cloned();
108        if let Some(key) = tracking_key {
109            let tracking_id = req
110                .headers()
111                .get(&key)
112                .and_then(|v| v.to_str().ok())
113                .map(|s| s.to_owned());
114            if let Some(id) = tracking_id {
115                req.headers_mut().remove(&key);
116                validate_tracking_id(&id)
117                    .map_err(|e| reqwest_middleware::Error::Middleware(anyhow::anyhow!(e)))?;
118                let name = reqwest::header::HeaderName::from_static("x-tapis-tracking-id");
119                let value = reqwest::header::HeaderValue::from_str(&id)
120                    .map_err(|e| reqwest_middleware::Error::Middleware(anyhow::anyhow!(e)))?;
121                req.headers_mut().insert(name, value);
122            }
123        }
124        next.run(req, extensions).await
125    }
126}
127
128/// Decode a base64url-encoded segment (no padding required) into raw bytes.
129fn decode_base64url(s: &str) -> Option<Vec<u8>> {
130    fn val(c: u8) -> Option<u8> {
131        match c {
132            b'A'..=b'Z' => Some(c - b'A'),
133            b'a'..=b'z' => Some(c - b'a' + 26),
134            b'0'..=b'9' => Some(c - b'0' + 52),
135            b'-' | b'+' => Some(62),
136            b'_' | b'/' => Some(63),
137            _ => None,
138        }
139    }
140    let chars: Vec<u8> = s.bytes().filter(|&b| b != b'=').collect();
141    let mut out = Vec::with_capacity(chars.len() * 3 / 4 + 1);
142    let mut i = 0;
143    while i < chars.len() {
144        let a = val(chars[i])?;
145        let b = val(*chars.get(i + 1)?)?;
146        out.push((a << 2) | (b >> 4));
147        if let Some(&c3) = chars.get(i + 2) {
148            let c = val(c3)?;
149            out.push(((b & 0x0f) << 4) | (c >> 2));
150            if let Some(&c4) = chars.get(i + 3) {
151                let d = val(c4)?;
152                out.push(((c & 0x03) << 6) | d);
153            }
154        }
155        i += 4;
156    }
157    Some(out)
158}
159
160/// Extract the `exp` (expiration) claim from a JWT without verifying the signature.
161fn extract_jwt_exp(token: &str) -> Option<i64> {
162    let payload_b64 = token.split('.').nth(1)?;
163    let bytes = decode_base64url(payload_b64)?;
164    let claims: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
165    claims.get("exp")?.as_i64()
166}
167
168struct RefreshMiddleware {
169    token_provider: Arc<dyn TokenProvider>,
170}
171
172#[async_trait::async_trait]
173impl Middleware for RefreshMiddleware {
174    async fn handle(
175        &self,
176        mut req: Request,
177        extensions: &mut http::Extensions,
178        next: Next<'_>,
179    ) -> MiddlewareResult<Response> {
180        let is_token_endpoint = {
181            let url = req.url().as_str();
182            url.contains("/oauth2/tokens") || url.contains("/v3/tokens")
183        };
184        if !is_token_endpoint {
185            let needs_refresh = req
186                .headers()
187                .get("x-tapis-token")
188                .and_then(|v| v.to_str().ok())
189                .and_then(extract_jwt_exp)
190                .map(|exp| {
191                    let now = std::time::SystemTime::now()
192                        .duration_since(std::time::UNIX_EPOCH)
193                        .map(|d| d.as_secs() as i64)
194                        .unwrap_or(0);
195                    exp - now < 5
196                })
197                .unwrap_or(false);
198            if needs_refresh {
199                if let Some(new_token) = self.token_provider.get_token().await {
200                    let value = HeaderValue::from_str(&new_token)
201                        .map_err(|e| reqwest_middleware::Error::Middleware(anyhow::anyhow!(e)))?;
202                    req.headers_mut().insert("x-tapis-token", value);
203                }
204            }
205        }
206        next.run(req, extensions).await
207    }
208}
209
210#[derive(Clone)]
211pub struct TapisAuthenticator {
212    config: Arc<configuration::Configuration>,
213    pub admin: AdminClient,
214    pub clients: ClientsClient,
215    pub health_check: HealthCheckClient,
216    pub metadata: MetadataClient,
217    pub profiles: ProfilesClient,
218    pub tokens: TokensClient,
219}
220
221impl TapisAuthenticator {
222    pub fn new(
223        base_url: &str,
224        jwt_token: Option<&str>,
225    ) -> Result<Self, Box<dyn std::error::Error>> {
226        Self::build(base_url, jwt_token, None)
227    }
228
229    /// Create a client with a [`TokenProvider`] for automatic token refresh.
230    /// `RefreshMiddleware` is added to the middleware chain and will call
231    /// `provider.get_token()` transparently whenever the JWT is about to expire.
232    pub fn with_token_provider(
233        base_url: &str,
234        jwt_token: Option<&str>,
235        provider: Arc<dyn TokenProvider>,
236    ) -> Result<Self, Box<dyn std::error::Error>> {
237        Self::build(base_url, jwt_token, Some(provider))
238    }
239
240    fn build(
241        base_url: &str,
242        jwt_token: Option<&str>,
243        token_provider: Option<Arc<dyn TokenProvider>>,
244    ) -> Result<Self, Box<dyn std::error::Error>> {
245        let mut headers = HeaderMap::new();
246        if let Some(token) = jwt_token {
247            headers.insert("X-Tapis-Token", HeaderValue::from_str(token)?);
248        }
249
250        let reqwest_client = Client::builder().default_headers(headers).build()?;
251
252        let mut builder = ClientBuilder::new(reqwest_client)
253            .with(LoggingMiddleware)
254            .with(HeaderInjectionMiddleware)
255            .with(TrackingIdMiddleware);
256
257        if let Some(provider) = token_provider {
258            builder = builder.with(RefreshMiddleware {
259                token_provider: provider,
260            });
261        }
262
263        let client = builder.build();
264
265        let config = Arc::new(configuration::Configuration {
266            base_path: base_url.to_string(),
267            client,
268            ..Default::default()
269        });
270
271        Ok(Self {
272            config: config.clone(),
273            admin: AdminClient {
274                config: config.clone(),
275            },
276            clients: ClientsClient {
277                config: config.clone(),
278            },
279            health_check: HealthCheckClient {
280                config: config.clone(),
281            },
282            metadata: MetadataClient {
283                config: config.clone(),
284            },
285            profiles: ProfilesClient {
286                config: config.clone(),
287            },
288            tokens: TokensClient {
289                config: config.clone(),
290            },
291        })
292    }
293
294    pub fn config(&self) -> &configuration::Configuration {
295        &self.config
296    }
297}
298
299#[derive(Clone)]
300pub struct AdminClient {
301    config: Arc<configuration::Configuration>,
302}
303
304impl AdminClient {
305    pub async fn get_config(
306        &self,
307    ) -> Result<models::GetConfig200Response, Error<admin_api::GetConfigError>> {
308        admin_api::get_config(&self.config).await
309    }
310
311    pub async fn update_config(
312        &self,
313        new_tenant_config: models::NewTenantConfig,
314    ) -> Result<models::GetConfig200Response, Error<admin_api::UpdateConfigError>> {
315        admin_api::update_config(&self.config, new_tenant_config).await
316    }
317}
318
319#[derive(Clone)]
320pub struct ClientsClient {
321    config: Arc<configuration::Configuration>,
322}
323
324impl ClientsClient {
325    pub async fn create_client(
326        &self,
327        new_client: models::NewClient,
328    ) -> Result<models::CreateClient201Response, Error<clients_api::CreateClientError>> {
329        clients_api::create_client(&self.config, new_client).await
330    }
331
332    pub async fn delete_client(
333        &self,
334        client_id: &str,
335    ) -> Result<models::DeleteClient200Response, Error<clients_api::DeleteClientError>> {
336        clients_api::delete_client(&self.config, client_id).await
337    }
338
339    pub async fn get_client(
340        &self,
341        client_id: &str,
342    ) -> Result<models::CreateClient201Response, Error<clients_api::GetClientError>> {
343        clients_api::get_client(&self.config, client_id).await
344    }
345
346    pub async fn list_clients(
347        &self,
348        limit: Option<i32>,
349        offset: Option<i32>,
350    ) -> Result<models::ListClients200Response, Error<clients_api::ListClientsError>> {
351        clients_api::list_clients(&self.config, limit, offset).await
352    }
353
354    pub async fn update_client(
355        &self,
356        client_id: &str,
357        update_client: models::UpdateClient,
358    ) -> Result<models::CreateClient201Response, Error<clients_api::UpdateClientError>> {
359        clients_api::update_client(&self.config, client_id, update_client).await
360    }
361}
362
363#[derive(Clone)]
364pub struct HealthCheckClient {
365    config: Arc<configuration::Configuration>,
366}
367
368impl HealthCheckClient {
369    pub async fn hello(
370        &self,
371    ) -> Result<models::BasicResponse, Error<health_check_api::HelloError>> {
372        health_check_api::hello(&self.config).await
373    }
374
375    pub async fn ready(
376        &self,
377    ) -> Result<models::BasicResponse, Error<health_check_api::ReadyError>> {
378        health_check_api::ready(&self.config).await
379    }
380}
381
382#[derive(Clone)]
383pub struct MetadataClient {
384    config: Arc<configuration::Configuration>,
385}
386
387impl MetadataClient {
388    pub async fn get_server_metadata(
389        &self,
390    ) -> Result<models::GetServerMetadata200Response, Error<metadata_api::GetServerMetadataError>>
391    {
392        metadata_api::get_server_metadata(&self.config).await
393    }
394}
395
396#[derive(Clone)]
397pub struct ProfilesClient {
398    config: Arc<configuration::Configuration>,
399}
400
401impl ProfilesClient {
402    pub async fn get_profile(
403        &self,
404        username: &str,
405    ) -> Result<models::GetUserinfo200Response, Error<profiles_api::GetProfileError>> {
406        profiles_api::get_profile(&self.config, username).await
407    }
408
409    pub async fn get_userinfo(
410        &self,
411    ) -> Result<models::GetUserinfo200Response, Error<profiles_api::GetUserinfoError>> {
412        profiles_api::get_userinfo(&self.config).await
413    }
414
415    pub async fn list_profiles(
416        &self,
417        limit: Option<i32>,
418        offset: Option<i32>,
419    ) -> Result<models::ListProfiles200Response, Error<profiles_api::ListProfilesError>> {
420        profiles_api::list_profiles(&self.config, limit, offset).await
421    }
422}
423
424#[derive(Clone)]
425pub struct TokensClient {
426    config: Arc<configuration::Configuration>,
427}
428
429impl TokensClient {
430    pub async fn create_token(
431        &self,
432        new_token: models::NewToken,
433    ) -> Result<models::CreateToken201Response, Error<tokens_api::CreateTokenError>> {
434        tokens_api::create_token(&self.config, new_token).await
435    }
436
437    pub async fn create_v2_token(
438        &self,
439        v2_token: models::V2Token,
440    ) -> Result<models::CreateV2Token200Response, Error<tokens_api::CreateV2TokenError>> {
441        tokens_api::create_v2_token(&self.config, v2_token).await
442    }
443
444    pub async fn generate_device_code(
445        &self,
446        new_device_code: models::NewDeviceCode,
447    ) -> Result<models::GenerateDeviceCode200Response, Error<tokens_api::GenerateDeviceCodeError>>
448    {
449        tokens_api::generate_device_code(&self.config, new_device_code).await
450    }
451
452    pub async fn revoke_token(
453        &self,
454        revoke_token_request: models::RevokeTokenRequest,
455    ) -> Result<models::BasicResponse, Error<tokens_api::RevokeTokenError>> {
456        tokens_api::revoke_token(&self.config, revoke_token_request).await
457    }
458}