secret_vault/
vault.rs

1use crate::encryption::SecretVaultEncryption;
2use crate::secrets_source::SecretsSource;
3use crate::vault_store::SecretVaultStore;
4use crate::*;
5use async_trait::async_trait;
6use std::sync::Arc;
7use tracing::*;
8
9pub struct SecretVault<S, E>
10where
11    S: SecretsSource,
12    E: SecretVaultEncryption + Sync + Send,
13{
14    source: S,
15    store: Arc<SecretVaultStore<E>>,
16    refs: Vec<SecretVaultRef>,
17}
18
19impl<S, E> SecretVault<S, E>
20where
21    S: SecretsSource,
22    E: SecretVaultEncryption + Sync + Send,
23{
24    pub fn new(source: S, encrypter: E) -> SecretVaultResult<Self> {
25        Ok(Self {
26            source,
27            store: Arc::new(SecretVaultStore::new(encrypter)),
28            refs: Vec::new(),
29        })
30    }
31
32    pub fn with_secret_refs(mut self, secret_refs: Vec<&SecretVaultRef>) -> Self {
33        self.refs = secret_refs.into_iter().cloned().collect();
34        self
35    }
36
37    pub fn register_secret_refs(&mut self, secret_refs: Vec<&SecretVaultRef>) -> &mut Self {
38        self.refs = secret_refs.into_iter().cloned().collect();
39        self
40    }
41
42    pub fn add_secret_refs(mut self, secret_refs: Vec<&SecretVaultRef>) -> Self {
43        self.refs = [secret_refs.into_iter().cloned().collect(), self.refs].concat();
44        self
45    }
46
47    pub fn add_secret_ref(&mut self, secret_ref: &SecretVaultRef) -> &mut Self {
48        self.refs.push(secret_ref.clone());
49        self
50    }
51
52    pub fn remove_secret_ref(&mut self, key: &SecretVaultKey) -> &mut Self {
53        self.refs.retain(|secret_ref| secret_ref.key != *key);
54        self
55    }
56
57    pub async fn refresh(&self) -> SecretVaultResult<&Self> {
58        info!(
59            "Refreshing secrets from the source: {}. Expected: {}. Required: {}",
60            self.source.name(),
61            self.refs.len(),
62            self.refs
63                .iter()
64                .filter(|secret_ref| secret_ref.required)
65                .count()
66        );
67
68        let mut secrets_map = self.source.get_secrets(&self.refs).await?;
69
70        for (secret_ref, secret) in secrets_map.drain() {
71            self.store.insert(secret_ref, &secret).await?;
72        }
73
74        info!("Secret vault contains: {} secrets", self.store.len().await);
75
76        self.compact().await?;
77
78        Ok(self)
79    }
80
81    pub async fn refresh_only(
82        &self,
83        predicate: fn(&SecretVaultRef) -> bool,
84    ) -> SecretVaultResult<&Self> {
85        let refs_auto_refresh_enabled: Vec<SecretVaultRef> = self
86            .refs
87            .iter()
88            .filter(|secret_ref| predicate(secret_ref))
89            .cloned()
90            .collect();
91
92        trace!(
93            "Refreshing secrets from the source: {}. All registered secrets: {}. Expected to be refreshed: {}",
94            self.source.name(),
95            self.refs.len(),
96            refs_auto_refresh_enabled.len()
97        );
98
99        let mut secrets_map = self.source.get_secrets(&refs_auto_refresh_enabled).await?;
100
101        for (secret_ref, secret) in secrets_map.drain() {
102            self.store.insert(secret_ref, &secret).await?;
103        }
104
105        trace!(
106            "Secret vault now contains: {} secrets in total",
107            self.store.len().await
108        );
109
110        Ok(self)
111    }
112
113    pub async fn refresh_only_not_present(&self) -> SecretVaultResult<&Self> {
114        let (existing_refs, missing_refs) = self.store.exists(&self.refs).await;
115
116        if !missing_refs.is_empty() {
117            trace!(
118                "Refreshing non cached secrets from the source. Existing: {}. Missing: {}",
119                existing_refs.len(),
120                missing_refs.len()
121            );
122
123            let missing_refs: Vec<SecretVaultRef> = missing_refs.into_iter().cloned().collect();
124
125            let mut secrets_map = self.source.get_secrets(&missing_refs).await?;
126
127            for (secret_ref, secret) in secrets_map.drain() {
128                self.store.insert(secret_ref, &secret).await?;
129            }
130
131            trace!(
132                "Secret vault now contains: {} secrets in total",
133                self.store.len().await
134            );
135        } else {
136            trace!(
137                "No secrets to refresh. All secrets are cached: {}.",
138                self.refs.len()
139            );
140        }
141
142        self.compact().await?;
143
144        Ok(self)
145    }
146
147    pub async fn compact(&self) -> SecretVaultResult<()> {
148        self.store.compact(&self.refs).await
149    }
150
151    pub async fn store_len(&self) -> usize {
152        self.store.len().await
153    }
154
155    pub fn viewer(&self) -> SecretVaultViewer<E> {
156        SecretVaultViewer::new(self.store.clone())
157    }
158
159    pub async fn snapshot<SNB, SN>(&self, builder: SNB) -> SecretVaultResult<SN>
160    where
161        SN: SecretVaultSnapshot,
162        SNB: SecretVaultSnapshotBuilder<SN>,
163    {
164        let refs_allowed_in_snapshot: Vec<SecretVaultRef> = self
165            .refs
166            .iter()
167            .filter(|secret_ref| secret_ref.allow_in_snapshots)
168            .cloned()
169            .collect();
170
171        let mut secrets: Vec<Secret> = Vec::with_capacity(refs_allowed_in_snapshot.len());
172
173        for secret_ref in refs_allowed_in_snapshot {
174            if let Some(secret) = self.store.get_secret(&secret_ref.key).await? {
175                secrets.push(secret);
176            }
177        }
178
179        Ok(builder.build_snapshot(secrets))
180    }
181}
182
183#[async_trait]
184impl<S, E> SecretVaultView for SecretVault<S, E>
185where
186    S: SecretsSource + Send + Sync,
187    E: SecretVaultEncryption + Send + Sync,
188{
189    async fn get_secret_by_ref(
190        &self,
191        secret_ref: &SecretVaultRef,
192    ) -> SecretVaultResult<Option<Secret>> {
193        self.store.get_secret(&secret_ref.key).await
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use crate::source_tests::*;
200    use crate::*;
201    use chrono::Utc;
202    use proptest::prelude::*;
203    use proptest::strategy::ValueTree;
204    use proptest::test_runner::TestRunner;
205    use secret_vault_value::SecretValue;
206
207    #[tokio::test]
208    async fn refresh_vault_test() {
209        let mut runner = TestRunner::default();
210        let mock_secrets_store = generate_mock_secrets_source("default".into())
211            .new_tree(&mut runner)
212            .unwrap()
213            .current();
214
215        let mut vault = SecretVaultBuilder::with_source(mock_secrets_store.clone())
216            .build()
217            .unwrap();
218
219        vault
220            .register_secret_refs(mock_secrets_store.keys().iter().collect())
221            .refresh()
222            .await
223            .unwrap();
224
225        for secret_ref in mock_secrets_store.keys() {
226            assert_eq!(
227                vault
228                    .get_secret_by_ref(&secret_ref)
229                    .await
230                    .unwrap()
231                    .map(|secret| secret.value)
232                    .as_ref(),
233                mock_secrets_store.get(&secret_ref).as_ref()
234            )
235        }
236    }
237
238    #[tokio::test]
239    async fn refresh_only_non_present() {
240        let mut runner = TestRunner::default();
241        let mut mock_secrets_store = generate_mock_secrets_source("default".into())
242            .new_tree(&mut runner)
243            .unwrap()
244            .current();
245
246        let mut vault = SecretVaultBuilder::with_source(mock_secrets_store.clone())
247            .build()
248            .unwrap()
249            .with_secret_refs(mock_secrets_store.keys().iter().collect());
250
251        vault.refresh().await.unwrap();
252
253        let cached_at = Utc::now();
254
255        let new_secret_ref =
256            SecretVaultRef::new("new_secret".into()).with_namespace("default".into());
257        vault.add_secret_ref(&new_secret_ref);
258        mock_secrets_store.add(
259            new_secret_ref.clone(),
260            SecretValue::new("new_secret_value".into()),
261        );
262
263        vault.refresh_only_not_present().await.unwrap();
264
265        for secret_ref in mock_secrets_store.keys() {
266            let ts = vault
267                .get_secret_by_ref(&secret_ref)
268                .await
269                .unwrap()
270                .map(|secret| secret.metadata.cached_at)
271                .as_ref()
272                .unwrap()
273                .timestamp();
274            if secret_ref.key != new_secret_ref.key {
275                assert!(ts <= cached_at.timestamp())
276            } else {
277                assert!(ts >= cached_at.timestamp())
278            }
279        }
280    }
281}