1use crate::backend::{KeyRecord, SecretBackend};
4use crate::encryptor::Encrypted;
5use crate::pg_queries::*;
6use crate::rotator::SecretRotationBackend;
7use async_trait::async_trait;
8use jiff::Timestamp;
9use jiff_sqlx::{Timestamp as SqlxTimestamp, ToSqlx};
10use sqlx::{PgPool, Postgres, Transaction};
11use std::time::SystemTime;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum SqlxPgSecretBackendError {
16 #[error("query error: {0}")]
17 Query(#[from] sqlx::Error),
18 #[error("timestamp conversion error: {0}")]
19 Timestamp(String),
20}
21
22#[derive(sqlx::FromRow)]
23struct KeyRow {
24 id: i64,
25 version: i16,
26 key_bytes: Vec<u8>,
27 nonce: Option<Vec<u8>>,
28 encryption_key_version: i16,
29 activated_at: SqlxTimestamp,
30}
31
32impl From<KeyRow> for KeyRecord {
33 fn from(r: KeyRow) -> Self {
34 KeyRecord {
35 id: r.id,
36 version: r.version as u8,
37 key_bytes: r.key_bytes,
38 nonce: r.nonce,
39 encryption_key_version: r.encryption_key_version as u8,
40 activated_at: r.activated_at.to_jiff().into(),
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct SqlxPgSecretBackend {
47 pool: PgPool,
48}
49
50impl SqlxPgSecretBackend {
51 pub fn new(pool: PgPool) -> Self {
52 Self { pool }
53 }
54}
55
56#[async_trait]
57impl SecretBackend for SqlxPgSecretBackend {
58 type Error = SqlxPgSecretBackendError;
59
60 async fn load_all(&self, group_id: &str) -> Result<Vec<KeyRecord>, Self::Error> {
61 let rows = sqlx::query_as::<_, KeyRow>(LOAD_ALL_QUERY)
62 .bind(group_id)
63 .fetch_all(&self.pool)
64 .await?;
65 Ok(rows.into_iter().map(KeyRecord::from).collect())
66 }
67
68 async fn poll_new(
69 &self,
70 group_id: &str,
71 since_time: SystemTime,
72 since_id: i64,
73 ) -> Result<Vec<KeyRecord>, Self::Error> {
74 let since_jiff = Timestamp::try_from(since_time)
75 .map_err(|e| SqlxPgSecretBackendError::Timestamp(e.to_string()))?;
76 let rows = sqlx::query_as::<_, KeyRow>(POLL_NEW_QUERY)
77 .bind(group_id)
78 .bind(since_jiff.to_sqlx())
79 .bind(since_id)
80 .fetch_all(&self.pool)
81 .await?;
82 Ok(rows.into_iter().map(KeyRecord::from).collect())
83 }
84}
85
86#[derive(sqlx::FromRow)]
87struct KeyInfoRow {
88 version: i16,
89 activated_at: SqlxTimestamp,
90}
91
92#[async_trait]
93impl SecretRotationBackend for SqlxPgSecretBackend {
94 type Error = SqlxPgSecretBackendError;
95
96 async fn latest_key_info(
97 &self,
98 group_id: &str,
99 ) -> Result<Option<(u8, SystemTime)>, Self::Error> {
100 let row = sqlx::query_as::<_, KeyInfoRow>(LATEST_KEY_INFO_QUERY)
101 .bind(group_id)
102 .fetch_optional(&self.pool)
103 .await?;
104 Ok(row.map(|r| (r.version as u8, r.activated_at.to_jiff().into())))
105 }
106
107 async fn try_insert_key(
108 &self,
109 group_id: &str,
110 expected_version: Option<u8>,
111 new_version: u8,
112 encrypted: &Encrypted,
113 activated_at: SystemTime,
114 ) -> Result<bool, Self::Error> {
115 let mut tx: Transaction<'_, Postgres> = self.pool.begin().await?;
116 sqlx::query(ADVISORY_LOCK_QUERY)
117 .bind(group_id)
118 .execute(&mut *tx)
119 .await?;
120 let row = sqlx::query_as::<_, KeyInfoRow>(LATEST_KEY_INFO_QUERY)
121 .bind(group_id)
122 .fetch_optional(&mut *tx)
123 .await?;
124 let current_version = row.map(|r| r.version as u8);
125 if current_version != expected_version {
126 return Ok(false);
127 }
128 let activated_at_jiff = Timestamp::try_from(activated_at)
129 .map_err(|e| SqlxPgSecretBackendError::Timestamp(e.to_string()))?;
130 sqlx::query(INSERT_KEY_QUERY)
131 .bind(group_id)
132 .bind(new_version as i16)
133 .bind(&encrypted.ciphertext)
134 .bind(&encrypted.nonce)
135 .bind(encrypted.key_version as i16)
136 .bind(activated_at_jiff.to_sqlx())
137 .execute(&mut *tx)
138 .await?;
139 tx.commit().await?;
140 Ok(true)
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::backend::SecretBackend;
148 use crate::encryptor::Encrypted;
149 use crate::rotator::SecretRotationBackend;
150 use std::time::{Duration, SystemTime};
151 use test_containers_util::sqlx_pg::PostgresTestDb;
152 use uuid::Uuid; static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("tests/sqlx-migrations");
155
156 async fn make_backend() -> (PostgresTestDb, SqlxPgSecretBackend) {
157 let db = PostgresTestDb::create("secret-manager-sqlx", &MIGRATIONS, None, None).await;
158 let backend = SqlxPgSecretBackend::new(db.pool());
159 (db, backend)
160 }
161
162 fn no_op_encrypted(bytes: &[u8]) -> Encrypted {
163 Encrypted { ciphertext: bytes.to_vec(), nonce: None, key_version: 0 }
164 }
165
166 async fn insert_key(backend: &SqlxPgSecretBackend, group_id: &str, version: i16, bytes: &[u8]) {
167 sqlx::query("INSERT INTO secret_keys (key_group, version, key_bytes) VALUES ($1, $2, $3)")
168 .bind(group_id)
169 .bind(version)
170 .bind(bytes)
171 .execute(&backend.pool)
172 .await
173 .unwrap();
174 }
175
176 fn abs_diff(a: SystemTime, b: SystemTime) -> Duration {
177 a.duration_since(b)
178 .or_else(|_| b.duration_since(a))
179 .unwrap_or(Duration::ZERO)
180 }
181
182 #[tokio::test(flavor = "multi_thread")]
183 async fn load_all_returns_empty_for_unknown_group() {
184 let (_db, backend) = make_backend().await;
185 let records = backend.load_all(&Uuid::new_v4().simple().to_string()).await.unwrap();
186 assert!(records.is_empty());
187 }
188
189 #[tokio::test(flavor = "multi_thread")]
190 async fn load_all_returns_rows_ordered_by_activated_at() {
191 let (_db, backend) = make_backend().await;
192 let gid = Uuid::new_v4().simple().to_string();
193 let t0 = SystemTime::now() - Duration::from_secs(120);
194 let t1 = SystemTime::now() - Duration::from_secs(60);
195 let t2 = SystemTime::now();
196
197 backend.try_insert_key(&gid, None, 2, &no_op_encrypted(&[2u8; 32]), t0).await.unwrap();
198 backend.try_insert_key(&gid, Some(2), 0, &no_op_encrypted(&[0u8; 32]), t1).await.unwrap();
199 backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t2).await.unwrap();
200
201 let records = backend.load_all(&gid).await.unwrap();
202 assert_eq!(records.len(), 3);
203 assert_eq!(records[0].version, 2);
204 assert_eq!(records[1].version, 0);
205 assert_eq!(records[2].version, 1);
206 assert_eq!(records[0].key_bytes, vec![2u8; 32]);
207 }
208
209 #[tokio::test(flavor = "multi_thread")]
210 async fn poll_new_returns_empty_when_no_newer_key() {
211 let (_db, backend) = make_backend().await;
212 let gid = Uuid::new_v4().simple().to_string();
213 let t = SystemTime::now();
214 backend.try_insert_key(&gid, None, 5, &no_op_encrypted(&[5u8; 32]), t).await.unwrap();
215
216 let inserted = backend.load_all(&gid).await.unwrap();
217 let id = inserted[0].id;
218
219 let result = backend.poll_new(&gid, t, id).await.unwrap();
220 assert!(result.is_empty());
221 }
222
223 #[tokio::test(flavor = "multi_thread")]
224 async fn poll_new_returns_keys_newer_than_cursor() {
225 let (_db, backend) = make_backend().await;
226 let gid = Uuid::new_v4().simple().to_string();
227 let t0 = SystemTime::now() - Duration::from_secs(180);
228 let t1 = SystemTime::now() - Duration::from_secs(120);
229 let t2 = SystemTime::now() - Duration::from_secs(60);
230
231 backend.try_insert_key(&gid, None, 3, &no_op_encrypted(&[3u8; 32]), t0).await.unwrap();
232 backend.try_insert_key(&gid, Some(3), 7, &no_op_encrypted(&[7u8; 32]), t1).await.unwrap();
233 backend.try_insert_key(&gid, Some(7), 5, &no_op_encrypted(&[5u8; 32]), t2).await.unwrap();
234
235 let all = backend.load_all(&gid).await.unwrap();
236 let id0 = all[0].id;
237
238 let records = backend.poll_new(&gid, t0, id0).await.unwrap();
239 assert_eq!(records.len(), 2);
240 assert_eq!(records[0].version, 7);
241 assert_eq!(records[1].version, 5);
242 assert!(abs_diff(t1, records[0].activated_at).as_millis() < 5);
243 assert!(abs_diff(t2, records[1].activated_at).as_millis() < 5);
244 }
245
246 #[tokio::test(flavor = "multi_thread")]
247 async fn load_all_isolates_groups() {
248 let (_db, backend) = make_backend().await;
249 let gid_a = Uuid::new_v4().simple().to_string();
250 let gid_b = Uuid::new_v4().simple().to_string();
251 insert_key(&backend, &gid_a, 0, &[10u8; 32]).await;
252 insert_key(&backend, &gid_b, 0, &[20u8; 32]).await;
253
254 let a = backend.load_all(&gid_a).await.unwrap();
255 let b = backend.load_all(&gid_b).await.unwrap();
256
257 assert_eq!(a.len(), 1);
258 assert_eq!(b.len(), 1);
259 assert_eq!(a[0].key_bytes, vec![10u8; 32]);
260 assert_eq!(b[0].key_bytes, vec![20u8; 32]);
261 }
262
263 #[tokio::test(flavor = "multi_thread")]
264 async fn latest_key_info_returns_none_for_empty_group() {
265 let (_db, backend) = make_backend().await;
266 let result = backend.latest_key_info(&Uuid::new_v4().simple().to_string()).await.unwrap();
267 assert!(result.is_none());
268 }
269
270 #[tokio::test(flavor = "multi_thread")]
271 async fn try_insert_key_inserts_first_key() {
272 let (_db, backend) = make_backend().await;
273 let gid = Uuid::new_v4().simple().to_string();
274 let activated_at = SystemTime::now() + Duration::from_secs(120);
275
276 let inserted = backend
277 .try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), activated_at)
278 .await
279 .unwrap();
280 assert!(inserted);
281
282 let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
283 assert_eq!(info.0, 0);
284 assert!(abs_diff(info.1, activated_at).as_millis() < 5);
285 }
286
287 #[tokio::test(flavor = "multi_thread")]
288 async fn try_insert_key_returns_false_when_version_already_changed() {
289 let (_db, backend) = make_backend().await;
290 let gid = Uuid::new_v4().simple().to_string();
291 let t = SystemTime::now() + Duration::from_secs(60);
292
293 insert_key(&backend, &gid, 0, &[0u8; 32]).await;
294
295 let inserted = backend
296 .try_insert_key(&gid, None, 1, &no_op_encrypted(&[1u8; 32]), t)
297 .await
298 .unwrap();
299 assert!(!inserted, "must return false when version already changed");
300
301 let records = backend.load_all(&gid).await.unwrap();
302 assert_eq!(records.len(), 1);
303 assert_eq!(records[0].version, 0);
304 }
305
306 #[tokio::test(flavor = "multi_thread")]
307 async fn try_insert_key_sequential_rotations_succeed() {
308 let (_db, backend) = make_backend().await;
309 let gid = Uuid::new_v4().simple().to_string();
310 let t0 = SystemTime::now() + Duration::from_secs(60);
311 let t1 = SystemTime::now() + Duration::from_secs(120);
312
313 let ok0 = backend.try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), t0).await.unwrap();
314 assert!(ok0);
315
316 let ok1 = backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t1).await.unwrap();
317 assert!(ok1);
318
319 let records = backend.load_all(&gid).await.unwrap();
320 assert_eq!(records.len(), 2);
321 assert_eq!(records[0].version, 0);
322 assert_eq!(records[1].version, 1);
323 }
324
325 #[tokio::test(flavor = "multi_thread")]
326 async fn latest_key_info_returns_most_recently_activated() {
327 let (_db, backend) = make_backend().await;
328 let gid = Uuid::new_v4().simple().to_string();
329 let t2 = SystemTime::now() + Duration::from_secs(120);
330
331 insert_key(&backend, &gid, 10, &[10u8; 32]).await;
332 let ok = backend
333 .try_insert_key(&gid, Some(10), 2, &no_op_encrypted(&[2u8; 32]), t2)
334 .await
335 .unwrap();
336 assert!(ok);
337
338 let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
339 assert_eq!(info.0, 2, "must return most recently activated, not highest version number");
340 }
341}