1use crate::indexing::plan::IndexPlan;
2use crate::indexing::selectors::SliceSelector;
3use crate::interpreter::errors::mex;
4use runmat_builtins::{CellArray, ComplexTensor, StringArray, Tensor, Value};
5use runmat_runtime::RuntimeError;
6
7pub fn build_subsasgn_paren_cell(numeric: &[Value]) -> Result<Value, RuntimeError> {
8 let cell = CellArray::new(numeric.to_vec(), 1, numeric.len())
9 .map_err(|e| format!("subsasgn build error: {e}"))?;
10 Ok(Value::Cell(cell))
11}
12
13pub async fn object_subsasgn_paren(
14 base: Value,
15 numeric: &[Value],
16 rhs: Value,
17) -> Result<Value, RuntimeError> {
18 let cell = build_subsasgn_paren_cell(numeric)?;
19 match base {
20 Value::Object(obj) => {
21 let args = vec![
22 Value::Object(obj),
23 Value::String("subsasgn".to_string()),
24 Value::String("()".to_string()),
25 cell,
26 rhs,
27 ];
28 runmat_runtime::call_builtin_async("call_method", &args).await
29 }
30 Value::HandleObject(handle) => {
31 let args = vec![
32 Value::HandleObject(handle),
33 Value::String("subsasgn".to_string()),
34 Value::String("()".to_string()),
35 cell,
36 rhs,
37 ];
38 runmat_runtime::call_builtin_async("call_method", &args).await
39 }
40 other => Err(format!("slice subsasgn requires object/handle, got {other:?}").into()),
41 }
42}
43
44pub enum ComplexRhsView {
45 Scalar((f64, f64)),
46 Tensor {
47 data: Vec<(f64, f64)>,
48 shape: Vec<usize>,
49 strides: Vec<usize>,
50 },
51}
52
53pub fn build_complex_rhs_view(
54 rhs: &Value,
55 selection_lengths: &[usize],
56) -> Result<ComplexRhsView, RuntimeError> {
57 match rhs {
58 Value::Complex(re, im) => Ok(ComplexRhsView::Scalar((*re, *im))),
59 Value::Num(n) => Ok(ComplexRhsView::Scalar((*n, 0.0))),
60 Value::ComplexTensor(rt) => {
61 let dims = selection_lengths.len();
62 let mut shape = rt.shape.clone();
63 if shape.len() < dims {
64 shape.resize(dims, 1);
65 }
66 if shape.len() > dims {
67 if shape.iter().skip(dims).any(|&s| s != 1) {
68 return Err("shape mismatch for slice assign".to_string().into());
69 }
70 shape.truncate(dims);
71 }
72 for d in 0..dims {
73 let out_len = selection_lengths[d];
74 let rhs_len = shape[d];
75 if !(rhs_len == 1 || rhs_len == out_len) {
76 return Err("shape mismatch for slice assign".to_string().into());
77 }
78 }
79 let mut rstrides = vec![0usize; dims];
80 let mut racc = 1usize;
81 for d in 0..dims {
82 rstrides[d] = racc;
83 racc *= shape[d];
84 }
85 Ok(ComplexRhsView::Tensor {
86 data: rt.data.clone(),
87 shape,
88 strides: rstrides,
89 })
90 }
91 _ => Err("rhs must be numeric or tensor".to_string().into()),
92 }
93}
94
95pub fn scatter_complex_with_plan(
96 t: &mut ComplexTensor,
97 plan: &IndexPlan,
98 rhs_view: &ComplexRhsView,
99) -> Result<(), RuntimeError> {
100 let dims = plan.dims;
101 let mut idx = vec![0usize; dims];
102 if plan.indices.is_empty() {
103 return Ok(());
104 }
105 let selection_lengths = if plan.selection_lengths.is_empty() {
106 plan.output_shape.clone()
107 } else {
108 plan.selection_lengths.clone()
109 };
110 loop {
111 let mut rlin = 0usize;
112 match rhs_view {
113 ComplexRhsView::Scalar(val) => {
114 let pos = plan.indices[rlin] as usize;
115 t.data[pos] = *val;
116 }
117 ComplexRhsView::Tensor {
118 data,
119 shape,
120 strides,
121 } => {
122 for d in 0..dims {
123 let rhs_len = shape[d];
124 let pos = if rhs_len == 1 { 0 } else { idx[d] };
125 rlin += pos * strides[d];
126 }
127 let lin_pos = {
128 let mut p = 0usize;
129 let mut mul = 1usize;
130 for d in 0..dims {
131 p += idx[d] * mul;
132 mul *= selection_lengths[d].max(1);
133 }
134 p
135 };
136 let dst = plan.indices[lin_pos] as usize;
137 t.data[dst] = data[rlin];
138 }
139 }
140 let mut d = 0usize;
141 while d < dims {
142 idx[d] += 1;
143 if idx[d] < selection_lengths[d].max(1) {
144 break;
145 }
146 idx[d] = 0;
147 d += 1;
148 }
149 if d == dims {
150 break;
151 }
152 }
153 Ok(())
154}
155
156pub enum StringRhsView {
157 Scalar(String),
158 Tensor {
159 data: Vec<String>,
160 shape: Vec<usize>,
161 strides: Vec<usize>,
162 },
163}
164
165pub fn build_string_rhs_view(
166 rhs: &Value,
167 selection_lengths: &[usize],
168) -> Result<StringRhsView, RuntimeError> {
169 let scalar = match rhs {
170 Value::String(s) => Some(s.clone()),
171 Value::CharArray(ca) => Some(ca.to_string()),
172 _ => None,
173 };
174 if let Some(s) = scalar {
175 return Ok(StringRhsView::Scalar(s));
176 }
177 if let Value::StringArray(rt) = rhs {
178 let dims = selection_lengths.len();
179 let mut shape = rt.shape.clone();
180 if shape.len() < dims {
181 shape.resize(dims, 1);
182 }
183 if shape.len() > dims {
184 if shape.iter().skip(dims).any(|&s| s != 1) {
185 return Err("shape mismatch for slice assign".to_string().into());
186 }
187 shape.truncate(dims);
188 }
189 for d in 0..dims {
190 let out_len = selection_lengths[d];
191 let rhs_len = shape[d];
192 if !(rhs_len == 1 || rhs_len == out_len) {
193 return Err("shape mismatch for slice assign".to_string().into());
194 }
195 }
196 let mut rstrides = vec![0usize; dims];
197 let mut racc = 1usize;
198 for d in 0..dims {
199 rstrides[d] = racc;
200 racc *= shape[d];
201 }
202 return Ok(StringRhsView::Tensor {
203 data: rt.data.clone(),
204 shape,
205 strides: rstrides,
206 });
207 }
208 Err("rhs must be string or string array".to_string().into())
209}
210
211pub fn scatter_string_with_plan(
212 sa: &mut StringArray,
213 plan: &IndexPlan,
214 rhs_view: &StringRhsView,
215) -> Result<(), RuntimeError> {
216 let dims = plan.dims;
217 let mut idx = vec![0usize; dims];
218 if plan.indices.is_empty() {
219 return Ok(());
220 }
221 let selection_lengths = if plan.selection_lengths.is_empty() {
222 plan.output_shape.clone()
223 } else {
224 plan.selection_lengths.clone()
225 };
226 loop {
227 match rhs_view {
228 StringRhsView::Scalar(val) => {
229 let lin_pos = {
230 let mut p = 0usize;
231 let mut mul = 1usize;
232 for d in 0..dims {
233 p += idx[d] * mul;
234 mul *= selection_lengths[d].max(1);
235 }
236 p
237 };
238 let dst = plan.indices[lin_pos] as usize;
239 sa.data[dst] = val.clone();
240 }
241 StringRhsView::Tensor {
242 data,
243 shape,
244 strides,
245 } => {
246 let mut rlin = 0usize;
247 for d in 0..dims {
248 let rhs_len = shape[d];
249 let pos = if rhs_len == 1 { 0 } else { idx[d] };
250 rlin += pos * strides[d];
251 }
252 let lin_pos = {
253 let mut p = 0usize;
254 let mut mul = 1usize;
255 for d in 0..dims {
256 p += idx[d] * mul;
257 mul *= selection_lengths[d].max(1);
258 }
259 p
260 };
261 let dst = plan.indices[lin_pos] as usize;
262 sa.data[dst] = data[rlin].clone();
263 }
264 }
265 let mut d = 0usize;
266 while d < dims {
267 idx[d] += 1;
268 if idx[d] < selection_lengths[d].max(1) {
269 break;
270 }
271 idx[d] = 0;
272 d += 1;
273 }
274 if d == dims {
275 break;
276 }
277 }
278 Ok(())
279}
280
281pub async fn materialize_rhs_real_for_plan(
282 rhs: &Value,
283 plan: &IndexPlan,
284) -> Result<Vec<f64>, RuntimeError> {
285 if plan.dims == 1 {
286 let count = plan.selection_lengths.first().copied().unwrap_or(0);
287 materialize_rhs_linear_real(rhs, count).await
288 } else {
289 materialize_rhs_nd_real(rhs, &plan.selection_lengths).await
290 }
291}
292
293pub fn scatter_real_with_plan(
294 t: &mut Tensor,
295 plan: &IndexPlan,
296 rhs_values: &[f64],
297) -> Result<(), RuntimeError> {
298 if rhs_values.len() != plan.indices.len() {
299 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
300 }
301 for (&dst, &value) in plan.indices.iter().zip(rhs_values.iter()) {
302 t.data[dst as usize] = value;
303 }
304 Ok(())
305}
306
307pub async fn assign_tensor_with_plan(
308 mut t: Tensor,
309 plan: &IndexPlan,
310 rhs: &Value,
311) -> Result<Value, RuntimeError> {
312 if plan.indices.is_empty() {
313 return Ok(Value::Tensor(t));
314 }
315 let rhs_values = materialize_rhs_real_for_plan(rhs, plan).await?;
316 scatter_real_with_plan(&mut t, plan, &rhs_values)?;
317 Ok(Value::Tensor(t))
318}
319
320pub async fn assign_gpu_slice_with_plan(
321 handle: &runmat_accelerate_api::GpuTensorHandle,
322 plan: &IndexPlan,
323 rhs: &Value,
324) -> Result<Value, RuntimeError> {
325 if plan.indices.is_empty() {
326 return Ok(Value::GpuTensor(handle.clone()));
327 }
328 let provider = runmat_accelerate_api::provider().ok_or_else(|| {
329 mex(
330 "AccelerationProviderUnavailable",
331 "No acceleration provider registered",
332 )
333 })?;
334 if let Value::GpuTensor(vh) = rhs {
335 let rows = plan.base_shape.first().copied().unwrap_or(1);
336 let cols = plan.base_shape.get(1).copied().unwrap_or(1);
337 if let Some(col) = plan.properties.full_column {
338 if col < cols {
339 let v_rows = match vh.shape.len() {
340 1 | 2 => vh.shape[0],
341 _ => 0,
342 };
343 if v_rows == rows {
344 if let Ok(new_h) = provider.scatter_column(handle, col, vh) {
345 return Ok(Value::GpuTensor(new_h));
346 }
347 }
348 }
349 }
350 if let Some(row) = plan.properties.full_row {
351 if row < rows {
352 let v_cols = match vh.shape.len() {
353 1 => vh.shape[0],
354 2 => vh.shape[1],
355 _ => 0,
356 };
357 if v_cols == cols {
358 if let Ok(new_h) = provider.scatter_row(handle, row, vh) {
359 return Ok(Value::GpuTensor(new_h));
360 }
361 }
362 }
363 }
364 }
365 let rhs_values = materialize_rhs_real_for_plan(rhs, plan).await?;
366 let value_shape = vec![rhs_values.len().max(1), 1];
367 let upload_result = if rhs_values.is_empty() {
368 provider.zeros(&[0, 1])
369 } else {
370 provider.upload(&runmat_accelerate_api::HostTensorView {
371 data: &rhs_values,
372 shape: &value_shape,
373 })
374 };
375 if let Ok(values_handle) = upload_result {
376 if provider
377 .scatter_linear(handle, &plan.indices, &values_handle)
378 .is_ok()
379 {
380 return Ok(Value::GpuTensor(handle.clone()));
381 }
382 }
383
384 let host = provider
385 .download(handle)
386 .await
387 .map_err(|e| format!("gather for slice assign: {e}"))?;
388 let mut t = Tensor::new(host.data, host.shape).map_err(|e| format!("slice assign: {e}"))?;
389 scatter_real_with_plan(&mut t, plan, &rhs_values)?;
390 upload_tensor_to_gpu(&t)
391}
392
393pub async fn materialize_rhs_linear_real(
394 rhs: &Value,
395 count: usize,
396) -> Result<Vec<f64>, RuntimeError> {
397 let host_rhs = runmat_runtime::dispatcher::gather_if_needed_async(rhs).await?;
398 match host_rhs {
399 Value::Num(n) => Ok(vec![n; count]),
400 Value::Int(int_val) => Ok(vec![int_val.to_f64(); count]),
401 Value::Bool(b) => Ok(vec![if b { 1.0 } else { 0.0 }; count]),
402 Value::Tensor(t) => {
403 if t.data.len() == count {
404 Ok(t.data)
405 } else if t.data.len() == 1 {
406 Ok(vec![t.data[0]; count])
407 } else {
408 Err(mex("ShapeMismatch", "shape mismatch for slice assign"))
409 }
410 }
411 Value::LogicalArray(la) => {
412 if la.data.len() == count {
413 Ok(la
414 .data
415 .into_iter()
416 .map(|b| if b != 0 { 1.0 } else { 0.0 })
417 .collect())
418 } else if la.data.len() == 1 {
419 let val = if la.data[0] != 0 { 1.0 } else { 0.0 };
420 Ok(vec![val; count])
421 } else {
422 Err(mex("ShapeMismatch", "shape mismatch for slice assign"))
423 }
424 }
425 other => Err(mex(
426 "InvalidSliceAssignmentRhs",
427 &format!("slice assign: unsupported RHS type {:?}", other),
428 )),
429 }
430}
431
432pub async fn materialize_rhs_nd_real(
433 rhs: &Value,
434 selection_lengths: &[usize],
435) -> Result<Vec<f64>, RuntimeError> {
436 let rhs_host = runmat_runtime::dispatcher::gather_if_needed_async(rhs).await?;
437 enum RhsView {
438 Scalar(f64),
439 Tensor {
440 data: Vec<f64>,
441 shape: Vec<usize>,
442 strides: Vec<usize>,
443 },
444 }
445 let view = match rhs_host {
446 Value::Num(n) => RhsView::Scalar(n),
447 Value::Int(iv) => RhsView::Scalar(iv.to_f64()),
448 Value::Bool(b) => RhsView::Scalar(if b { 1.0 } else { 0.0 }),
449 Value::Tensor(t) => {
450 let mut shape = t.shape.clone();
451 if shape.len() < selection_lengths.len() {
452 shape.resize(selection_lengths.len(), 1);
453 }
454 if shape.len() > selection_lengths.len() {
455 if shape.iter().skip(selection_lengths.len()).any(|&s| s != 1) {
456 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
457 }
458 shape.truncate(selection_lengths.len());
459 }
460 for (dim_len, &sel_len) in shape.iter().zip(selection_lengths.iter()) {
461 if *dim_len != 1 && *dim_len != sel_len {
462 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
463 }
464 }
465 let mut strides = vec![1usize; selection_lengths.len()];
466 for d in 1..selection_lengths.len() {
467 strides[d] = strides[d - 1] * shape[d - 1].max(1);
468 }
469 if t.data.len()
470 != shape
471 .iter()
472 .copied()
473 .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)))
474 {
475 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
476 }
477 RhsView::Tensor {
478 data: t.data,
479 shape,
480 strides,
481 }
482 }
483 Value::LogicalArray(la) => {
484 if la.shape.len() > selection_lengths.len()
485 && la
486 .shape
487 .iter()
488 .skip(selection_lengths.len())
489 .any(|&s| s != 1)
490 {
491 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
492 }
493 let mut shape = la.shape.clone();
494 if shape.len() < selection_lengths.len() {
495 shape.resize(selection_lengths.len(), 1);
496 } else {
497 shape.truncate(selection_lengths.len());
498 }
499 for (dim_len, &sel_len) in shape.iter().zip(selection_lengths.iter()) {
500 if *dim_len != 1 && *dim_len != sel_len {
501 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
502 }
503 }
504 let mut strides = vec![1usize; selection_lengths.len()];
505 for d in 1..selection_lengths.len() {
506 strides[d] = strides[d - 1] * shape[d - 1].max(1);
507 }
508 if la.data.len()
509 != shape
510 .iter()
511 .copied()
512 .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)))
513 {
514 return Err(mex("ShapeMismatch", "shape mismatch for slice assign"));
515 }
516 let data: Vec<f64> = la
517 .data
518 .into_iter()
519 .map(|b| if b != 0 { 1.0 } else { 0.0 })
520 .collect();
521 RhsView::Tensor {
522 data,
523 shape,
524 strides,
525 }
526 }
527 other => {
528 return Err(mex(
529 "InvalidSliceAssignmentRhs",
530 &format!("slice assign: unsupported RHS type {:?}", other),
531 ))
532 }
533 };
534
535 let total = selection_lengths
536 .iter()
537 .copied()
538 .fold(1usize, |acc, len| acc.saturating_mul(len.max(1)));
539 let mut out = Vec::with_capacity(total);
540 let mut idx = vec![0usize; selection_lengths.len()];
541 if selection_lengths.is_empty() {
542 return Ok(out);
543 }
544 loop {
545 match &view {
546 RhsView::Scalar(val) => out.push(*val),
547 RhsView::Tensor {
548 data,
549 shape,
550 strides,
551 } => {
552 let mut rlin = 0usize;
553 for d in 0..idx.len() {
554 let rhs_len = shape[d];
555 let pos = if rhs_len == 1 { 0 } else { idx[d] };
556 rlin += pos * strides[d];
557 }
558 out.push(data.get(rlin).copied().unwrap_or(0.0));
559 }
560 }
561 let mut d = 0usize;
562 while d < idx.len() {
563 idx[d] += 1;
564 if idx[d] < selection_lengths[d].max(1) {
565 break;
566 }
567 idx[d] = 0;
568 d += 1;
569 }
570 if d == idx.len() {
571 break;
572 }
573 }
574 Ok(out)
575}
576
577pub fn upload_tensor_to_gpu(t: &Tensor) -> Result<Value, RuntimeError> {
578 let provider = runmat_accelerate_api::provider().ok_or_else(|| {
579 mex(
580 "AccelerationProviderUnavailable",
581 "No acceleration provider registered",
582 )
583 })?;
584 let view = runmat_accelerate_api::HostTensorView {
585 data: &t.data,
586 shape: &t.shape,
587 };
588 let new_h = provider
589 .upload(&view)
590 .map_err(|e| format!("reupload after slice assign: {e}"))?;
591 Ok(Value::GpuTensor(new_h))
592}
593
594pub struct ExprSelectorSpec<'a> {
595 pub dims: usize,
596 pub colon_mask: u32,
597 pub end_mask: u32,
598 pub range_dims: &'a [usize],
599 pub range_params: &'a [(f64, f64)],
600 pub range_start_exprs: &'a [Option<crate::bytecode::EndExpr>],
601 pub range_step_exprs: &'a [Option<crate::bytecode::EndExpr>],
602 pub range_end_exprs: &'a [crate::bytecode::EndExpr],
603 pub numeric: &'a [Value],
604 pub shape: &'a [usize],
605}
606
607pub async fn build_expr_selectors<ResolveEnd, Fut>(
608 spec: ExprSelectorSpec<'_>,
609 mut resolve_end: ResolveEnd,
610) -> Result<Vec<SliceSelector>, RuntimeError>
611where
612 ResolveEnd: FnMut(usize, &crate::bytecode::EndExpr) -> Fut,
613 Fut: std::future::Future<Output = Result<i64, RuntimeError>>,
614{
615 let mut selectors: Vec<SliceSelector> = Vec::with_capacity(spec.dims);
616 let mut num_iter = 0usize;
617 let mut rp_iter = 0usize;
618 for d in 0..spec.dims {
619 if let Some(pos) = spec.range_dims.iter().position(|&rd| rd == d) {
620 let (raw_st, raw_sp) = spec.range_params[rp_iter];
621 let dim_len = *spec.shape.get(d).unwrap_or(&1);
622 let st = if let Some(expr) = &spec.range_start_exprs[rp_iter] {
623 resolve_end(dim_len, expr).await? as f64
624 } else {
625 raw_st
626 };
627 let sp = if let Some(expr) = &spec.range_step_exprs[rp_iter] {
628 resolve_end(dim_len, expr).await? as f64
629 } else {
630 raw_sp
631 };
632 rp_iter += 1;
633 let step_i = if sp >= 0.0 {
634 sp as i64
635 } else {
636 -(sp.abs() as i64)
637 };
638 let end_i = resolve_end(dim_len, &spec.range_end_exprs[pos]).await?;
639 if step_i == 0 {
640 return Err(mex("IndexStepZero", "Index step cannot be zero"));
641 }
642 let mut vals = Vec::new();
643 let mut cur = st as i64;
644 if step_i > 0 {
645 while cur <= end_i {
646 if cur < 1 || cur > dim_len as i64 {
647 break;
648 }
649 vals.push(cur as usize);
650 cur += step_i;
651 }
652 } else {
653 while cur >= end_i {
654 if cur < 1 || cur > dim_len as i64 {
655 break;
656 }
657 vals.push(cur as usize);
658 cur += step_i;
659 }
660 }
661 selectors.push(SliceSelector::Indices(vals));
662 continue;
663 }
664 let is_colon = (spec.colon_mask & (1u32 << d)) != 0;
665 let is_end = (spec.end_mask & (1u32 << d)) != 0;
666 if is_colon {
667 selectors.push(SliceSelector::Colon);
668 } else if is_end {
669 selectors.push(SliceSelector::Scalar(*spec.shape.get(d).unwrap_or(&1)));
670 } else {
671 let v = spec
672 .numeric
673 .get(num_iter)
674 .ok_or_else(|| mex("MissingNumericIndex", "missing numeric index"))?;
675 num_iter += 1;
676 let dim_len = *spec.shape.get(d).unwrap_or(&1);
677 selectors.push(
678 match crate::indexing::selectors::selector_from_value_dim(v, dim_len).await? {
679 SliceSelector::LinearIndices { values, .. } => SliceSelector::Indices(values),
680 other => other,
681 },
682 );
683 }
684 }
685 Ok(selectors)
686}