zeebe_rs/
oauth.rs

1use oauth2::{
2    AuthUrl, ClientId, ClientSecret, EndpointNotSet, EndpointSet, TokenResponse, TokenUrl,
3    basic::BasicClient, reqwest,
4};
5use std::{
6    sync::{Arc, Mutex},
7    time::{Duration, SystemTime},
8};
9use thiserror::Error;
10use tokio::{
11    sync::Notify,
12    time::{error::Elapsed, timeout},
13};
14use tonic::{Status, metadata::MetadataValue, service::Interceptor};
15use tracing::error;
16
17const OAUTH_REFRESH_INTERVAL_SEC: u64 = 10;
18const OAUTH_REFRESH_MARGIN_SEC: u64 = 15;
19
20#[derive(Debug, Clone)]
21pub enum AuthType {
22    RequestBody,
23    BasicAuth,
24}
25
26impl From<AuthType> for oauth2::AuthType {
27    fn from(value: AuthType) -> oauth2::AuthType {
28        match value {
29            AuthType::RequestBody => oauth2::AuthType::RequestBody,
30            AuthType::BasicAuth => oauth2::AuthType::BasicAuth,
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct OAuthConfig {
37    pub client_id: String,
38    pub client_secret: String,
39    pub auth_url: String,
40    pub audience: String,
41    pub auth_type: AuthType,
42}
43
44impl OAuthConfig {
45    pub(crate) fn new(
46        client_id: String,
47        client_secret: String,
48        auth_url: String,
49        audience: String,
50        auth_type: Option<AuthType>,
51    ) -> Self {
52        let mut auth_type = auth_type;
53        if auth_type.is_none() {
54            auth_type = Some(AuthType::RequestBody)
55        }
56
57        OAuthConfig {
58            client_id,
59            client_secret,
60            auth_url,
61            audience,
62            auth_type: auth_type.unwrap(),
63        }
64    }
65}
66
67/// Represents the different types of errors that can occur during OAuth operations.
68///
69/// The `OAuthError` enum encapsulates various error scenarios that may arise during the OAuth process,
70/// such as acquiring token locks, making requests, handling timeouts, or token unavailability.
71///
72/// # Variants
73///
74/// - `LockUnavailable`
75///   Indicates a failure to acquire a lock for the token. This typically occurs when a lock
76///   mechanism is being used to ensure safe concurrent access to OAuth tokens.
77///   - Fields:
78///     - `String`: A message providing additional context about the lock failure.
79///
80/// - `Request`
81///   Represents a failure that occurred during an OAuth-related request.
82///   - Fields:
83///     - `String`: A message describing the nature of the request failure.
84///
85/// - `Timeout`
86///   Indicates that a timeout occurred during an OAuth operation. This variant wraps the `Elapsed`
87///   type from the `tokio` crate, which provides details about the timeout event.
88///   - Source: `Elapsed`
89///
90/// - `TokenUnavailable`
91///   Indicates that a required token was not available. This error typically occurs when attempting
92///   to retrieve or use an OAuth token that has not been issued or is otherwise inaccessible.
93#[derive(Error, Debug)]
94pub enum OAuthError {
95    #[error("failed to acquire token lock")]
96    LockUnavailable(String),
97
98    #[error("failed request")]
99    Request(String),
100
101    #[error("timeout")]
102    Timeout(#[from] Elapsed),
103
104    #[error("token unavailable")]
105    TokenUnavailable,
106}
107
108impl<T> From<std::sync::PoisonError<T>> for OAuthError {
109    fn from(err: std::sync::PoisonError<T>) -> Self {
110        OAuthError::LockUnavailable(err.to_string())
111    }
112}
113
114#[derive(Clone, Debug)]
115struct CachedToken {
116    secret: String,
117    expire_at: SystemTime,
118}
119
120#[derive(Clone, Debug)]
121pub(crate) struct OAuthProvider {
122    client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
123    reqwest_client: reqwest::Client,
124    audience: String,
125    request_timeout: Duration,
126    cached_token: Arc<Mutex<Option<CachedToken>>>,
127    token_refreshed: Arc<Notify>,
128}
129
130impl OAuthProvider {
131    fn new(config: OAuthConfig, request_timeout: Duration) -> Self {
132        let audience = config.audience.clone();
133        let client = BasicClient::new(ClientId::new(config.client_id))
134            .set_client_secret(ClientSecret::new(config.client_secret))
135            .set_auth_uri(AuthUrl::new(config.auth_url.clone()).unwrap())
136            .set_token_uri(TokenUrl::new(config.auth_url.clone()).unwrap())
137            .set_auth_type(config.auth_type.into());
138
139        let reqwest_client = reqwest::ClientBuilder::new()
140            .redirect(reqwest::redirect::Policy::none())
141            .build()
142            .expect("Client should build");
143
144        OAuthProvider {
145            client,
146            reqwest_client,
147            audience,
148            request_timeout,
149            cached_token: Arc::new(Mutex::new(None)),
150            token_refreshed: Arc::new(Notify::new()),
151        }
152    }
153
154    async fn token_refreshed(&self) {
155        self.token_refreshed.notified().await;
156    }
157
158    fn read_token(&self) -> Result<String, OAuthError> {
159        let cached_token = self.cached_token.lock()?;
160
161        if let Some(cached_token) = &*cached_token {
162            return Ok(cached_token.secret.clone());
163        }
164
165        Err(OAuthError::TokenUnavailable)
166    }
167
168    fn cached_token_is_expired(&self) -> Result<bool, OAuthError> {
169        let lock = self.cached_token.lock()?;
170
171        let expired = if let Some(token) = lock.as_ref() {
172            token.expire_at <= SystemTime::now()
173        } else {
174            true
175        };
176
177        Ok(expired)
178    }
179
180    fn run(self: Arc<Self>, refresh_interval: Duration) {
181        tokio::task::spawn(async move {
182            let mut interval = tokio::time::interval(refresh_interval);
183            loop {
184                interval.tick().await;
185                if let Err(error) = self.refresh_token().await {
186                    error!("Failed to refresh OAuth token: {}", error)
187                }
188            }
189        });
190    }
191
192    async fn refresh_token(&self) -> Result<(), OAuthError> {
193        if !self.cached_token_is_expired()? {
194            return Ok(());
195        }
196
197        let token_request = self
198            .client
199            .exchange_client_credentials()
200            .add_extra_param("audience", self.audience.clone());
201
202        let result = match timeout(
203            self.request_timeout,
204            token_request.request_async(&self.reqwest_client),
205        )
206        .await
207        {
208            Ok(Ok(response)) => response,
209            Ok(Err(err)) => return Err(OAuthError::Request(err.to_string())),
210            Err(err) => return Err(OAuthError::Timeout(err)),
211        };
212
213        let expiry = std::time::SystemTime::now() + result.expires_in().unwrap_or_default()
214            - Duration::from_secs(OAUTH_REFRESH_MARGIN_SEC);
215        let new_token = CachedToken {
216            secret: result.access_token().secret().to_owned(),
217            expire_at: expiry,
218        };
219
220        let _ = self.cached_token.lock()?.replace(new_token);
221        self.token_refreshed.notify_one();
222
223        Ok(())
224    }
225}
226
227#[derive(Clone, Debug, Default)]
228pub(crate) struct OAuthInterceptor {
229    oauth_provider: Option<Arc<OAuthProvider>>,
230}
231
232impl OAuthInterceptor {
233    pub(crate) fn new(oauth_config: OAuthConfig, auth_timeout: Duration) -> Self {
234        let provider = Arc::new(OAuthProvider::new(oauth_config, auth_timeout));
235
236        provider
237            .clone()
238            .run(Duration::from_secs(OAUTH_REFRESH_INTERVAL_SEC));
239
240        OAuthInterceptor {
241            oauth_provider: Some(provider),
242        }
243    }
244
245    pub(crate) async fn auth_initialized(&self) {
246        if let Some(provider) = &self.oauth_provider {
247            provider.token_refreshed().await;
248        }
249    }
250}
251
252impl Interceptor for OAuthInterceptor {
253    fn call(
254        &mut self,
255        mut request: tonic::Request<()>,
256    ) -> std::result::Result<tonic::Request<()>, Status> {
257        if let Some(oauth_client) = &mut self.oauth_provider {
258            let token = match oauth_client.read_token() {
259                Ok(token) => token,
260                Err(err) => {
261                    return Err(tonic::Status::unauthenticated(format!(
262                        "{}: {}",
263                        "failed to get token", err
264                    )));
265                }
266            };
267
268            request.metadata_mut().insert(
269                "authorization",
270                MetadataValue::try_from(&format!("Bearer {}", token)).map_err(|_| {
271                    tonic::Status::unauthenticated(format!(
272                        "{}: {}",
273                        "token is not a valid header value", token
274                    ))
275                })?,
276            );
277        }
278
279        Ok(request)
280    }
281}