1extern crate futures;
46extern crate tokio;
47
48mod error;
49
50use futures::prelude::*;
51use futures::{future, Future};
52use std::any::Any;
53use std::error::Error as StdError;
54use tokio::sync::{mpsc, oneshot};
55
56pub use error::Error;
58
59type AnyBox = Box<Any + Send + 'static>;
60
61pub struct Lock<T, E>
63where
64 E: StdError + From<Error> + Send + 'static,
65{
66 tx: Option<mpsc::UnboundedSender<Acquire<T, E>>>,
67}
68
69enum Closure<T, E>
70where
71 E: StdError + From<Error> + Send + 'static,
72{
73 Read(Box<(FnMut(&T) -> Box<Future<Item = AnyBox, Error = E> + Send>) + Send>),
74 Write(Box<(FnMut(&mut T) -> Box<Future<Item = AnyBox, Error = E> + Send>) + Send>),
75}
76
77struct Acquire<T, E>
78where
79 E: StdError + From<Error> + Send + 'static,
80{
81 tx: oneshot::Sender<Result<AnyBox, E>>,
82 closure: Closure<T, E>,
83}
84
85impl<T, E> Lock<T, E>
86where
87 E: StdError + From<Error> + Send + 'static,
88{
89 pub fn new() -> Self {
91 Self { tx: None }
92 }
93
94 pub fn manage(&mut self, mut value: T) -> impl Future<Item = (), Error = Error> {
96 let (tx, rx) = mpsc::unbounded_channel();
97
98 self.tx = Some(tx);
99
100 rx.from_err::<Error>()
101 .for_each(move |acquire| {
102 let (res_tx, closure) = (acquire.tx, acquire.closure);
103 let item = match closure {
104 Closure::Read(mut f) => f(&value),
105 Closure::Write(mut f) => f(&mut value),
106 };
107
108 item.then(move |res| res_tx.send(res).map_err(|_| Error::OneShotSend))
109 .from_err()
110 })
111 .from_err()
112 }
113
114 fn run_closure(
115 &mut self,
116 closure: Closure<T, E>,
117 ) -> Box<Future<Item = AnyBox, Error = E> + Send> {
118 let tx = match &mut self.tx {
119 Some(tx) => tx,
120 None => {
121 return Box::new(future::err(E::from(Error::NotRunning)));
122 }
123 };
124
125 let (res_tx, res_rx) = oneshot::channel();
126
127 let acquire = Acquire {
128 tx: res_tx,
129 closure,
130 };
131 if let Err(err) = tx.try_send(acquire) {
132 return Box::new(future::err(E::from(Error::from(err))));
133 }
134
135 Box::new(res_rx.from_err::<Error>().from_err().and_then(|res| res))
136 }
137
138 pub fn get<CB, F, I>(&mut self, mut cb: CB) -> impl Future<Item = I, Error = E>
140 where
141 CB: (FnMut(&T) -> F) + Send + 'static,
142 F: Future<Item = I, Error = E> + Send + 'static,
143 I: Send + 'static,
144 {
145 let closure = Closure::Read(Box::new(move |t| {
146 Box::new(cb(t).map(|t| -> AnyBox { Box::new(t) }))
147 }));
148 self.run_closure(closure)
149 .map(|res| -> I { *res.downcast::<I>().unwrap() })
150 }
151
152 pub fn get_mut<I, CB, F>(&mut self, mut cb: CB) -> impl Future<Item = I, Error = E>
154 where
155 CB: (FnMut(&mut T) -> F) + Send + 'static,
156 F: Future<Item = I, Error = E> + Send + 'static,
157 I: Send + 'static,
158 {
159 let closure = Closure::Write(Box::new(move |t| {
160 Box::new(cb(t).map(|t| -> AnyBox { Box::new(t) }))
161 }));
162 self.run_closure(closure)
163 .map(|res| -> I { *res.downcast::<I>().unwrap() })
164 }
165
166 pub fn stop(&mut self) {
168 self.tx = None;
169 }
170}
171
172impl<T, E> Default for Lock<T, E>
173where
174 E: StdError + From<Error> + Send + 'static,
175{
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181impl<T, E> Clone for Lock<T, E>
182where
183 E: StdError + From<Error> + Send + 'static,
184{
185 fn clone(&self) -> Self {
186 Self {
187 tx: self.tx.clone(),
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use futures::future::FutureResult;
196
197 struct TestObject {
198 x: u32,
199 y: u64,
200 }
201
202 #[test]
203 fn it_should_compute_digest_for_abc() {
204 let o = TestObject { x: 23, y: 42 };
205
206 let mut l = Lock::new();
207 let poll = l.manage(o).map_err(|err| {
208 panic!("Got error {}", err);
209 });
210
211 let get_x = l.get(|o| -> FutureResult<u32, Error> { future::ok(o.x) });
212 let get_y = l
213 .clone()
214 .get(|o| -> FutureResult<u64, Error> { future::ok(o.y) });
215
216 let get = get_x
217 .join(get_y)
218 .map_err(|err| {
219 panic!("Got error {}", err);
220 })
221 .map(move |val| {
222 assert_eq!(val, (23, 42));
223 l.stop();
224 });
225
226 tokio::run(poll.join(get).map(|_| ()));
227 }
228}