switchgear_components/discovery/
db.rs

1use crate::discovery::error::DiscoveryBackendStoreError;
2use async_trait::async_trait;
3use chrono::Utc;
4use sea_orm::entity::prelude::*;
5use sea_orm::sea_query::OnConflict;
6use sea_orm::{
7    ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, FromJsonQueryResult,
8    QueryFilter, QueryOrder, QuerySelect, Set, TransactionTrait,
9};
10use secp256k1::PublicKey;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeSet;
13use switchgear_migration::{MigratorTrait, DISCOVERY_BACKEND_GET_ALL_ETAG_ID};
14use switchgear_service_api::discovery::{
15    DiscoveryBackend, DiscoveryBackendPatch, DiscoveryBackendSparse, DiscoveryBackendStore,
16    DiscoveryBackends,
17};
18use switchgear_service_api::service::ServiceErrorSource;
19
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, FromJsonQueryResult)]
21pub struct DiscoveryBackendPartitions(BTreeSet<String>);
22
23#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
24#[sea_orm(table_name = "discovery_backend")]
25pub struct Model {
26    #[sea_orm(column_type = "JsonBinary")]
27    pub partitions: DiscoveryBackendPartitions,
28    #[sea_orm(primary_key, auto_increment = false)]
29    pub id: Vec<u8>,
30    pub name: Option<String>,
31    pub weight: i32,
32    pub enabled: bool,
33    pub implementation: Vec<u8>,
34    pub created_at: DateTimeWithTimeZone,
35    pub updated_at: DateTimeWithTimeZone,
36}
37
38#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
39pub enum Relation {}
40
41impl ActiveModelBehavior for ActiveModel {}
42
43pub mod etag {
44    use super::*;
45
46    #[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
47    #[sea_orm(table_name = "discovery_backend_etag")]
48    pub struct Model {
49        #[sea_orm(primary_key, auto_increment = false)]
50        pub id: i32,
51        pub value: i64,
52    }
53
54    #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
55    pub enum Relation {}
56
57    impl ActiveModelBehavior for ActiveModel {}
58}
59
60#[derive(Clone, Debug)]
61pub struct DbDiscoveryBackendStore {
62    db: DatabaseConnection,
63}
64
65impl DbDiscoveryBackendStore {
66    pub async fn connect(
67        uri: &str,
68        max_connections: u32,
69    ) -> Result<Self, DiscoveryBackendStoreError> {
70        let mut opt = sea_orm::ConnectOptions::new(uri);
71        opt.max_connections(max_connections);
72        let db = Database::connect(opt).await.map_err(|e| {
73            DiscoveryBackendStoreError::from_db(
74                ServiceErrorSource::Internal,
75                "connecting to discovery backend database",
76                e,
77            )
78        })?;
79
80        Ok(Self::from_db(db))
81    }
82
83    pub async fn migrate_up(&self) -> Result<(), DiscoveryBackendStoreError> {
84        switchgear_migration::DiscoveryBackendMigrator::up(&self.db, None)
85            .await
86            .map_err(|e| {
87                DiscoveryBackendStoreError::from_db(
88                    ServiceErrorSource::Internal,
89                    "migrating database up",
90                    e,
91                )
92            })?;
93        Ok(())
94    }
95
96    pub async fn migrate_down(&self) -> Result<(), DiscoveryBackendStoreError> {
97        switchgear_migration::DiscoveryBackendMigrator::down(&self.db, None)
98            .await
99            .map_err(|e| {
100                DiscoveryBackendStoreError::from_db(
101                    ServiceErrorSource::Internal,
102                    "migrating database down",
103                    e,
104                )
105            })?;
106        Ok(())
107    }
108
109    pub fn from_db(db: DatabaseConnection) -> Self {
110        Self { db }
111    }
112
113    fn model_to_domain(model: Model) -> Result<DiscoveryBackend, DiscoveryBackendStoreError> {
114        Ok(DiscoveryBackend {
115            public_key: PublicKey::from_slice(&model.id).map_err(|e| {
116                DiscoveryBackendStoreError::internal_error(
117                    ServiceErrorSource::Internal,
118                    format!("deserializing public key {:?} from database", model.id),
119                    format!("deserializing failure: {e}"),
120                )
121            })?,
122            backend: DiscoveryBackendSparse {
123                name: model.name,
124                partitions: model.partitions.0,
125                weight: model.weight as usize,
126                enabled: model.enabled,
127                implementation: model.implementation,
128            },
129        })
130    }
131}
132
133#[async_trait]
134impl DiscoveryBackendStore for DbDiscoveryBackendStore {
135    type Error = DiscoveryBackendStoreError;
136
137    async fn get(&self, public_key: &PublicKey) -> Result<Option<DiscoveryBackend>, Self::Error> {
138        let result = Entity::find_by_id(public_key.serialize())
139            .one(&self.db)
140            .await
141            .map_err(|e| {
142                DiscoveryBackendStoreError::from_db(
143                    ServiceErrorSource::Internal,
144                    format!("fetching backend for public key {public_key}",),
145                    e,
146                )
147            })?;
148
149        match result {
150            Some(model) => Ok(Some(Self::model_to_domain(model)?)),
151            None => Ok(None),
152        }
153    }
154
155    async fn get_all(&self, request_etag: Option<u64>) -> Result<DiscoveryBackends, Self::Error> {
156        let response_etag = etag::Entity::find_by_id(DISCOVERY_BACKEND_GET_ALL_ETAG_ID)
157            .one(&self.db)
158            .await
159            .map_err(|e| {
160                DiscoveryBackendStoreError::from_db(
161                    ServiceErrorSource::Internal,
162                    "fetching etag value",
163                    e,
164                )
165            })?
166            .map(|e| e.value as u64)
167            .unwrap_or(0);
168
169        if request_etag == Some(response_etag) {
170            Ok(DiscoveryBackends {
171                etag: response_etag,
172                backends: None,
173            })
174        } else {
175            let models = Entity::find()
176                .order_by_asc(Column::CreatedAt)
177                .order_by_asc(Column::Id)
178                .all(&self.db)
179                .await
180                .map_err(|e| {
181                    DiscoveryBackendStoreError::from_db(
182                        ServiceErrorSource::Internal,
183                        "fetching all backends",
184                        e,
185                    )
186                })?;
187
188            let backends = models
189                .into_iter()
190                .map(Self::model_to_domain)
191                .collect::<Result<Vec<_>, _>>()?;
192            Ok(DiscoveryBackends {
193                etag: response_etag,
194                backends: Some(backends),
195            })
196        }
197    }
198
199    async fn post(&self, backend: DiscoveryBackend) -> Result<Option<PublicKey>, Self::Error> {
200        let now = Utc::now();
201        let active_model = ActiveModel {
202            partitions: Set(DiscoveryBackendPartitions(backend.backend.partitions)),
203            id: Set(backend.public_key.serialize().to_vec()),
204            name: Set(backend.backend.name),
205            weight: Set(backend.backend.weight as i32),
206            enabled: Set(backend.backend.enabled),
207            implementation: Set(backend.backend.implementation),
208            created_at: Set(now.into()),
209            updated_at: Set(now.into()),
210        };
211
212        let (insert_result, etag_result) = self
213            .db
214            .transaction::<_, (Result<_, _>, Option<Result<_, _>>), sea_orm::DbErr>(|txn| {
215                Box::pin(async move {
216                    let insert = active_model.insert(txn).await;
217                    let etag = if insert.is_ok() {
218                        Some(
219                            etag::Entity::update_many()
220                                .col_expr(
221                                    etag::Column::Value,
222                                    Expr::col(etag::Column::Value).add(1),
223                                )
224                                .filter(etag::Column::Id.eq(DISCOVERY_BACKEND_GET_ALL_ETAG_ID))
225                                .exec(txn)
226                                .await,
227                        )
228                    } else {
229                        None
230                    };
231                    Ok((insert, etag))
232                })
233            })
234            .await
235            .map_err(|e| {
236                DiscoveryBackendStoreError::from_tx(
237                    ServiceErrorSource::Internal,
238                    "post transaction",
239                    e,
240                )
241            })?;
242
243        etag_result.transpose().map_err(|e| {
244            DiscoveryBackendStoreError::from_db(
245                ServiceErrorSource::Internal,
246                "incrementing etag value",
247                e,
248            )
249        })?;
250
251        match insert_result {
252            Ok(_) => Ok(Some(backend.public_key)),
253            // PostgreSQL unique constraint violation
254            Err(sea_orm::DbErr::Query(sea_orm::RuntimeErr::SqlxError(sqlx::Error::Database(
255                db_err,
256            )))) if db_err.is_unique_violation() => Ok(None),
257            // SQLite unique constraint violation
258            Err(sea_orm::DbErr::Exec(sea_orm::RuntimeErr::SqlxError(sqlx::Error::Database(
259                db_err,
260            )))) if db_err.is_unique_violation() => Ok(None),
261            Err(e) => Err(DiscoveryBackendStoreError::from_db(
262                ServiceErrorSource::Internal,
263                format!("inserting backend for public key {}", backend.public_key),
264                e,
265            )),
266        }
267    }
268
269    async fn put(&self, backend: DiscoveryBackend) -> Result<bool, Self::Error> {
270        let now = Utc::now();
271        let future_timestamp = now + chrono::Duration::seconds(1);
272
273        let id = backend.public_key.serialize();
274        let active_model = ActiveModel {
275            partitions: Set(DiscoveryBackendPartitions(backend.backend.partitions)),
276            id: Set(id.to_vec()),
277            name: Set(backend.backend.name),
278            weight: Set(backend.backend.weight as i32),
279            enabled: Set(backend.backend.enabled),
280            implementation: Set(backend.backend.implementation),
281            created_at: Set(now.into()),
282            updated_at: Set(now.into()),
283        };
284
285        let (upsert_result, fetch_result, etag_result) = self
286            .db
287            .transaction::<_, (Result<_, _>, Result<_, _>, Option<Result<_, _>>), sea_orm::DbErr>(
288                |txn| {
289                    Box::pin(async move {
290                        let upsert = Entity::insert(active_model)
291                            .on_conflict(
292                                OnConflict::columns([Column::Id])
293                                    .update_columns([
294                                        Column::Name,
295                                        Column::Weight,
296                                        Column::Enabled,
297                                        Column::Implementation,
298                                    ])
299                                    .value(Column::UpdatedAt, Expr::val(future_timestamp))
300                                    .to_owned(),
301                            )
302                            .exec(txn)
303                            .await;
304
305                        let timestamps = if upsert.is_ok() {
306                            Entity::find()
307                                .filter(Column::Id.eq(id.as_slice()))
308                                .select_only()
309                                .column(Column::CreatedAt)
310                                .column(Column::UpdatedAt)
311                                .into_tuple::<(DateTimeWithTimeZone, DateTimeWithTimeZone)>()
312                                .one(txn)
313                                .await
314                        } else {
315                            Ok(None)
316                        };
317
318                        let etag = if timestamps.is_ok() {
319                            Some(
320                                etag::Entity::update_many()
321                                    .col_expr(
322                                        etag::Column::Value,
323                                        Expr::col(etag::Column::Value).add(1),
324                                    )
325                                    .filter(etag::Column::Id.eq(DISCOVERY_BACKEND_GET_ALL_ETAG_ID))
326                                    .exec(txn)
327                                    .await,
328                            )
329                        } else {
330                            None
331                        };
332
333                        Ok((upsert, timestamps, etag))
334                    })
335                },
336            )
337            .await
338            .map_err(|e| {
339                DiscoveryBackendStoreError::from_tx(
340                    ServiceErrorSource::Internal,
341                    "put transaction",
342                    e,
343                )
344            })?;
345
346        upsert_result.map_err(|e| {
347            DiscoveryBackendStoreError::from_db(
348                ServiceErrorSource::Internal,
349                format!("upserting backend for public key {}", backend.public_key),
350                e,
351            )
352        })?;
353
354        etag_result.transpose().map_err(|e| {
355            DiscoveryBackendStoreError::from_db(
356                ServiceErrorSource::Internal,
357                "incrementing etag value",
358                e,
359            )
360        })?;
361
362        let result = fetch_result
363            .map_err(|e| {
364                DiscoveryBackendStoreError::from_db(
365                    ServiceErrorSource::Internal,
366                    format!(
367                        "fetching backend after upsert for public key {}",
368                        backend.public_key
369                    ),
370                    e,
371                )
372            })?
373            .ok_or_else(|| {
374                DiscoveryBackendStoreError::internal_error(
375                    ServiceErrorSource::Internal,
376                    "upsert succeeded but record not found",
377                    "Record should exist after successful upsert".to_string(),
378                )
379            })?;
380
381        // Compare timestamps to determine if it was insert (true) or update (false)
382        Ok(result.0 == result.1)
383    }
384
385    async fn patch(&self, backend: DiscoveryBackendPatch) -> Result<bool, Self::Error> {
386        let mut update =
387            Entity::update_many().filter(Column::Id.eq(backend.public_key.serialize().as_slice()));
388
389        if let Some(name) = backend.backend.name {
390            update = update.col_expr(Column::Name, Expr::value(name));
391        }
392        if let Some(partitions) = backend.backend.partitions {
393            update = update.col_expr(
394                Column::Partitions,
395                Expr::value(DiscoveryBackendPartitions(partitions)),
396            );
397        }
398        if let Some(weight) = backend.backend.weight {
399            update = update.col_expr(Column::Weight, Expr::value(weight as i32));
400        }
401        if let Some(enabled) = backend.backend.enabled {
402            update = update.col_expr(Column::Enabled, Expr::value(enabled));
403        }
404
405        update = update.col_expr(Column::UpdatedAt, Expr::value(Utc::now()));
406
407        let (patch_result, etag_result) = self
408            .db
409            .transaction::<_, _, _>(|txn| {
410                Box::pin(async move {
411                    let patch = update.exec(txn).await;
412
413                    let etag = if patch
414                        .as_ref()
415                        .ok()
416                        .map(|r| r.rows_affected > 0)
417                        .unwrap_or(false)
418                    {
419                        Some(
420                            etag::Entity::update_many()
421                                .col_expr(
422                                    etag::Column::Value,
423                                    Expr::col(etag::Column::Value).add(1),
424                                )
425                                .filter(etag::Column::Id.eq(DISCOVERY_BACKEND_GET_ALL_ETAG_ID))
426                                .exec(txn)
427                                .await,
428                        )
429                    } else {
430                        None
431                    };
432
433                    Ok((patch, etag))
434                })
435            })
436            .await
437            .map_err(|e| {
438                DiscoveryBackendStoreError::from_tx(
439                    ServiceErrorSource::Internal,
440                    "patch transaction",
441                    e,
442                )
443            })?;
444
445        etag_result.transpose().map_err(|e| {
446            DiscoveryBackendStoreError::from_db(
447                ServiceErrorSource::Internal,
448                "incrementing etag value",
449                e,
450            )
451        })?;
452
453        let result = patch_result.map_err(|e| {
454            DiscoveryBackendStoreError::from_db(
455                ServiceErrorSource::Internal,
456                format!("patching backend for public key {}", backend.public_key),
457                e,
458            )
459        })?;
460
461        Ok(result.rows_affected > 0)
462    }
463
464    async fn delete(&self, public_key: &PublicKey) -> Result<bool, Self::Error> {
465        let id = public_key.serialize();
466
467        let (delete_result, etag_result) = self
468            .db
469            .transaction::<_, _, _>(|txn| {
470                Box::pin(async move {
471                    let delete = Entity::delete_by_id(id).exec(txn).await;
472
473                    let etag = if delete
474                        .as_ref()
475                        .ok()
476                        .map(|r| r.rows_affected > 0)
477                        .unwrap_or(false)
478                    {
479                        Some(
480                            etag::Entity::update_many()
481                                .col_expr(
482                                    etag::Column::Value,
483                                    Expr::col(etag::Column::Value).add(1),
484                                )
485                                .filter(etag::Column::Id.eq(DISCOVERY_BACKEND_GET_ALL_ETAG_ID))
486                                .exec(txn)
487                                .await,
488                        )
489                    } else {
490                        None
491                    };
492
493                    Ok((delete, etag))
494                })
495            })
496            .await
497            .map_err(|e| {
498                DiscoveryBackendStoreError::from_tx(
499                    ServiceErrorSource::Internal,
500                    "delete transaction",
501                    e,
502                )
503            })?;
504
505        etag_result.transpose().map_err(|e| {
506            DiscoveryBackendStoreError::from_db(
507                ServiceErrorSource::Internal,
508                "incrementing etag value",
509                e,
510            )
511        })?;
512
513        let result = delete_result.map_err(|e| {
514            DiscoveryBackendStoreError::from_db(
515                ServiceErrorSource::Internal,
516                format!("deleting backend for public key {public_key}"),
517                e,
518            )
519        })?;
520
521        Ok(result.rows_affected > 0)
522    }
523}