runmat_vm/indexing/
selectors.rs1use crate::indexing::plan::total_len_from_shape;
2use crate::interpreter::errors::mex;
3use runmat_builtins::Value;
4use runmat_runtime::{
5 builtins::common::shape::is_scalar_shape, dispatcher::gather_if_needed_async, RuntimeError,
6};
7
8pub type VmResult<T> = Result<T, RuntimeError>;
9
10#[derive(Clone)]
11pub enum SliceSelector {
12 Colon,
13 Scalar(usize),
14 Indices(Vec<usize>),
15 LinearIndices {
16 values: Vec<usize>,
17 output_shape: Vec<usize>,
18 },
19}
20
21fn index_scalar_from_host_value(value: &Value) -> Option<i64> {
22 match value {
23 Value::Num(n) => Some(*n as i64),
24 Value::Int(int_val) => Some(int_val.to_i64()),
25 Value::Tensor(t) if t.data.len() == 1 && is_scalar_shape(&t.shape) => {
26 Some(t.data[0] as i64)
27 }
28 _ => None,
29 }
30}
31
32pub async fn index_scalar_from_value(value: &Value) -> VmResult<Option<i64>> {
33 if let Value::GpuTensor(handle) = value {
34 let total = total_len_from_shape(&handle.shape);
35 if total != 1 {
36 return Ok(None);
37 }
38 let gathered = gather_if_needed_async(value).await?;
39 return Ok(index_scalar_from_host_value(&gathered));
40 }
41 Ok(index_scalar_from_host_value(value))
42}
43
44pub async fn materialize_index_value(value: &Value) -> VmResult<Value> {
45 if matches!(value, Value::GpuTensor(_)) {
46 return gather_if_needed_async(value)
47 .await
48 .map_err(|e| mex("IndexGather", &format!("Failed to gather index value: {e}")));
49 }
50 Ok(value.clone())
51}
52
53pub async fn indices_from_value_linear(value: &Value, total_len: usize) -> VmResult<Vec<usize>> {
54 if let Value::Bool(b) = value {
55 return Ok(if *b { vec![1] } else { Vec::new() });
56 }
57 if let Value::LogicalArray(la) = value {
58 if la.data.len() == 1 && is_scalar_shape(&la.shape) {
59 return Ok(if la.data[0] != 0 { vec![1] } else { Vec::new() });
60 }
61 }
62 if let Some(idx_val) = index_scalar_from_value(value).await? {
63 if idx_val < 1 || (idx_val as usize) > total_len {
64 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
65 }
66 return Ok(vec![idx_val as usize]);
67 }
68 let materialized;
69 let value = if matches!(value, Value::GpuTensor(_)) {
70 materialized = materialize_index_value(value).await?;
71 &materialized
72 } else {
73 value
74 };
75 match value {
76 Value::Tensor(idx_t) => {
77 let len = idx_t.shape.iter().product::<usize>();
78 let mut indices = Vec::with_capacity(len);
79 for &val in &idx_t.data {
80 let idx = val as isize;
81 if idx < 1 || (idx as usize) > total_len {
82 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
83 }
84 indices.push(idx as usize);
85 }
86 Ok(indices)
87 }
88 Value::LogicalArray(la) => {
89 if la.data.len() != total_len {
90 return Err(mex(
91 "IndexShape",
92 "Logical mask length mismatch for linear indexing",
93 ));
94 }
95 let mut indices = Vec::new();
96 for (i, &b) in la.data.iter().enumerate() {
97 if b != 0 {
98 indices.push(i + 1);
99 }
100 }
101 Ok(indices)
102 }
103 _ => Err(mex(
104 "UnsupportedIndexType",
105 "Unsupported index type for linear indexing",
106 )),
107 }
108}
109
110pub async fn selector_from_value_dim(value: &Value, dim_len: usize) -> VmResult<SliceSelector> {
111 if let Value::Bool(b) = value {
112 if *b {
113 return Ok(SliceSelector::Indices(vec![1]));
114 }
115 return Ok(SliceSelector::Indices(Vec::new()));
116 }
117 if let Value::LogicalArray(la) = value {
118 if la.data.len() == 1 && is_scalar_shape(&la.shape) {
119 if la.data[0] != 0 {
120 return Ok(SliceSelector::Indices(vec![1]));
121 }
122 return Ok(SliceSelector::Indices(Vec::new()));
123 }
124 }
125 if let Some(idx_val) = index_scalar_from_value(value).await? {
126 if idx_val < 1 || (idx_val as usize) > dim_len {
127 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
128 }
129 return Ok(SliceSelector::Scalar(idx_val as usize));
130 }
131 let materialized;
132 let value = if matches!(value, Value::GpuTensor(_)) {
133 materialized = materialize_index_value(value).await?;
134 &materialized
135 } else {
136 value
137 };
138 match value {
139 Value::Tensor(idx_t) => {
140 let len = idx_t.shape.iter().product::<usize>();
141 let mut indices = Vec::with_capacity(len);
142 for &val in &idx_t.data {
143 let idx = val as isize;
144 if idx < 1 || (idx as usize) > dim_len {
145 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
146 }
147 indices.push(idx as usize);
148 }
149 Ok(SliceSelector::Indices(indices))
150 }
151 Value::LogicalArray(la) => {
152 if la.data.len() != dim_len {
153 return Err(mex(
154 "IndexShape",
155 "Logical mask length mismatch for dimension",
156 ));
157 }
158 let mut indices = Vec::new();
159 for (i, &b) in la.data.iter().enumerate() {
160 if b != 0 {
161 indices.push(i + 1);
162 }
163 }
164 Ok(SliceSelector::Indices(indices))
165 }
166 _ => Err(mex(
167 "UnsupportedIndexType",
168 "Unsupported index type for slicing",
169 )),
170 }
171}
172
173pub async fn build_slice_selectors(
174 dims: usize,
175 colon_mask: u32,
176 end_mask: u32,
177 numeric: &[Value],
178 base_shape: &[usize],
179) -> VmResult<Vec<SliceSelector>> {
180 let mut selectors = Vec::with_capacity(dims);
181 if dims == 1 {
182 let total_len = total_len_from_shape(base_shape);
183 if (colon_mask & 1u32) != 0 {
184 selectors.push(SliceSelector::Indices((1..=total_len).collect()));
185 return Ok(selectors);
186 }
187 if (end_mask & 1u32) != 0 {
188 selectors.push(SliceSelector::Scalar(total_len.max(1)));
189 return Ok(selectors);
190 }
191 let value = numeric.first().ok_or_else(|| {
192 mex(
193 "MissingNumericIndex",
194 "missing numeric index for linear slice",
195 )
196 })?;
197 let materialized = materialize_index_value(value).await?;
198 if let Value::Tensor(idx_t) = &materialized {
199 let len = idx_t.shape.iter().product::<usize>();
200 let mut indices = Vec::with_capacity(len);
201 for &val in &idx_t.data {
202 let idx = val as isize;
203 if idx < 1 || (idx as usize) > total_len {
204 return Err(mex("IndexOutOfBounds", "Index out of bounds"));
205 }
206 indices.push(idx as usize);
207 }
208 selectors.push(SliceSelector::LinearIndices {
209 values: indices,
210 output_shape: idx_t.shape.clone(),
211 });
212 } else {
213 let idxs = indices_from_value_linear(&materialized, total_len).await?;
214 selectors.push(SliceSelector::Indices(idxs));
215 }
216 return Ok(selectors);
217 }
218
219 let mut numeric_iter = 0usize;
220 for d in 0..dims {
221 let is_colon = (colon_mask & (1u32 << d)) != 0;
222 if is_colon {
223 selectors.push(SliceSelector::Colon);
224 continue;
225 }
226 let dim_len = base_shape.get(d).copied().unwrap_or(1);
227 let is_end = (end_mask & (1u32 << d)) != 0;
228 if is_end {
229 selectors.push(SliceSelector::Scalar(dim_len));
230 continue;
231 }
232 let value = numeric
233 .get(numeric_iter)
234 .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index for slice"))?;
235 numeric_iter += 1;
236 selectors.push(selector_from_value_dim(value, dim_len).await?);
237 }
238 Ok(selectors)
239}