vortex_array/arrays/shared/
array.rs1use std::future::Future;
5use std::sync::Arc;
6use std::sync::OnceLock;
7
8use async_lock::Mutex as AsyncMutex;
9use vortex_error::SharedVortexResult;
10use vortex_error::VortexResult;
11
12use crate::ArrayRef;
13use crate::Canonical;
14use crate::IntoArray;
15use crate::dtype::DType;
16use crate::stats::ArrayStats;
17
18#[derive(Debug, Clone)]
23pub struct SharedArray {
24 source: ArrayRef,
25 cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
26 async_compute_lock: Arc<AsyncMutex<()>>,
27 pub(super) dtype: DType,
28 pub(super) stats: ArrayStats,
29}
30
31impl SharedArray {
32 pub fn new(source: ArrayRef) -> Self {
33 Self {
34 dtype: source.dtype().clone(),
35 source,
36 cached: Arc::new(OnceLock::new()),
37 async_compute_lock: Arc::new(AsyncMutex::new(())),
38 stats: ArrayStats::default(),
39 }
40 }
41
42 pub(super) fn current_array_ref(&self) -> &ArrayRef {
47 match self.cached.get() {
48 Some(Ok(arr)) => arr,
49 _ => &self.source,
50 }
51 }
52
53 pub fn get_or_compute(
57 &self,
58 f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
59 ) -> VortexResult<ArrayRef> {
60 let result = self
61 .cached
62 .get_or_init(|| f(&self.source).map(|c| c.into_array()).map_err(Arc::new));
63 result.clone().map_err(Into::into)
64 }
65
66 pub async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
68 where
69 F: FnOnce(ArrayRef) -> Fut,
70 Fut: Future<Output = VortexResult<Canonical>>,
71 {
72 if let Some(result) = self.cached.get() {
74 return result.clone().map_err(Into::into);
75 }
76
77 let _guard = self.async_compute_lock.lock().await;
79
80 if let Some(result) = self.cached.get() {
82 return result.clone().map_err(Into::into);
83 }
84
85 let computed = f(self.source.clone())
86 .await
87 .map(|c| c.into_array())
88 .map_err(Arc::new);
89
90 let result = self.cached.get_or_init(|| computed);
91 result.clone().map_err(Into::into)
92 }
93
94 pub(super) fn set_source(&mut self, source: ArrayRef) {
95 self.dtype = source.dtype().clone();
96 self.source = source;
97 self.cached = Arc::new(OnceLock::new());
98 self.async_compute_lock = Arc::new(AsyncMutex::new(()));
99 }
100}