Skip to main content

stack_auth/
oauth_strategy.rs

1use cts_common::{Crn, CtsServiceDiscovery, Region, ServiceDiscovery};
2use tracing::warn;
3
4use stack_profile::ProfileStore;
5
6use crate::auto_refresh::AutoRefresh;
7use crate::oauth_refresher::OAuthRefresher;
8use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken, Token};
9
10/// An [`AuthStrategy`] that uses OAuth refresh tokens to maintain a valid access token.
11///
12/// # Construction
13///
14/// Use [`OAuthStrategy::with_token`] with a token obtained from a device code flow
15/// (or any other OAuth flow) for in-memory caching only. Use
16/// [`OAuthStrategy::with_profile`] to load a token from disk and persist
17/// refreshed tokens back to the store.
18///
19/// # Example
20///
21/// ```no_run
22/// use stack_auth::{OAuthStrategy, Token};
23/// use cts_common::Region;
24///
25/// # fn run(token: Token) -> Result<(), Box<dyn std::error::Error>> {
26/// let region = Region::aws("ap-southeast-2")?;
27/// let strategy = OAuthStrategy::with_token(region, "my-client-id", token).build()?;
28/// # Ok(())
29/// # }
30/// ```
31pub struct OAuthStrategy {
32    crn: Option<Crn>,
33    inner: AutoRefresh<OAuthRefresher>,
34}
35
36impl OAuthStrategy {
37    /// Return a builder for configuring an `OAuthStrategy` from a token.
38    ///
39    /// The token's `region` and `client_id` fields are set before caching.
40    /// No token store is used — tokens are not persisted to disk.
41    pub fn with_token(
42        region: Region,
43        client_id: impl Into<String>,
44        token: Token,
45    ) -> OAuthStrategyBuilder {
46        OAuthStrategyBuilder {
47            source: OAuthTokenSource::Token {
48                region,
49                client_id: client_id.into(),
50                token,
51            },
52            base_url_override: None,
53        }
54    }
55
56    /// Return a builder for configuring an `OAuthStrategy` from a profile store.
57    ///
58    /// The token is loaded from the store when [`OAuthStrategyBuilder::build`] is called.
59    /// The builder allows further configuration (e.g. overriding the base URL) before building.
60    ///
61    /// The token must have `region` and `client_id` set (as saved by
62    /// [`DeviceCodeStrategy`](crate::DeviceCodeStrategy) or a prior
63    /// `OAuthStrategy`). The store is used for persisting refreshed tokens.
64    pub fn with_profile(store: ProfileStore) -> OAuthStrategyBuilder {
65        OAuthStrategyBuilder {
66            source: OAuthTokenSource::Store(store),
67            base_url_override: None,
68        }
69    }
70
71    /// Return the workspace CRN, if one was extracted from the token at build time.
72    pub fn workspace_crn(&self) -> Option<&Crn> {
73        self.crn.as_ref()
74    }
75}
76
77impl AuthStrategy for &OAuthStrategy {
78    async fn get_token(self) -> Result<ServiceToken, AuthError> {
79        Ok(self.inner.get_token().await?)
80    }
81}
82
83/// Where the initial OAuth token comes from.
84enum OAuthTokenSource {
85    /// A token provided directly (in-memory only, no store).
86    Token {
87        region: Region,
88        client_id: String,
89        token: Token,
90    },
91    /// A token loaded from a persistent store.
92    Store(ProfileStore),
93}
94
95/// Builder for [`OAuthStrategy`].
96///
97/// Created via [`OAuthStrategy::with_token`] or [`OAuthStrategy::with_profile`].
98pub struct OAuthStrategyBuilder {
99    source: OAuthTokenSource,
100    base_url_override: Option<url::Url>,
101}
102
103impl OAuthStrategyBuilder {
104    /// Override the base URL resolved by service discovery.
105    ///
106    /// Useful for pointing at a local or mock auth server during testing.
107    #[cfg(any(test, feature = "test-utils"))]
108    pub fn base_url(mut self, url: url::Url) -> Self {
109        self.base_url_override = Some(url);
110        self
111    }
112
113    /// Build the [`OAuthStrategy`].
114    ///
115    /// Resolves the base URL via service discovery unless overridden with
116    /// `base_url` (available when the `test-utils` feature is enabled).
117    pub fn build(self) -> Result<OAuthStrategy, AuthError> {
118        match self.source {
119            OAuthTokenSource::Token {
120                region,
121                client_id,
122                mut token,
123            } => {
124                let base_url = match self.base_url_override {
125                    Some(url) => url,
126                    None => crate::cts_base_url_from_env()?
127                        .unwrap_or(CtsServiceDiscovery::endpoint(region)?),
128                };
129                // Derive CRN from the explicit region parameter and the token's
130                // workspace claim. We can't use token.workspace_crn() here
131                // because set_region() hasn't been called on the token yet.
132                let crn = token
133                    .workspace_id()
134                    .map(|ws| Crn::new(region, ws))
135                    .map_err(|e| {
136                        warn!("Could not extract workspace CRN from token: {e}");
137                        e
138                    })
139                    .ok();
140                let region_id = region.identifier();
141                let device_instance_id = token.device_instance_id().map(String::from);
142                token.set_region(&region_id);
143                token.set_client_id(&client_id);
144                let refresher = OAuthRefresher::new(
145                    None,
146                    ensure_trailing_slash(base_url),
147                    &client_id,
148                    &region_id,
149                    device_instance_id,
150                );
151                Ok(OAuthStrategy {
152                    crn,
153                    inner: AutoRefresh::with_token(refresher, token),
154                })
155            }
156            OAuthTokenSource::Store(store) => {
157                let token: Token = store.load_profile()?;
158
159                let region_str = token
160                    .region()
161                    .ok_or(AuthError::NotAuthenticated)?
162                    .to_string();
163                let client_id = token
164                    .client_id()
165                    .ok_or(AuthError::NotAuthenticated)?
166                    .to_string();
167                let crn = token
168                    .workspace_crn()
169                    .map_err(|e| {
170                        warn!("Could not extract workspace CRN from token: {e}");
171                        e
172                    })
173                    .ok();
174                let device_instance_id = token.device_instance_id().map(String::from);
175
176                let base_url = match self.base_url_override {
177                    Some(url) => url,
178                    None => crate::cts_base_url_from_env()?.unwrap_or(token.issuer()?),
179                };
180
181                let refresher = OAuthRefresher::new(
182                    Some(store),
183                    ensure_trailing_slash(base_url),
184                    &client_id,
185                    &region_str,
186                    device_instance_id,
187                );
188                Ok(OAuthStrategy {
189                    crn,
190                    inner: AutoRefresh::with_token(refresher, token),
191                })
192            }
193        }
194    }
195}