1use std::cell::Cell;
5use std::{mem, ops, fmt, cmp};
6
7#[repr(transparent)]
9pub struct SharedMutRef<'r, T: ?Sized + 'r>(Cell<Option<&'r mut T>>);
10
11impl<'r, T: ?Sized + 'r> SharedMutRef<'r, T> {
12 pub fn new(value_ref: &'r mut T) -> Self {
23 Self(Cell::new(Some(value_ref)))
24 }
25
26 pub fn into_inner(self) -> &'r mut T {
39 self.0.take().unwrap()
40 }
41
42 pub fn get_temp(&self) -> Option<TempMutRef<'_, 'r, T>> {
58 self.0.take().map(|value_ref| TempMutRef {
59 shared: self,
60 value_ref: Some(value_ref),
61 })
62 }
63
64 pub fn modify<U, F>(&self, f: F) -> Option<U>
78 where
79 F: FnOnce(&mut T) -> U,
80 {
81 self.get_temp()
82 .map(|mut temp_mut_ref| f(&mut *temp_mut_ref))
83 }
84}
85
86impl<'r, T: 'r> SharedMutRef<'r, T> {
87 pub fn set(&self, value: T) -> Result<(), T> {
101 let Some(value_ref) = self.0.take() else {
102 return Err(value);
103 };
104 *value_ref = value;
105 self.0.set(Some(value_ref));
106
107 Ok(())
108 }
109
110 pub fn replace(&self, new_value: T) -> Result<T, T> {
125 let Some(value_ref) = self.0.take() else {
126 return Err(new_value);
127 };
128 let old_value = mem::replace(value_ref, new_value);
129 self.0.set(Some(value_ref));
130
131 Ok(old_value)
132 }
133}
134
135impl<'r, T: Default + 'r> SharedMutRef<'r, T> {
136 pub fn take(&self) -> Option<T> {
151 let Some(value_ref) = self.0.take() else {
152 return None;
153 };
154 let value = mem::take(value_ref);
155 self.0.set(Some(value_ref));
156
157 Some(value)
158 }
159}
160
161impl<'r, T: Clone + 'r> SharedMutRef<'r, T> {
162 pub fn get(&self) -> Option<T> {
175 let Some(value_ref) = self.0.take() else {
176 return None;
177 };
178 let value = value_ref.clone();
179 self.0.set(Some(value_ref));
180
181 Some(value)
182 }
183}
184
185impl<T: fmt::Debug> fmt::Debug for SharedMutRef<'_, T> {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 let Some(value_ref) = self.0.take() else {
188 return f.write_fmt(format_args!("SharedMutRef(<temporary_borrowed>)"));
189 };
190 let ret = f.write_fmt(format_args!("SharedMutRef({:?})", *value_ref));
191 self.0.set(Some(value_ref));
192
193 ret
194 }
195}
196
197impl<'r, T: ?Sized + 'r> From<&'r mut T> for SharedMutRef<'r, T> {
198 fn from(value: &'r mut T) -> Self {
199 Self::new(value)
200 }
201}
202
203pub struct TempMutRef<'s, 'r, T: ?Sized + 'r> {
209 shared: &'s SharedMutRef<'r, T>,
210 value_ref: Option<&'r mut T>,
211}
212
213impl<T: ?Sized> ops::Deref for TempMutRef<'_, '_, T> {
214 type Target = T;
215
216 fn deref(&self) -> &Self::Target {
217 &**self
218 .value_ref
219 .as_ref()
220 .expect("TempMutRef doesn't point to any value")
221 }
222}
223
224impl<T: ?Sized> ops::DerefMut for TempMutRef<'_, '_, T> {
225 fn deref_mut(&mut self) -> &mut Self::Target {
226 &mut **self
227 .value_ref
228 .as_mut()
229 .expect("TempMutRef doesn't point to any value")
230 }
231}
232
233impl<T: fmt::Debug> fmt::Debug for TempMutRef<'_, '_, T> {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 f.write_fmt(format_args!("TempMutRef({:?})", **self))
236 }
237}
238
239impl<T: fmt::Display> fmt::Display for TempMutRef<'_, '_, T> {
240 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241 (**self).fmt(f)
242 }
243}
244
245impl<T: ?Sized> ops::Drop for TempMutRef<'_, '_, T> {
246 fn drop(&mut self) {
247 self.shared.0.set(self.value_ref.take());
248 }
249}
250
251impl<T: cmp::PartialEq> cmp::PartialEq for TempMutRef<'_, '_, T> {
252 fn eq(&self, other: &Self) -> bool {
253 **self == **other
254 }
255}
256
257impl<T: cmp::Eq> cmp::Eq for TempMutRef<'_, '_, T> {}
258
259impl<T: cmp::PartialOrd> cmp::PartialOrd for TempMutRef<'_, '_, T> {
260 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
261 (**self).partial_cmp(&**other)
262 }
263}
264
265impl<T: cmp::Ord> cmp::Ord for TempMutRef<'_, '_, T> {
266 fn cmp(&self, other: &Self) -> cmp::Ordering {
267 (**self).cmp(&**other)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_shared_mut_ref_get() {
277 let mut vec = vec![1, 2, 3];
278 let shared_mut_ref = SharedMutRef::new(&mut vec);
279
280 assert_eq!(shared_mut_ref.get(), Some(vec![1, 2, 3]));
281
282 let _temp = shared_mut_ref.get_temp();
283 assert!(shared_mut_ref.get().is_none())
284 }
285
286 #[test]
287 fn test_shared_mut_ref_set() {
288 let mut vec = vec![1, 2, 3];
289 let shared_mut_ref = SharedMutRef::new(&mut vec);
290
291 {
292 assert!(shared_mut_ref.set(vec![4, 5, 6]).is_ok());
293 assert_eq!(shared_mut_ref.get(), Some(vec![4, 5, 6]));
294 }
295
296 {
297 let _temp = shared_mut_ref.get_temp();
298 assert_eq!(shared_mut_ref.set(vec![7, 8, 9]), Err(vec![7, 8, 9]));
299 }
300
301 assert_eq!(shared_mut_ref.get(), Some(vec![4, 5, 6]));
302 }
303
304 #[test]
305 fn test_shared_mut_ref_replace() {
306 let mut vec = vec![1, 2, 3];
307 let shared_mut_ref = SharedMutRef::new(&mut vec);
308
309 {
310 assert_eq!(shared_mut_ref.replace(vec![4, 5, 6]), Ok(vec![1, 2, 3]));
311 assert_eq!(shared_mut_ref.get(), Some(vec![4, 5, 6]));
312 }
313
314 {
315 let _temp = shared_mut_ref.get_temp();
316 assert_eq!(shared_mut_ref.replace(vec![7, 8, 9]), Err(vec![7, 8, 9]));
317 }
318
319 assert_eq!(shared_mut_ref.get(), Some(vec![4, 5, 6]));
320 }
321
322 #[test]
323 fn test_shared_mut_ref_take() {
324 let mut vec = vec![1, 2, 3];
325 let shared_mut_ref = SharedMutRef::new(&mut vec);
326
327 {
328 assert_eq!(shared_mut_ref.take(), Some(vec![1, 2, 3]));
329 assert_eq!(shared_mut_ref.get(), Some(vec![]));
330 }
331
332 {
333 let _temp = shared_mut_ref.get_temp();
334 assert!(shared_mut_ref.take().is_none());
335 }
336
337 assert_eq!(shared_mut_ref.get(), Some(vec![]));
338 }
339
340 #[test]
341 fn test_shared_mut_ref_modify() {
342 let mut vec = vec![1, 2, 3];
343 let shared_mut_ref = SharedMutRef::new(&mut vec);
344
345 {
346 shared_mut_ref.modify(|value_ref| value_ref.reverse());
347 assert_eq!(shared_mut_ref.get(), Some(vec![3, 2, 1]));
348 }
349
350 {
351 let _temp = shared_mut_ref.get_temp();
352 shared_mut_ref.modify(|value_ref| value_ref.reverse());
353 }
354
355 assert_eq!(shared_mut_ref.get(), Some(vec![3, 2, 1]));
356 }
357
358 #[test]
359 fn test_shared_mut_ref_get_temp() {
360 let mut vec = vec![1, 2, 3];
361 let shared_mut_ref = SharedMutRef::new(&mut vec);
362
363 {
364 let temp = shared_mut_ref.get_temp().unwrap();
365 assert_eq!(*temp, vec![1, 2, 3]);
366 }
367
368 assert_eq!(shared_mut_ref.get(), Some(vec![1, 2, 3]));
369
370 {
371 let mut temp = shared_mut_ref.get_temp().unwrap();
372 *temp = vec![4, 5, 6];
373 assert_eq!(*temp, vec![4, 5, 6]);
374 }
375
376 assert_eq!(shared_mut_ref.get(), Some(vec![4, 5, 6]));
377
378 {
379 let mut temp1 = shared_mut_ref.get_temp().unwrap();
380
381 let temp2 = shared_mut_ref.get_temp();
382 assert!(temp2.is_none());
383
384 *temp1 = vec![1, 2, 3];
385 assert_eq!(*temp1, vec![1, 2, 3]);
386 }
387
388 assert_eq!(shared_mut_ref.get(), Some(vec![1, 2, 3]));
389 }
390}