xand_secrets_key_vault/
lib.rs

1#![forbid(unsafe_code)]
2
3use async_trait::async_trait;
4use futures::{lock::Mutex, Future};
5use models::{
6    KeyVaultError, KeyVaultErrorResponse, ReadSecretResponse, TokenFromSharedSecretRequestBody,
7    TokenResponse, TokenRetrievalErrorResponse,
8};
9use reqwest::StatusCode;
10use serde::Deserialize;
11use std::{error::Error, fmt::Debug, fmt::Display};
12use thiserror::Error;
13use xand_secrets::{CheckHealthError, ExposeSecret, ReadSecretError, Secret, SecretKeyValueStore};
14
15mod models;
16
17const DEFAULT_OAUTH_TOKEN_ENDPOINT: &str = "https://login.microsoftonline.com";
18
19#[derive(Clone, Debug, Deserialize)]
20#[serde(rename_all = "kebab-case")]
21pub struct KeyVaultConfiguration {
22    pub tenant_id: String,
23    pub http_endpoint: String,
24    pub client_id: String,
25    pub client_secret: Secret<String>,
26    pub azure_resource_id: String,
27    pub custom_oauth_token_endpoint: Option<String>,
28}
29
30#[derive(Debug)]
31pub struct KeyVaultSecretKeyValueStore {
32    config: KeyVaultConfiguration,
33    client: reqwest::Client,
34    token: Mutex<Option<Secret<String>>>,
35}
36
37#[derive(Debug)]
38enum TokenRetrievalError {
39    Authentication {
40        internal_error: Box<dyn Error + Send + Sync>,
41    },
42    BodyParse {
43        internal_error: Box<dyn Error + Send + Sync>,
44    },
45    HttpResponse {
46        internal_error: TokenRetrievalHttpError,
47    },
48    Transport {
49        internal_error: Box<dyn Error + Send + Sync>,
50    },
51}
52
53impl From<TokenRetrievalError> for ReadSecretError {
54    fn from(e: TokenRetrievalError) -> Self {
55        match e {
56            TokenRetrievalError::Authentication { internal_error } => {
57                ReadSecretError::Authentication { internal_error }
58            }
59            TokenRetrievalError::BodyParse { internal_error }
60            | TokenRetrievalError::Transport { internal_error } => {
61                ReadSecretError::Request { internal_error }
62            }
63            TokenRetrievalError::HttpResponse { internal_error } => ReadSecretError::Request {
64                internal_error: Box::new(internal_error),
65            },
66        }
67    }
68}
69
70impl From<TokenRetrievalError> for CheckHealthError {
71    fn from(e: TokenRetrievalError) -> Self {
72        match e {
73            TokenRetrievalError::Authentication { internal_error } => {
74                CheckHealthError::Authentication { internal_error }
75            }
76            TokenRetrievalError::BodyParse { internal_error }
77            | TokenRetrievalError::Transport { internal_error } => {
78                CheckHealthError::RemoteInternal { internal_error }
79            }
80            TokenRetrievalError::HttpResponse { internal_error } => {
81                CheckHealthError::RemoteInternal {
82                    internal_error: Box::new(internal_error),
83                }
84            }
85        }
86    }
87}
88
89#[derive(Debug, Error)]
90pub struct TokenRetrievalHttpError {
91    pub status_code: StatusCode,
92    pub response: Option<TokenRetrievalErrorResponse>,
93}
94
95impl Display for TokenRetrievalHttpError {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        if let Some(error) = &self.response {
98            write!(
99                f,
100                "Authentication endpoint returned status code {}. Error: {}. Description: {}.",
101                self.status_code, error.error, error.error_description
102            )?;
103        } else {
104            write!(
105                f,
106                "Authentication endpoint returned status code {} with an unexpected error body.",
107                self.status_code
108            )?;
109        }
110
111        Ok(())
112    }
113}
114
115#[derive(Debug, Error)]
116pub struct KeyVaultHttpError {
117    pub status_code: StatusCode,
118    pub response: Option<KeyVaultError>,
119}
120
121impl KeyVaultHttpError {
122    async fn from_response(response: reqwest::Response) -> KeyVaultHttpError {
123        KeyVaultHttpError {
124            status_code: response.status(),
125            response: response
126                .json::<KeyVaultErrorResponse>()
127                .await
128                .ok()
129                .map(|r| r.error),
130        }
131    }
132}
133
134impl Display for KeyVaultHttpError {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        if let Some(error) = &self.response {
137            write!(
138                f,
139                "Key Vault returned status code {}. Error code: {}. Message: {}.",
140                self.status_code, error.code, error.message
141            )?;
142        } else {
143            write!(
144                f,
145                "Key Vault returned status code {} with an unexpected error body.",
146                self.status_code
147            )?;
148        }
149
150        Ok(())
151    }
152}
153
154impl KeyVaultSecretKeyValueStore {
155    pub fn create_from_config(config: KeyVaultConfiguration) -> KeyVaultSecretKeyValueStore {
156        KeyVaultSecretKeyValueStore {
157            config,
158            client: reqwest::Client::new(),
159            token: Mutex::new(None),
160        }
161    }
162
163    async fn get_last_known_token(&self) -> Option<Secret<String>> {
164        let token = self.token.lock().await;
165        token.clone()
166    }
167
168    fn get_oauth_token_endpoint(&self) -> String {
169        let token_address = self
170            .config
171            .custom_oauth_token_endpoint
172            .as_ref()
173            .map_or(DEFAULT_OAUTH_TOKEN_ENDPOINT, String::as_str);
174
175        format!(
176            "{}/{}/oauth2/token",
177            token_address,
178            urlencoding::encode(self.config.tenant_id.as_str())
179        )
180    }
181
182    fn get_secret_read_latest_version_endpoint(&self, secret_name: &str) -> String {
183        format!(
184            "{}/secrets/{}?api-version=7.1",
185            self.config.http_endpoint,
186            urlencoding::encode(secret_name),
187        )
188    }
189
190    fn get_secret_list_default_num_endpoint(&self) -> String {
191        format!("{}/secrets?api-version=7.1", self.config.http_endpoint,)
192    }
193
194    async fn get_updated_token(&self) -> Result<Secret<String>, TokenRetrievalError> {
195        let response = self
196            .client
197            .post(&self.get_oauth_token_endpoint())
198            .form(&TokenFromSharedSecretRequestBody {
199                grant_type: "client_credentials",
200                client_id: self.config.client_id.as_str(),
201                client_secret: self.config.client_secret.expose_secret(),
202                resource: self.config.azure_resource_id.as_str(),
203            })
204            .send()
205            .await
206            .map_err(|e| TokenRetrievalError::Transport {
207                internal_error: Box::new(e),
208            })?;
209
210        if !response.status().is_success() {
211            return Err(match response.status() {
212                StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
213                    TokenRetrievalError::Authentication {
214                        internal_error: Box::new(TokenRetrievalHttpError {
215                            status_code: response.status(),
216                            response: response.json::<TokenRetrievalErrorResponse>().await.ok(),
217                        }),
218                    }
219                }
220                _ => TokenRetrievalError::HttpResponse {
221                    internal_error: TokenRetrievalHttpError {
222                        status_code: response.status(),
223                        response: response.json::<TokenRetrievalErrorResponse>().await.ok(),
224                    },
225                },
226            });
227        };
228
229        let response_body =
230            response
231                .json::<TokenResponse>()
232                .await
233                .map_err(|e| TokenRetrievalError::BodyParse {
234                    internal_error: Box::new(e),
235                })?;
236
237        Ok(response_body.access_token)
238    }
239
240    async fn send_read_request_to_key_vault(
241        &self,
242        token: Secret<String>,
243        name: &str,
244    ) -> Result<Secret<String>, ReadSecretError> {
245        let response = self
246            .client
247            .get(&self.get_secret_read_latest_version_endpoint(name))
248            .bearer_auth(token.expose_secret())
249            .send()
250            .await
251            .map_err(|e| ReadSecretError::Request {
252                internal_error: Box::new(e),
253            })?;
254
255        if !response.status().is_success() {
256            return Err(match response.status() {
257                StatusCode::NOT_FOUND => ReadSecretError::KeyNotFound {
258                    key: String::from(name),
259                },
260                StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
261                    ReadSecretError::Authentication {
262                        internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
263                    }
264                }
265                _ => ReadSecretError::Request {
266                    internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
267                },
268            });
269        };
270
271        let response_body =
272            response
273                .json::<ReadSecretResponse>()
274                .await
275                .map_err(|e| ReadSecretError::Request {
276                    internal_error: Box::new(e),
277                })?;
278
279        Ok(response_body.value)
280    }
281
282    async fn send_health_probe_to_key_vault(
283        &self,
284        token: Secret<String>,
285    ) -> Result<(), CheckHealthError> {
286        let response = self
287            .client
288            // Azure does not provide a Key Vault specific health endpoint so the list secrets
289            // endpoint is used to check authenticated access to this particular Key Vault.
290            .get(&self.get_secret_list_default_num_endpoint())
291            .bearer_auth(token.expose_secret())
292            .send()
293            .await
294            .map_err(|e| CheckHealthError::Unreachable {
295                internal_error: Box::new(e),
296            })?;
297
298        match response.status() {
299            status if status.is_success() => Ok(()),
300            StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
301                Err(CheckHealthError::Authentication {
302                    internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
303                })
304            }
305            _ => Err(CheckHealthError::RemoteInternal {
306                internal_error: Box::new(KeyVaultHttpError::from_response(response).await),
307            }),
308        }
309
310        // throw out the body; we just care about the HTTP status
311    }
312
313    async fn lock_and_update_token(&self) -> Result<Secret<String>, TokenRetrievalError> {
314        let mut token = self.token.lock().await;
315        let updated_token = self.get_updated_token().await?;
316        *token = Some(updated_token.clone());
317
318        Ok(updated_token)
319    }
320
321    async fn attempt_with_token_renewal<O, OFut, F, T, E>(
322        &self,
323        make_request: O,
324        is_authentication_error: F,
325    ) -> Result<T, E>
326    where
327        O: Fn(Secret<String>) -> OFut,
328        OFut: Future<Output = Result<T, E>>,
329        F: Fn(&E) -> bool,
330        T: Debug,
331        E: From<TokenRetrievalError>,
332    {
333        let token = self.get_last_known_token().await;
334        if let Some(token_value) = token {
335            let result = make_request(token_value).await;
336            if !(result.is_err() && is_authentication_error(result.as_ref().unwrap_err())) {
337                return result;
338            }
339        }
340
341        // Fetching a token from AAD, does not invalidate existing tokens. Thus, the race conditions
342        // introduced here won't cause a failure because other threads will still be able to
343        // authenticate using the tokens they previously fetched. It is possible that multiple threads
344        // will unnecessarily fetch new tokens independently.
345        let updated_token = self.lock_and_update_token().await?;
346
347        make_request(updated_token).await
348    }
349}
350
351#[async_trait]
352impl SecretKeyValueStore for KeyVaultSecretKeyValueStore {
353    async fn read(&self, key: &str) -> Result<Secret<String>, ReadSecretError> {
354        self.attempt_with_token_renewal(
355            |token| self.send_read_request_to_key_vault(token, key),
356            |e| matches!(e, ReadSecretError::Authentication { .. }),
357        )
358        .await
359    }
360
361    async fn check_health(&self) -> Result<(), xand_secrets::CheckHealthError> {
362        self.attempt_with_token_renewal(
363            |token| self.send_health_probe_to_key_vault(token),
364            |e| matches!(e, CheckHealthError::Authentication { .. }),
365        )
366        .await
367    }
368}
369
370#[cfg(test)]
371mod tests;