vortex_array/arrays/shared/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use async_lock::Mutex as AsyncMutex;
11use smallvec::smallvec;
12use vortex_error::SharedVortexResult;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15
16use crate::ArrayRef;
17use crate::Canonical;
18use crate::IntoArray;
19use crate::array::Array;
20use crate::array::ArrayParts;
21use crate::array::TypedArrayRef;
22use crate::arrays::Shared;
23
24pub(super) const SOURCE_SLOT: usize = 0;
26pub(super) const NUM_SLOTS: usize = 1;
27pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["source"];
28
29#[derive(Debug, Clone)]
34pub struct SharedData {
35 cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
36 async_compute_lock: Arc<AsyncMutex<()>>,
37}
38
39impl Display for SharedData {
40 fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
41 Ok(())
42 }
43}
44
45#[expect(async_fn_in_trait)]
46pub trait SharedArrayExt: TypedArrayRef<Shared> {
47 fn source(&self) -> &ArrayRef {
48 self.as_ref().slots()[SOURCE_SLOT]
49 .as_ref()
50 .vortex_expect("validated shared source slot")
51 }
52
53 fn current_array_ref(&self) -> &ArrayRef {
54 match self.cached.get() {
55 Some(Ok(arr)) => arr,
56 _ => self.source(),
57 }
58 }
59
60 fn get_or_compute(
61 &self,
62 f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
63 ) -> VortexResult<ArrayRef> {
64 let result = self
65 .cached
66 .get_or_init(|| f(self.source()).map(|c| c.into_array()).map_err(Arc::new));
67 result.clone().map_err(Into::into)
68 }
69
70 async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
71 where
72 F: FnOnce(ArrayRef) -> Fut,
73 Fut: Future<Output = VortexResult<Canonical>>,
74 {
75 if let Some(result) = self.cached.get() {
76 return result.clone().map_err(Into::into);
77 }
78
79 let _guard = self.async_compute_lock.lock().await;
80
81 if let Some(result) = self.cached.get() {
82 return result.clone().map_err(Into::into);
83 }
84
85 let computed = f(self.source().clone())
86 .await
87 .map(|c| c.into_array())
88 .map_err(Arc::new);
89
90 let result = self.cached.get_or_init(|| computed);
91 result.clone().map_err(Into::into)
92 }
93}
94impl<T: TypedArrayRef<Shared>> SharedArrayExt for T {}
95
96impl SharedData {
97 pub fn new() -> Self {
98 Self {
99 cached: Arc::new(OnceLock::new()),
100 async_compute_lock: Arc::new(AsyncMutex::new(())),
101 }
102 }
103}
104
105impl Default for SharedData {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl Array<Shared> {
112 pub fn new(source: ArrayRef) -> Self {
114 let dtype = source.dtype().clone();
115 let len = source.len();
116 unsafe {
117 Array::from_parts_unchecked(
118 ArrayParts::new(Shared, dtype, len, SharedData::new())
119 .with_slots(smallvec![Some(source)]),
120 )
121 }
122 }
123}