vegafusion_runtime/task_graph/
cache.rs1use async_lock::{Mutex, MutexGuard, RwLock};
2use futures::FutureExt;
3use lru::LruCache;
4
5use cfg_if::cfg_if;
6use std::collections::HashMap;
7use std::future::Future;
8use std::panic::{resume_unwind, AssertUnwindSafe};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use vegafusion_core::error::{DuplicateResult, Result, VegaFusionError};
12use vegafusion_core::task_graph::task_value::TaskValue;
13
14#[cfg(not(target_arch = "wasm32"))]
15use {std::time::Instant, vegafusion_core::error::ToExternalError};
16
17#[derive(Debug, Clone)]
18struct CachedValue {
19 value: NodeValue,
20 _calculation_millis: Option<u128>,
21}
22
23impl CachedValue {
24 pub fn size_of(&self) -> usize {
25 self.value.0.size_of() + self.value.1.iter().map(|v| v.size_of()).sum::<usize>()
26 }
27}
28
29type NodeValue = (TaskValue, Vec<TaskValue>);
30type Initializer = Arc<RwLock<Option<Result<NodeValue>>>>;
31
32#[derive(Debug, Clone)]
37pub struct VegaFusionCache {
38 protected_cache: Arc<Mutex<LruCache<u64, CachedValue>>>,
39 probationary_cache: Arc<Mutex<LruCache<u64, CachedValue>>>,
40 protected_fraction: f64,
41 initializers: Arc<RwLock<HashMap<u64, Initializer>>>,
42 size: Arc<AtomicUsize>,
43 protected_memory: Arc<AtomicUsize>,
44 probationary_memory: Arc<AtomicUsize>,
45 capacity: Option<usize>,
46 memory_limit: Option<usize>,
47}
48
49impl VegaFusionCache {
50 pub fn new(capacity: Option<usize>, size_limit: Option<usize>) -> Self {
51 Self {
52 protected_cache: Arc::new(Mutex::new(LruCache::unbounded())),
53 probationary_cache: Arc::new(Mutex::new(LruCache::unbounded())),
54 protected_fraction: 0.5,
55 initializers: Default::default(),
56 capacity,
57 memory_limit: size_limit,
58 size: Arc::new(AtomicUsize::new(0)),
59 protected_memory: Arc::new(AtomicUsize::new(0)),
60 probationary_memory: Arc::new(AtomicUsize::new(0)),
61 }
62 }
63
64 pub fn capacity(&self) -> Option<usize> {
65 self.capacity
66 }
67
68 pub fn memory_limit(&self) -> Option<usize> {
69 self.memory_limit
70 }
71
72 pub fn size(&self) -> usize {
73 self.size.load(Ordering::Relaxed)
74 }
75
76 pub fn total_memory(&self) -> usize {
77 self.protected_memory() + self.probationary_memory()
78 }
79
80 pub fn protected_memory(&self) -> usize {
81 self.protected_memory.load(Ordering::Relaxed)
82 }
83
84 pub fn probationary_memory(&self) -> usize {
85 self.probationary_memory.load(Ordering::Relaxed)
86 }
87
88 pub async fn clear(&self) {
89 self.protected_cache.lock().await.clear();
92 self.probationary_cache.lock().await.clear();
93 self.protected_memory.store(0, Ordering::Relaxed);
94 self.probationary_memory.store(0, Ordering::Relaxed);
95 self.size.store(0, Ordering::Relaxed);
96 }
97
98 async fn get(&self, state_fingerprint: u64) -> Option<CachedValue> {
99 let mut protected = self.protected_cache.lock().await;
100 let mut probationary = self.probationary_cache.lock().await;
101
102 if protected.contains(&state_fingerprint) {
103 protected.get(&state_fingerprint).cloned()
104 } else if probationary.contains(&state_fingerprint) {
105 let value = probationary.pop(&state_fingerprint).unwrap();
107 let value_memory = value.size_of();
108 protected.put(state_fingerprint, value.clone());
109
110 self.protected_memory
111 .fetch_add(value_memory, Ordering::Relaxed);
112 self.probationary_memory
113 .fetch_sub(value_memory, Ordering::Relaxed);
114
115 self.balance(&mut protected, &mut probationary);
117
118 Some(value)
119 } else {
120 None
121 }
122 }
123
124 fn pop_protected_lru(
125 &self,
126 protected: &mut MutexGuard<LruCache<u64, CachedValue>>,
127 probationary: &mut MutexGuard<LruCache<u64, CachedValue>>,
128 ) {
129 let (key, popped_value) = protected.pop_lru().unwrap();
131 let popped_memory = popped_value.size_of();
132
133 self.protected_memory
135 .fetch_sub(popped_memory, Ordering::Relaxed);
136
137 probationary.put(key, popped_value);
139
140 self.probationary_memory
142 .fetch_add(popped_memory, Ordering::Relaxed);
143 }
144
145 fn pop_probationary_lru(&self, probationary: &mut MutexGuard<LruCache<u64, CachedValue>>) {
146 let (_, popped_value) = probationary.pop_lru().unwrap();
147 let popped_memory = popped_value.size_of();
148
149 self.probationary_memory
151 .fetch_sub(popped_memory, Ordering::Relaxed);
152 }
153
154 fn balance(
155 &self,
156 protected: &mut MutexGuard<LruCache<u64, CachedValue>>,
157 probationary: &mut MutexGuard<LruCache<u64, CachedValue>>,
158 ) {
159 let (protected_capacity, probationary_capacity) = if let Some(capacity) = self.capacity {
161 let protected_capacity = (capacity as f64 * self.protected_fraction).ceil() as usize;
162 (
163 Some(protected_capacity),
164 Some(capacity - protected_capacity),
165 )
166 } else {
167 (None, None)
168 };
169
170 let (protected_mem_limit, probationary_mem_limit) =
171 if let Some(memory_limit) = self.memory_limit {
172 let protected_mem_limit =
173 (memory_limit as f64 * self.protected_fraction).ceil() as usize;
174 (
175 Some(protected_mem_limit),
176 Some(memory_limit - protected_mem_limit),
177 )
178 } else {
179 (None, None)
180 };
181
182 if let Some(capacity) = protected_capacity {
186 while protected.len() > 1 && protected.len() > capacity {
187 self.pop_protected_lru(protected, probationary);
188 }
189 }
190
191 if let Some(memory_limit) = protected_mem_limit {
193 while protected.len() > 1
194 && self.protected_memory.load(Ordering::Relaxed) > memory_limit
195 {
196 self.pop_protected_lru(protected, probationary);
197 }
198 }
199
200 if let Some(capacity) = probationary_capacity {
203 while probationary.len() > 1 && probationary.len() > capacity {
204 self.pop_probationary_lru(probationary);
205 }
206 }
207
208 if let Some(memory_limit) = probationary_mem_limit {
210 while probationary.len() > 1
211 && self.probationary_memory.load(Ordering::Relaxed) > memory_limit
212 {
213 self.pop_probationary_lru(probationary);
214 }
215 }
216
217 self.size
219 .store(protected.len() + probationary.len(), Ordering::Relaxed);
220 }
221
222 async fn set_value(
223 &self,
224 state_fingerprint: u64,
225 value: NodeValue,
226 calculation_millis: Option<u128>,
227 ) {
228 let cache_value = CachedValue {
229 value,
230 _calculation_millis: calculation_millis,
231 };
232 let value_memory = cache_value.size_of();
233
234 let mut protected = self.protected_cache.lock().await;
235 let mut probationary = self.probationary_cache.lock().await;
236 if protected.contains(&state_fingerprint) {
237 protected.put(state_fingerprint, cache_value);
239 } else if probationary.contains(&state_fingerprint) {
240 protected.put(
242 state_fingerprint,
243 probationary.pop(&state_fingerprint).unwrap(),
244 );
245 self.protected_memory
246 .fetch_add(value_memory, Ordering::Relaxed);
247 self.probationary_memory
248 .fetch_sub(value_memory, Ordering::Relaxed);
249 self.balance(&mut protected, &mut probationary);
250 } else {
251 probationary.put(state_fingerprint, cache_value);
253 self.probationary_memory
254 .fetch_add(value_memory, Ordering::Relaxed);
255 self.balance(&mut protected, &mut probationary);
256 }
257 }
258
259 async fn remove_initializer(&self, state_fingerprint: u64) -> Option<Initializer> {
260 self.initializers.write().await.remove(&state_fingerprint)
261 }
262
263 pub async fn get_or_try_insert_with<F>(
264 &self,
265 state_fingerprint: u64,
266 init: F,
267 ) -> Result<NodeValue>
268 where
269 F: Future<Output = Result<NodeValue>> + Send + 'static,
270 {
271 if let Some(value) = self.get(state_fingerprint).await {
273 return Ok(value.value);
274 }
275
276 let initializer = {
279 self.initializers
280 .write()
281 .await
282 .get(&state_fingerprint)
283 .cloned()
284 };
285
286 if let Some(initializer) = initializer {
287 let result = initializer.read().await;
290 let result = match result.as_ref() {
291 None => self.spawn_initializer(state_fingerprint, init).await,
292 Some(result) => result.duplicate(),
293 };
294 result
295 } else {
296 self.spawn_initializer(state_fingerprint, init).await
297 }
298 }
299
300 async fn spawn_initializer<F>(&self, state_fingerprint: u64, init: F) -> Result<NodeValue>
301 where
302 F: Future<Output = Result<NodeValue>> + Send + 'static,
303 {
304 let initializer: Initializer = Arc::new(RwLock::new(None));
306
307 let mut initializer_lock = initializer.write().await;
309
310 self.initializers
312 .write()
313 .await
314 .insert(state_fingerprint, initializer.clone());
315
316 cfg_if! {
318 if #[cfg(target_arch = "wasm32")] {
319 match AssertUnwindSafe(init).catch_unwind().await {
322 Ok(value) => {
324 match value {
326 Ok(value) => {
327 *initializer_lock = Some(Ok(value.clone()));
328 self.set_value(state_fingerprint, value.clone(), None).await;
329
330 self.remove_initializer(state_fingerprint).await;
334 Ok(value)
335 }
336 Err(e) => {
337 *initializer_lock = Some(Err(e.duplicate()));
339 self.remove_initializer(state_fingerprint).await;
340 Err(e)
341 }
342 }
343 }
344 Err(payload) => {
346 *initializer_lock = Some(Err(VegaFusionError::internal("Panic error")));
347
348 self.remove_initializer(state_fingerprint).await;
350 resume_unwind(payload);
352 }
353 }
354 } else {
355 let start = Instant::now();
357 match AssertUnwindSafe(tokio::spawn(init)).catch_unwind().await {
358 Ok(Ok(value)) => {
360 match value {
362 Ok(value) => {
363 *initializer_lock = Some(Ok(value.clone()));
364
365 let duration = start.elapsed();
367 let millis = duration.as_millis();
368
369 self.set_value(state_fingerprint, value.clone(), Some(millis))
370 .await;
371
372 self.remove_initializer(state_fingerprint).await;
376 Ok(value)
377 }
378 Err(e) => {
379 *initializer_lock = Some(Err(e.duplicate()));
381 self.remove_initializer(state_fingerprint).await;
382 Err(e)
383 }
384 }
385 }
386 Ok(Err(err)) => {
387 *initializer_lock = Some(Err(VegaFusionError::internal(err.to_string())));
388 self.remove_initializer(state_fingerprint).await;
389 Err(err).external("tokio error")
390 }
391 Err(payload) => {
393 *initializer_lock = Some(Err(VegaFusionError::internal("Panic error")));
394
395 self.remove_initializer(state_fingerprint).await;
397 resume_unwind(payload);
399 }
400 }
401 }
402 }
403 }
404}
405
406#[cfg(test)]
407mod test_cache {
408 use crate::task_graph::cache::{NodeValue, VegaFusionCache};
409 use tokio::time::Duration;
410 use vegafusion_common::data::scalar::ScalarValue;
411 use vegafusion_common::error::Result;
412 use vegafusion_core::task_graph::task_value::TaskValue;
413
414 async fn make_value(value: ScalarValue) -> Result<NodeValue> {
415 tokio::time::sleep(Duration::from_millis(1000)).await;
416 Ok((TaskValue::Scalar(value), Vec::new()))
417 }
418
419 #[tokio::test(flavor = "multi_thread")]
420 async fn try_cache() {
421 let cache = VegaFusionCache::new(Some(4), None);
422
423 let value_future1 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
424 let value_future2 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
425 let value_future3 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
426
427 tokio::time::sleep(Duration::from_millis(100)).await;
428 println!("{:?}", cache.initializers);
429
430 let futures = vec![value_future1, value_future2];
434 let values = futures::future::join_all(futures).await;
435
436 let next_value = value_future3.await;
437
438 println!("{:?}", cache.initializers);
440 println!("values: {values:?}");
444 println!("next_value: {next_value:?}");
445 }
446}