Skip to main content

secret_manager/
sqlx_pg_backend.rs

1//! PostgreSQL-backed `SecretBackend` implementation using SQLx.
2
3use crate::backend::{KeyRecord, SecretBackend};
4use crate::encryptor::Encrypted;
5use crate::pg_queries::*;
6use crate::rotator::SecretRotationBackend;
7use async_trait::async_trait;
8use jiff::Timestamp;
9use jiff_sqlx::{Timestamp as SqlxTimestamp, ToSqlx};
10use sqlx::{PgPool, Postgres, Transaction};
11use std::time::SystemTime;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum SqlxPgSecretBackendError {
16    #[error("query error: {0}")]
17    Query(#[from] sqlx::Error),
18    #[error("timestamp conversion error: {0}")]
19    Timestamp(String),
20}
21
22#[derive(sqlx::FromRow)]
23struct KeyRow {
24    id: i64,
25    version: i16,
26    key_bytes: Vec<u8>,
27    nonce: Option<Vec<u8>>,
28    encryption_key_version: i16,
29    activated_at: SqlxTimestamp,
30}
31
32impl From<KeyRow> for KeyRecord {
33    fn from(r: KeyRow) -> Self {
34        KeyRecord {
35            id: r.id,
36            version: r.version as u8,
37            key_bytes: r.key_bytes,
38            nonce: r.nonce,
39            encryption_key_version: r.encryption_key_version as u8,
40            activated_at: r.activated_at.to_jiff().into(),
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct SqlxPgSecretBackend {
47    pool: PgPool,
48}
49
50impl SqlxPgSecretBackend {
51    pub fn new(pool: PgPool) -> Self {
52        Self { pool }
53    }
54}
55
56#[async_trait]
57impl SecretBackend for SqlxPgSecretBackend {
58    type Error = SqlxPgSecretBackendError;
59
60    async fn load_all(&self, group_id: &str) -> Result<Vec<KeyRecord>, Self::Error> {
61        let rows = sqlx::query_as::<_, KeyRow>(LOAD_ALL_QUERY)
62            .bind(group_id)
63            .fetch_all(&self.pool)
64            .await?;
65        Ok(rows.into_iter().map(KeyRecord::from).collect())
66    }
67
68    async fn poll_new(
69        &self,
70        group_id: &str,
71        since_time: SystemTime,
72        since_id: i64,
73    ) -> Result<Vec<KeyRecord>, Self::Error> {
74        let since_jiff = Timestamp::try_from(since_time)
75            .map_err(|e| SqlxPgSecretBackendError::Timestamp(e.to_string()))?;
76        let rows = sqlx::query_as::<_, KeyRow>(POLL_NEW_QUERY)
77            .bind(group_id)
78            .bind(since_jiff.to_sqlx())
79            .bind(since_id)
80            .fetch_all(&self.pool)
81            .await?;
82        Ok(rows.into_iter().map(KeyRecord::from).collect())
83    }
84}
85
86#[derive(sqlx::FromRow)]
87struct KeyInfoRow {
88    version: i16,
89    activated_at: SqlxTimestamp,
90}
91
92#[async_trait]
93impl SecretRotationBackend for SqlxPgSecretBackend {
94    type Error = SqlxPgSecretBackendError;
95
96    async fn latest_key_info(
97        &self,
98        group_id: &str,
99    ) -> Result<Option<(u8, SystemTime)>, Self::Error> {
100        let row = sqlx::query_as::<_, KeyInfoRow>(LATEST_KEY_INFO_QUERY)
101            .bind(group_id)
102            .fetch_optional(&self.pool)
103            .await?;
104        Ok(row.map(|r| (r.version as u8, r.activated_at.to_jiff().into())))
105    }
106
107    async fn try_insert_key(
108        &self,
109        group_id: &str,
110        expected_version: Option<u8>,
111        new_version: u8,
112        encrypted: &Encrypted,
113        activated_at: SystemTime,
114    ) -> Result<bool, Self::Error> {
115        let mut tx: Transaction<'_, Postgres> = self.pool.begin().await?;
116        sqlx::query(ADVISORY_LOCK_QUERY)
117            .bind(group_id)
118            .execute(&mut *tx)
119            .await?;
120        let row = sqlx::query_as::<_, KeyInfoRow>(LATEST_KEY_INFO_QUERY)
121            .bind(group_id)
122            .fetch_optional(&mut *tx)
123            .await?;
124        let current_version = row.map(|r| r.version as u8);
125        if current_version != expected_version {
126            return Ok(false);
127        }
128        let activated_at_jiff = Timestamp::try_from(activated_at)
129            .map_err(|e| SqlxPgSecretBackendError::Timestamp(e.to_string()))?;
130        sqlx::query(INSERT_KEY_QUERY)
131            .bind(group_id)
132            .bind(new_version as i16)
133            .bind(&encrypted.ciphertext)
134            .bind(&encrypted.nonce)
135            .bind(encrypted.key_version as i16)
136            .bind(activated_at_jiff.to_sqlx())
137            .execute(&mut *tx)
138            .await?;
139        tx.commit().await?;
140        Ok(true)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::backend::SecretBackend;
148    use crate::encryptor::Encrypted;
149    use crate::rotator::SecretRotationBackend;
150    use std::time::{Duration, SystemTime};
151    use test_containers_util::sqlx_pg::PostgresTestDb;
152    use uuid::Uuid; // used only to generate unique group IDs for test isolation
153
154    static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("tests/sqlx-migrations");
155
156    async fn make_backend() -> (PostgresTestDb, SqlxPgSecretBackend) {
157        let db = PostgresTestDb::create("secret-manager-sqlx", &MIGRATIONS, None, None).await;
158        let backend = SqlxPgSecretBackend::new(db.pool());
159        (db, backend)
160    }
161
162    fn no_op_encrypted(bytes: &[u8]) -> Encrypted {
163        Encrypted { ciphertext: bytes.to_vec(), nonce: None, key_version: 0 }
164    }
165
166    async fn insert_key(backend: &SqlxPgSecretBackend, group_id: &str, version: i16, bytes: &[u8]) {
167        sqlx::query("INSERT INTO secret_keys (key_group, version, key_bytes) VALUES ($1, $2, $3)")
168            .bind(group_id)
169            .bind(version)
170            .bind(bytes)
171            .execute(&backend.pool)
172            .await
173            .unwrap();
174    }
175
176    fn abs_diff(a: SystemTime, b: SystemTime) -> Duration {
177        a.duration_since(b)
178            .or_else(|_| b.duration_since(a))
179            .unwrap_or(Duration::ZERO)
180    }
181
182    #[tokio::test(flavor = "multi_thread")]
183    async fn load_all_returns_empty_for_unknown_group() {
184        let (_db, backend) = make_backend().await;
185        let records = backend.load_all(&Uuid::new_v4().simple().to_string()).await.unwrap();
186        assert!(records.is_empty());
187    }
188
189    #[tokio::test(flavor = "multi_thread")]
190    async fn load_all_returns_rows_ordered_by_activated_at() {
191        let (_db, backend) = make_backend().await;
192        let gid = Uuid::new_v4().simple().to_string();
193        let t0 = SystemTime::now() - Duration::from_secs(120);
194        let t1 = SystemTime::now() - Duration::from_secs(60);
195        let t2 = SystemTime::now();
196
197        backend.try_insert_key(&gid, None, 2, &no_op_encrypted(&[2u8; 32]), t0).await.unwrap();
198        backend.try_insert_key(&gid, Some(2), 0, &no_op_encrypted(&[0u8; 32]), t1).await.unwrap();
199        backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t2).await.unwrap();
200
201        let records = backend.load_all(&gid).await.unwrap();
202        assert_eq!(records.len(), 3);
203        assert_eq!(records[0].version, 2);
204        assert_eq!(records[1].version, 0);
205        assert_eq!(records[2].version, 1);
206        assert_eq!(records[0].key_bytes, vec![2u8; 32]);
207    }
208
209    #[tokio::test(flavor = "multi_thread")]
210    async fn poll_new_returns_empty_when_no_newer_key() {
211        let (_db, backend) = make_backend().await;
212        let gid = Uuid::new_v4().simple().to_string();
213        let t = SystemTime::now();
214        backend.try_insert_key(&gid, None, 5, &no_op_encrypted(&[5u8; 32]), t).await.unwrap();
215
216        let inserted = backend.load_all(&gid).await.unwrap();
217        let id = inserted[0].id;
218
219        let result = backend.poll_new(&gid, t, id).await.unwrap();
220        assert!(result.is_empty());
221    }
222
223    #[tokio::test(flavor = "multi_thread")]
224    async fn poll_new_returns_keys_newer_than_cursor() {
225        let (_db, backend) = make_backend().await;
226        let gid = Uuid::new_v4().simple().to_string();
227        let t0 = SystemTime::now() - Duration::from_secs(180);
228        let t1 = SystemTime::now() - Duration::from_secs(120);
229        let t2 = SystemTime::now() - Duration::from_secs(60);
230
231        backend.try_insert_key(&gid, None, 3, &no_op_encrypted(&[3u8; 32]), t0).await.unwrap();
232        backend.try_insert_key(&gid, Some(3), 7, &no_op_encrypted(&[7u8; 32]), t1).await.unwrap();
233        backend.try_insert_key(&gid, Some(7), 5, &no_op_encrypted(&[5u8; 32]), t2).await.unwrap();
234
235        let all = backend.load_all(&gid).await.unwrap();
236        let id0 = all[0].id;
237
238        let records = backend.poll_new(&gid, t0, id0).await.unwrap();
239        assert_eq!(records.len(), 2);
240        assert_eq!(records[0].version, 7);
241        assert_eq!(records[1].version, 5);
242        assert!(abs_diff(t1, records[0].activated_at).as_millis() < 5);
243        assert!(abs_diff(t2, records[1].activated_at).as_millis() < 5);
244    }
245
246    #[tokio::test(flavor = "multi_thread")]
247    async fn load_all_isolates_groups() {
248        let (_db, backend) = make_backend().await;
249        let gid_a = Uuid::new_v4().simple().to_string();
250        let gid_b = Uuid::new_v4().simple().to_string();
251        insert_key(&backend, &gid_a, 0, &[10u8; 32]).await;
252        insert_key(&backend, &gid_b, 0, &[20u8; 32]).await;
253
254        let a = backend.load_all(&gid_a).await.unwrap();
255        let b = backend.load_all(&gid_b).await.unwrap();
256
257        assert_eq!(a.len(), 1);
258        assert_eq!(b.len(), 1);
259        assert_eq!(a[0].key_bytes, vec![10u8; 32]);
260        assert_eq!(b[0].key_bytes, vec![20u8; 32]);
261    }
262
263    #[tokio::test(flavor = "multi_thread")]
264    async fn latest_key_info_returns_none_for_empty_group() {
265        let (_db, backend) = make_backend().await;
266        let result = backend.latest_key_info(&Uuid::new_v4().simple().to_string()).await.unwrap();
267        assert!(result.is_none());
268    }
269
270    #[tokio::test(flavor = "multi_thread")]
271    async fn try_insert_key_inserts_first_key() {
272        let (_db, backend) = make_backend().await;
273        let gid = Uuid::new_v4().simple().to_string();
274        let activated_at = SystemTime::now() + Duration::from_secs(120);
275
276        let inserted = backend
277            .try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), activated_at)
278            .await
279            .unwrap();
280        assert!(inserted);
281
282        let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
283        assert_eq!(info.0, 0);
284        assert!(abs_diff(info.1, activated_at).as_millis() < 5);
285    }
286
287    #[tokio::test(flavor = "multi_thread")]
288    async fn try_insert_key_returns_false_when_version_already_changed() {
289        let (_db, backend) = make_backend().await;
290        let gid = Uuid::new_v4().simple().to_string();
291        let t = SystemTime::now() + Duration::from_secs(60);
292
293        insert_key(&backend, &gid, 0, &[0u8; 32]).await;
294
295        let inserted = backend
296            .try_insert_key(&gid, None, 1, &no_op_encrypted(&[1u8; 32]), t)
297            .await
298            .unwrap();
299        assert!(!inserted, "must return false when version already changed");
300
301        let records = backend.load_all(&gid).await.unwrap();
302        assert_eq!(records.len(), 1);
303        assert_eq!(records[0].version, 0);
304    }
305
306    #[tokio::test(flavor = "multi_thread")]
307    async fn try_insert_key_sequential_rotations_succeed() {
308        let (_db, backend) = make_backend().await;
309        let gid = Uuid::new_v4().simple().to_string();
310        let t0 = SystemTime::now() + Duration::from_secs(60);
311        let t1 = SystemTime::now() + Duration::from_secs(120);
312
313        let ok0 = backend.try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), t0).await.unwrap();
314        assert!(ok0);
315
316        let ok1 = backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t1).await.unwrap();
317        assert!(ok1);
318
319        let records = backend.load_all(&gid).await.unwrap();
320        assert_eq!(records.len(), 2);
321        assert_eq!(records[0].version, 0);
322        assert_eq!(records[1].version, 1);
323    }
324
325    #[tokio::test(flavor = "multi_thread")]
326    async fn latest_key_info_returns_most_recently_activated() {
327        let (_db, backend) = make_backend().await;
328        let gid = Uuid::new_v4().simple().to_string();
329        let t2 = SystemTime::now() + Duration::from_secs(120);
330
331        insert_key(&backend, &gid, 10, &[10u8; 32]).await;
332        let ok = backend
333            .try_insert_key(&gid, Some(10), 2, &no_op_encrypted(&[2u8; 32]), t2)
334            .await
335            .unwrap();
336        assert!(ok);
337
338        let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
339        assert_eq!(info.0, 2, "must return most recently activated, not highest version number");
340    }
341}