1extern crate arc_swap;
13
14use arc_swap::{ArcSwap, Guard};
15use std::ops::Deref;
16use std::sync::{Arc, LockResult, Mutex, MutexGuard, PoisonError};
17
18use crate::{GuestAddressSpace, GuestMemory};
19
20#[derive(Debug)]
28pub struct GuestMemoryAtomic<M: GuestMemory> {
29 inner: Arc<(ArcSwap<M>, Mutex<()>)>,
35}
36
37impl<M: GuestMemory> From<Arc<M>> for GuestMemoryAtomic<M> {
38 fn from(map: Arc<M>) -> Self {
41 let inner = (ArcSwap::new(map), Mutex::new(()));
42 GuestMemoryAtomic {
43 inner: Arc::new(inner),
44 }
45 }
46}
47
48impl<M: GuestMemory> GuestMemoryAtomic<M> {
49 pub fn new(map: M) -> Self {
52 Arc::new(map).into()
53 }
54
55 fn load(&self) -> Guard<Arc<M>> {
56 self.inner.0.load()
57 }
58
59 pub fn lock(&self) -> LockResult<GuestMemoryExclusiveGuard<'_, M>> {
65 match self.inner.1.lock() {
66 Ok(guard) => Ok(GuestMemoryExclusiveGuard {
67 parent: self,
68 _guard: guard,
69 }),
70 Err(err) => Err(PoisonError::new(GuestMemoryExclusiveGuard {
71 parent: self,
72 _guard: err.into_inner(),
73 })),
74 }
75 }
76}
77
78impl<M: GuestMemory> Clone for GuestMemoryAtomic<M> {
79 fn clone(&self) -> Self {
80 Self {
81 inner: self.inner.clone(),
82 }
83 }
84}
85
86impl<M: GuestMemory> GuestAddressSpace for GuestMemoryAtomic<M> {
87 type T = GuestMemoryLoadGuard<M>;
88 type M = M;
89
90 fn memory(&self) -> Self::T {
91 GuestMemoryLoadGuard { guard: self.load() }
92 }
93}
94
95#[derive(Debug)]
100pub struct GuestMemoryLoadGuard<M: GuestMemory> {
101 guard: Guard<Arc<M>>,
102}
103
104impl<M: GuestMemory> GuestMemoryLoadGuard<M> {
105 pub fn into_inner(self) -> Arc<M> {
111 Guard::into_inner(self.guard)
112 }
113}
114
115impl<M: GuestMemory> Clone for GuestMemoryLoadGuard<M> {
116 fn clone(&self) -> Self {
117 GuestMemoryLoadGuard {
118 guard: Guard::from_inner(Arc::clone(&*self.guard)),
119 }
120 }
121}
122
123impl<M: GuestMemory> Deref for GuestMemoryLoadGuard<M> {
124 type Target = M;
125
126 fn deref(&self) -> &Self::Target {
127 &self.guard
128 }
129}
130
131#[derive(Debug)]
136pub struct GuestMemoryExclusiveGuard<'a, M: GuestMemory> {
137 parent: &'a GuestMemoryAtomic<M>,
138 _guard: MutexGuard<'a, ()>,
139}
140
141impl<M: GuestMemory> GuestMemoryExclusiveGuard<'_, M> {
142 pub fn replace(self, map: M) {
146 self.parent.inner.0.store(Arc::new(map))
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::region::tests::{new_guest_memory_collection_from_regions, Collection, MockRegion};
154 use crate::{GuestAddress, GuestMemory, GuestMemoryBackend, GuestMemoryRegion, GuestUsize};
155
156 type GuestMemoryMmapAtomic = GuestMemoryAtomic<Collection>;
157
158 #[test]
159 fn test_atomic_memory() {
160 let region_size = 0x400;
161 let regions = vec![
162 (GuestAddress(0x0), region_size),
163 (GuestAddress(0x1000), region_size),
164 ];
165 let mut iterated_regions = Vec::new();
166 let gmm = new_guest_memory_collection_from_regions(®ions).unwrap();
167 let gm = GuestMemoryMmapAtomic::new(gmm);
168 let vmem = gm.memory();
169 let mem = vmem.physical_memory().unwrap();
170
171 for region in mem.iter() {
172 assert_eq!(region.len(), region_size as GuestUsize);
173 }
174
175 for region in mem.iter() {
176 iterated_regions.push((region.start_addr(), region.len()));
177 }
178 assert_eq!(regions, iterated_regions);
179 assert_eq!(mem.num_regions(), 2);
180 assert!(mem.find_region(GuestAddress(0x1000)).is_some());
181 assert!(mem.find_region(GuestAddress(0x10000)).is_none());
182
183 assert!(regions
184 .iter()
185 .map(|x| (x.0, x.1))
186 .eq(iterated_regions.iter().copied()));
187
188 let mem2 = vmem.into_inner();
189 for region in mem2.iter() {
190 assert_eq!(region.len(), region_size as GuestUsize);
191 }
192 assert_eq!(mem2.num_regions(), 2);
193 assert!(mem2.find_region(GuestAddress(0x1000)).is_some());
194 assert!(mem2.find_region(GuestAddress(0x10000)).is_none());
195
196 assert!(regions
197 .iter()
198 .map(|x| (x.0, x.1))
199 .eq(iterated_regions.iter().copied()));
200
201 let mem3 = mem2.memory();
202 for region in mem3.iter() {
203 assert_eq!(region.len(), region_size as GuestUsize);
204 }
205 assert_eq!(mem3.num_regions(), 2);
206 assert!(mem3.find_region(GuestAddress(0x1000)).is_some());
207 assert!(mem3.find_region(GuestAddress(0x10000)).is_none());
208
209 let gm2 = gm.clone();
210 let mem4 = gm2.memory();
211 for region in mem4.iter() {
212 assert_eq!(region.len(), region_size as GuestUsize);
213 }
214 assert_eq!(mem4.num_regions(), 2);
215 assert!(mem4.find_region(GuestAddress(0x1000)).is_some());
216 assert!(mem4.find_region(GuestAddress(0x10000)).is_none());
217 }
218
219 #[test]
220 fn test_clone_guard() {
221 let region_size = 0x400;
222 let regions = vec![
223 (GuestAddress(0x0), region_size),
224 (GuestAddress(0x1000), region_size),
225 ];
226 let gmm = new_guest_memory_collection_from_regions(®ions).unwrap();
227 let gm = GuestMemoryMmapAtomic::new(gmm);
228 let mem = {
229 let guard1 = gm.memory();
230 Clone::clone(&guard1)
231 };
232 assert_eq!(mem.num_regions(), 2);
233 }
234
235 #[test]
236 fn test_atomic_hotplug() {
237 let region_size = 0x1000;
238 let regions = [
239 (GuestAddress(0x0), region_size),
240 (GuestAddress(0x10_0000), region_size),
241 ];
242 let mut gmm = Arc::new(new_guest_memory_collection_from_regions(®ions).unwrap());
243 let gm: GuestMemoryAtomic<_> = gmm.clone().into();
244 let mem_orig = gm.memory();
245 assert_eq!(mem_orig.num_regions(), 2);
246
247 {
248 let guard = gm.lock().unwrap();
249 let new_gmm = Arc::make_mut(&mut gmm);
250 let new_gmm = new_gmm
251 .insert_region(Arc::new(MockRegion {
252 start: GuestAddress(0x8000),
253 len: 0x1000,
254 }))
255 .unwrap();
256 let new_gmm = new_gmm
257 .insert_region(Arc::new(MockRegion {
258 start: GuestAddress(0x4000),
259 len: 0x1000,
260 }))
261 .unwrap();
262 let new_gmm = new_gmm
263 .insert_region(Arc::new(MockRegion {
264 start: GuestAddress(0xc000),
265 len: 0x1000,
266 }))
267 .unwrap();
268
269 new_gmm
270 .insert_region(Arc::new(MockRegion {
271 start: GuestAddress(0x8000),
272 len: 0x1000,
273 }))
274 .unwrap_err();
275
276 guard.replace(new_gmm);
277 }
278
279 assert_eq!(mem_orig.num_regions(), 2);
280 let mem = gm.memory();
281 assert_eq!(mem.num_regions(), 5);
282 }
283}