Skip to main content

s3/auth/
provider.rs

1use std::sync::Arc;
2
3#[cfg(feature = "async")]
4use std::{future::Future, pin::Pin};
5
6#[cfg(all(
7    any(feature = "async", feature = "blocking"),
8    any(feature = "credentials-imds", feature = "credentials-sts")
9))]
10use reqx::advanced::TlsRootStore;
11
12use crate::{Error, Result};
13
14use super::{Auth, Credentials};
15
16#[cfg(any(feature = "async", feature = "blocking"))]
17use super::CredentialsSnapshot;
18#[cfg(feature = "credentials-sts")]
19use super::Region;
20#[cfg(any(feature = "credentials-imds", feature = "credentials-sts"))]
21use super::cache::CachedProvider;
22
23#[cfg(feature = "async")]
24/// Async credentials lookup future.
25pub type CredentialsFuture<'a> =
26    Pin<Box<dyn Future<Output = Result<CredentialsSnapshot>> + Send + 'a>>;
27
28/// Source of credential snapshots for request signing.
29///
30/// Implement this trait when credentials may rotate over time. If the underlying provider performs
31/// network calls or expensive refreshes, wrap it in [`crate::CachedProvider`] so multiple requests
32/// can share cached credentials and coalesce refresh work.
33pub trait CredentialsProvider: std::fmt::Debug + Send + Sync {
34    /// Returns credentials asynchronously.
35    #[cfg(feature = "async")]
36    fn credentials_async(&self) -> CredentialsFuture<'_>;
37
38    /// Returns credentials in blocking mode.
39    #[cfg(feature = "blocking")]
40    fn credentials_blocking(&self) -> Result<CredentialsSnapshot>;
41}
42
43/// Shared credentials provider trait object.
44pub type DynCredentialsProvider = Arc<dyn CredentialsProvider>;
45
46/// Trust root selection for credential-provider HTTPS requests (IMDS/STS).
47#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
48pub enum CredentialsTlsRootStore {
49    /// Use the backend default trust roots.
50    ///
51    /// For `rustls`, this maps to WebPKI roots.
52    /// For `native-tls`, this follows backend default behavior.
53    #[default]
54    BackendDefault,
55    /// Force WebPKI roots.
56    WebPki,
57    /// Use platform/system trust verification.
58    System,
59}
60
61impl CredentialsTlsRootStore {
62    #[cfg(all(
63        any(feature = "async", feature = "blocking"),
64        any(feature = "credentials-imds", feature = "credentials-sts")
65    ))]
66    pub(crate) const fn into_reqx(self) -> TlsRootStore {
67        match self {
68            Self::BackendDefault => TlsRootStore::BackendDefault,
69            Self::WebPki => TlsRootStore::WebPki,
70            Self::System => TlsRootStore::System,
71        }
72    }
73}
74
75#[cfg(feature = "credentials-sts")]
76#[derive(Clone, Debug)]
77struct StaticCredentialsProvider {
78    snapshot: CredentialsSnapshot,
79}
80
81#[cfg(feature = "credentials-sts")]
82impl StaticCredentialsProvider {
83    fn new(credentials: Credentials) -> Self {
84        Self {
85            snapshot: CredentialsSnapshot::new(credentials),
86        }
87    }
88}
89
90#[cfg(feature = "credentials-sts")]
91impl CredentialsProvider for StaticCredentialsProvider {
92    #[cfg(feature = "async")]
93    fn credentials_async(&self) -> CredentialsFuture<'_> {
94        let snapshot = self.snapshot.clone();
95        Box::pin(async move { Ok(snapshot) })
96    }
97
98    #[cfg(feature = "blocking")]
99    fn credentials_blocking(&self) -> Result<CredentialsSnapshot> {
100        Ok(self.snapshot.clone())
101    }
102}
103
104#[cfg(feature = "credentials-imds")]
105#[derive(Debug, Clone, Copy)]
106struct ImdsProvider {
107    tls_root_store: CredentialsTlsRootStore,
108}
109
110#[cfg(feature = "credentials-imds")]
111impl CredentialsProvider for ImdsProvider {
112    #[cfg(feature = "async")]
113    fn credentials_async(&self) -> CredentialsFuture<'_> {
114        Box::pin(async move {
115            crate::credentials::imds::load_async(self.tls_root_store.into_reqx()).await
116        })
117    }
118
119    #[cfg(feature = "blocking")]
120    fn credentials_blocking(&self) -> Result<CredentialsSnapshot> {
121        crate::credentials::imds::load_blocking(self.tls_root_store.into_reqx())
122    }
123}
124
125#[cfg(feature = "credentials-sts")]
126#[derive(Debug)]
127struct StsAssumeRoleProvider {
128    region: Region,
129    role_arn: String,
130    role_session_name: String,
131    source: DynCredentialsProvider,
132    tls_root_store: CredentialsTlsRootStore,
133}
134
135#[cfg(feature = "credentials-sts")]
136impl CredentialsProvider for StsAssumeRoleProvider {
137    #[cfg(feature = "async")]
138    fn credentials_async(&self) -> CredentialsFuture<'_> {
139        Box::pin(async move {
140            let source = self.source.credentials_async().await?;
141            crate::credentials::sts::assume_role_async(
142                self.region.clone(),
143                self.role_arn.clone(),
144                self.role_session_name.clone(),
145                source.credentials().clone(),
146                self.tls_root_store.into_reqx(),
147            )
148            .await
149        })
150    }
151
152    #[cfg(feature = "blocking")]
153    fn credentials_blocking(&self) -> Result<CredentialsSnapshot> {
154        let source = self.source.credentials_blocking()?;
155        crate::credentials::sts::assume_role_blocking(
156            self.region.clone(),
157            self.role_arn.clone(),
158            self.role_session_name.clone(),
159            source.credentials().clone(),
160            self.tls_root_store.into_reqx(),
161        )
162    }
163}
164
165#[cfg(feature = "credentials-sts")]
166#[derive(Debug, Clone, Copy)]
167struct StsWebIdentityProvider {
168    tls_root_store: CredentialsTlsRootStore,
169}
170
171#[cfg(feature = "credentials-sts")]
172impl CredentialsProvider for StsWebIdentityProvider {
173    #[cfg(feature = "async")]
174    fn credentials_async(&self) -> CredentialsFuture<'_> {
175        Box::pin(async move {
176            crate::credentials::sts::assume_role_with_web_identity_env_async(
177                self.tls_root_store.into_reqx(),
178            )
179            .await
180        })
181    }
182
183    #[cfg(feature = "blocking")]
184    fn credentials_blocking(&self) -> Result<CredentialsSnapshot> {
185        crate::credentials::sts::assume_role_with_web_identity_env_blocking(
186            self.tls_root_store.into_reqx(),
187        )
188    }
189}
190
191impl Auth {
192    /// Loads static credentials from standard AWS env vars.
193    ///
194    /// Reads `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and optionally `AWS_SESSION_TOKEN`.
195    pub fn from_env() -> Result<Self> {
196        let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
197            .map_err(|_| Error::invalid_config("missing AWS_ACCESS_KEY_ID"))?;
198        let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
199            .map_err(|_| Error::invalid_config("missing AWS_SECRET_ACCESS_KEY"))?;
200        let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
201
202        let mut creds = Credentials::new(access_key_id, secret_access_key)?;
203        if let Some(token) = session_token {
204            creds = creds.with_session_token(token)?;
205        }
206
207        Ok(Self::Static(creds))
208    }
209
210    /// Uses a dynamic credentials provider.
211    pub fn provider(provider: DynCredentialsProvider) -> Self {
212        Self::Provider(provider)
213    }
214
215    /// Loads credentials from a named profile.
216    #[cfg(feature = "credentials-profile")]
217    pub fn from_profile(profile: impl AsRef<str>) -> Result<Self> {
218        let creds = crate::credentials::profile::load_profile_credentials(profile.as_ref())?;
219        Ok(Self::Static(creds))
220    }
221
222    /// Loads credentials from the profile defined by environment variables.
223    #[cfg(feature = "credentials-profile")]
224    pub fn from_profile_env() -> Result<Self> {
225        Self::from_profile(crate::credentials::profile::profile_from_env())
226    }
227
228    /// Loads IMDS credentials and wraps them in a cached provider.
229    #[cfg(all(feature = "credentials-imds", feature = "async"))]
230    pub async fn from_imds() -> Result<Self> {
231        Self::from_imds_with_tls_root_store(CredentialsTlsRootStore::BackendDefault).await
232    }
233
234    /// Loads IMDS credentials and wraps them in a cached provider.
235    #[cfg(all(feature = "credentials-imds", feature = "async"))]
236    pub async fn from_imds_with_tls_root_store(
237        tls_root_store: CredentialsTlsRootStore,
238    ) -> Result<Self> {
239        let initial = crate::credentials::imds::load_async(tls_root_store.into_reqx()).await?;
240        let provider = CachedProvider::new(ImdsProvider { tls_root_store }).with_initial(initial);
241        Ok(Self::Provider(Arc::new(provider)))
242    }
243
244    /// Loads IMDS credentials and wraps them in a cached provider.
245    #[cfg(all(feature = "credentials-imds", feature = "blocking"))]
246    pub fn from_imds_blocking() -> Result<Self> {
247        Self::from_imds_blocking_with_tls_root_store(CredentialsTlsRootStore::BackendDefault)
248    }
249
250    /// Loads IMDS credentials and wraps them in a cached provider.
251    #[cfg(all(feature = "credentials-imds", feature = "blocking"))]
252    pub fn from_imds_blocking_with_tls_root_store(
253        tls_root_store: CredentialsTlsRootStore,
254    ) -> Result<Self> {
255        let initial = crate::credentials::imds::load_blocking(tls_root_store.into_reqx())?;
256        let provider = CachedProvider::new(ImdsProvider { tls_root_store }).with_initial(initial);
257        Ok(Self::Provider(Arc::new(provider)))
258    }
259
260    /// Assumes a role using static source credentials (async).
261    #[cfg(all(feature = "credentials-sts", feature = "async"))]
262    pub async fn assume_role(
263        region: Region,
264        role_arn: impl Into<String>,
265        role_session_name: impl Into<String>,
266        source_credentials: Credentials,
267    ) -> Result<Self> {
268        Self::assume_role_with_tls_root_store(
269            region,
270            role_arn,
271            role_session_name,
272            source_credentials,
273            CredentialsTlsRootStore::BackendDefault,
274        )
275        .await
276    }
277
278    /// Assumes a role using static source credentials and a specific trust root policy (async).
279    #[cfg(all(feature = "credentials-sts", feature = "async"))]
280    pub async fn assume_role_with_tls_root_store(
281        region: Region,
282        role_arn: impl Into<String>,
283        role_session_name: impl Into<String>,
284        source_credentials: Credentials,
285        tls_root_store: CredentialsTlsRootStore,
286    ) -> Result<Self> {
287        Self::assume_role_with_provider_with_tls_root_store(
288            region,
289            role_arn,
290            role_session_name,
291            Arc::new(StaticCredentialsProvider::new(source_credentials)),
292            tls_root_store,
293        )
294        .await
295    }
296
297    /// Assumes a role using static source credentials (blocking).
298    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
299    pub fn assume_role_blocking(
300        region: Region,
301        role_arn: impl Into<String>,
302        role_session_name: impl Into<String>,
303        source_credentials: Credentials,
304    ) -> Result<Self> {
305        Self::assume_role_blocking_with_tls_root_store(
306            region,
307            role_arn,
308            role_session_name,
309            source_credentials,
310            CredentialsTlsRootStore::BackendDefault,
311        )
312    }
313
314    /// Assumes a role using static source credentials and a specific trust root policy (blocking).
315    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
316    pub fn assume_role_blocking_with_tls_root_store(
317        region: Region,
318        role_arn: impl Into<String>,
319        role_session_name: impl Into<String>,
320        source_credentials: Credentials,
321        tls_root_store: CredentialsTlsRootStore,
322    ) -> Result<Self> {
323        Self::assume_role_with_provider_blocking_with_tls_root_store(
324            region,
325            role_arn,
326            role_session_name,
327            Arc::new(StaticCredentialsProvider::new(source_credentials)),
328            tls_root_store,
329        )
330    }
331
332    /// Loads web identity credentials from env vars (async).
333    #[cfg(all(feature = "credentials-sts", feature = "async"))]
334    pub async fn from_web_identity_env() -> Result<Self> {
335        Self::from_web_identity_env_with_tls_root_store(CredentialsTlsRootStore::BackendDefault)
336            .await
337    }
338
339    /// Loads web identity credentials from env vars and a specific trust root policy (async).
340    #[cfg(all(feature = "credentials-sts", feature = "async"))]
341    pub async fn from_web_identity_env_with_tls_root_store(
342        tls_root_store: CredentialsTlsRootStore,
343    ) -> Result<Self> {
344        let provider = StsWebIdentityProvider { tls_root_store };
345        let initial = provider.credentials_async().await?;
346        let provider = CachedProvider::new(provider).with_initial(initial);
347        Ok(Self::Provider(Arc::new(provider)))
348    }
349
350    /// Loads web identity credentials from env vars (blocking).
351    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
352    pub fn from_web_identity_env_blocking() -> Result<Self> {
353        Self::from_web_identity_env_blocking_with_tls_root_store(
354            CredentialsTlsRootStore::BackendDefault,
355        )
356    }
357
358    /// Loads web identity credentials from env vars and a specific trust root policy (blocking).
359    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
360    pub fn from_web_identity_env_blocking_with_tls_root_store(
361        tls_root_store: CredentialsTlsRootStore,
362    ) -> Result<Self> {
363        let provider = StsWebIdentityProvider { tls_root_store };
364        let initial = provider.credentials_blocking()?;
365        let provider = CachedProvider::new(provider).with_initial(initial);
366        Ok(Self::Provider(Arc::new(provider)))
367    }
368
369    /// Assumes a role using a credentials provider (async).
370    #[cfg(all(feature = "credentials-sts", feature = "async"))]
371    pub async fn assume_role_with_provider(
372        region: Region,
373        role_arn: impl Into<String>,
374        role_session_name: impl Into<String>,
375        source: DynCredentialsProvider,
376    ) -> Result<Self> {
377        Self::assume_role_with_provider_with_tls_root_store(
378            region,
379            role_arn,
380            role_session_name,
381            source,
382            CredentialsTlsRootStore::BackendDefault,
383        )
384        .await
385    }
386
387    /// Assumes a role using a credentials provider and a specific trust root policy (async).
388    #[cfg(all(feature = "credentials-sts", feature = "async"))]
389    pub async fn assume_role_with_provider_with_tls_root_store(
390        region: Region,
391        role_arn: impl Into<String>,
392        role_session_name: impl Into<String>,
393        source: DynCredentialsProvider,
394        tls_root_store: CredentialsTlsRootStore,
395    ) -> Result<Self> {
396        let provider = StsAssumeRoleProvider {
397            region,
398            role_arn: role_arn.into(),
399            role_session_name: role_session_name.into(),
400            source,
401            tls_root_store,
402        };
403        let initial = provider.credentials_async().await?;
404        let provider = CachedProvider::new(provider).with_initial(initial);
405        Ok(Self::Provider(Arc::new(provider)))
406    }
407
408    /// Assumes a role using a credentials provider (blocking).
409    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
410    pub fn assume_role_with_provider_blocking(
411        region: Region,
412        role_arn: impl Into<String>,
413        role_session_name: impl Into<String>,
414        source: DynCredentialsProvider,
415    ) -> Result<Self> {
416        Self::assume_role_with_provider_blocking_with_tls_root_store(
417            region,
418            role_arn,
419            role_session_name,
420            source,
421            CredentialsTlsRootStore::BackendDefault,
422        )
423    }
424
425    /// Assumes a role using a credentials provider and a specific trust root policy (blocking).
426    #[cfg(all(feature = "credentials-sts", feature = "blocking"))]
427    pub fn assume_role_with_provider_blocking_with_tls_root_store(
428        region: Region,
429        role_arn: impl Into<String>,
430        role_session_name: impl Into<String>,
431        source: DynCredentialsProvider,
432        tls_root_store: CredentialsTlsRootStore,
433    ) -> Result<Self> {
434        let provider = StsAssumeRoleProvider {
435            region,
436            role_arn: role_arn.into(),
437            role_session_name: role_session_name.into(),
438            source,
439            tls_root_store,
440        };
441        let initial = provider.credentials_blocking()?;
442        let provider = CachedProvider::new(provider).with_initial(initial);
443        Ok(Self::Provider(Arc::new(provider)))
444    }
445
446    #[cfg(feature = "async")]
447    pub(crate) fn static_credentials(&self) -> Option<&Credentials> {
448        match self {
449            Self::Static(creds) => Some(creds),
450            Self::Anonymous | Self::Provider(_) => None,
451        }
452    }
453
454    #[cfg(feature = "async")]
455    pub(crate) async fn credentials_snapshot_async(&self) -> Result<Option<CredentialsSnapshot>> {
456        match self {
457            Self::Anonymous => Ok(None),
458            Self::Static(creds) => Ok(Some(CredentialsSnapshot::new(creds.clone()))),
459            Self::Provider(provider) => provider.credentials_async().await.map(Some),
460        }
461    }
462
463    #[cfg(feature = "blocking")]
464    pub(crate) fn credentials_snapshot_blocking(&self) -> Result<Option<CredentialsSnapshot>> {
465        match self {
466            Self::Anonymous => Ok(None),
467            Self::Static(creds) => Ok(Some(CredentialsSnapshot::new(creds.clone()))),
468            Self::Provider(provider) => provider.credentials_blocking().map(Some),
469        }
470    }
471}