1use std::marker::PhantomData;
2
3use sled::transaction::{ConflictableTransactionResult, TransactionResult};
4
5use crate::{deserialize, serialize, Batch, Tree, KV};
6
7pub struct TransactionalTree<'a, K, V> {
8 inner: &'a sled::transaction::TransactionalTree,
9 _key: PhantomData<fn() -> K>,
10 _value: PhantomData<fn() -> V>,
11}
12
13impl<'a, K, V> TransactionalTree<'a, K, V> {
14 pub(crate) fn new(sled: &'a sled::transaction::TransactionalTree) -> Self {
15 Self {
16 inner: sled,
17 _key: PhantomData,
18 _value: PhantomData,
19 }
20 }
21
22 pub fn insert(
23 &self,
24 key: &K,
25 value: &V,
26 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
27 where
28 K: KV,
29 V: KV,
30 {
31 self.inner
32 .insert(serialize(key), serialize(value))
33 .map(|opt| opt.map(|v| deserialize(&v)))
34 }
35
36 pub fn remove(
37 &self,
38 key: &K,
39 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
40 where
41 K: KV,
42 V: KV,
43 {
44 self.inner
45 .remove(serialize(key))
46 .map(|opt| opt.map(|v| deserialize(&v)))
47 }
48
49 pub fn get(
50 &self,
51 key: &K,
52 ) -> std::result::Result<Option<V>, sled::transaction::UnabortableTransactionError>
53 where
54 K: KV,
55 V: KV,
56 {
57 self.inner
58 .get(serialize(key))
59 .map(|opt| opt.map(|v| deserialize(&v)))
60 }
61
62 pub fn apply_batch(
63 &self,
64 batch: &Batch<K, V>,
65 ) -> std::result::Result<(), sled::transaction::UnabortableTransactionError> {
66 self.inner.apply_batch(&batch.inner)
67 }
68
69 pub fn flush(&self) {
70 self.inner.flush()
71 }
72
73 pub fn generate_id(&self) -> sled::Result<u64> {
74 self.inner.generate_id()
75 }
76}
77
78pub trait Transactional<E = ()> {
79 type View<'a>;
80
81 fn transaction<F, A>(&self, f: F) -> TransactionResult<A, E>
82 where
83 F: for<'a> Fn(Self::View<'a>) -> ConflictableTransactionResult<A, E>;
84}
85
86macro_rules! impl_transactional {
87 ($($k:ident, $v:ident, $i:tt),+) => {
88 impl<E, $($k, $v),+> Transactional<E> for ($(&Tree<$k, $v>),+) {
89 type View<'a> = (
90 $(TransactionalTree<'a, $k, $v>),+
91 );
92
93 fn transaction<F, A>(&self, f: F) -> TransactionResult<A, E>
94 where
95 F: for<'a> Fn(Self::View<'a>) -> ConflictableTransactionResult<A, E>,
96 {
97 use sled::Transactional;
98
99 ($(&self.$i.inner),+).transaction(|trees| {
100 f((
101 $(TransactionalTree::new(&trees.$i)),+
102 ))
103 })
104 }
105 }
106 };
107}
108
109impl_transactional!(K0, V0, 0, K1, V1, 1);
110impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2);
111impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3);
112impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4);
113impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5);
114impl_transactional!(K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6);
115impl_transactional!(
116 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7
117);
118impl_transactional!(
119 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
120 8
121);
122impl_transactional!(
123 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
124 8, K9, V9, 9
125);
126impl_transactional!(
127 K0, V0, 0, K1, V1, 1, K2, V2, 2, K3, V3, 3, K4, V4, 4, K5, V5, 5, K6, V6, 6, K7, V7, 7, K8, V8,
128 8, K9, V9, 9, K10, V10, 10
129);
130
131#[test]
132fn test_multiple_tree_transaction() {
133 let db = sled::Config::new().temporary(true).open().unwrap();
134 let tree0 = Tree::<u32, i32>::open(&db, "tree0");
135 let tree1 = Tree::<u16, i16>::open(&db, "tree1");
136 let tree2 = Tree::<u8, i8>::open(&db, "tree2");
137
138 (&tree0, &tree1, &tree2)
139 .transaction(|trees| {
140 trees.0.insert(&0, &0)?;
141 trees.1.insert(&0, &0)?;
142 trees.2.insert(&0, &0)?;
143 Ok::<(), sled::transaction::ConflictableTransactionError<()>>(())
146 })
147 .unwrap();
148
149 assert_eq!(tree0.get(&0), Ok(Some(0)));
150 assert_eq!(tree1.get(&0), Ok(Some(0)));
151}