vortex_array/arrays/shared/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use async_lock::Mutex as AsyncMutex;
11use vortex_error::SharedVortexResult;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::IntoArray;
18use crate::array::Array;
19use crate::array::ArrayParts;
20use crate::array::TypedArrayRef;
21use crate::arrays::Shared;
22
23pub(super) const SOURCE_SLOT: usize = 0;
25pub(super) const NUM_SLOTS: usize = 1;
26pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["source"];
27
28#[derive(Debug, Clone)]
33pub struct SharedData {
34 cached: Arc<OnceLock<SharedVortexResult<ArrayRef>>>,
35 async_compute_lock: Arc<AsyncMutex<()>>,
36}
37
38impl Display for SharedData {
39 fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
40 Ok(())
41 }
42}
43
44#[allow(async_fn_in_trait)]
45pub trait SharedArrayExt: TypedArrayRef<Shared> {
46 fn source(&self) -> &ArrayRef {
47 self.as_ref().slots()[SOURCE_SLOT]
48 .as_ref()
49 .vortex_expect("validated shared source slot")
50 }
51
52 fn current_array_ref(&self) -> &ArrayRef {
53 match self.cached.get() {
54 Some(Ok(arr)) => arr,
55 _ => self.source(),
56 }
57 }
58
59 fn get_or_compute(
60 &self,
61 f: impl FnOnce(&ArrayRef) -> VortexResult<Canonical>,
62 ) -> VortexResult<ArrayRef> {
63 let result = self
64 .cached
65 .get_or_init(|| f(self.source()).map(|c| c.into_array()).map_err(Arc::new));
66 result.clone().map_err(Into::into)
67 }
68
69 async fn get_or_compute_async<F, Fut>(&self, f: F) -> VortexResult<ArrayRef>
70 where
71 F: FnOnce(ArrayRef) -> Fut,
72 Fut: Future<Output = VortexResult<Canonical>>,
73 {
74 if let Some(result) = self.cached.get() {
75 return result.clone().map_err(Into::into);
76 }
77
78 let _guard = self.async_compute_lock.lock().await;
79
80 if let Some(result) = self.cached.get() {
81 return result.clone().map_err(Into::into);
82 }
83
84 let computed = f(self.source().clone())
85 .await
86 .map(|c| c.into_array())
87 .map_err(Arc::new);
88
89 let result = self.cached.get_or_init(|| computed);
90 result.clone().map_err(Into::into)
91 }
92}
93impl<T: TypedArrayRef<Shared>> SharedArrayExt for T {}
94
95impl SharedData {
96 pub fn new() -> Self {
97 Self {
98 cached: Arc::new(OnceLock::new()),
99 async_compute_lock: Arc::new(AsyncMutex::new(())),
100 }
101 }
102}
103
104impl Default for SharedData {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl Array<Shared> {
111 pub fn new(source: ArrayRef) -> Self {
113 let dtype = source.dtype().clone();
114 let len = source.len();
115 unsafe {
116 Array::from_parts_unchecked(
117 ArrayParts::new(Shared, dtype, len, SharedData::new())
118 .with_slots(vec![Some(source)]),
119 )
120 }
121 }
122}