Skip to main content

secret_manager/
diesel_pg_backend.rs

1//! PostgreSQL-backed `SecretBackend` implementation using Diesel.
2
3use crate::backend::{KeyRecord, SecretBackend};
4use crate::encryptor::Encrypted;
5use crate::pg_queries::*;
6use crate::rotator::SecretRotationBackend;
7use async_trait::async_trait;
8use diesel::sql_types::{Bytea, Nullable, SmallInt, Text, Timestamptz};
9use diesel::{QueryableByName, sql_query};
10use diesel_async::pooled_connection::bb8::Pool;
11use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl};
12use jiff::Timestamp;
13use jiff_diesel::{Timestamp as DieselTimestamp, ToDiesel};
14use std::time::SystemTime;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum DieselPgSecretBackendError {
19    #[error("connection pool error: {0}")]
20    Pool(String),
21    #[error("query error: {0}")]
22    Query(#[from] diesel::result::Error),
23}
24
25#[derive(QueryableByName)]
26struct KeyRow {
27    #[diesel(sql_type = diesel::sql_types::BigInt)]
28    id: i64,
29    #[diesel(sql_type = SmallInt)]
30    version: i16,
31    #[diesel(sql_type = Bytea)]
32    key_bytes: Vec<u8>,
33    #[diesel(sql_type = Nullable<Bytea>)]
34    nonce: Option<Vec<u8>>,
35    #[diesel(sql_type = SmallInt)]
36    encryption_key_version: i16,
37    #[diesel(sql_type = Timestamptz)]
38    activated_at: DieselTimestamp,
39}
40
41impl From<KeyRow> for KeyRecord {
42    fn from(r: KeyRow) -> Self {
43        KeyRecord {
44            id: r.id,
45            version: r.version as u8,
46            key_bytes: r.key_bytes,
47            nonce: r.nonce,
48            encryption_key_version: r.encryption_key_version as u8,
49            activated_at: r.activated_at.to_jiff().into(),
50        }
51    }
52}
53
54#[derive(Clone)]
55pub struct DieselPgSecretBackend {
56    pool: Pool<AsyncPgConnection>,
57}
58
59impl DieselPgSecretBackend {
60    pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
61        Self { pool }
62    }
63}
64
65#[async_trait]
66impl SecretBackend for DieselPgSecretBackend {
67    type Error = DieselPgSecretBackendError;
68
69    async fn load_all(&self, group_id: &str) -> Result<Vec<KeyRecord>, Self::Error> {
70        let mut conn = self
71            .pool
72            .get()
73            .await
74            .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
75        let rows: Vec<KeyRow> = sql_query(LOAD_ALL_QUERY)
76            .bind::<Text, _>(group_id)
77            .get_results(&mut conn)
78            .await?;
79        Ok(rows.into_iter().map(KeyRecord::from).collect())
80    }
81
82    async fn poll_new(
83        &self,
84        group_id: &str,
85        since_time: SystemTime,
86        since_id: i64,
87    ) -> Result<Vec<KeyRecord>, Self::Error> {
88        let mut conn = self
89            .pool
90            .get()
91            .await
92            .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
93        let since_jiff = Timestamp::try_from(since_time)
94            .map_err(|e| diesel::result::Error::SerializationError(Box::new(e)))?;
95        let rows: Vec<KeyRow> = sql_query(POLL_NEW_QUERY)
96            .bind::<Text, _>(group_id)
97            .bind::<Timestamptz, _>(since_jiff.to_diesel())
98            .bind::<diesel::sql_types::BigInt, _>(since_id)
99            .get_results(&mut conn)
100            .await?;
101        Ok(rows.into_iter().map(KeyRecord::from).collect())
102    }
103}
104
105#[derive(QueryableByName)]
106struct KeyInfoRow {
107    #[diesel(sql_type = SmallInt)]
108    version: i16,
109    #[diesel(sql_type = Timestamptz)]
110    activated_at: DieselTimestamp,
111}
112
113#[async_trait]
114impl SecretRotationBackend for DieselPgSecretBackend {
115    type Error = DieselPgSecretBackendError;
116
117    async fn latest_key_info(
118        &self,
119        group_id: &str,
120    ) -> Result<Option<(u8, SystemTime)>, Self::Error> {
121        let mut conn = self
122            .pool
123            .get()
124            .await
125            .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
126        let rows: Vec<KeyInfoRow> = sql_query(LATEST_KEY_INFO_QUERY)
127            .bind::<Text, _>(group_id)
128            .get_results(&mut conn)
129            .await?;
130        Ok(rows
131            .into_iter()
132            .next()
133            .map(|r| (r.version as u8, r.activated_at.to_jiff().into())))
134    }
135
136    async fn try_insert_key(
137        &self,
138        group_id: &str,
139        expected_version: Option<u8>,
140        new_version: u8,
141        encrypted: &Encrypted,
142        activated_at: SystemTime,
143    ) -> Result<bool, Self::Error> {
144        let mut conn = self
145            .pool
146            .get()
147            .await
148            .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
149        let group_id = group_id.to_owned();
150        let ciphertext = encrypted.ciphertext.clone();
151        let nonce = encrypted.nonce.clone();
152        let enc_version = encrypted.key_version as i16;
153        let activated_at_jiff = Timestamp::try_from(activated_at)
154            .map_err(|e| diesel::result::Error::SerializationError(Box::new(e)))?;
155        let activated_at_diesel = activated_at_jiff.to_diesel();
156
157        let inserted = conn
158            .transaction(|conn| {
159                Box::pin(async move {
160                    sql_query(ADVISORY_LOCK_QUERY)
161                        .bind::<Text, _>(&group_id)
162                        .execute(conn)
163                        .await?;
164                    let rows: Vec<KeyInfoRow> = sql_query(LATEST_KEY_INFO_QUERY)
165                        .bind::<Text, _>(&group_id)
166                        .get_results(conn)
167                        .await?;
168                    let current_version = rows.into_iter().next().map(|r| r.version as u8);
169                    if current_version != expected_version {
170                        return Ok::<bool, diesel::result::Error>(false);
171                    }
172                    let rows_affected = sql_query(INSERT_KEY_QUERY)
173                        .bind::<Text, _>(&group_id)
174                        .bind::<SmallInt, _>(new_version as i16)
175                        .bind::<Bytea, _>(ciphertext.as_slice())
176                        .bind::<Nullable<Bytea>, _>(nonce.as_ref().map(|n| n as &[u8]))
177                        .bind::<SmallInt, _>(enc_version)
178                        .bind::<Timestamptz, _>(activated_at_diesel)
179                        .execute(conn)
180                        .await?;
181                    Ok(rows_affected > 0)
182                })
183            })
184            .await?;
185        Ok(inserted)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::backend::SecretBackend;
193    use crate::encryptor::Encrypted;
194    use crate::rotator::SecretRotationBackend;
195    use diesel::sql_types::{Bytea, SmallInt};
196    use diesel_async::RunQueryDsl;
197    use diesel_migrations::{EmbeddedMigrations, embed_migrations};
198    use std::time::{Duration, SystemTime};
199    use test_containers_util::diesel_pg::PostgresTestDb;
200    use uuid::Uuid; // used only to generate unique group IDs for test isolation
201
202    const MIGRATIONS: EmbeddedMigrations = embed_migrations!("tests/diesel-migrations");
203
204    async fn make_backend() -> (PostgresTestDb, DieselPgSecretBackend) {
205        let db = PostgresTestDb::create("secret-manager-diesel", MIGRATIONS, None, None).await;
206        let backend = DieselPgSecretBackend::new(db.pool());
207        (db, backend)
208    }
209
210    fn no_op_encrypted(bytes: &[u8]) -> Encrypted {
211        Encrypted { ciphertext: bytes.to_vec(), nonce: None, key_version: 0 }
212    }
213
214    async fn insert_key(
215        backend: &DieselPgSecretBackend,
216        group_id: &str,
217        version: i16,
218        bytes: &[u8],
219    ) {
220        let mut conn = backend.pool.get().await.unwrap();
221        diesel::sql_query(
222            "INSERT INTO secret_keys (key_group, version, key_bytes) VALUES ($1, $2, $3)",
223        )
224        .bind::<diesel::sql_types::Text, _>(group_id)
225        .bind::<SmallInt, _>(version)
226        .bind::<Bytea, _>(bytes)
227        .execute(&mut conn)
228        .await
229        .unwrap();
230    }
231
232    fn abs_diff(a: SystemTime, b: SystemTime) -> Duration {
233        a.duration_since(b)
234            .or_else(|_| b.duration_since(a))
235            .unwrap_or(Duration::ZERO)
236    }
237
238    #[tokio::test(flavor = "multi_thread")]
239    async fn load_all_returns_empty_for_unknown_group() {
240        let (_db, backend) = make_backend().await;
241        let records = backend.load_all(&Uuid::new_v4().simple().to_string()).await.unwrap();
242        assert!(records.is_empty());
243    }
244
245    #[tokio::test(flavor = "multi_thread")]
246    async fn load_all_returns_rows_ordered_by_activated_at() {
247        let (_db, backend) = make_backend().await;
248        let gid = Uuid::new_v4().simple().to_string();
249        let t0 = SystemTime::now() - Duration::from_secs(120);
250        let t1 = SystemTime::now() - Duration::from_secs(60);
251        let t2 = SystemTime::now();
252
253        backend.try_insert_key(&gid, None, 2, &no_op_encrypted(&[2u8; 32]), t0).await.unwrap();
254        backend.try_insert_key(&gid, Some(2), 0, &no_op_encrypted(&[0u8; 32]), t1).await.unwrap();
255        backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t2).await.unwrap();
256
257        let records = backend.load_all(&gid).await.unwrap();
258        assert_eq!(records.len(), 3);
259        assert_eq!(records[0].version, 2);
260        assert_eq!(records[1].version, 0);
261        assert_eq!(records[2].version, 1);
262        assert_eq!(records[0].key_bytes, vec![2u8; 32]);
263    }
264
265    #[tokio::test(flavor = "multi_thread")]
266    async fn poll_new_returns_empty_when_no_newer_key() {
267        let (_db, backend) = make_backend().await;
268        let gid = Uuid::new_v4().simple().to_string();
269        let t = SystemTime::now();
270        backend.try_insert_key(&gid, None, 5, &no_op_encrypted(&[5u8; 32]), t).await.unwrap();
271
272        let inserted = backend.load_all(&gid).await.unwrap();
273        let id = inserted[0].id;
274
275        let result = backend.poll_new(&gid, t, id).await.unwrap();
276        assert!(result.is_empty());
277    }
278
279    #[tokio::test(flavor = "multi_thread")]
280    async fn poll_new_returns_keys_newer_than_cursor() {
281        let (_db, backend) = make_backend().await;
282        let gid = Uuid::new_v4().simple().to_string();
283        let t0 = SystemTime::now() - Duration::from_secs(180);
284        let t1 = SystemTime::now() - Duration::from_secs(120);
285        let t2 = SystemTime::now() - Duration::from_secs(60);
286
287        backend.try_insert_key(&gid, None, 3, &no_op_encrypted(&[3u8; 32]), t0).await.unwrap();
288        backend.try_insert_key(&gid, Some(3), 7, &no_op_encrypted(&[7u8; 32]), t1).await.unwrap();
289        backend.try_insert_key(&gid, Some(7), 5, &no_op_encrypted(&[5u8; 32]), t2).await.unwrap();
290
291        let all = backend.load_all(&gid).await.unwrap();
292        let id0 = all[0].id;
293
294        let records = backend.poll_new(&gid, t0, id0).await.unwrap();
295        assert_eq!(records.len(), 2);
296        assert_eq!(records[0].version, 7);
297        assert_eq!(records[1].version, 5);
298        assert!(abs_diff(t1, records[0].activated_at).as_millis() < 5);
299        assert!(abs_diff(t2, records[1].activated_at).as_millis() < 5);
300    }
301
302    #[tokio::test(flavor = "multi_thread")]
303    async fn load_all_isolates_groups() {
304        let (_db, backend) = make_backend().await;
305        let gid_a = Uuid::new_v4().simple().to_string();
306        let gid_b = Uuid::new_v4().simple().to_string();
307        insert_key(&backend, &gid_a, 0, &[10u8; 32]).await;
308        insert_key(&backend, &gid_b, 0, &[20u8; 32]).await;
309
310        let a = backend.load_all(&gid_a).await.unwrap();
311        let b = backend.load_all(&gid_b).await.unwrap();
312
313        assert_eq!(a.len(), 1);
314        assert_eq!(b.len(), 1);
315        assert_eq!(a[0].key_bytes, vec![10u8; 32]);
316        assert_eq!(b[0].key_bytes, vec![20u8; 32]);
317    }
318
319    #[tokio::test(flavor = "multi_thread")]
320    async fn latest_key_info_returns_none_for_empty_group() {
321        let (_db, backend) = make_backend().await;
322        let result = backend.latest_key_info(&Uuid::new_v4().simple().to_string()).await.unwrap();
323        assert!(result.is_none());
324    }
325
326    #[tokio::test(flavor = "multi_thread")]
327    async fn try_insert_key_inserts_first_key() {
328        let (_db, backend) = make_backend().await;
329        let gid = Uuid::new_v4().simple().to_string();
330        let activated_at = SystemTime::now() + Duration::from_secs(120);
331
332        let inserted = backend
333            .try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), activated_at)
334            .await
335            .unwrap();
336        assert!(inserted);
337
338        let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
339        assert_eq!(info.0, 0);
340        assert!(abs_diff(info.1, activated_at).as_millis() < 5);
341    }
342
343    #[tokio::test(flavor = "multi_thread")]
344    async fn try_insert_key_returns_false_when_version_already_changed() {
345        let (_db, backend) = make_backend().await;
346        let gid = Uuid::new_v4().simple().to_string();
347        let t = SystemTime::now() + Duration::from_secs(60);
348
349        insert_key(&backend, &gid, 0, &[0u8; 32]).await;
350
351        let inserted = backend
352            .try_insert_key(&gid, None, 1, &no_op_encrypted(&[1u8; 32]), t)
353            .await
354            .unwrap();
355        assert!(!inserted, "must return false when version already changed");
356
357        let records = backend.load_all(&gid).await.unwrap();
358        assert_eq!(records.len(), 1);
359        assert_eq!(records[0].version, 0);
360    }
361
362    #[tokio::test(flavor = "multi_thread")]
363    async fn try_insert_key_sequential_rotations_succeed() {
364        let (_db, backend) = make_backend().await;
365        let gid = Uuid::new_v4().simple().to_string();
366        let t0 = SystemTime::now() + Duration::from_secs(60);
367        let t1 = SystemTime::now() + Duration::from_secs(120);
368
369        let ok0 = backend.try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), t0).await.unwrap();
370        assert!(ok0);
371
372        let ok1 = backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t1).await.unwrap();
373        assert!(ok1);
374
375        let records = backend.load_all(&gid).await.unwrap();
376        assert_eq!(records.len(), 2);
377        assert_eq!(records[0].version, 0);
378        assert_eq!(records[1].version, 1);
379    }
380
381    #[tokio::test(flavor = "multi_thread")]
382    async fn latest_key_info_returns_most_recently_activated() {
383        let (_db, backend) = make_backend().await;
384        let gid = Uuid::new_v4().simple().to_string();
385        let t2 = SystemTime::now() + Duration::from_secs(120);
386
387        insert_key(&backend, &gid, 10, &[10u8; 32]).await;
388        let ok = backend
389            .try_insert_key(&gid, Some(10), 2, &no_op_encrypted(&[2u8; 32]), t2)
390            .await
391            .unwrap();
392        assert!(ok);
393
394        let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
395        assert_eq!(info.0, 2, "must return most recently activated, not highest version number");
396    }
397}