1use crate::{Doc, Origin, Store, Transaction, TransactionMut};
2use async_lock::futures::{Read, Write};
3use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use thiserror::Error;
7
8pub trait Transact {
11 fn try_transact(&self) -> Result<Transaction, TransactionAcqError>;
19
20 fn try_transact_mut(&self) -> Result<TransactionMut, TransactionAcqError>;
30
31 fn try_transact_mut_with<T>(&self, origin: T) -> Result<TransactionMut, TransactionAcqError>
44 where
45 T: Into<Origin>;
46
47 fn transact_mut_with<T>(&self, origin: T) -> TransactionMut
59 where
60 T: Into<Origin>,
61 {
62 self.try_transact_mut_with(origin).unwrap()
63 }
64
65 fn transact(&self) -> Transaction {
73 self.try_transact().unwrap()
74 }
75
76 fn transact_mut(&self) -> TransactionMut {
85 self.try_transact_mut().unwrap()
86 }
87}
88
89impl Transact for Doc {
90 fn try_transact(&self) -> Result<Transaction, TransactionAcqError> {
91 match self.store.try_read() {
92 Some(store) => Ok(Transaction::new(store)),
93 None => Err(TransactionAcqError::SharedAcqFailed),
94 }
95 }
96
97 fn try_transact_mut(&self) -> Result<TransactionMut, TransactionAcqError> {
98 match self.store.try_write() {
99 Some(store) => Ok(TransactionMut::new(self.clone(), store, None)),
100 None => Err(TransactionAcqError::ExclusiveAcqFailed),
101 }
102 }
103
104 fn try_transact_mut_with<T>(&self, origin: T) -> Result<TransactionMut, TransactionAcqError>
105 where
106 T: Into<Origin>,
107 {
108 match self.store.try_write() {
109 Some(store) => Ok(TransactionMut::new(
110 self.clone(),
111 store,
112 Some(origin.into()),
113 )),
114 None => Err(TransactionAcqError::ExclusiveAcqFailed),
115 }
116 }
117
118 fn transact_mut_with<T>(&self, origin: T) -> TransactionMut
119 where
120 T: Into<Origin>,
121 {
122 let lock = self.store.write_blocking();
123 TransactionMut::new(self.clone(), lock, Some(origin.into()))
124 }
125
126 fn transact(&self) -> Transaction {
127 let lock = self.store.read_blocking();
128 Transaction::new(lock)
129 }
130
131 fn transact_mut(&self) -> TransactionMut {
132 let lock = self.store.write_blocking();
133 TransactionMut::new(self.clone(), lock, None)
134 }
135}
136
137pub trait AsyncTransact<'doc> {
140 type Read: Future<Output = Transaction<'doc>>;
141 type Write: Future<Output = TransactionMut<'doc>>;
142
143 fn transact(&'doc self) -> Self::Read;
144 fn transact_mut(&'doc self) -> Self::Write;
145
146 fn transact_mut_with<T>(&'doc self, origin: T) -> Self::Write
153 where
154 T: Into<Origin>;
155}
156
157impl<'doc> AsyncTransact<'doc> for Doc {
158 type Read = AcquireTransaction<'doc>;
159 type Write = AcquireTransactionMut<'doc>;
160
161 fn transact(&'doc self) -> Self::Read {
162 let fut = self.store.read_async();
163 AcquireTransaction { fut }
164 }
165
166 fn transact_mut(&'doc self) -> Self::Write {
167 let fut = self.store.write_async();
168 AcquireTransactionMut {
169 doc: self.clone(),
170 origin: None,
171 fut,
172 }
173 }
174
175 fn transact_mut_with<T>(&'doc self, origin: T) -> Self::Write
176 where
177 T: Into<Origin>,
178 {
179 let fut = self.store.write_async();
180 AcquireTransactionMut {
181 doc: self.clone(),
182 origin: Some(origin.into()),
183 fut,
184 }
185 }
186}
187
188pub struct AcquireTransaction<'doc> {
189 fut: Read<'doc, Store>,
190}
191
192impl<'doc> Unpin for AcquireTransaction<'doc> {}
193
194impl<'doc> Future for AcquireTransaction<'doc> {
195 type Output = Transaction<'doc>;
196
197 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
198 let pinned = unsafe { Pin::new_unchecked(&mut self.fut) };
199 pinned.poll(cx).map(Transaction::new)
200 }
201}
202
203pub struct AcquireTransactionMut<'doc> {
204 doc: Doc,
205 origin: Option<Origin>,
206 fut: Write<'doc, Store>,
207}
208
209impl<'doc> Unpin for AcquireTransactionMut<'doc> {}
210
211impl<'doc> Future for AcquireTransactionMut<'doc> {
212 type Output = TransactionMut<'doc>;
213
214 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
215 let pinned = unsafe { Pin::new_unchecked(&mut self.fut) };
216 match pinned.poll(cx) {
217 Poll::Ready(store) => {
218 let doc = self.doc.clone();
219 let origin = self.origin.take();
220 Poll::Ready(TransactionMut::new(doc, store, origin))
221 }
222 Poll::Pending => Poll::Pending,
223 }
224 }
225}
226
227#[derive(Error, Debug)]
228pub enum TransactionAcqError {
229 #[error("Failed to acquire read-only transaction. Drop read-write transaction and retry.")]
230 SharedAcqFailed,
231 #[error("Failed to acquire read-write transaction. Drop other transactions and retry.")]
232 ExclusiveAcqFailed,
233 #[error("All references to a parent document containing this structure has been dropped.")]
234 DocumentDropped,
235}
236
237#[cfg(test)]
238mod test {
239 use crate::{Doc, GetString, Text, Transact};
240 use rand::random;
241 use std::sync::{Arc, Barrier};
242 use std::time::{Duration, Instant};
243
244 #[test]
245 fn multi_thread_transact_mut() {
246 let doc = Doc::new();
247 let txt = doc.get_or_insert_text("text");
248
249 const N: usize = 3;
250 let barrier = Arc::new(Barrier::new(N + 1));
251
252 let start = Instant::now();
253 for _ in 0..N {
254 let d = doc.clone();
255 let t = txt.clone();
256 let b = barrier.clone();
257 std::thread::spawn(move || {
258 let mut txn = d.transact_mut();
260 let n = random::<u64>() % 5;
261 std::thread::sleep(Duration::from_millis(n * 100));
262 t.insert(&mut txn, 0, "a");
263 drop(txn);
264 b.wait();
265 });
266 }
267
268 barrier.wait();
269 println!("{} threads executed in {:?}", N, Instant::now() - start);
270
271 let expected: String = (0..N).map(|_| 'a').collect();
272 let txn = doc.transact();
273 let str = txt.get_string(&txn);
274 assert_eq!(str, expected);
275 }
276}