Skip to main content

vortex_array/arrays/shared/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::sync::Arc;
6
7use async_lock::Mutex;
8use async_lock::MutexGuard;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::IntoArray;
15use crate::stats::ArrayStats;
16
17#[derive(Debug, Clone)]
18pub struct SharedArray {
19    pub(super) state: Arc<Mutex<SharedState>>,
20    pub(super) dtype: DType,
21    pub(super) stats: ArrayStats,
22}
23
24#[derive(Debug, Clone)]
25pub(super) enum SharedState {
26    Source(ArrayRef),
27    Cached(Canonical),
28}
29
30impl SharedArray {
31    /// Creates a new `SharedArray` wrapping the given source array.
32    pub fn new(source: ArrayRef) -> Self {
33        Self {
34            dtype: source.dtype().clone(),
35            state: Arc::new(Mutex::new(SharedState::Source(source))),
36            stats: ArrayStats::default(),
37        }
38    }
39
40    #[cfg(not(target_family = "wasm"))]
41    fn lock_sync(&self) -> MutexGuard<'_, SharedState> {
42        self.state.lock_blocking()
43    }
44
45    #[cfg(target_family = "wasm")]
46    fn lock_sync(&self) -> MutexGuard<'_, SharedState> {
47        // this should mirror how parking_lot compiles to wasm
48        self.state
49            .try_lock()
50            .expect("SharedArray: mutex contention on single-threaded wasm target")
51    }
52
53    pub fn get_or_compute(
54        &self,
55        f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
56    ) -> VortexResult<Canonical> {
57        let mut state = self.lock_sync();
58        match &*state {
59            SharedState::Cached(canonical) => Ok(canonical.clone()),
60            SharedState::Source(source) => {
61                let canonical = f(source)?;
62                *state = SharedState::Cached(canonical.clone());
63                Ok(canonical)
64            }
65        }
66    }
67
68    pub async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<Canonical>
69    where
70        F: FnOnce(ArrayRef) -> Fut,
71        Fut: Future<Output = VortexResult<Canonical>>,
72    {
73        let mut state = self.state.lock().await;
74        match &*state {
75            SharedState::Cached(canonical) => Ok(canonical.clone()),
76            SharedState::Source(source) => {
77                let source = source.clone();
78                let canonical = f(source).await?;
79                *state = SharedState::Cached(canonical.clone());
80                Ok(canonical)
81            }
82        }
83    }
84
85    pub(super) fn current_array_ref(&self) -> ArrayRef {
86        let state = self.lock_sync();
87        match &*state {
88            SharedState::Source(source) => source.clone(),
89            SharedState::Cached(canonical) => canonical.clone().into_array(),
90        }
91    }
92
93    pub(super) fn set_source(&mut self, source: ArrayRef) {
94        self.dtype = source.dtype().clone();
95        *self.lock_sync() = SharedState::Source(source);
96    }
97}