vegafusion_runtime/task_graph/
cache.rs

1use async_lock::{Mutex, MutexGuard, RwLock};
2use futures::FutureExt;
3use lru::LruCache;
4
5use cfg_if::cfg_if;
6use std::collections::HashMap;
7use std::future::Future;
8use std::panic::{resume_unwind, AssertUnwindSafe};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11use vegafusion_core::error::{DuplicateResult, Result, VegaFusionError};
12use vegafusion_core::task_graph::task_value::TaskValue;
13
14#[cfg(not(target_arch = "wasm32"))]
15use {std::time::Instant, vegafusion_core::error::ToExternalError};
16
17#[derive(Debug, Clone)]
18struct CachedValue {
19    value: NodeValue,
20    _calculation_millis: Option<u128>,
21}
22
23impl CachedValue {
24    pub fn size_of(&self) -> usize {
25        self.value.0.size_of() + self.value.1.iter().map(|v| v.size_of()).sum::<usize>()
26    }
27}
28
29type NodeValue = (TaskValue, Vec<TaskValue>);
30type Initializer = Arc<RwLock<Option<Result<NodeValue>>>>;
31
32/// The VegaFusionCache uses a Segmented LRU (SLRU) cache policy
33/// (https://en.wikipedia.org/wiki/Cache_replacement_policies#Segmented_LRU_(SLRU)) where both the
34/// protected and probationary LRU caches are limited by capacity (number of entries) and memory
35/// limit.
36#[derive(Debug, Clone)]
37pub struct VegaFusionCache {
38    protected_cache: Arc<Mutex<LruCache<u64, CachedValue>>>,
39    probationary_cache: Arc<Mutex<LruCache<u64, CachedValue>>>,
40    protected_fraction: f64,
41    initializers: Arc<RwLock<HashMap<u64, Initializer>>>,
42    size: Arc<AtomicUsize>,
43    protected_memory: Arc<AtomicUsize>,
44    probationary_memory: Arc<AtomicUsize>,
45    capacity: Option<usize>,
46    memory_limit: Option<usize>,
47}
48
49impl VegaFusionCache {
50    pub fn new(capacity: Option<usize>, size_limit: Option<usize>) -> Self {
51        Self {
52            protected_cache: Arc::new(Mutex::new(LruCache::unbounded())),
53            probationary_cache: Arc::new(Mutex::new(LruCache::unbounded())),
54            protected_fraction: 0.5,
55            initializers: Default::default(),
56            capacity,
57            memory_limit: size_limit,
58            size: Arc::new(AtomicUsize::new(0)),
59            protected_memory: Arc::new(AtomicUsize::new(0)),
60            probationary_memory: Arc::new(AtomicUsize::new(0)),
61        }
62    }
63
64    pub fn capacity(&self) -> Option<usize> {
65        self.capacity
66    }
67
68    pub fn memory_limit(&self) -> Option<usize> {
69        self.memory_limit
70    }
71
72    pub fn size(&self) -> usize {
73        self.size.load(Ordering::Relaxed)
74    }
75
76    pub fn total_memory(&self) -> usize {
77        self.protected_memory() + self.probationary_memory()
78    }
79
80    pub fn protected_memory(&self) -> usize {
81        self.protected_memory.load(Ordering::Relaxed)
82    }
83
84    pub fn probationary_memory(&self) -> usize {
85        self.probationary_memory.load(Ordering::Relaxed)
86    }
87
88    pub async fn clear(&self) {
89        // Clear the values cache. There may still be initializers representing in progress
90        // futures which will not be cleared.
91        self.protected_cache.lock().await.clear();
92        self.probationary_cache.lock().await.clear();
93        self.protected_memory.store(0, Ordering::Relaxed);
94        self.probationary_memory.store(0, Ordering::Relaxed);
95        self.size.store(0, Ordering::Relaxed);
96    }
97
98    async fn get(&self, state_fingerprint: u64) -> Option<CachedValue> {
99        let mut protected = self.protected_cache.lock().await;
100        let mut probationary = self.probationary_cache.lock().await;
101
102        if protected.contains(&state_fingerprint) {
103            protected.get(&state_fingerprint).cloned()
104        } else if probationary.contains(&state_fingerprint) {
105            // Promote entry from probationary to protected
106            let value = probationary.pop(&state_fingerprint).unwrap();
107            let value_memory = value.size_of();
108            protected.put(state_fingerprint, value.clone());
109
110            self.protected_memory
111                .fetch_add(value_memory, Ordering::Relaxed);
112            self.probationary_memory
113                .fetch_sub(value_memory, Ordering::Relaxed);
114
115            // Balance caches
116            self.balance(&mut protected, &mut probationary);
117
118            Some(value)
119        } else {
120            None
121        }
122    }
123
124    fn pop_protected_lru(
125        &self,
126        protected: &mut MutexGuard<LruCache<u64, CachedValue>>,
127        probationary: &mut MutexGuard<LruCache<u64, CachedValue>>,
128    ) {
129        // Remove one protected LRU entry
130        let (key, popped_value) = protected.pop_lru().unwrap();
131        let popped_memory = popped_value.size_of();
132
133        // Decrement protected memory
134        self.protected_memory
135            .fetch_sub(popped_memory, Ordering::Relaxed);
136
137        // Add entry to probationary cache
138        probationary.put(key, popped_value);
139
140        // Increment probationary memory
141        self.probationary_memory
142            .fetch_add(popped_memory, Ordering::Relaxed);
143    }
144
145    fn pop_probationary_lru(&self, probationary: &mut MutexGuard<LruCache<u64, CachedValue>>) {
146        let (_, popped_value) = probationary.pop_lru().unwrap();
147        let popped_memory = popped_value.size_of();
148
149        // Decrement protected memory
150        self.probationary_memory
151            .fetch_sub(popped_memory, Ordering::Relaxed);
152    }
153
154    fn balance(
155        &self,
156        protected: &mut MutexGuard<LruCache<u64, CachedValue>>,
157        probationary: &mut MutexGuard<LruCache<u64, CachedValue>>,
158    ) {
159        // Compute capacity and memory limits for both protected and probationary caches
160        let (protected_capacity, probationary_capacity) = if let Some(capacity) = self.capacity {
161            let protected_capacity = (capacity as f64 * self.protected_fraction).ceil() as usize;
162            (
163                Some(protected_capacity),
164                Some(capacity - protected_capacity),
165            )
166        } else {
167            (None, None)
168        };
169
170        let (protected_mem_limit, probationary_mem_limit) =
171            if let Some(memory_limit) = self.memory_limit {
172                let protected_mem_limit =
173                    (memory_limit as f64 * self.protected_fraction).ceil() as usize;
174                (
175                    Some(protected_mem_limit),
176                    Some(memory_limit - protected_mem_limit),
177                )
178            } else {
179                (None, None)
180            };
181
182        // Step 1: Shrink protected cache until it satisfies limits, moving evicted items to
183        //         probationary cache
184        // Pop to capacity limit
185        if let Some(capacity) = protected_capacity {
186            while protected.len() > 1 && protected.len() > capacity {
187                self.pop_protected_lru(protected, probationary);
188            }
189        }
190
191        // Pop LRU to memory limit
192        if let Some(memory_limit) = protected_mem_limit {
193            while protected.len() > 1
194                && self.protected_memory.load(Ordering::Relaxed) > memory_limit
195            {
196                self.pop_protected_lru(protected, probationary);
197            }
198        }
199
200        // Step 2: Shrink probationary cache until it satisfies limits,
201        //         decrementing memory estimate
202        if let Some(capacity) = probationary_capacity {
203            while probationary.len() > 1 && probationary.len() > capacity {
204                self.pop_probationary_lru(probationary);
205            }
206        }
207
208        // Pop LRU to memory limit
209        if let Some(memory_limit) = probationary_mem_limit {
210            while probationary.len() > 1
211                && self.probationary_memory.load(Ordering::Relaxed) > memory_limit
212            {
213                self.pop_probationary_lru(probationary);
214            }
215        }
216
217        // Step 3: Update size atomics
218        self.size
219            .store(protected.len() + probationary.len(), Ordering::Relaxed);
220    }
221
222    async fn set_value(
223        &self,
224        state_fingerprint: u64,
225        value: NodeValue,
226        calculation_millis: Option<u128>,
227    ) {
228        let cache_value = CachedValue {
229            value,
230            _calculation_millis: calculation_millis,
231        };
232        let value_memory = cache_value.size_of();
233
234        let mut protected = self.protected_cache.lock().await;
235        let mut probationary = self.probationary_cache.lock().await;
236        if protected.contains(&state_fingerprint) {
237            // Set on protected to update usage
238            protected.put(state_fingerprint, cache_value);
239        } else if probationary.contains(&state_fingerprint) {
240            // Promote from probationary to protected
241            protected.put(
242                state_fingerprint,
243                probationary.pop(&state_fingerprint).unwrap(),
244            );
245            self.protected_memory
246                .fetch_add(value_memory, Ordering::Relaxed);
247            self.probationary_memory
248                .fetch_sub(value_memory, Ordering::Relaxed);
249            self.balance(&mut protected, &mut probationary);
250        } else {
251            // Add to probationary and update memory usage
252            probationary.put(state_fingerprint, cache_value);
253            self.probationary_memory
254                .fetch_add(value_memory, Ordering::Relaxed);
255            self.balance(&mut protected, &mut probationary);
256        }
257    }
258
259    async fn remove_initializer(&self, state_fingerprint: u64) -> Option<Initializer> {
260        self.initializers.write().await.remove(&state_fingerprint)
261    }
262
263    pub async fn get_or_try_insert_with<F>(
264        &self,
265        state_fingerprint: u64,
266        init: F,
267    ) -> Result<NodeValue>
268    where
269        F: Future<Output = Result<NodeValue>> + Send + 'static,
270    {
271        // Check if present in the values cache
272        if let Some(value) = self.get(state_fingerprint).await {
273            return Ok(value.value);
274        }
275
276        // Check if present in initializers
277        // let mut initializers_lock = self.initializers.write().await;
278        let initializer = {
279            self.initializers
280                .write()
281                .await
282                .get(&state_fingerprint)
283                .cloned()
284        };
285
286        if let Some(initializer) = initializer {
287            // Calculation is in progress, await on Arc clone of it's initializer
288            // Drop lock on initializers collection
289            let result = initializer.read().await;
290            let result = match result.as_ref() {
291                None => self.spawn_initializer(state_fingerprint, init).await,
292                Some(result) => result.duplicate(),
293            };
294            result
295        } else {
296            self.spawn_initializer(state_fingerprint, init).await
297        }
298    }
299
300    async fn spawn_initializer<F>(&self, state_fingerprint: u64, init: F) -> Result<NodeValue>
301    where
302        F: Future<Output = Result<NodeValue>> + Send + 'static,
303    {
304        // Create new initializer
305        let initializer: Initializer = Arc::new(RwLock::new(None));
306
307        // Get and hold write lock for initializer
308        let mut initializer_lock = initializer.write().await;
309
310        // Store Arc clone of initializer in initializers map
311        self.initializers
312            .write()
313            .await
314            .insert(state_fingerprint, initializer.clone());
315
316        // Invoke future to initialize
317        cfg_if! {
318            if #[cfg(target_arch = "wasm32")] {
319                // In WASM we await the future directly since multi-threading with tokio::spawn
320                // is not available.
321                match AssertUnwindSafe(init).catch_unwind().await {
322                    // Resolved.
323                    Ok(value) => {
324                        // If result Ok, clone to values
325                        match value {
326                            Ok(value) => {
327                                *initializer_lock = Some(Ok(value.clone()));
328                                self.set_value(state_fingerprint, value.clone(), None).await;
329
330                                // Stored initializer no longer required. Initializers are Arc
331                                // pointers, so it's fine to drop initializer from here even if
332                                // other tasks are still awaiting on it.
333                                self.remove_initializer(state_fingerprint).await;
334                                Ok(value)
335                            }
336                            Err(e) => {
337                                // Remove initializer so that another future can try again
338                                *initializer_lock = Some(Err(e.duplicate()));
339                                self.remove_initializer(state_fingerprint).await;
340                                Err(e)
341                            }
342                        }
343                    }
344                    // Panicked.
345                    Err(payload) => {
346                        *initializer_lock = Some(Err(VegaFusionError::internal("Panic error")));
347
348                        // Remove the waiter so that others can retry.
349                        self.remove_initializer(state_fingerprint).await;
350                        // triggers panic, so no return value in this branch
351                        resume_unwind(payload);
352                    }
353                }
354            } else {
355                // When not in WASM, use tokio::spawn for multi-threading
356                let start = Instant::now();
357                match AssertUnwindSafe(tokio::spawn(init)).catch_unwind().await {
358                    // Resolved.
359                    Ok(Ok(value)) => {
360                        // If result Ok, clone to values
361                        match value {
362                            Ok(value) => {
363                                *initializer_lock = Some(Ok(value.clone()));
364
365                                // Check if we should add value to long-term cache
366                                let duration = start.elapsed();
367                                let millis = duration.as_millis();
368
369                                self.set_value(state_fingerprint, value.clone(), Some(millis))
370                                    .await;
371
372                                // Stored initializer no longer required. Initializers are Arc
373                                // pointers, so it's fine to drop initializer from here even if
374                                // other tasks are still awaiting on it.
375                                self.remove_initializer(state_fingerprint).await;
376                                Ok(value)
377                            }
378                            Err(e) => {
379                                // Remove initializer so that another future can try again
380                                *initializer_lock = Some(Err(e.duplicate()));
381                                self.remove_initializer(state_fingerprint).await;
382                                Err(e)
383                            }
384                        }
385                    }
386                    Ok(Err(err)) => {
387                        *initializer_lock = Some(Err(VegaFusionError::internal(err.to_string())));
388                        self.remove_initializer(state_fingerprint).await;
389                        Err(err).external("tokio error")
390                    }
391                    // Panicked.
392                    Err(payload) => {
393                        *initializer_lock = Some(Err(VegaFusionError::internal("Panic error")));
394
395                        // Remove the waiter so that others can retry.
396                        self.remove_initializer(state_fingerprint).await;
397                        // triggers panic, so no return value in this branch
398                        resume_unwind(payload);
399                    }
400                }
401            }
402        }
403    }
404}
405
406#[cfg(test)]
407mod test_cache {
408    use crate::task_graph::cache::{NodeValue, VegaFusionCache};
409    use tokio::time::Duration;
410    use vegafusion_common::data::scalar::ScalarValue;
411    use vegafusion_common::error::Result;
412    use vegafusion_core::task_graph::task_value::TaskValue;
413
414    async fn make_value(value: ScalarValue) -> Result<NodeValue> {
415        tokio::time::sleep(Duration::from_millis(1000)).await;
416        Ok((TaskValue::Scalar(value), Vec::new()))
417    }
418
419    #[tokio::test(flavor = "multi_thread")]
420    async fn try_cache() {
421        let cache = VegaFusionCache::new(Some(4), None);
422
423        let value_future1 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
424        let value_future2 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
425        let value_future3 = cache.get_or_try_insert_with(1, make_value(ScalarValue::from(23.5)));
426
427        tokio::time::sleep(Duration::from_millis(100)).await;
428        println!("{:?}", cache.initializers);
429
430        // assert_eq!(cache.num_values().await, 0);
431        // assert_eq!(cache.num_initializers().await, 1);
432
433        let futures = vec![value_future1, value_future2];
434        let values = futures::future::join_all(futures).await;
435
436        let next_value = value_future3.await;
437
438        // tokio::time::sleep(Duration::from_millis(300));
439        println!("{:?}", cache.initializers);
440        // assert_eq!(cache.num_values().await, 1);
441        // assert_eq!(cache.num_initializers().await, 0);
442
443        println!("values: {values:?}");
444        println!("next_value: {next_value:?}");
445    }
446}