1use crate::backend::{KeyRecord, SecretBackend};
4use crate::encryptor::Encrypted;
5use crate::pg_queries::*;
6use crate::rotator::SecretRotationBackend;
7use async_trait::async_trait;
8use diesel::sql_types::{Bytea, Nullable, SmallInt, Text, Timestamptz};
9use diesel::{QueryableByName, sql_query};
10use diesel_async::pooled_connection::bb8::Pool;
11use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl};
12use jiff::Timestamp;
13use jiff_diesel::{Timestamp as DieselTimestamp, ToDiesel};
14use std::time::SystemTime;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum DieselPgSecretBackendError {
19 #[error("connection pool error: {0}")]
20 Pool(String),
21 #[error("query error: {0}")]
22 Query(#[from] diesel::result::Error),
23}
24
25#[derive(QueryableByName)]
26struct KeyRow {
27 #[diesel(sql_type = diesel::sql_types::BigInt)]
28 id: i64,
29 #[diesel(sql_type = SmallInt)]
30 version: i16,
31 #[diesel(sql_type = Bytea)]
32 key_bytes: Vec<u8>,
33 #[diesel(sql_type = Nullable<Bytea>)]
34 nonce: Option<Vec<u8>>,
35 #[diesel(sql_type = SmallInt)]
36 encryption_key_version: i16,
37 #[diesel(sql_type = Timestamptz)]
38 activated_at: DieselTimestamp,
39}
40
41impl From<KeyRow> for KeyRecord {
42 fn from(r: KeyRow) -> Self {
43 KeyRecord {
44 id: r.id,
45 version: r.version as u8,
46 key_bytes: r.key_bytes,
47 nonce: r.nonce,
48 encryption_key_version: r.encryption_key_version as u8,
49 activated_at: r.activated_at.to_jiff().into(),
50 }
51 }
52}
53
54#[derive(Clone)]
55pub struct DieselPgSecretBackend {
56 pool: Pool<AsyncPgConnection>,
57}
58
59impl DieselPgSecretBackend {
60 pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
61 Self { pool }
62 }
63}
64
65#[async_trait]
66impl SecretBackend for DieselPgSecretBackend {
67 type Error = DieselPgSecretBackendError;
68
69 async fn load_all(&self, group_id: &str) -> Result<Vec<KeyRecord>, Self::Error> {
70 let mut conn = self
71 .pool
72 .get()
73 .await
74 .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
75 let rows: Vec<KeyRow> = sql_query(LOAD_ALL_QUERY)
76 .bind::<Text, _>(group_id)
77 .get_results(&mut conn)
78 .await?;
79 Ok(rows.into_iter().map(KeyRecord::from).collect())
80 }
81
82 async fn poll_new(
83 &self,
84 group_id: &str,
85 since_time: SystemTime,
86 since_id: i64,
87 ) -> Result<Vec<KeyRecord>, Self::Error> {
88 let mut conn = self
89 .pool
90 .get()
91 .await
92 .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
93 let since_jiff = Timestamp::try_from(since_time)
94 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e)))?;
95 let rows: Vec<KeyRow> = sql_query(POLL_NEW_QUERY)
96 .bind::<Text, _>(group_id)
97 .bind::<Timestamptz, _>(since_jiff.to_diesel())
98 .bind::<diesel::sql_types::BigInt, _>(since_id)
99 .get_results(&mut conn)
100 .await?;
101 Ok(rows.into_iter().map(KeyRecord::from).collect())
102 }
103}
104
105#[derive(QueryableByName)]
106struct KeyInfoRow {
107 #[diesel(sql_type = SmallInt)]
108 version: i16,
109 #[diesel(sql_type = Timestamptz)]
110 activated_at: DieselTimestamp,
111}
112
113#[async_trait]
114impl SecretRotationBackend for DieselPgSecretBackend {
115 type Error = DieselPgSecretBackendError;
116
117 async fn latest_key_info(
118 &self,
119 group_id: &str,
120 ) -> Result<Option<(u8, SystemTime)>, Self::Error> {
121 let mut conn = self
122 .pool
123 .get()
124 .await
125 .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
126 let rows: Vec<KeyInfoRow> = sql_query(LATEST_KEY_INFO_QUERY)
127 .bind::<Text, _>(group_id)
128 .get_results(&mut conn)
129 .await?;
130 Ok(rows
131 .into_iter()
132 .next()
133 .map(|r| (r.version as u8, r.activated_at.to_jiff().into())))
134 }
135
136 async fn try_insert_key(
137 &self,
138 group_id: &str,
139 expected_version: Option<u8>,
140 new_version: u8,
141 encrypted: &Encrypted,
142 activated_at: SystemTime,
143 ) -> Result<bool, Self::Error> {
144 let mut conn = self
145 .pool
146 .get()
147 .await
148 .map_err(|e| DieselPgSecretBackendError::Pool(e.to_string()))?;
149 let group_id = group_id.to_owned();
150 let ciphertext = encrypted.ciphertext.clone();
151 let nonce = encrypted.nonce.clone();
152 let enc_version = encrypted.key_version as i16;
153 let activated_at_jiff = Timestamp::try_from(activated_at)
154 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e)))?;
155 let activated_at_diesel = activated_at_jiff.to_diesel();
156
157 let inserted = conn
158 .transaction(|conn| {
159 Box::pin(async move {
160 sql_query(ADVISORY_LOCK_QUERY)
161 .bind::<Text, _>(&group_id)
162 .execute(conn)
163 .await?;
164 let rows: Vec<KeyInfoRow> = sql_query(LATEST_KEY_INFO_QUERY)
165 .bind::<Text, _>(&group_id)
166 .get_results(conn)
167 .await?;
168 let current_version = rows.into_iter().next().map(|r| r.version as u8);
169 if current_version != expected_version {
170 return Ok::<bool, diesel::result::Error>(false);
171 }
172 let rows_affected = sql_query(INSERT_KEY_QUERY)
173 .bind::<Text, _>(&group_id)
174 .bind::<SmallInt, _>(new_version as i16)
175 .bind::<Bytea, _>(ciphertext.as_slice())
176 .bind::<Nullable<Bytea>, _>(nonce.as_ref().map(|n| n as &[u8]))
177 .bind::<SmallInt, _>(enc_version)
178 .bind::<Timestamptz, _>(activated_at_diesel)
179 .execute(conn)
180 .await?;
181 Ok(rows_affected > 0)
182 })
183 })
184 .await?;
185 Ok(inserted)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::backend::SecretBackend;
193 use crate::encryptor::Encrypted;
194 use crate::rotator::SecretRotationBackend;
195 use diesel::sql_types::{Bytea, SmallInt};
196 use diesel_async::RunQueryDsl;
197 use diesel_migrations::{EmbeddedMigrations, embed_migrations};
198 use std::time::{Duration, SystemTime};
199 use test_containers_util::diesel_pg::PostgresTestDb;
200 use uuid::Uuid; const MIGRATIONS: EmbeddedMigrations = embed_migrations!("tests/diesel-migrations");
203
204 async fn make_backend() -> (PostgresTestDb, DieselPgSecretBackend) {
205 let db = PostgresTestDb::create("secret-manager-diesel", MIGRATIONS, None, None).await;
206 let backend = DieselPgSecretBackend::new(db.pool());
207 (db, backend)
208 }
209
210 fn no_op_encrypted(bytes: &[u8]) -> Encrypted {
211 Encrypted { ciphertext: bytes.to_vec(), nonce: None, key_version: 0 }
212 }
213
214 async fn insert_key(
215 backend: &DieselPgSecretBackend,
216 group_id: &str,
217 version: i16,
218 bytes: &[u8],
219 ) {
220 let mut conn = backend.pool.get().await.unwrap();
221 diesel::sql_query(
222 "INSERT INTO secret_keys (key_group, version, key_bytes) VALUES ($1, $2, $3)",
223 )
224 .bind::<diesel::sql_types::Text, _>(group_id)
225 .bind::<SmallInt, _>(version)
226 .bind::<Bytea, _>(bytes)
227 .execute(&mut conn)
228 .await
229 .unwrap();
230 }
231
232 fn abs_diff(a: SystemTime, b: SystemTime) -> Duration {
233 a.duration_since(b)
234 .or_else(|_| b.duration_since(a))
235 .unwrap_or(Duration::ZERO)
236 }
237
238 #[tokio::test(flavor = "multi_thread")]
239 async fn load_all_returns_empty_for_unknown_group() {
240 let (_db, backend) = make_backend().await;
241 let records = backend.load_all(&Uuid::new_v4().simple().to_string()).await.unwrap();
242 assert!(records.is_empty());
243 }
244
245 #[tokio::test(flavor = "multi_thread")]
246 async fn load_all_returns_rows_ordered_by_activated_at() {
247 let (_db, backend) = make_backend().await;
248 let gid = Uuid::new_v4().simple().to_string();
249 let t0 = SystemTime::now() - Duration::from_secs(120);
250 let t1 = SystemTime::now() - Duration::from_secs(60);
251 let t2 = SystemTime::now();
252
253 backend.try_insert_key(&gid, None, 2, &no_op_encrypted(&[2u8; 32]), t0).await.unwrap();
254 backend.try_insert_key(&gid, Some(2), 0, &no_op_encrypted(&[0u8; 32]), t1).await.unwrap();
255 backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t2).await.unwrap();
256
257 let records = backend.load_all(&gid).await.unwrap();
258 assert_eq!(records.len(), 3);
259 assert_eq!(records[0].version, 2);
260 assert_eq!(records[1].version, 0);
261 assert_eq!(records[2].version, 1);
262 assert_eq!(records[0].key_bytes, vec![2u8; 32]);
263 }
264
265 #[tokio::test(flavor = "multi_thread")]
266 async fn poll_new_returns_empty_when_no_newer_key() {
267 let (_db, backend) = make_backend().await;
268 let gid = Uuid::new_v4().simple().to_string();
269 let t = SystemTime::now();
270 backend.try_insert_key(&gid, None, 5, &no_op_encrypted(&[5u8; 32]), t).await.unwrap();
271
272 let inserted = backend.load_all(&gid).await.unwrap();
273 let id = inserted[0].id;
274
275 let result = backend.poll_new(&gid, t, id).await.unwrap();
276 assert!(result.is_empty());
277 }
278
279 #[tokio::test(flavor = "multi_thread")]
280 async fn poll_new_returns_keys_newer_than_cursor() {
281 let (_db, backend) = make_backend().await;
282 let gid = Uuid::new_v4().simple().to_string();
283 let t0 = SystemTime::now() - Duration::from_secs(180);
284 let t1 = SystemTime::now() - Duration::from_secs(120);
285 let t2 = SystemTime::now() - Duration::from_secs(60);
286
287 backend.try_insert_key(&gid, None, 3, &no_op_encrypted(&[3u8; 32]), t0).await.unwrap();
288 backend.try_insert_key(&gid, Some(3), 7, &no_op_encrypted(&[7u8; 32]), t1).await.unwrap();
289 backend.try_insert_key(&gid, Some(7), 5, &no_op_encrypted(&[5u8; 32]), t2).await.unwrap();
290
291 let all = backend.load_all(&gid).await.unwrap();
292 let id0 = all[0].id;
293
294 let records = backend.poll_new(&gid, t0, id0).await.unwrap();
295 assert_eq!(records.len(), 2);
296 assert_eq!(records[0].version, 7);
297 assert_eq!(records[1].version, 5);
298 assert!(abs_diff(t1, records[0].activated_at).as_millis() < 5);
299 assert!(abs_diff(t2, records[1].activated_at).as_millis() < 5);
300 }
301
302 #[tokio::test(flavor = "multi_thread")]
303 async fn load_all_isolates_groups() {
304 let (_db, backend) = make_backend().await;
305 let gid_a = Uuid::new_v4().simple().to_string();
306 let gid_b = Uuid::new_v4().simple().to_string();
307 insert_key(&backend, &gid_a, 0, &[10u8; 32]).await;
308 insert_key(&backend, &gid_b, 0, &[20u8; 32]).await;
309
310 let a = backend.load_all(&gid_a).await.unwrap();
311 let b = backend.load_all(&gid_b).await.unwrap();
312
313 assert_eq!(a.len(), 1);
314 assert_eq!(b.len(), 1);
315 assert_eq!(a[0].key_bytes, vec![10u8; 32]);
316 assert_eq!(b[0].key_bytes, vec![20u8; 32]);
317 }
318
319 #[tokio::test(flavor = "multi_thread")]
320 async fn latest_key_info_returns_none_for_empty_group() {
321 let (_db, backend) = make_backend().await;
322 let result = backend.latest_key_info(&Uuid::new_v4().simple().to_string()).await.unwrap();
323 assert!(result.is_none());
324 }
325
326 #[tokio::test(flavor = "multi_thread")]
327 async fn try_insert_key_inserts_first_key() {
328 let (_db, backend) = make_backend().await;
329 let gid = Uuid::new_v4().simple().to_string();
330 let activated_at = SystemTime::now() + Duration::from_secs(120);
331
332 let inserted = backend
333 .try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), activated_at)
334 .await
335 .unwrap();
336 assert!(inserted);
337
338 let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
339 assert_eq!(info.0, 0);
340 assert!(abs_diff(info.1, activated_at).as_millis() < 5);
341 }
342
343 #[tokio::test(flavor = "multi_thread")]
344 async fn try_insert_key_returns_false_when_version_already_changed() {
345 let (_db, backend) = make_backend().await;
346 let gid = Uuid::new_v4().simple().to_string();
347 let t = SystemTime::now() + Duration::from_secs(60);
348
349 insert_key(&backend, &gid, 0, &[0u8; 32]).await;
350
351 let inserted = backend
352 .try_insert_key(&gid, None, 1, &no_op_encrypted(&[1u8; 32]), t)
353 .await
354 .unwrap();
355 assert!(!inserted, "must return false when version already changed");
356
357 let records = backend.load_all(&gid).await.unwrap();
358 assert_eq!(records.len(), 1);
359 assert_eq!(records[0].version, 0);
360 }
361
362 #[tokio::test(flavor = "multi_thread")]
363 async fn try_insert_key_sequential_rotations_succeed() {
364 let (_db, backend) = make_backend().await;
365 let gid = Uuid::new_v4().simple().to_string();
366 let t0 = SystemTime::now() + Duration::from_secs(60);
367 let t1 = SystemTime::now() + Duration::from_secs(120);
368
369 let ok0 = backend.try_insert_key(&gid, None, 0, &no_op_encrypted(&[0u8; 32]), t0).await.unwrap();
370 assert!(ok0);
371
372 let ok1 = backend.try_insert_key(&gid, Some(0), 1, &no_op_encrypted(&[1u8; 32]), t1).await.unwrap();
373 assert!(ok1);
374
375 let records = backend.load_all(&gid).await.unwrap();
376 assert_eq!(records.len(), 2);
377 assert_eq!(records[0].version, 0);
378 assert_eq!(records[1].version, 1);
379 }
380
381 #[tokio::test(flavor = "multi_thread")]
382 async fn latest_key_info_returns_most_recently_activated() {
383 let (_db, backend) = make_backend().await;
384 let gid = Uuid::new_v4().simple().to_string();
385 let t2 = SystemTime::now() + Duration::from_secs(120);
386
387 insert_key(&backend, &gid, 10, &[10u8; 32]).await;
388 let ok = backend
389 .try_insert_key(&gid, Some(10), 2, &no_op_encrypted(&[2u8; 32]), t2)
390 .await
391 .unwrap();
392 assert!(ok);
393
394 let info = backend.latest_key_info(&gid).await.unwrap().expect("expected Some");
395 assert_eq!(info.0, 2, "must return most recently activated, not highest version number");
396 }
397}