Skip to main content

stack_auth/
oauth_strategy.rs

1use cts_common::{Crn, CtsServiceDiscovery, Region, ServiceDiscovery};
2use tracing::warn;
3
4#[cfg(not(target_arch = "wasm32"))]
5use stack_profile::ProfileStore;
6
7use crate::auto_refresh::AutoRefresh;
8use crate::oauth_refresher::OAuthRefresher;
9use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken, Token};
10
11/// An [`AuthStrategy`] that uses OAuth refresh tokens to maintain a valid access token.
12///
13/// # Construction
14///
15/// Use [`OAuthStrategy::with_token`] with a token obtained from a device code flow
16/// (or any other OAuth flow) for in-memory caching only. Use
17/// [`OAuthStrategy::with_profile`] to load a token from disk and persist
18/// refreshed tokens back to the store.
19///
20/// # Example
21///
22/// ```no_run
23/// use stack_auth::{OAuthStrategy, Token};
24/// use cts_common::Region;
25///
26/// # fn run(token: Token) -> Result<(), Box<dyn std::error::Error>> {
27/// let region = Region::aws("ap-southeast-2")?;
28/// let strategy = OAuthStrategy::with_token(region, "my-client-id", token).build()?;
29/// # Ok(())
30/// # }
31/// ```
32pub struct OAuthStrategy {
33    crn: Option<Crn>,
34    inner: AutoRefresh<OAuthRefresher>,
35}
36
37impl OAuthStrategy {
38    /// Return a builder for configuring an `OAuthStrategy` from a token.
39    ///
40    /// The token's `region` and `client_id` fields are set before caching.
41    /// No token store is used — tokens are not persisted to disk.
42    pub fn with_token(
43        region: Region,
44        client_id: impl Into<String>,
45        token: Token,
46    ) -> OAuthStrategyBuilder {
47        OAuthStrategyBuilder {
48            source: OAuthTokenSource::Token {
49                region,
50                client_id: client_id.into(),
51                token,
52            },
53            base_url_override: None,
54        }
55    }
56
57    /// Return a builder for configuring an `OAuthStrategy` from a profile store.
58    ///
59    /// The token is loaded from the store when [`OAuthStrategyBuilder::build`] is called.
60    /// The builder allows further configuration (e.g. overriding the base URL) before building.
61    ///
62    /// The token must have `region` and `client_id` set (as saved by
63    /// [`DeviceCodeStrategy`](crate::DeviceCodeStrategy) or a prior
64    /// `OAuthStrategy`). The store is used for persisting refreshed tokens.
65    #[cfg(not(target_arch = "wasm32"))]
66    pub fn with_profile(store: ProfileStore) -> OAuthStrategyBuilder {
67        OAuthStrategyBuilder {
68            source: OAuthTokenSource::Store(store),
69            base_url_override: None,
70        }
71    }
72
73    /// Return the workspace CRN, if one was extracted from the token at build time.
74    pub fn workspace_crn(&self) -> Option<&Crn> {
75        self.crn.as_ref()
76    }
77}
78
79impl AuthStrategy for &OAuthStrategy {
80    async fn get_token(self) -> Result<ServiceToken, AuthError> {
81        Ok(self.inner.get_token().await?)
82    }
83}
84
85/// Where the initial OAuth token comes from.
86enum OAuthTokenSource {
87    /// A token provided directly (in-memory only, no store).
88    Token {
89        region: Region,
90        client_id: String,
91        token: Token,
92    },
93    /// A token loaded from a persistent store.
94    #[cfg(not(target_arch = "wasm32"))]
95    Store(ProfileStore),
96}
97
98/// Builder for [`OAuthStrategy`].
99///
100/// Created via [`OAuthStrategy::with_token`] or [`OAuthStrategy::with_profile`].
101pub struct OAuthStrategyBuilder {
102    source: OAuthTokenSource,
103    base_url_override: Option<url::Url>,
104}
105
106impl OAuthStrategyBuilder {
107    /// Override the base URL resolved by service discovery.
108    ///
109    /// Useful for pointing at a local or mock auth server during testing.
110    #[cfg(any(test, feature = "test-utils"))]
111    pub fn base_url(mut self, url: url::Url) -> Self {
112        self.base_url_override = Some(url);
113        self
114    }
115
116    /// Build the [`OAuthStrategy`].
117    ///
118    /// Resolves the base URL via service discovery unless overridden with
119    /// `base_url` (available when the `test-utils` feature is enabled).
120    pub fn build(self) -> Result<OAuthStrategy, AuthError> {
121        match self.source {
122            OAuthTokenSource::Token {
123                region,
124                client_id,
125                mut token,
126            } => {
127                let base_url = match self.base_url_override {
128                    Some(url) => url,
129                    None => crate::cts_base_url_from_env()?
130                        .unwrap_or(CtsServiceDiscovery::endpoint(region)?),
131                };
132                // Derive CRN from the explicit region parameter and the token's
133                // workspace claim. We can't use token.workspace_crn() here
134                // because set_region() hasn't been called on the token yet.
135                let crn = token
136                    .workspace_id()
137                    .map(|ws| Crn::new(region, ws))
138                    .map_err(|e| {
139                        warn!("Could not extract workspace CRN from token: {e}");
140                        e
141                    })
142                    .ok();
143                let region_id = region.identifier();
144                let device_instance_id = token.device_instance_id().map(String::from);
145                token.set_region(&region_id);
146                token.set_client_id(&client_id);
147                let refresher = OAuthRefresher::new(
148                    None,
149                    ensure_trailing_slash(base_url),
150                    &client_id,
151                    &region_id,
152                    device_instance_id,
153                );
154                Ok(OAuthStrategy {
155                    crn,
156                    inner: AutoRefresh::with_token(refresher, token),
157                })
158            }
159            #[cfg(not(target_arch = "wasm32"))]
160            OAuthTokenSource::Store(store) => {
161                let ws_store = store.current_workspace_store()?;
162                let token: Token = ws_store.load_profile()?;
163
164                let region_str = token
165                    .region()
166                    .ok_or(AuthError::NotAuthenticated)?
167                    .to_string();
168                let client_id = token
169                    .client_id()
170                    .ok_or(AuthError::NotAuthenticated)?
171                    .to_string();
172                let crn = token
173                    .workspace_crn()
174                    .map_err(|e| {
175                        warn!("Could not extract workspace CRN from token: {e}");
176                        e
177                    })
178                    .ok();
179                let device_instance_id = token.device_instance_id().map(String::from);
180
181                let base_url = match self.base_url_override {
182                    Some(url) => url,
183                    None => crate::cts_base_url_from_env()?.unwrap_or(token.issuer()?),
184                };
185
186                let refresher = OAuthRefresher::new(
187                    Some(ws_store),
188                    ensure_trailing_slash(base_url),
189                    &client_id,
190                    &region_str,
191                    device_instance_id,
192                );
193                Ok(OAuthStrategy {
194                    crn,
195                    inner: AutoRefresh::with_token(refresher, token),
196                })
197            }
198        }
199    }
200}