1use std::{
2 any::{Any, TypeId},
3 borrow::Borrow,
4 collections::HashMap,
5 rc::Rc,
6};
7
8use lunar_lib::trace;
9use sled::{
10 CompareAndSwapError, Db, IVec, Transactional, Tree,
11 transaction::{ConflictableTransactionError, TransactionError, TransactionalTree},
12};
13
14use crate::{
15 database::{
16 DatabaseEntry, DatabaseError, DbKey, EntryId, Patchable, deserialize_from_ivec, library_db,
17 serialize_to_ivec, sled_get_raw, validator::DatabaseReferenceError,
18 },
19 library::{
20 album::{Album, AlbumId, TrackReference},
21 artist::{ArtistGroup, ArtistId},
22 track::TrackId,
23 },
24};
25
26#[derive(Debug)]
27pub struct CompareAndSwapValue<T: DatabaseEntry> {
28 pub old: Option<IVec>,
29 pub new: Option<T>,
30}
31
32impl<T: DatabaseEntry> CompareAndSwapValue<T> {
33 #[must_use]
34 pub fn new(old: Option<IVec>, new: Option<T>) -> Self {
35 Self { old, new }
36 }
37}
38
39#[derive(Debug)]
40pub struct TreeCompareAndSwap<T: DatabaseEntry> {
41 tree: Tree,
42 swaps: HashMap<DbKey, CompareAndSwapValue<T>>,
43}
44
45impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
46 fn new(db: &Db) -> Self {
47 Self {
48 tree: T::tree(db),
49 swaps: HashMap::new(),
50 }
51 }
52
53 #[must_use]
54 pub fn tree(&self) -> &Tree {
55 &self.tree
56 }
57}
58
59pub trait GenericCompareAndSwap: Any + std::fmt::Debug {
60 fn tree(&self) -> &Tree;
61 fn as_any(&self) -> &dyn Any;
62 fn as_any_mut(&mut self) -> &mut dyn Any;
63 fn apply(
64 &self,
65 tx_tree: &TransactionalTree,
66 ) -> Result<(), ConflictableTransactionError<CompareAndSwapError>>;
67}
68
69impl<T: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<T> {
70 fn tree(&self) -> &Tree {
71 &self.tree
72 }
73
74 fn as_any(&self) -> &dyn Any {
75 self
76 }
77
78 fn as_any_mut(&mut self) -> &mut dyn Any {
79 self
80 }
81
82 fn apply(
83 &self,
84 tx_tree: &TransactionalTree,
85 ) -> Result<(), ConflictableTransactionError<CompareAndSwapError>> {
86 for (k, v) in &self.swaps {
87 let ivec = tx_tree.get(k)?;
88
89 if ivec == v.old {
90 if let Some(new) = &v.new {
91 tx_tree.insert(k, serialize_to_ivec(&new))?;
92 } else {
93 tx_tree.remove(k)?;
94 }
95 } else {
96 return Err(ConflictableTransactionError::Abort(CompareAndSwapError {
97 current: ivec,
98 proposed: v.new.as_ref().map(serialize_to_ivec),
99 }));
100 }
101 }
102 Ok(())
103 }
104}
105
106#[derive(Debug)]
107pub struct CompareAndSwapTransaction {
108 swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
109 database: Rc<Db>,
110}
111
112impl CompareAndSwapTransaction {
113 #[must_use]
114 pub(crate) fn new() -> Self {
115 Self {
116 swaps: HashMap::new(),
117 database: Rc::new(library_db()),
118 }
119 }
120
121 #[must_use]
122 pub(crate) fn with_db(database: Rc<Db>) -> Self {
123 Self {
124 swaps: HashMap::new(),
125 database,
126 }
127 }
128
129 pub(crate) fn tx_patch<T: Patchable<T> + DatabaseEntry + 'static>(
137 &mut self,
138 item: T,
139 ) -> Result<(), DatabaseError> {
140 if let Some(mut old_item) = self.tx_get(item.id())? {
141 let item_id = item.id();
142 old_item.patch(item);
143 self.tx_upsert(item_id, Some(old_item))?;
144 } else {
145 self.tx_upsert(item.id(), Some(item))?;
146 }
147
148 Ok(())
149 }
150
151 pub fn tx_get<Id: EntryId>(&self, id: Id) -> Result<Option<Id::Entry>, DatabaseError> {
157 if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
158 let cas_tree = boxed
159 .as_any()
160 .downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
161 .unwrap();
162 if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
163 return Ok(get.new.clone());
164 }
165 }
166
167 let tree = Id::Entry::tree(&self.database);
168 let raw = sled_get_raw(&tree, id.as_bytes())?;
169 Ok(raw.map(deserialize_from_ivec))
170 }
171
172 pub(crate) fn tx_get_batch<I, A, Entry: DatabaseEntry>(
173 &self,
174 items: I,
175 ) -> Result<Vec<Entry>, DatabaseError>
176 where
177 I: IntoIterator<Item = A>,
178 A: Borrow<Entry::Id>,
179 {
180 items
181 .into_iter()
182 .map(|id| {
183 self.tx_get(*id.borrow())?
184 .ok_or(DatabaseError::MissingEntry)
185 })
186 .collect()
187 }
188
189 pub fn tx_remove<Id: EntryId>(&mut self, key: Id) -> Result<(), DatabaseError> {
190 let db = self.database.clone();
191 let request = self.get_or_new_request::<Id::Entry>();
192
193 let key = *key.as_bytes();
194 if let Some(get_mut) = request.swaps.get_mut(&key) {
195 get_mut.new = None;
196 } else {
197 let old = sled_get_raw(&Id::Entry::tree(&db), &key)?;
198 request
199 .swaps
200 .insert(key, CompareAndSwapValue { old, new: None });
201 }
202 Ok(())
203 }
204
205 pub fn tx_upsert<T: DatabaseEntry>(
206 &mut self,
207 key: T::Id,
208 mut new: Option<T>,
209 ) -> Result<(), DatabaseError> {
210 let db = self.database.clone();
211 if let Some(new) = &mut new {
212 new.pre_upsert(self)?;
213 }
214
215 let request = self.get_or_new_request::<T>();
216
217 let key = *key.as_bytes();
218
219 if let Some(get_mut) = request.swaps.get_mut(&key) {
220 get_mut.new = new;
221 } else {
222 let old = sled_get_raw(&T::tree(&db), &key)?;
223 request.swaps.insert(key, CompareAndSwapValue { old, new });
224 }
225 Ok(())
226 }
227
228 pub fn tx_insert<T: DatabaseEntry>(&mut self, item: T) -> Result<(), DatabaseError> {
229 if self.tx_get(item.id())?.is_some() {
230 return Err(DatabaseError::AlreadyInDatabase);
231 }
232
233 self.tx_upsert(item.id(), Some(item))?;
234 Ok(())
235 }
236
237 pub fn get_or_new_request<T: DatabaseEntry>(&mut self) -> &mut TreeCompareAndSwap<T> {
238 self.swaps
239 .entry(TypeId::of::<T>())
240 .or_insert_with(|| Box::new(TreeCompareAndSwap::<T>::new(&self.database)))
241 .as_any_mut()
242 .downcast_mut::<TreeCompareAndSwap<T>>()
243 .unwrap()
244 }
245
246 #[must_use]
247 pub fn trees(&self) -> Vec<&Tree> {
248 self.swaps.values().map(|a| a.tree()).collect()
249 }
250}
251
252pub fn apply_cas_tx(
258 tx: CompareAndSwapTransaction,
259 flush: bool,
260) -> Result<(), TransactionError<CompareAndSwapError>> {
261 tx.trees().transaction(|tx_trees| {
262 for (tree, cas) in tx_trees.iter().zip(tx.swaps.values()) {
263 cas.apply(tree)?;
264 if flush {
265 tree.flush();
266 }
267 }
268 Ok(())
269 })
270}
271
272pub fn db_transaction<F, E>(mut f: F, db: Option<Db>, flush: bool) -> Result<(), E>
280where
281 F: FnMut(&mut CompareAndSwapTransaction) -> Result<(), E>,
282 E: From<TransactionError<CompareAndSwapError>>,
283{
284 let db = db.map(Rc::new);
285
286 loop {
287 let mut cas_tx = if let Some(db) = db.clone() {
288 CompareAndSwapTransaction::with_db(db)
289 } else {
290 CompareAndSwapTransaction::new()
291 };
292
293 f(&mut cas_tx)?;
294
295 match apply_cas_tx(cas_tx, flush) {
296 Ok(()) => return Ok(()),
297 Err(TransactionError::Abort(CompareAndSwapError {
298 current: _,
299 proposed: _,
300 })) => {
301 trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
302 }
303 Err(err) => return Err(err.into()),
304 }
305 }
306}
307
308impl CompareAndSwapTransaction {
310 pub fn relink_track_to_album(
312 &mut self,
313 track_id: TrackId,
314 album: Option<AlbumId>,
315 ) -> Result<bool, DatabaseError> {
316 let Some(mut track) = self.tx_get(track_id)? else {
317 return Ok(false);
318 };
319
320 if track.metadata.album == album {
321 return Ok(false);
322 }
323
324 let old_album_id = track.metadata.album;
325
326 track.metadata.album = album;
327 self.tx_upsert(track.id(), Some(track.clone()))?;
328
329 if let Some(old_album_id) = old_album_id {
330 let mut old_album = self.tx_get(old_album_id)?.ok_or({
331 DatabaseReferenceError::TrackDanglingAlbumRef {
332 track: track_id,
333 album: old_album_id,
334 }
335 })?;
336
337 old_album.tracks.retain(|t| t.id != track_id);
338
339 self.tx_upsert(old_album.id(), Some(old_album))?;
340 }
341
342 if let Some(new_album_id) = album {
343 let mut new_album = self
344 .tx_get(new_album_id)?
345 .ok_or(DatabaseError::MissingEntry)?;
346
347 new_album.tracks.push(TrackReference {
348 id: track_id,
349 track_num: None,
350 disc_num: None,
351 });
352
353 self.tx_upsert(new_album.id(), Some(new_album))?;
354 }
355
356 Ok(true)
357 }
358
359 pub fn album_set_and_relink_artists(
361 &mut self,
362 album_id: AlbumId,
363 artists: &[ArtistId],
364 ) -> Result<bool, DatabaseError> {
365 let mut album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
366
367 let old_artists: Vec<ArtistId> = album.artist_group.artist_ids().to_vec();
368
369 album.artist_group = ArtistGroup::from_artist_ids(artists.iter().cloned());
370
371 let removed_artists: Vec<ArtistId> = old_artists
372 .into_iter()
373 .filter(|old_artist| !artists.contains(old_artist))
374 .collect();
375
376 self.artists_add_album(album_id, artists)?;
377 self.artists_remove_album(album_id, &removed_artists)?;
378
379 self.tx_upsert(album_id, Some(album))?;
380 Ok(true)
381 }
382
383 pub fn album_set_and_relink_tracks(
385 &mut self,
386 album_id: AlbumId,
387 tracks: &[TrackId],
388 ) -> Result<bool, DatabaseError> {
389 let album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
390
391 let old_tracks: Vec<TrackId> = album.tracks.iter().map(|t| t.id).collect();
392
393 let removed_tracks: Vec<TrackId> = old_tracks
394 .iter()
395 .filter(|old_track| !tracks.contains(old_track))
396 .cloned()
397 .collect();
398
399 self.album_set_tracks(album, tracks)?;
400 self.tracks_set_album(Some(album_id), tracks)?;
401 self.tracks_set_album(None, &removed_tracks)?;
402 Ok(true)
403 }
404}
405
406impl CompareAndSwapTransaction {
408 pub(crate) fn album_set_tracks(
409 &mut self,
410 mut album: Album,
411 tracks: &[TrackId],
412 ) -> Result<(), DatabaseError> {
413 album.tracks = tracks
414 .iter()
415 .map(|t| {
416 album
417 .tracks
418 .iter()
419 .find(|old| old.id == *t)
420 .cloned()
421 .unwrap_or(TrackReference {
422 id: *t,
423 track_num: None,
424 disc_num: None,
425 })
426 })
427 .collect();
428 self.tx_upsert(album.id(), Some(album))?;
429 Ok(())
430 }
431
432 pub(crate) fn tracks_set_album<'a>(
433 &mut self,
434 album_id: Option<AlbumId>,
435 tracks: impl IntoIterator<Item = &'a TrackId>,
436 ) -> Result<(), DatabaseError> {
437 for track_id in tracks {
438 let Some(mut track) = self.tx_get(*track_id)? else {
439 return Err(DatabaseError::MissingEntry);
440 };
441
442 track.metadata.album = album_id;
443
444 for artist_id in track.metadata.artists.artist_ids() {
445 if let Some(album_id) = album_id {
446 let Some(mut artist) = self.tx_get(*artist_id)? else {
447 return Err(DatabaseError::MissingEntry);
448 };
449
450 if artist.albums.contains(&album_id) {
451 artist.tracks.retain(|t| t != track_id);
452 self.tx_upsert(*artist_id, Some(artist))?;
453 } else {
454 self.artist_add_tracks(*artist_id, &[*track_id])?;
455 }
456 } else {
457 self.artist_add_tracks(*artist_id, &[*track_id])?;
458 }
459 }
460
461 self.tx_upsert(*track_id, Some(track))?;
462 }
463 Ok(())
464 }
465
466 pub(crate) fn artists_remove_album(
467 &mut self,
468 album_id: AlbumId,
469 artists: &[ArtistId],
470 ) -> Result<(), DatabaseError> {
471 for artist_id in artists {
472 let Some(mut artist) = self.tx_get(*artist_id)? else {
473 return Err(DatabaseError::MissingEntry);
474 };
475
476 artist.albums.retain(|a| *a != album_id);
477
478 self.tx_upsert(*artist_id, Some(artist))?;
479 }
480
481 Ok(())
482 }
483
484 pub(crate) fn artists_add_album(
485 &mut self,
486 album_id: AlbumId,
487 artists: &[ArtistId],
488 ) -> Result<(), DatabaseError> {
489 for artist_id in artists {
490 let Some(mut artist) = self.tx_get(*artist_id)? else {
491 return Err(DatabaseError::MissingEntry);
492 };
493
494 if !artist.albums.contains(&album_id) {
495 artist.albums.push(album_id);
496 }
497 self.tx_upsert(*artist_id, Some(artist))?;
498 }
499 Ok(())
500 }
501
502 pub(crate) fn artist_add_tracks(
503 &mut self,
504 artist_id: ArtistId,
505 tracks: &[TrackId],
506 ) -> Result<(), DatabaseError> {
507 let Some(mut artist) = self.tx_get(artist_id)? else {
508 return Err(DatabaseError::MissingEntry);
509 };
510
511 for track_id in tracks {
512 let Some(track) = self.tx_get(*track_id)? else {
513 return Err(DatabaseError::MissingEntry);
514 };
515
516 if let Some(album_id) = track.metadata.album
517 && artist.albums.contains(&album_id)
518 {
519 continue;
520 }
521
522 if !artist.tracks.contains(track_id) {
523 artist.tracks.push(*track_id)
524 }
525 }
526
527 self.tx_upsert(artist_id, Some(artist))?;
528
529 Ok(())
530 }
531}