Skip to main content

vortex_array/arrays/shared/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use async_lock::Mutex as AsyncMutex;
11use smallvec::smallvec;
12use vortex_error::SharedVortexResult;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::IntoArray;
19use crate::array::Array;
20use crate::array::ArrayParts;
21use crate::array::TypedArrayRef;
22use crate::arrays::Shared;
23
24/// The source array that is shared and lazily computed.
25pub(super) const SOURCE_SLOT: usize = 0;
26pub(super) const NUM_SLOTS: usize = 1;
27pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["source"];
28
29/// A lazily-executing array wrapper with a one-way transition from source to cached form.
30///
31/// Before materialization, operations delegate to the source array.
32/// After materialization (via `get_or_compute`), operations delegate to the cached result.
33#[derive(Debug, Clone)]
34pub struct SharedData {
35    cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
36    async_compute_lock: Arc<AsyncMutex<()>>,
37}
38
39impl Display for SharedData {
40    fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
41        Ok(())
42    }
43}
44
45#[expect(async_fn_in_trait)]
46pub trait SharedArrayExt: TypedArrayRef<Shared> {
47    fn source(&self) -> &ArrayRef {
48        self.as_ref().slots()[SOURCE_SLOT]
49            .as_ref()
50            .vortex_expect("validated shared source slot")
51    }
52
53    fn current_array_ref(&self) -> &ArrayRef {
54        match self.cached.get() {
55            Some(Ok(arr)) => arr,
56            _ => self.source(),
57        }
58    }
59
60    fn get_or_compute(
61        &self,
62        f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
63    ) -> VortexResult<ArrayRef> {
64        let result = self
65            .cached
66            .get_or_init(|| f(self.source()).map(|c| c.into_array()).map_err(Arc::new));
67        result.clone().map_err(Into::into)
68    }
69
70    async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
71    where
72        F: FnOnce(ArrayRef) -> Fut,
73        Fut: Future<Output = VortexResult<Canonical>>,
74    {
75        if let Some(result) = self.cached.get() {
76            return result.clone().map_err(Into::into);
77        }
78
79        let _guard = self.async_compute_lock.lock().await;
80
81        if let Some(result) = self.cached.get() {
82            return result.clone().map_err(Into::into);
83        }
84
85        let computed = f(self.source().clone())
86            .await
87            .map(|c| c.into_array())
88            .map_err(Arc::new);
89
90        let result = self.cached.get_or_init(|| computed);
91        result.clone().map_err(Into::into)
92    }
93}
94impl<T: TypedArrayRef<Shared>> SharedArrayExt for T {}
95
96impl SharedData {
97    pub fn new() -> Self {
98        Self {
99            cached: Arc::new(OnceLock::new()),
100            async_compute_lock: Arc::new(AsyncMutex::new(())),
101        }
102    }
103}
104
105impl Default for SharedData {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl Array<Shared> {
112    /// Creates a new `SharedArray`.
113    pub fn new(source: ArrayRef) -> Self {
114        let dtype = source.dtype().clone();
115        let len = source.len();
116        unsafe {
117            Array::from_parts_unchecked(
118                ArrayParts::new(Shared, dtype, len, SharedData::new())
119                    .with_slots(smallvec![Some(source)]),
120            )
121        }
122    }
123}