1use oauth2::{
2 AuthUrl, ClientId, ClientSecret, EndpointNotSet, EndpointSet, TokenResponse, TokenUrl,
3 basic::BasicClient, reqwest,
4};
5use std::{
6 sync::{Arc, Mutex},
7 time::{Duration, SystemTime},
8};
9use thiserror::Error;
10use tokio::{
11 sync::Notify,
12 time::{error::Elapsed, timeout},
13};
14use tonic::{Status, metadata::MetadataValue, service::Interceptor};
15use tracing::error;
16
17const OAUTH_REFRESH_INTERVAL_SEC: u64 = 10;
18const OAUTH_REFRESH_MARGIN_SEC: u64 = 15;
19
20#[derive(Debug, Clone)]
21pub enum AuthType {
22 RequestBody,
23 BasicAuth,
24}
25
26impl From<AuthType> for oauth2::AuthType {
27 fn from(value: AuthType) -> oauth2::AuthType {
28 match value {
29 AuthType::RequestBody => oauth2::AuthType::RequestBody,
30 AuthType::BasicAuth => oauth2::AuthType::BasicAuth,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct OAuthConfig {
37 pub client_id: String,
38 pub client_secret: String,
39 pub auth_url: String,
40 pub audience: String,
41 pub auth_type: AuthType,
42}
43
44impl OAuthConfig {
45 pub(crate) fn new(
46 client_id: String,
47 client_secret: String,
48 auth_url: String,
49 audience: String,
50 auth_type: Option<AuthType>,
51 ) -> Self {
52 let mut auth_type = auth_type;
53 if auth_type.is_none() {
54 auth_type = Some(AuthType::RequestBody)
55 }
56
57 OAuthConfig {
58 client_id,
59 client_secret,
60 auth_url,
61 audience,
62 auth_type: auth_type.unwrap(),
63 }
64 }
65}
66
67#[derive(Error, Debug)]
94pub enum OAuthError {
95 #[error("failed to acquire token lock")]
96 LockUnavailable(String),
97
98 #[error("failed request")]
99 Request(String),
100
101 #[error("timeout")]
102 Timeout(#[from] Elapsed),
103
104 #[error("token unavailable")]
105 TokenUnavailable,
106}
107
108impl<T> From<std::sync::PoisonError<T>> for OAuthError {
109 fn from(err: std::sync::PoisonError<T>) -> Self {
110 OAuthError::LockUnavailable(err.to_string())
111 }
112}
113
114#[derive(Clone, Debug)]
115struct CachedToken {
116 secret: String,
117 expire_at: SystemTime,
118}
119
120#[derive(Clone, Debug)]
121pub(crate) struct OAuthProvider {
122 client: BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,
123 reqwest_client: reqwest::Client,
124 audience: String,
125 request_timeout: Duration,
126 cached_token: Arc<Mutex<Option<CachedToken>>>,
127 token_refreshed: Arc<Notify>,
128}
129
130impl OAuthProvider {
131 fn new(config: OAuthConfig, request_timeout: Duration) -> Self {
132 let audience = config.audience.clone();
133 let client = BasicClient::new(ClientId::new(config.client_id))
134 .set_client_secret(ClientSecret::new(config.client_secret))
135 .set_auth_uri(AuthUrl::new(config.auth_url.clone()).unwrap())
136 .set_token_uri(TokenUrl::new(config.auth_url.clone()).unwrap())
137 .set_auth_type(config.auth_type.into());
138
139 let reqwest_client = reqwest::ClientBuilder::new()
140 .redirect(reqwest::redirect::Policy::none())
141 .build()
142 .expect("Client should build");
143
144 OAuthProvider {
145 client,
146 reqwest_client,
147 audience,
148 request_timeout,
149 cached_token: Arc::new(Mutex::new(None)),
150 token_refreshed: Arc::new(Notify::new()),
151 }
152 }
153
154 async fn token_refreshed(&self) {
155 self.token_refreshed.notified().await;
156 }
157
158 fn read_token(&self) -> Result<String, OAuthError> {
159 let cached_token = self.cached_token.lock()?;
160
161 if let Some(cached_token) = &*cached_token {
162 return Ok(cached_token.secret.clone());
163 }
164
165 Err(OAuthError::TokenUnavailable)
166 }
167
168 fn cached_token_is_expired(&self) -> Result<bool, OAuthError> {
169 let lock = self.cached_token.lock()?;
170
171 let expired = if let Some(token) = lock.as_ref() {
172 token.expire_at <= SystemTime::now()
173 } else {
174 true
175 };
176
177 Ok(expired)
178 }
179
180 fn run(self: Arc<Self>, refresh_interval: Duration) {
181 tokio::task::spawn(async move {
182 let mut interval = tokio::time::interval(refresh_interval);
183 loop {
184 interval.tick().await;
185 if let Err(error) = self.refresh_token().await {
186 error!("Failed to refresh OAuth token: {}", error)
187 }
188 }
189 });
190 }
191
192 async fn refresh_token(&self) -> Result<(), OAuthError> {
193 if !self.cached_token_is_expired()? {
194 return Ok(());
195 }
196
197 let token_request = self
198 .client
199 .exchange_client_credentials()
200 .add_extra_param("audience", self.audience.clone());
201
202 let result = match timeout(
203 self.request_timeout,
204 token_request.request_async(&self.reqwest_client),
205 )
206 .await
207 {
208 Ok(Ok(response)) => response,
209 Ok(Err(err)) => return Err(OAuthError::Request(err.to_string())),
210 Err(err) => return Err(OAuthError::Timeout(err)),
211 };
212
213 let expiry = std::time::SystemTime::now() + result.expires_in().unwrap_or_default()
214 - Duration::from_secs(OAUTH_REFRESH_MARGIN_SEC);
215 let new_token = CachedToken {
216 secret: result.access_token().secret().to_owned(),
217 expire_at: expiry,
218 };
219
220 let _ = self.cached_token.lock()?.replace(new_token);
221 self.token_refreshed.notify_one();
222
223 Ok(())
224 }
225}
226
227#[derive(Clone, Debug, Default)]
228pub(crate) struct OAuthInterceptor {
229 oauth_provider: Option<Arc<OAuthProvider>>,
230}
231
232impl OAuthInterceptor {
233 pub(crate) fn new(oauth_config: OAuthConfig, auth_timeout: Duration) -> Self {
234 let provider = Arc::new(OAuthProvider::new(oauth_config, auth_timeout));
235
236 provider
237 .clone()
238 .run(Duration::from_secs(OAUTH_REFRESH_INTERVAL_SEC));
239
240 OAuthInterceptor {
241 oauth_provider: Some(provider),
242 }
243 }
244
245 pub(crate) async fn auth_initialized(&self) {
246 if let Some(provider) = &self.oauth_provider {
247 provider.token_refreshed().await;
248 }
249 }
250}
251
252impl Interceptor for OAuthInterceptor {
253 fn call(
254 &mut self,
255 mut request: tonic::Request<()>,
256 ) -> std::result::Result<tonic::Request<()>, Status> {
257 if let Some(oauth_client) = &mut self.oauth_provider {
258 let token = match oauth_client.read_token() {
259 Ok(token) => token,
260 Err(err) => {
261 return Err(tonic::Status::unauthenticated(format!(
262 "{}: {}",
263 "failed to get token", err
264 )));
265 }
266 };
267
268 request.metadata_mut().insert(
269 "authorization",
270 MetadataValue::try_from(&format!("Bearer {}", token)).map_err(|_| {
271 tonic::Status::unauthenticated(format!(
272 "{}: {}",
273 "token is not a valid header value", token
274 ))
275 })?,
276 );
277 }
278
279 Ok(request)
280 }
281}