Skip to main content

vortex_array/arrays/shared/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use async_lock::Mutex as AsyncMutex;
11use vortex_error::SharedVortexResult;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::IntoArray;
18use crate::array::Array;
19use crate::array::ArrayParts;
20use crate::array::TypedArrayRef;
21use crate::arrays::Shared;
22
23/// The source array that is shared and lazily computed.
24pub(super) const SOURCE_SLOT: usize = 0;
25pub(super) const NUM_SLOTS: usize = 1;
26pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["source"];
27
28/// A lazily-executing array wrapper with a one-way transition from source to cached form.
29///
30/// Before materialization, operations delegate to the source array.
31/// After materialization (via `get_or_compute`), operations delegate to the cached result.
32#[derive(Debug, Clone)]
33pub struct SharedData {
34    cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
35    async_compute_lock: Arc<AsyncMutex<()>>,
36}
37
38impl Display for SharedData {
39    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
40        Ok(())
41    }
42}
43
44#[allow(async_fn_in_trait)]
45pub trait SharedArrayExt: TypedArrayRef<Shared> {
46    fn source(&self) -> &ArrayRef {
47        self.as_ref().slots()[SOURCE_SLOT]
48            .as_ref()
49            .vortex_expect("validated shared source slot")
50    }
51
52    fn current_array_ref(&self) -> &ArrayRef {
53        match self.cached.get() {
54            Some(Ok(arr)) => arr,
55            _ => self.source(),
56        }
57    }
58
59    fn get_or_compute(
60        &self,
61        f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
62    ) -> VortexResult<ArrayRef> {
63        let result = self
64            .cached
65            .get_or_init(|| f(self.source()).map(|c| c.into_array()).map_err(Arc::new));
66        result.clone().map_err(Into::into)
67    }
68
69    async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
70    where
71        F: FnOnce(ArrayRef) -> Fut,
72        Fut: Future<Output = VortexResult<Canonical>>,
73    {
74        if let Some(result) = self.cached.get() {
75            return result.clone().map_err(Into::into);
76        }
77
78        let _guard = self.async_compute_lock.lock().await;
79
80        if let Some(result) = self.cached.get() {
81            return result.clone().map_err(Into::into);
82        }
83
84        let computed = f(self.source().clone())
85            .await
86            .map(|c| c.into_array())
87            .map_err(Arc::new);
88
89        let result = self.cached.get_or_init(|| computed);
90        result.clone().map_err(Into::into)
91    }
92}
93impl<T: TypedArrayRef<Shared>> SharedArrayExt for T {}
94
95impl SharedData {
96    pub fn new() -> Self {
97        Self {
98            cached: Arc::new(OnceLock::new()),
99            async_compute_lock: Arc::new(AsyncMutex::new(())),
100        }
101    }
102}
103
104impl Default for SharedData {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl Array<Shared> {
111    /// Creates a new `SharedArray`.
112    pub fn new(source: ArrayRef) -> Self {
113        let dtype = source.dtype().clone();
114        let len = source.len();
115        unsafe {
116            Array::from_parts_unchecked(
117                ArrayParts::new(Shared, dtype, len, SharedData::new())
118                    .with_slots(vec![Some(source)]),
119            )
120        }
121    }
122}