1#![forbid(unsafe_code)]
2
3use async_trait::async_trait;
4use futures::{lock::Mutex, Future};
5use models::{
6 KeyVaultError, KeyVaultErrorResponse, ReadSecretResponse, TokenFromSharedSecretRequestBody,
7 TokenResponse, TokenRetrievalErrorResponse,
8};
9use reqwest::StatusCode;
10use serde::Deserialize;
11use std::{error::Error, fmt::Debug, fmt::Display};
12use thiserror::Error;
13use xand_secrets::{CheckHealthError, ExposeSecret, ReadSecretError, Secret, SecretKeyValueStore};
14
15mod models;
16
17const DEFAULT_OAUTH_TOKEN_ENDPOINT: &str = "https://login.microsoftonline.com";
18
19#[derive(Clone, Debug, Deserialize)]
20#[serde(rename_all = "kebab-case")]
21pub struct KeyVaultConfiguration {
22 pub tenant_id: String,
23 pub http_endpoint: String,
24 pub client_id: String,
25 pub client_secret: Secret<String>,
26 pub azure_resource_id: String,
27 pub custom_oauth_token_endpoint: Option<String>,
28}
29
30#[derive(Debug)]
31pub struct KeyVaultSecretKeyValueStore {
32 config: KeyVaultConfiguration,
33 client: reqwest::Client,
34 token: Mutex<Option<Secret<String>>>,
35}
36
37#[derive(Debug)]
38enum TokenRetrievalError {
39 Authentication {
40 internal_error: Box<dyn Error + Send + Sync>,
41 },
42 BodyParse {
43 internal_error: Box<dyn Error + Send + Sync>,
44 },
45 HttpResponse {
46 internal_error: TokenRetrievalHttpError,
47 },
48 Transport {
49 internal_error: Box<dyn Error + Send + Sync>,
50 },
51}
52
53impl From<TokenRetrievalError> for ReadSecretError {
54 fn from(e: TokenRetrievalError) -> Self {
55 match e {
56 TokenRetrievalError::Authentication { internal_error } => {
57 ReadSecretError::Authentication { internal_error }
58 }
59 TokenRetrievalError::BodyParse { internal_error }
60 | TokenRetrievalError::Transport { internal_error } => {
61 ReadSecretError::Request { internal_error }
62 }
63 TokenRetrievalError::HttpResponse { internal_error } => ReadSecretError::Request {
64 internal_error: Box::new(internal_error),
65 },
66 }
67 }
68}
69
70impl From<TokenRetrievalError> for CheckHealthError {
71 fn from(e: TokenRetrievalError) -> Self {
72 match e {
73 TokenRetrievalError::Authentication { internal_error } => {
74 CheckHealthError::Authentication { internal_error }
75 }
76 TokenRetrievalError::BodyParse { internal_error }
77 | TokenRetrievalError::Transport { internal_error } => {
78 CheckHealthError::RemoteInternal { internal_error }
79 }
80 TokenRetrievalError::HttpResponse { internal_error } => {
81 CheckHealthError::RemoteInternal {
82 internal_error: Box::new(internal_error),
83 }
84 }
85 }
86 }
87}
88
89#[derive(Debug, Error)]
90pub struct TokenRetrievalHttpError {
91 pub status_code: StatusCode,
92 pub response: Option<TokenRetrievalErrorResponse>,
93}
94
95impl Display for TokenRetrievalHttpError {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 if let Some(error) = &self.response {
98 write!(
99 f,
100 "Authentication endpoint returned status code {}. Error: {}. Description: {}.",
101 self.status_code, error.error, error.error_description
102 )?;
103 } else {
104 write!(
105 f,
106 "Authentication endpoint returned status code {} with an unexpected error body.",
107 self.status_code
108 )?;
109 }
110
111 Ok(())
112 }
113}
114
115#[derive(Debug, Error)]
116pub struct KeyVaultHttpError {
117 pub status_code: StatusCode,
118 pub response: Option<KeyVaultError>,
119}
120
121impl KeyVaultHttpError {
122 async fn from_response(response: reqwest::Response) -> KeyVaultHttpError {
123 KeyVaultHttpError {
124 status_code: response.status(),
125 response: response
126 .json::<KeyVaultErrorResponse>()
127 .await
128 .ok()
129 .map(|r| r.error),
130 }
131 }
132}
133
134impl Display for KeyVaultHttpError {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 if let Some(error) = &self.response {
137 write!(
138 f,
139 "Key Vault returned status code {}. Error code: {}. Message: {}.",
140 self.status_code, error.code, error.message
141 )?;
142 } else {
143 write!(
144 f,
145 "Key Vault returned status code {} with an unexpected error body.",
146 self.status_code
147 )?;
148 }
149
150 Ok(())
151 }
152}
153
154impl KeyVaultSecretKeyValueStore {
155 pub fn create_from_config(config: KeyVaultConfiguration) -> KeyVaultSecretKeyValueStore {
156 KeyVaultSecretKeyValueStore {
157 config,
158 client: reqwest::Client::new(),
159 token: Mutex::new(None),
160 }
161 }
162
163 async fn get_last_known_token(&self) -> Option<Secret<String>> {
164 let token = self.token.lock().await;
165 token.clone()
166 }
167
168 fn get_oauth_token_endpoint(&self) -> String {
169 let token_address = self
170 .config
171 .custom_oauth_token_endpoint
172 .as_ref()
173 .map_or(DEFAULT_OAUTH_TOKEN_ENDPOINT, String::as_str);
174
175 format!(
176 "{}/{}/oauth2/token",
177 token_address,
178 urlencoding::encode(self.config.tenant_id.as_str())
179 )
180 }
181
182 fn get_secret_read_latest_version_endpoint(&self, secret_name: &str) -> String {
183 format!(
184 "{}/secrets/{}?api-version=7.1",
185 self.config.http_endpoint,
186 urlencoding::encode(secret_name),
187 )
188 }
189
190 fn get_secret_list_default_num_endpoint(&self) -> String {
191 format!("{}/secrets?api-version=7.1", self.config.http_endpoint,)
192 }
193
194 async fn get_updated_token(&self) -> Result<Secret<String>, TokenRetrievalError> {
195 let response = self
196 .client
197 .post(&self.get_oauth_token_endpoint())
198 .form(&TokenFromSharedSecretRequestBody {
199 grant_type: "client_credentials",
200 client_id: self.config.client_id.as_str(),
201 client_secret: self.config.client_secret.expose_secret(),
202 resource: self.config.azure_resource_id.as_str(),
203 })
204 .send()
205 .await
206 .map_err(|e| TokenRetrievalError::Transport {
207 internal_error: Box::new(e),
208 })?;
209
210 if !response.status().is_success() {
211 return Err(match response.status() {
212 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
213 TokenRetrievalError::Authentication {
214 internal_error: Box::new(TokenRetrievalHttpError {
215 status_code: response.status(),
216 response: response.json::<TokenRetrievalErrorResponse>().await.ok(),
217 }),
218 }
219 }
220 _ => TokenRetrievalError::HttpResponse {
221 internal_error: TokenRetrievalHttpError {
222 status_code: response.status(),
223 response: response.json::<TokenRetrievalErrorResponse>().await.ok(),
224 },
225 },
226 });
227 };
228
229 let response_body =
230 response
231 .json::<TokenResponse>()
232 .await
233 .map_err(|e| TokenRetrievalError::BodyParse {
234 internal_error: Box::new(e),
235 })?;
236
237 Ok(response_body.access_token)
238 }
239
240 async fn send_read_request_to_key_vault(
241 &self,
242 token: Secret<String>,
243 name: &str,
244 ) -> Result<Secret<String>, ReadSecretError> {
245 let response = self
246 .client
247 .get(&self.get_secret_read_latest_version_endpoint(name))
248 .bearer_auth(token.expose_secret())
249 .send()
250 .await
251 .map_err(|e| ReadSecretError::Request {
252 internal_error: Box::new(e),
253 })?;
254
255 if !response.status().is_success() {
256 return Err(match response.status() {
257 StatusCode::NOT_FOUND => ReadSecretError::KeyNotFound {
258 key: String::from(name),
259 },
260 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
261 ReadSecretError::Authentication {
262 internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
263 }
264 }
265 _ => ReadSecretError::Request {
266 internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
267 },
268 });
269 };
270
271 let response_body =
272 response
273 .json::<ReadSecretResponse>()
274 .await
275 .map_err(|e| ReadSecretError::Request {
276 internal_error: Box::new(e),
277 })?;
278
279 Ok(response_body.value)
280 }
281
282 async fn send_health_probe_to_key_vault(
283 &self,
284 token: Secret<String>,
285 ) -> Result<(), CheckHealthError> {
286 let response = self
287 .client
288 .get(&self.get_secret_list_default_num_endpoint())
291 .bearer_auth(token.expose_secret())
292 .send()
293 .await
294 .map_err(|e| CheckHealthError::Unreachable {
295 internal_error: Box::new(e),
296 })?;
297
298 match response.status() {
299 status if status.is_success() => Ok(()),
300 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
301 Err(CheckHealthError::Authentication {
302 internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
303 })
304 }
305 _ => Err(CheckHealthError::RemoteInternal {
306 internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
307 }),
308 }
309
310 }
312
313 async fn lock_and_update_token(&self) -> Result<Secret<String>, TokenRetrievalError> {
314 let mut token = self.token.lock().await;
315 let updated_token = self.get_updated_token().await?;
316 *token = Some(updated_token.clone());
317
318 Ok(updated_token)
319 }
320
321 async fn attempt_with_token_renewal<O, OFut, F, T, E>(
322 &self,
323 make_request: O,
324 is_authentication_error: F,
325 ) -> Result<T, E>
326 where
327 O: Fn(Secret<String>) -> OFut,
328 OFut: Future<Output = Result<T, E>>,
329 F: Fn(&E) -> bool,
330 T: Debug,
331 E: From<TokenRetrievalError>,
332 {
333 let token = self.get_last_known_token().await;
334 if let Some(token_value) = token {
335 let result = make_request(token_value).await;
336 if !(result.is_err() && is_authentication_error(result.as_ref().unwrap_err())) {
337 return result;
338 }
339 }
340
341 let updated_token = self.lock_and_update_token().await?;
346
347 make_request(updated_token).await
348 }
349}
350
351#[async_trait]
352impl SecretKeyValueStore for KeyVaultSecretKeyValueStore {
353 async fn read(&self, key: &str) -> Result<Secret<String>, ReadSecretError> {
354 self.attempt_with_token_renewal(
355 |token| self.send_read_request_to_key_vault(token, key),
356 |e| matches!(e, ReadSecretError::Authentication { .. }),
357 )
358 .await
359 }
360
361 async fn check_health(&self) -> Result<(), xand_secrets::CheckHealthError> {
362 self.attempt_with_token_renewal(
363 |token| self.send_health_probe_to_key_vault(token),
364 |e| matches!(e, CheckHealthError::Authentication { .. }),
365 )
366 .await
367 }
368}
369
370#[cfg(test)]
371mod tests;