Skip to main content

selene_core/database/
transaction_args.rs

1use std::{
2    any::{Any, TypeId},
3    borrow::Borrow,
4    collections::HashMap,
5    rc::Rc,
6};
7
8use lunar_lib::trace;
9use sled::{
10    CompareAndSwapError, Db, IVec, Transactional, Tree,
11    transaction::{ConflictableTransactionError, TransactionError, TransactionalTree},
12};
13
14use crate::{
15    database::{
16        DatabaseEntry, DatabaseError, DbKey, EntryId, Patchable, deserialize_from_ivec, library_db,
17        serialize_to_ivec, sled_get_raw, validator::DatabaseReferenceError,
18    },
19    library::{
20        album::{Album, AlbumId, TrackReference},
21        artist::{ArtistGroup, ArtistId},
22        track::TrackId,
23    },
24};
25
26#[derive(Debug)]
27pub struct CompareAndSwapValue<T: DatabaseEntry> {
28    pub old: Option<IVec>,
29    pub new: Option<T>,
30}
31
32impl<T: DatabaseEntry> CompareAndSwapValue<T> {
33    #[must_use]
34    pub fn new(old: Option<IVec>, new: Option<T>) -> Self {
35        Self { old, new }
36    }
37}
38
39#[derive(Debug)]
40pub struct TreeCompareAndSwap<T: DatabaseEntry> {
41    tree: Tree,
42    swaps: HashMap<DbKey, CompareAndSwapValue<T>>,
43}
44
45impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
46    fn new(db: &Db) -> Self {
47        Self {
48            tree: T::tree(db),
49            swaps: HashMap::new(),
50        }
51    }
52
53    #[must_use]
54    pub fn tree(&self) -> &Tree {
55        &self.tree
56    }
57}
58
59pub trait GenericCompareAndSwap: Any + std::fmt::Debug {
60    fn tree(&self) -> &Tree;
61    fn as_any(&self) -> &dyn Any;
62    fn as_any_mut(&mut self) -> &mut dyn Any;
63    fn apply(
64        &self,
65        tx_tree: &TransactionalTree,
66    ) -> Result<(), ConflictableTransactionError<CompareAndSwapError>>;
67}
68
69impl<T: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<T> {
70    fn tree(&self) -> &Tree {
71        &self.tree
72    }
73
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77
78    fn as_any_mut(&mut self) -> &mut dyn Any {
79        self
80    }
81
82    fn apply(
83        &self,
84        tx_tree: &TransactionalTree,
85    ) -> Result<(), ConflictableTransactionError<CompareAndSwapError>> {
86        for (k, v) in &self.swaps {
87            let ivec = tx_tree.get(k)?;
88
89            if ivec == v.old {
90                if let Some(new) = &v.new {
91                    tx_tree.insert(k, serialize_to_ivec(&new))?;
92                } else {
93                    tx_tree.remove(k)?;
94                }
95            } else {
96                return Err(ConflictableTransactionError::Abort(CompareAndSwapError {
97                    current: ivec,
98                    proposed: v.new.as_ref().map(serialize_to_ivec),
99                }));
100            }
101        }
102        Ok(())
103    }
104}
105
106#[derive(Debug)]
107pub struct CompareAndSwapTransaction {
108    swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
109    database: Rc<Db>,
110}
111
112impl CompareAndSwapTransaction {
113    #[must_use]
114    pub(crate) fn new() -> Self {
115        Self {
116            swaps: HashMap::new(),
117            database: Rc::new(library_db()),
118        }
119    }
120
121    #[must_use]
122    pub(crate) fn with_db(database: Rc<Db>) -> Self {
123        Self {
124            swaps: HashMap::new(),
125            database,
126        }
127    }
128
129    /// Patches `item` with the existing database entry, if any, else inserts
130    ///
131    /// Patching rules depend on how [`T`] implements [`Patchable<T>`]
132    ///
133    /// # Warning
134    ///
135    /// This function will create dangling references if not used correctly
136    pub(crate) fn tx_patch<T: Patchable<T> + DatabaseEntry + 'static>(
137        &mut self,
138        item: T,
139    ) -> Result<(), DatabaseError> {
140        if let Some(mut old_item) = self.tx_get(item.id())? {
141            let item_id = item.id();
142            old_item.patch(item);
143            self.tx_upsert(item_id, Some(old_item))?;
144        } else {
145            self.tx_upsert(item.id(), Some(item))?;
146        }
147
148        Ok(())
149    }
150
151    /// Gets the latest version of the item in the transaction context, looking it up in the database if its unmodified
152    ///
153    /// # Errors
154    ///
155    /// Errors if [`DatabaseEntry::db_get()`] errors
156    pub fn tx_get<Id: EntryId>(&self, id: Id) -> Result<Option<Id::Entry>, DatabaseError> {
157        if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
158            let cas_tree = boxed
159                .as_any()
160                .downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
161                .unwrap();
162            if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
163                return Ok(get.new.clone());
164            }
165        }
166
167        let tree = Id::Entry::tree(&self.database);
168        let raw = sled_get_raw(&tree, id.as_bytes())?;
169        Ok(raw.map(deserialize_from_ivec))
170    }
171
172    pub(crate) fn tx_get_batch<I, A, Entry: DatabaseEntry>(
173        &self,
174        items: I,
175    ) -> Result<Vec<Entry>, DatabaseError>
176    where
177        I: IntoIterator<Item = A>,
178        A: Borrow<Entry::Id>,
179    {
180        items
181            .into_iter()
182            .map(|id| {
183                self.tx_get(*id.borrow())?
184                    .ok_or(DatabaseError::MissingEntry)
185            })
186            .collect()
187    }
188
189    pub fn tx_remove<Id: EntryId>(&mut self, key: Id) -> Result<(), DatabaseError> {
190        let db = self.database.clone();
191        let request = self.get_or_new_request::<Id::Entry>();
192
193        let key = *key.as_bytes();
194        if let Some(get_mut) = request.swaps.get_mut(&key) {
195            get_mut.new = None;
196        } else {
197            let old = sled_get_raw(&Id::Entry::tree(&db), &key)?;
198            request
199                .swaps
200                .insert(key, CompareAndSwapValue { old, new: None });
201        }
202        Ok(())
203    }
204
205    pub fn tx_upsert<T: DatabaseEntry>(
206        &mut self,
207        key: T::Id,
208        mut new: Option<T>,
209    ) -> Result<(), DatabaseError> {
210        let db = self.database.clone();
211        if let Some(new) = &mut new {
212            new.pre_upsert(self)?;
213        }
214
215        let request = self.get_or_new_request::<T>();
216
217        let key = *key.as_bytes();
218
219        if let Some(get_mut) = request.swaps.get_mut(&key) {
220            get_mut.new = new;
221        } else {
222            let old = sled_get_raw(&T::tree(&db), &key)?;
223            request.swaps.insert(key, CompareAndSwapValue { old, new });
224        }
225        Ok(())
226    }
227
228    pub fn tx_insert<T: DatabaseEntry>(&mut self, item: T) -> Result<(), DatabaseError> {
229        if self.tx_get(item.id())?.is_some() {
230            return Err(DatabaseError::AlreadyInDatabase);
231        }
232
233        self.tx_upsert(item.id(), Some(item))?;
234        Ok(())
235    }
236
237    pub fn get_or_new_request<T: DatabaseEntry>(&mut self) -> &mut TreeCompareAndSwap<T> {
238        self.swaps
239            .entry(TypeId::of::<T>())
240            .or_insert_with(|| Box::new(TreeCompareAndSwap::<T>::new(&self.database)))
241            .as_any_mut()
242            .downcast_mut::<TreeCompareAndSwap<T>>()
243            .unwrap()
244    }
245
246    #[must_use]
247    pub fn trees(&self) -> Vec<&Tree> {
248        self.swaps.values().map(|a| a.tree()).collect()
249    }
250}
251
252/// Applies a [`CompareAndSwapTransaction`] atomically to the database
253///
254/// # Errors
255///
256/// This function will error if `[sled]` fails get, insert, or remove a key OR abort with a [`CompareAndSwapError`] if the current value does not match the expected value
257pub fn apply_cas_tx(
258    tx: CompareAndSwapTransaction,
259    flush: bool,
260) -> Result<(), TransactionError<CompareAndSwapError>> {
261    tx.trees().transaction(|tx_trees| {
262        for (tree, cas) in tx_trees.iter().zip(tx.swaps.values()) {
263            cas.apply(tree)?;
264            if flush {
265                tree.flush();
266            }
267        }
268        Ok(())
269    })
270}
271
272/// Calls a closure that returns a [`CompareAndSwapTransaction`] and applies it atomically to the database
273///
274/// This function will be called again if a [`CompareAndSwapError`] occurs
275///
276/// # Errors
277///
278/// This function will error if `f()` returns an error, or if [`apply_cas_tx()`] fails with an error other than [`CompareAndSwapError`]
279pub fn db_transaction<F, E>(mut f: F, db: Option<Db>, flush: bool) -> Result<(), E>
280where
281    F: FnMut(&mut CompareAndSwapTransaction) -> Result<(), E>,
282    E: From<TransactionError<CompareAndSwapError>>,
283{
284    let db = db.map(Rc::new);
285
286    loop {
287        let mut cas_tx = if let Some(db) = db.clone() {
288            CompareAndSwapTransaction::with_db(db)
289        } else {
290            CompareAndSwapTransaction::new()
291        };
292
293        f(&mut cas_tx)?;
294
295        match apply_cas_tx(cas_tx, flush) {
296            Ok(()) => return Ok(()),
297            Err(TransactionError::Abort(CompareAndSwapError {
298                current: _,
299                proposed: _,
300            })) => {
301                trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
302            }
303            Err(err) => return Err(err.into()),
304        }
305    }
306}
307
308// Safe two-way helpers
309impl CompareAndSwapTransaction {
310    /// Applies a two-way relinking operation, relinking to the track to a new (or no) album, and disconnecting the tracks old album, if any
311    pub fn relink_track_to_album(
312        &mut self,
313        track_id: TrackId,
314        album: Option<AlbumId>,
315    ) -> Result<bool, DatabaseError> {
316        let Some(mut track) = self.tx_get(track_id)? else {
317            return Ok(false);
318        };
319
320        if track.metadata.album == album {
321            return Ok(false);
322        }
323
324        let old_album_id = track.metadata.album;
325
326        track.metadata.album = album;
327        self.tx_upsert(track.id(), Some(track.clone()))?;
328
329        if let Some(old_album_id) = old_album_id {
330            let mut old_album = self.tx_get(old_album_id)?.ok_or({
331                DatabaseReferenceError::TrackDanglingAlbumRef {
332                    track: track_id,
333                    album: old_album_id,
334                }
335            })?;
336
337            old_album.tracks.retain(|t| t.id != track_id);
338
339            self.tx_upsert(old_album.id(), Some(old_album))?;
340        }
341
342        if let Some(new_album_id) = album {
343            let mut new_album = self
344                .tx_get(new_album_id)?
345                .ok_or(DatabaseError::MissingEntry)?;
346
347            new_album.tracks.push(TrackReference {
348                id: track_id,
349                track_num: None,
350                disc_num: None,
351            });
352
353            self.tx_upsert(new_album.id(), Some(new_album))?;
354        }
355
356        Ok(true)
357    }
358
359    /// Applies a two-way relinking operation, relinking artists to the album, and disconnecting references from removed artists
360    pub fn album_set_and_relink_artists(
361        &mut self,
362        album_id: AlbumId,
363        artists: &[ArtistId],
364    ) -> Result<bool, DatabaseError> {
365        let mut album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
366
367        let old_artists: Vec<ArtistId> = album.artist_group.artist_ids().to_vec();
368
369        album.artist_group = ArtistGroup::from_artist_ids(artists.iter().cloned());
370
371        let removed_artists: Vec<ArtistId> = old_artists
372            .into_iter()
373            .filter(|old_artist| !artists.contains(old_artist))
374            .collect();
375
376        self.artists_add_album(album_id, artists)?;
377        self.artists_remove_album(album_id, &removed_artists)?;
378
379        self.tx_upsert(album_id, Some(album))?;
380        Ok(true)
381    }
382
383    /// Applies a two-way relinking operation, relinking tracks to the album, and disconnecting references from removed tracks
384    pub fn album_set_and_relink_tracks(
385        &mut self,
386        album_id: AlbumId,
387        tracks: &[TrackId],
388    ) -> Result<bool, DatabaseError> {
389        let album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
390
391        let old_tracks: Vec<TrackId> = album.tracks.iter().map(|t| t.id).collect();
392
393        let removed_tracks: Vec<TrackId> = old_tracks
394            .iter()
395            .filter(|old_track| !tracks.contains(old_track))
396            .cloned()
397            .collect();
398
399        self.album_set_tracks(album, tracks)?;
400        self.tracks_set_album(Some(album_id), tracks)?;
401        self.tracks_set_album(None, &removed_tracks)?;
402        Ok(true)
403    }
404}
405
406// Unsafe one-way helpers
407impl CompareAndSwapTransaction {
408    pub(crate) fn album_set_tracks(
409        &mut self,
410        mut album: Album,
411        tracks: &[TrackId],
412    ) -> Result<(), DatabaseError> {
413        album.tracks = tracks
414            .iter()
415            .map(|t| {
416                album
417                    .tracks
418                    .iter()
419                    .find(|old| old.id == *t)
420                    .cloned()
421                    .unwrap_or(TrackReference {
422                        id: *t,
423                        track_num: None,
424                        disc_num: None,
425                    })
426            })
427            .collect();
428        self.tx_upsert(album.id(), Some(album))?;
429        Ok(())
430    }
431
432    pub(crate) fn tracks_set_album<'a>(
433        &mut self,
434        album_id: Option<AlbumId>,
435        tracks: impl IntoIterator<Item = &'a TrackId>,
436    ) -> Result<(), DatabaseError> {
437        for track_id in tracks {
438            let Some(mut track) = self.tx_get(*track_id)? else {
439                return Err(DatabaseError::MissingEntry);
440            };
441
442            track.metadata.album = album_id;
443
444            for artist_id in track.metadata.artists.artist_ids() {
445                if let Some(album_id) = album_id {
446                    let Some(mut artist) = self.tx_get(*artist_id)? else {
447                        return Err(DatabaseError::MissingEntry);
448                    };
449
450                    if artist.albums.contains(&album_id) {
451                        artist.tracks.retain(|t| t != track_id);
452                        self.tx_upsert(*artist_id, Some(artist))?;
453                    } else {
454                        self.artist_add_tracks(*artist_id, &[*track_id])?;
455                    }
456                } else {
457                    self.artist_add_tracks(*artist_id, &[*track_id])?;
458                }
459            }
460
461            self.tx_upsert(*track_id, Some(track))?;
462        }
463        Ok(())
464    }
465
466    pub(crate) fn artists_remove_album(
467        &mut self,
468        album_id: AlbumId,
469        artists: &[ArtistId],
470    ) -> Result<(), DatabaseError> {
471        for artist_id in artists {
472            let Some(mut artist) = self.tx_get(*artist_id)? else {
473                return Err(DatabaseError::MissingEntry);
474            };
475
476            artist.albums.retain(|a| *a != album_id);
477
478            self.tx_upsert(*artist_id, Some(artist))?;
479        }
480
481        Ok(())
482    }
483
484    pub(crate) fn artists_add_album(
485        &mut self,
486        album_id: AlbumId,
487        artists: &[ArtistId],
488    ) -> Result<(), DatabaseError> {
489        for artist_id in artists {
490            let Some(mut artist) = self.tx_get(*artist_id)? else {
491                return Err(DatabaseError::MissingEntry);
492            };
493
494            if !artist.albums.contains(&album_id) {
495                artist.albums.push(album_id);
496            }
497            self.tx_upsert(*artist_id, Some(artist))?;
498        }
499        Ok(())
500    }
501
502    pub(crate) fn artist_add_tracks(
503        &mut self,
504        artist_id: ArtistId,
505        tracks: &[TrackId],
506    ) -> Result<(), DatabaseError> {
507        let Some(mut artist) = self.tx_get(artist_id)? else {
508            return Err(DatabaseError::MissingEntry);
509        };
510
511        for track_id in tracks {
512            let Some(track) = self.tx_get(*track_id)? else {
513                return Err(DatabaseError::MissingEntry);
514            };
515
516            if let Some(album_id) = track.metadata.album
517                && artist.albums.contains(&album_id)
518            {
519                continue;
520            }
521
522            if !artist.tracks.contains(track_id) {
523                artist.tracks.push(*track_id)
524            }
525        }
526
527        self.tx_upsert(artist_id, Some(artist))?;
528
529        Ok(())
530    }
531}