Skip to main content

vortex_array/arrays/shared/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::future::Future;
5use std::sync::Arc;
6use std::sync::OnceLock;
7
8use async_lock::Mutex as AsyncMutex;
9use vortex_error::SharedVortexResult;
10use vortex_error::VortexResult;
11
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::IntoArray;
15use crate::dtype::DType;
16use crate::stats::ArrayStats;
17
18/// A lazily-executing array wrapper with a one-way transition from source to cached form.
19///
20/// Before materialization, operations delegate to the source array.
21/// After materialization (via `get_or_compute`), operations delegate to the cached result.
22#[derive(Debug, Clone)]
23pub struct SharedArray {
24    source: ArrayRef,
25    cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
26    async_compute_lock: Arc<AsyncMutex<()>>,
27    pub(super) dtype: DType,
28    pub(super) stats: ArrayStats,
29}
30
31impl SharedArray {
32    pub fn new(source: ArrayRef) -> Self {
33        Self {
34            dtype: source.dtype().clone(),
35            source,
36            cached: Arc::new(OnceLock::new()),
37            async_compute_lock: Arc::new(AsyncMutex::new(())),
38            stats: ArrayStats::default(),
39        }
40    }
41
42    /// Returns the current array reference.
43    ///
44    /// After materialization, returns the cached result. Otherwise, returns the source.
45    /// If materialization failed, falls back to the source.
46    pub(super) fn current_array_ref(&self) -> &ArrayRef {
47        match self.cached.get() {
48            Some(Ok(arr)) => arr,
49            _ => &self.source,
50        }
51    }
52
53    /// Compute and cache the result. The computation runs exactly once via `OnceLock`.
54    ///
55    /// If the computation fails, the error is cached and returned on all subsequent calls.
56    pub fn get_or_compute(
57        &self,
58        f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
59    ) -> VortexResult<ArrayRef> {
60        let result = self
61            .cached
62            .get_or_init(|| f(&self.source).map(|c| c.into_array()).map_err(Arc::new));
63        result.clone().map_err(Into::into)
64    }
65
66    /// Async version of `get_or_compute`.
67    pub async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
68    where
69        F: FnOnce(ArrayRef) -> Fut,
70        Fut: Future<Output = VortexResult<Canonical>>,
71    {
72        // Fast path: already computed.
73        if let Some(result) = self.cached.get() {
74            return result.clone().map_err(Into::into);
75        }
76
77        // Serialize async computation to prevent redundant work.
78        let _guard = self.async_compute_lock.lock().await;
79
80        // Double-check after acquiring the lock.
81        if let Some(result) = self.cached.get() {
82            return result.clone().map_err(Into::into);
83        }
84
85        let computed = f(self.source.clone())
86            .await
87            .map(|c| c.into_array())
88            .map_err(Arc::new);
89
90        let result = self.cached.get_or_init(|| computed);
91        result.clone().map_err(Into::into)
92    }
93
94    pub(super) fn set_source(&mut self, source: ArrayRef) {
95        self.dtype = source.dtype().clone();
96        self.source = source;
97        self.cached = Arc::new(OnceLock::new());
98        self.async_compute_lock = Arc::new(AsyncMutex::new(()));
99    }
100}