1use crate::core_ops::Tensor;
6use torsh_core::{
7 device::DeviceType,
8 dtype::TensorElement,
9 error::{Result, TorshError},
10};
11
12impl<T: TensorElement + Copy> Tensor<T> {
13 pub fn from_scalar(value: T, shape: &[usize], device: DeviceType) -> Result<Self>
15 where
16 T: Copy,
17 {
18 let numel = shape.iter().product::<usize>();
19 let data = vec![value; numel];
20 Self::from_data(data, shape.to_vec(), device)
21 }
22 pub fn fill_(&mut self, value: T) -> Result<()>
24 where
25 T: Copy,
26 {
27 for i in 0..self.numel() {
28 self.storage.set(i, value)?;
29 }
30 Ok(())
31 }
32 pub fn zero_(&mut self) -> Result<()>
34 where
35 T: Copy,
36 {
37 self.fill_(T::zero())
38 }
39 pub fn ones_(&mut self) -> Result<()>
41 where
42 T: Copy,
43 {
44 self.fill_(T::one())
45 }
46 pub fn copy_(&mut self, other: &Self) -> Result<()>
48 where
49 T: Copy,
50 {
51 if self.shape() != other.shape() {
52 return Err(TorshError::ShapeMismatch {
53 expected: self.shape().dims().to_vec(),
54 got: other.shape().dims().to_vec(),
55 });
56 }
57 let other_data = other.to_vec()?;
58 for (i, &value) in other_data.iter().enumerate() {
59 self.storage.set(i, value)?;
60 }
61 Ok(())
62 }
63 pub fn get_item(&self, indices: &[usize]) -> Result<T>
65 where
66 T: Copy,
67 {
68 if indices.len() != self.ndim() {
69 return Err(TorshError::InvalidArgument(format!(
70 "Expected {} indices, got {}",
71 self.ndim(),
72 indices.len()
73 )));
74 }
75 let binding = self.shape();
76 let shape = binding.dims();
77 for (i, &idx) in indices.iter().enumerate() {
78 if idx >= shape[i] {
79 return Err(TorshError::IndexOutOfBounds {
80 index: idx,
81 size: shape[i],
82 });
83 }
84 }
85 let flat_index = self.multi_to_flat_index(indices)?;
86 self.get_item_flat(flat_index)
87 }
88 pub fn set_item(&mut self, indices: &[usize], value: T) -> Result<()>
90 where
91 T: Copy,
92 {
93 if indices.len() != self.ndim() {
94 return Err(TorshError::InvalidArgument(format!(
95 "Expected {} indices, got {}",
96 self.ndim(),
97 indices.len()
98 )));
99 }
100 let binding = self.shape();
101 let shape = binding.dims();
102 for (i, &idx) in indices.iter().enumerate() {
103 if idx >= shape[i] {
104 return Err(TorshError::IndexOutOfBounds {
105 index: idx,
106 size: shape[i],
107 });
108 }
109 }
110 let flat_index = self.multi_to_flat_index(indices)?;
111 self.set_item_flat(flat_index, value)
112 }
113 pub fn get_item_flat(&self, index: usize) -> Result<T>
115 where
116 T: Copy,
117 {
118 if index >= self.numel() {
119 return Err(TorshError::IndexOutOfBounds {
120 index,
121 size: self.numel(),
122 });
123 }
124 self.storage.get(index)
125 }
126 pub fn set_item_flat(&mut self, index: usize, value: T) -> Result<()>
128 where
129 T: Copy,
130 {
131 if index >= self.numel() {
132 return Err(TorshError::IndexOutOfBounds {
133 index,
134 size: self.numel(),
135 });
136 }
137 self.storage.set(index, value)
138 }
139 pub fn multi_to_flat_index(&self, indices: &[usize]) -> Result<usize> {
141 let binding = self.shape();
142 let shape = binding.dims();
143 if indices.len() != shape.len() {
144 return Err(TorshError::InvalidArgument(format!(
145 "Expected {} indices, got {}",
146 shape.len(),
147 indices.len()
148 )));
149 }
150 let mut flat_index = 0;
151 let mut stride = 1;
152 for i in (0..indices.len()).rev() {
153 flat_index += indices[i] * stride;
154 stride *= shape[i];
155 }
156 Ok(flat_index)
157 }
158 pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
160 if dim >= self.ndim() {
161 return Err(TorshError::InvalidArgument(format!(
162 "Dimension {} out of range for tensor with {} dimensions",
163 dim,
164 self.ndim()
165 )));
166 }
167 let self_data = self.to_vec()?;
168 let indices_data = indices.to_vec()?;
169 let mut result_data = Vec::new();
170 let result_shape = indices.shape().dims().to_vec();
171 if self.ndim() == 1 {
172 for &index in &indices_data {
173 let idx = if index < 0 {
174 (self.shape().dims()[0] as i64 + index) as usize
175 } else {
176 index as usize
177 };
178 if idx >= self.shape().dims()[0] {
179 return Err(TorshError::InvalidArgument(format!(
180 "Index {} out of range for tensor with size {}",
181 index,
182 self.shape().dims()[0]
183 )));
184 }
185 result_data.push(self_data[idx]);
186 }
187 } else {
188 let self_shape_ref = self.shape();
189 let self_shape = self_shape_ref.dims();
190 let indices_shape_ref = indices.shape();
191 let indices_shape = indices_shape_ref.dims();
192 let dim_size = self_shape[dim];
193 let mut self_strides = vec![1; self_shape.len()];
194 let mut indices_strides = vec![1; indices_shape.len()];
195 for i in (0..self_shape.len() - 1).rev() {
196 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
197 }
198 for i in (0..indices_shape.len() - 1).rev() {
199 indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
200 }
201 let total_elements = indices_data.len();
202 for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
203 let mut indices_coords = vec![0; indices_shape.len()];
204 let mut temp_i = i;
205 for j in 0..indices_shape.len() {
206 indices_coords[j] = temp_i / indices_strides[j];
207 temp_i %= indices_strides[j];
208 }
209 let idx = if index_value < 0 {
210 (dim_size as i64 + index_value) as usize
211 } else {
212 index_value as usize
213 };
214 if idx >= dim_size {
215 return Err(TorshError::InvalidArgument(format!(
216 "Index {index_value} out of range for dimension {dim} with size {dim_size}"
217 )));
218 }
219 let mut self_coords = indices_coords.clone();
220 if dim < self_coords.len() {
221 self_coords[dim] = idx;
222 }
223 let mut flat_idx = 0;
224 for j in 0..self_coords.len() {
225 flat_idx += self_coords[j] * self_strides[j];
226 }
227 result_data.push(self_data[flat_idx]);
228 }
229 }
230 Self::from_data(result_data, result_shape, self.device)
231 }
232 pub fn scatter(&self, dim: usize, indices: &Tensor<i64>, src: &Tensor<T>) -> Result<Self> {
234 if dim >= self.ndim() {
235 return Err(TorshError::InvalidArgument(format!(
236 "Dimension {} out of range for tensor with {} dimensions",
237 dim,
238 self.ndim()
239 )));
240 }
241 let mut result_data = self.to_vec()?;
242 let indices_data = indices.to_vec()?;
243 let src_data = src.to_vec()?;
244 if indices_data.len() != src_data.len() {
245 return Err(TorshError::InvalidArgument(
246 "Indices and source tensor must have the same number of elements".to_string(),
247 ));
248 }
249 if self.ndim() == 1 {
250 for (i, &index) in indices_data.iter().enumerate() {
251 let idx = if index < 0 {
252 (self.shape().dims()[0] as i64 + index) as usize
253 } else {
254 index as usize
255 };
256 if idx >= self.shape().dims()[0] {
257 return Err(TorshError::InvalidArgument(format!(
258 "Index {} out of range for tensor with size {}",
259 index,
260 self.shape().dims()[0]
261 )));
262 }
263 result_data[idx] = src_data[i];
264 }
265 } else {
266 let self_shape_ref = self.shape();
267 let self_shape = self_shape_ref.dims();
268 let indices_shape_ref = indices.shape();
269 let indices_shape = indices_shape_ref.dims();
270 let dim_size = self_shape[dim];
271 let mut self_strides = vec![1; self_shape.len()];
272 let mut indices_strides = vec![1; indices_shape.len()];
273 for i in (0..self_shape.len() - 1).rev() {
274 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
275 }
276 for i in (0..indices_shape.len() - 1).rev() {
277 indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
278 }
279 let total_elements = indices_data.len();
280 for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
281 let mut indices_coords = vec![0; indices_shape.len()];
282 let mut temp_i = i;
283 for j in 0..indices_shape.len() {
284 indices_coords[j] = temp_i / indices_strides[j];
285 temp_i %= indices_strides[j];
286 }
287 let idx = if index_value < 0 {
288 (dim_size as i64 + index_value) as usize
289 } else {
290 index_value as usize
291 };
292 if idx >= dim_size {
293 return Err(TorshError::InvalidArgument(format!(
294 "Index {index_value} out of range for dimension {dim} with size {dim_size}"
295 )));
296 }
297 let mut self_coords = indices_coords.clone();
298 if dim < self_coords.len() {
299 self_coords[dim] = idx;
300 }
301 let mut flat_idx = 0;
302 for j in 0..self_coords.len() {
303 flat_idx += self_coords[j] * self_strides[j];
304 }
305 result_data[flat_idx] = src_data[i];
306 }
307 }
308 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
309 }
310 pub fn scatter_add(&self, dim: usize, indices: &Tensor<i64>, src: &Tensor<T>) -> Result<Self>
331 where
332 T: std::ops::Add<Output = T>,
333 {
334 if dim >= self.ndim() {
335 return Err(TorshError::InvalidArgument(format!(
336 "Dimension {} out of range for tensor with {} dimensions",
337 dim,
338 self.ndim()
339 )));
340 }
341 let mut result_data = self.to_vec()?;
342 let indices_data = indices.to_vec()?;
343 let src_data = src.to_vec()?;
344 if indices_data.len() != src_data.len() {
345 return Err(TorshError::InvalidArgument(
346 "Indices and source tensor must have the same number of elements".to_string(),
347 ));
348 }
349 if self.ndim() == 1 {
350 for (i, &index) in indices_data.iter().enumerate() {
351 let idx = if index < 0 {
352 (self.shape().dims()[0] as i64 + index) as usize
353 } else {
354 index as usize
355 };
356 if idx >= self.shape().dims()[0] {
357 return Err(TorshError::InvalidArgument(format!(
358 "Index {} out of range for tensor with size {}",
359 index,
360 self.shape().dims()[0]
361 )));
362 }
363 result_data[idx] = result_data[idx] + src_data[i];
364 }
365 } else {
366 let self_shape_ref = self.shape();
367 let self_shape = self_shape_ref.dims();
368 let indices_shape_ref = indices.shape();
369 let indices_shape = indices_shape_ref.dims();
370 let dim_size = self_shape[dim];
371 let mut self_strides = vec![1; self_shape.len()];
372 let mut indices_strides = vec![1; indices_shape.len()];
373 for i in (0..self_shape.len() - 1).rev() {
374 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
375 }
376 for i in (0..indices_shape.len() - 1).rev() {
377 indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
378 }
379 let total_elements = indices_data.len();
380 for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
381 let mut indices_coords = vec![0; indices_shape.len()];
382 let mut temp_i = i;
383 for j in 0..indices_shape.len() {
384 indices_coords[j] = temp_i / indices_strides[j];
385 temp_i %= indices_strides[j];
386 }
387 let idx = if index_value < 0 {
388 (dim_size as i64 + index_value) as usize
389 } else {
390 index_value as usize
391 };
392 if idx >= dim_size {
393 return Err(TorshError::InvalidArgument(format!(
394 "Index {index_value} out of range for dimension {dim} with size {dim_size}"
395 )));
396 }
397 let mut self_coords = indices_coords.clone();
398 if dim < self_coords.len() {
399 self_coords[dim] = idx;
400 }
401 let mut flat_idx = 0;
402 for j in 0..self_coords.len() {
403 flat_idx += self_coords[j] * self_strides[j];
404 }
405 result_data[flat_idx] = result_data[flat_idx] + src_data[i];
406 }
407 }
408 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
409 }
410 pub fn repeat(&self, repeats: &[usize]) -> Result<Self> {
412 if repeats.len() != self.ndim() {
413 return Err(TorshError::InvalidArgument(format!(
414 "Number of repeats {} must match tensor dimensions {}",
415 repeats.len(),
416 self.ndim()
417 )));
418 }
419 let self_data = self.to_vec()?;
420 let shape_binding = self.shape();
421 let self_shape = shape_binding.dims();
422 let new_shape: Vec<usize> = self_shape
423 .iter()
424 .zip(repeats.iter())
425 .map(|(&dim, &repeat)| dim * repeat)
426 .collect();
427 let new_numel = new_shape.iter().product();
428 let mut result_data = Vec::with_capacity(new_numel);
429 for result_idx in 0..new_numel {
430 let mut result_coords = vec![0; new_shape.len()];
431 let mut temp_idx = result_idx;
432 for i in (0..new_shape.len()).rev() {
433 result_coords[i] = temp_idx % new_shape[i];
434 temp_idx /= new_shape[i];
435 }
436 let source_coords: Vec<usize> = result_coords
437 .iter()
438 .zip(self_shape.iter())
439 .map(|(&result_coord, &dim_size)| result_coord % dim_size)
440 .collect();
441 let mut source_idx = 0;
442 let mut stride = 1;
443 for i in (0..self_shape.len()).rev() {
444 source_idx += source_coords[i] * stride;
445 stride *= self_shape[i];
446 }
447 result_data.push(self_data[source_idx]);
448 }
449 Self::from_data(result_data, new_shape, self.device)
450 }
451 pub fn index_add(&self, dim: isize, index: &Tensor<i64>, source: &Self) -> Result<Self>
469 where
470 T: std::ops::Add<Output = T>,
471 {
472 let ndim = self.ndim();
473 let dim = if dim < 0 {
474 (ndim as isize + dim) as usize
475 } else {
476 dim as usize
477 };
478 if dim >= ndim {
479 return Err(TorshError::InvalidArgument(format!(
480 "Dimension {} out of range for {}-D tensor",
481 dim, ndim
482 )));
483 }
484 if index.ndim() != 1 {
485 return Err(TorshError::InvalidArgument(
486 "index must be 1D tensor".to_string(),
487 ));
488 }
489 let index_size = index.shape().dims()[0];
490 let self_shape = self.shape().to_vec();
491 let source_shape = source.shape().to_vec();
492 if source_shape.len() != self_shape.len() {
493 return Err(TorshError::ShapeMismatch {
494 expected: self_shape.clone(),
495 got: source_shape.clone(),
496 });
497 }
498 for (i, (&s, &src_s)) in self_shape.iter().zip(source_shape.iter()).enumerate() {
499 if i == dim {
500 if src_s != index_size {
501 return Err(TorshError::InvalidArgument(format!(
502 "source dimension {} size {} must match index size {}",
503 i, src_s, index_size
504 )));
505 }
506 } else if s != src_s {
507 return Err(TorshError::ShapeMismatch {
508 expected: self_shape.clone(),
509 got: source_shape.clone(),
510 });
511 }
512 }
513 let mut result_data = self.to_vec()?;
514 let source_data = source.to_vec()?;
515 let index_data = index.to_vec()?;
516 let dim_size = self_shape[dim];
517 let outer_size: usize = self_shape[..dim].iter().product();
518 let inner_size: usize = self_shape[dim + 1..].iter().product();
519 for (src_idx_in_dim, &target_idx) in index_data.iter().enumerate() {
520 if target_idx < 0 || target_idx as usize >= dim_size {
521 return Err(TorshError::InvalidArgument(format!(
522 "Index {} out of range for dimension size {}",
523 target_idx, dim_size
524 )));
525 }
526 let target_idx = target_idx as usize;
527 for outer in 0..outer_size {
528 for inner in 0..inner_size {
529 let result_idx =
530 outer * dim_size * inner_size + target_idx * inner_size + inner;
531 let source_idx =
532 outer * index_size * inner_size + src_idx_in_dim * inner_size + inner;
533 result_data[result_idx] = result_data[result_idx] + source_data[source_idx];
534 }
535 }
536 }
537 Self::from_data(result_data, self_shape, self.device)
538 }
539 pub fn index_copy(&self, dim: isize, index: &Tensor<i64>, source: &Self) -> Result<Self> {
557 let ndim = self.ndim();
558 let dim = if dim < 0 {
559 (ndim as isize + dim) as usize
560 } else {
561 dim as usize
562 };
563 if dim >= ndim {
564 return Err(TorshError::InvalidArgument(format!(
565 "Dimension {} out of range for {}-D tensor",
566 dim, ndim
567 )));
568 }
569 if index.ndim() != 1 {
570 return Err(TorshError::InvalidArgument(
571 "index must be 1D tensor".to_string(),
572 ));
573 }
574 let index_size = index.shape().dims()[0];
575 let self_shape = self.shape().to_vec();
576 let source_shape = source.shape().to_vec();
577 if source_shape.len() != self_shape.len() {
578 return Err(TorshError::ShapeMismatch {
579 expected: self_shape.clone(),
580 got: source_shape.clone(),
581 });
582 }
583 for (i, (&s, &src_s)) in self_shape.iter().zip(source_shape.iter()).enumerate() {
584 if i == dim {
585 if src_s != index_size {
586 return Err(TorshError::InvalidArgument(format!(
587 "source dimension {} size {} must match index size {}",
588 i, src_s, index_size
589 )));
590 }
591 } else if s != src_s {
592 return Err(TorshError::ShapeMismatch {
593 expected: self_shape.clone(),
594 got: source_shape.clone(),
595 });
596 }
597 }
598 let mut result_data = self.to_vec()?;
599 let source_data = source.to_vec()?;
600 let index_data = index.to_vec()?;
601 let dim_size = self_shape[dim];
602 let outer_size: usize = self_shape[..dim].iter().product();
603 let inner_size: usize = self_shape[dim + 1..].iter().product();
604 for (src_idx_in_dim, &target_idx) in index_data.iter().enumerate() {
605 if target_idx < 0 || target_idx as usize >= dim_size {
606 return Err(TorshError::InvalidArgument(format!(
607 "Index {} out of range for dimension size {}",
608 target_idx, dim_size
609 )));
610 }
611 let target_idx = target_idx as usize;
612 for outer in 0..outer_size {
613 for inner in 0..inner_size {
614 let result_idx =
615 outer * dim_size * inner_size + target_idx * inner_size + inner;
616 let source_idx =
617 outer * index_size * inner_size + src_idx_in_dim * inner_size + inner;
618 result_data[result_idx] = source_data[source_idx];
619 }
620 }
621 }
622 Self::from_data(result_data, self_shape, self.device)
623 }
624 pub fn index_fill(&self, dim: isize, index: &Tensor<i64>, value: T) -> Result<Self> {
641 let ndim = self.ndim();
642 let dim = if dim < 0 {
643 (ndim as isize + dim) as usize
644 } else {
645 dim as usize
646 };
647 if dim >= ndim {
648 return Err(TorshError::InvalidArgument(format!(
649 "Dimension {} out of range for {}-D tensor",
650 dim, ndim
651 )));
652 }
653 if index.ndim() != 1 {
654 return Err(TorshError::InvalidArgument(
655 "index must be 1D tensor".to_string(),
656 ));
657 }
658 let mut result_data = self.to_vec()?;
659 let index_data = index.to_vec()?;
660 let self_shape = self.shape().to_vec();
661 let dim_size = self_shape[dim];
662 let outer_size: usize = self_shape[..dim].iter().product();
663 let inner_size: usize = self_shape[dim + 1..].iter().product();
664 for &target_idx in index_data.iter() {
665 if target_idx < 0 || target_idx as usize >= dim_size {
666 return Err(TorshError::InvalidArgument(format!(
667 "Index {} out of range for dimension size {}",
668 target_idx, dim_size
669 )));
670 }
671 let target_idx = target_idx as usize;
672 for outer in 0..outer_size {
673 for inner in 0..inner_size {
674 let result_idx =
675 outer * dim_size * inner_size + target_idx * inner_size + inner;
676 result_data[result_idx] = value;
677 }
678 }
679 }
680 Self::from_data(result_data, self_shape, self.device)
681 }
682 pub fn put_(&self, indices: &Tensor<i64>, values: &Tensor<T>) -> Result<Self> {
699 if indices.ndim() != 1 {
700 return Err(TorshError::InvalidArgument(
701 "indices must be 1D tensor".to_string(),
702 ));
703 }
704 if values.ndim() != 1 {
705 return Err(TorshError::InvalidArgument(
706 "values must be 1D tensor".to_string(),
707 ));
708 }
709 let indices_data = indices.to_vec()?;
710 let values_data = values.to_vec()?;
711 if indices_data.len() != values_data.len() {
712 return Err(TorshError::InvalidArgument(format!(
713 "Number of values {} must match number of indices {}",
714 values_data.len(),
715 indices_data.len()
716 )));
717 }
718 let mut result_data = self.to_vec()?;
719 let numel = self.numel();
720 for (i, &index) in indices_data.iter().enumerate() {
721 let idx = if index < 0 {
722 ((numel as i64) + index) as usize
723 } else {
724 index as usize
725 };
726 if idx >= numel {
727 return Err(TorshError::InvalidArgument(format!(
728 "Index {} out of range for tensor with {} elements",
729 index, numel
730 )));
731 }
732 result_data[idx] = values_data[i];
733 }
734 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
735 }
736 pub fn masked_scatter(&self, mask: &Tensor<bool>, source: &Tensor<T>) -> Result<Self> {
761 if self.shape() != mask.shape() {
762 return Err(TorshError::ShapeMismatch {
763 expected: self.shape().dims().to_vec(),
764 got: mask.shape().dims().to_vec(),
765 });
766 }
767 let mask_data = mask.to_vec()?;
768 let true_count = mask_data.iter().filter(|&&x| x).count();
769 if source.numel() < true_count {
770 return Err(TorshError::InvalidArgument(format!(
771 "Source tensor has {} elements but need {} for scatter (mask has {} true values)",
772 source.numel(),
773 true_count,
774 true_count
775 )));
776 }
777 let self_data = self.to_vec()?;
778 let source_data = source.to_vec()?;
779 let mut result_data = Vec::with_capacity(self_data.len());
780 let mut source_idx = 0;
781 for (i, &self_val) in self_data.iter().enumerate() {
782 if i < mask_data.len() && mask_data[i] {
783 result_data.push(source_data[source_idx]);
784 source_idx += 1;
785 } else {
786 result_data.push(self_val);
787 }
788 }
789 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
790 }
791 pub fn index_put(&self, indices: &[Tensor<i64>], values: &Tensor<T>) -> Result<Self> {
815 if indices.is_empty() {
816 return Err(TorshError::InvalidArgument(
817 "indices cannot be empty".to_string(),
818 ));
819 }
820 if indices.len() > self.ndim() {
821 return Err(TorshError::InvalidArgument(format!(
822 "Too many indices ({}) for tensor with {} dimensions",
823 indices.len(),
824 self.ndim()
825 )));
826 }
827 let index_shape_ref = indices[0].shape();
828 let index_shape = index_shape_ref.dims();
829 let num_indices = indices[0].numel();
830 for idx_tensor in indices.iter() {
831 if idx_tensor.shape().dims() != index_shape {
832 return Err(TorshError::ShapeMismatch {
833 expected: index_shape.to_vec(),
834 got: idx_tensor.shape().dims().to_vec(),
835 });
836 }
837 }
838 if values.numel() != num_indices && values.numel() != 1 {
839 return Err(TorshError::InvalidArgument(format!(
840 "Values tensor has {} elements but need {} (or 1 for broadcasting)",
841 values.numel(),
842 num_indices
843 )));
844 }
845 let mut result_data = self.to_vec()?;
846 let self_shape_ref = self.shape();
847 let self_shape = self_shape_ref.dims();
848 let values_data = values.to_vec()?;
849 let index_data: Result<Vec<Vec<i64>>> = indices.iter().map(|idx| idx.to_vec()).collect();
850 let index_data = index_data?;
851 let mut strides = vec![1; self_shape.len()];
852 for i in (0..self_shape.len() - 1).rev() {
853 strides[i] = strides[i + 1] * self_shape[i + 1];
854 }
855 for i in 0..num_indices {
856 let value = if values_data.len() == 1 {
857 values_data[0]
858 } else {
859 values_data[i]
860 };
861 let mut flat_idx = 0;
862 for (dim, idx_vec) in index_data.iter().enumerate() {
863 let mut idx = idx_vec[i];
864 if idx < 0 {
865 idx += self_shape[dim] as i64;
866 }
867 if idx < 0 || idx >= self_shape[dim] as i64 {
868 return Err(TorshError::InvalidArgument(format!(
869 "Index {} out of bounds for dimension {} with size {}",
870 idx_vec[i], dim, self_shape[dim]
871 )));
872 }
873 flat_idx += (idx as usize) * strides[dim];
874 }
875 result_data[flat_idx] = value;
876 }
877 Self::from_data(result_data, self_shape.to_vec(), self.device)
878 }
879 pub fn scatter_reduce(
903 &self,
904 dim: usize,
905 indices: &Tensor<i64>,
906 src: &Tensor<T>,
907 reduce: &str,
908 ) -> Result<Self>
909 where
910 T: std::ops::Add<Output = T>
911 + std::ops::Mul<Output = T>
912 + std::ops::Div<Output = T>
913 + PartialOrd
914 + num_traits::FromPrimitive,
915 {
916 if dim >= self.ndim() {
917 return Err(TorshError::InvalidArgument(format!(
918 "Dimension {} out of range for {}-dimensional tensor",
919 dim,
920 self.ndim()
921 )));
922 }
923 if indices.shape() != src.shape() {
924 return Err(TorshError::ShapeMismatch {
925 expected: indices.shape().dims().to_vec(),
926 got: src.shape().dims().to_vec(),
927 });
928 }
929 let indices_data = indices.to_vec()?;
930 let src_data = src.to_vec()?;
931 let mut result_data = self.to_vec()?;
932 let self_shape_ref = self.shape();
933 let self_shape = self_shape_ref.dims();
934 let mut counts = if reduce == "mean" {
935 vec![0usize; result_data.len()]
936 } else {
937 vec![]
938 };
939 if self.ndim() == 1 {
940 for (i, &index) in indices_data.iter().enumerate() {
941 let idx = if index < 0 {
942 (self_shape[0] as i64 + index) as usize
943 } else {
944 index as usize
945 };
946 if idx >= self_shape[0] {
947 return Err(TorshError::InvalidArgument(format!(
948 "Index {} out of bounds for dimension size {}",
949 index, self_shape[0]
950 )));
951 }
952 result_data[idx] = match reduce {
953 "sum" => result_data[idx] + src_data[i],
954 "prod" => result_data[idx] * src_data[i],
955 "mean" => {
956 counts[idx] += 1;
957 result_data[idx] + src_data[i]
958 }
959 "amax" => {
960 if src_data[i] > result_data[idx] {
961 src_data[i]
962 } else {
963 result_data[idx]
964 }
965 }
966 "amin" => {
967 if src_data[i] < result_data[idx] {
968 src_data[i]
969 } else {
970 result_data[idx]
971 }
972 }
973 _ => {
974 return Err(TorshError::InvalidArgument(format!(
975 "Unknown reduce operation: {}. Supported: sum, prod, mean, amax, amin",
976 reduce
977 )));
978 }
979 };
980 }
981 if reduce == "mean" {
982 for (i, count) in counts.iter().enumerate() {
983 if *count > 0 {
984 result_data[i] = T::from_usize(*count)
985 .and_then(|c| Some(result_data[i] / c))
986 .unwrap_or(result_data[i]);
987 }
988 }
989 }
990 } else {
991 let dim_size = self_shape[dim];
992 let _outer_size: usize = self_shape[..dim].iter().product();
993 let _inner_size: usize = self_shape[dim + 1..].iter().product();
994 let mut self_strides = vec![1; self_shape.len()];
995 for i in (0..self_shape.len() - 1).rev() {
996 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
997 }
998 let src_shape_ref = src.shape();
999 let src_shape = src_shape_ref.dims();
1000 let mut src_strides = vec![1; src_shape.len()];
1001 for i in (0..src_shape.len() - 1).rev() {
1002 src_strides[i] = src_strides[i + 1] * src_shape[i + 1];
1003 }
1004 for i in 0..indices_data.len() {
1005 let index = indices_data[i];
1006 let idx = if index < 0 {
1007 (dim_size as i64 + index) as usize
1008 } else {
1009 index as usize
1010 };
1011 if idx >= dim_size {
1012 return Err(TorshError::InvalidArgument(format!(
1013 "Index {} out of bounds for dimension {} size {}",
1014 index, dim, dim_size
1015 )));
1016 }
1017 let mut coords = vec![0; self_shape.len()];
1018 let mut remainder = i;
1019 for (d, &stride) in src_strides.iter().enumerate() {
1020 coords[d] = remainder / stride;
1021 remainder %= stride;
1022 }
1023 coords[dim] = idx;
1024 let flat_idx = coords
1025 .iter()
1026 .zip(self_strides.iter())
1027 .map(|(c, s)| c * s)
1028 .sum::<usize>();
1029 result_data[flat_idx] = match reduce {
1030 "sum" => result_data[flat_idx] + src_data[i],
1031 "prod" => result_data[flat_idx] * src_data[i],
1032 "mean" => {
1033 counts[flat_idx] += 1;
1034 result_data[flat_idx] + src_data[i]
1035 }
1036 "amax" => {
1037 if src_data[i] > result_data[flat_idx] {
1038 src_data[i]
1039 } else {
1040 result_data[flat_idx]
1041 }
1042 }
1043 "amin" => {
1044 if src_data[i] < result_data[flat_idx] {
1045 src_data[i]
1046 } else {
1047 result_data[flat_idx]
1048 }
1049 }
1050 _ => {
1051 return Err(TorshError::InvalidArgument(format!(
1052 "Unknown reduce operation: {}",
1053 reduce
1054 )));
1055 }
1056 };
1057 }
1058 if reduce == "mean" {
1059 for (i, count) in counts.iter().enumerate() {
1060 if *count > 0 {
1061 result_data[i] = T::from_usize(*count)
1062 .and_then(|c| Some(result_data[i] / c))
1063 .unwrap_or(result_data[i]);
1064 }
1065 }
1066 }
1067 }
1068 Self::from_data(result_data, self_shape.to_vec(), self.device)
1069 }
1070 pub fn diagonal_scatter(
1092 &self,
1093 src: &Tensor<T>,
1094 offset: isize,
1095 dim1: usize,
1096 dim2: usize,
1097 ) -> Result<Self> {
1098 if dim1 >= self.ndim() || dim2 >= self.ndim() {
1099 return Err(TorshError::InvalidArgument(format!(
1100 "Dimensions ({}, {}) out of range for {}-dimensional tensor",
1101 dim1,
1102 dim2,
1103 self.ndim()
1104 )));
1105 }
1106 if dim1 == dim2 {
1107 return Err(TorshError::InvalidArgument(
1108 "dim1 and dim2 must be different".to_string(),
1109 ));
1110 }
1111 let self_shape_ref = self.shape();
1112 let self_shape = self_shape_ref.dims();
1113 let dim1_size = self_shape[dim1];
1114 let dim2_size = self_shape[dim2];
1115 let diag_len = if offset >= 0 {
1116 let offset_u = offset as usize;
1117 if offset_u >= dim2_size {
1118 0
1119 } else {
1120 std::cmp::min(dim1_size, dim2_size - offset_u)
1121 }
1122 } else {
1123 let offset_u = (-offset) as usize;
1124 if offset_u >= dim1_size {
1125 0
1126 } else {
1127 std::cmp::min(dim1_size - offset_u, dim2_size)
1128 }
1129 };
1130 if src.numel() != diag_len {
1131 return Err(TorshError::ShapeMismatch {
1132 expected: vec![diag_len],
1133 got: vec![src.numel()],
1134 });
1135 }
1136 let mut result_data = self.to_vec()?;
1137 let src_data = src.to_vec()?;
1138 let mut strides = vec![1; self_shape.len()];
1139 for i in (0..self_shape.len() - 1).rev() {
1140 strides[i] = strides[i + 1] * self_shape[i + 1];
1141 }
1142 for i in 0..diag_len {
1143 let mut indices = vec![0; self_shape.len()];
1144 if offset >= 0 {
1145 indices[dim1] = i;
1146 indices[dim2] = i + offset as usize;
1147 } else {
1148 indices[dim1] = i + (-offset) as usize;
1149 indices[dim2] = i;
1150 }
1151 let mut flat_idx = 0;
1152 for (d, &idx) in indices.iter().enumerate() {
1153 flat_idx += idx * strides[d];
1154 }
1155 result_data[flat_idx] = src_data[i];
1156 }
1157 Self::from_data(result_data, self_shape.to_vec(), self.device)
1158 }
1159 pub fn select_scatter(&self, src: &Tensor<T>, dim: isize, index: isize) -> Result<Self> {
1180 let ndim = self.ndim() as isize;
1181 let dim_normalized = if dim < 0 { ndim + dim } else { dim };
1182 if dim_normalized < 0 || dim_normalized >= ndim {
1183 return Err(TorshError::InvalidArgument(format!(
1184 "Dimension {} out of range for {}-dimensional tensor",
1185 dim,
1186 self.ndim()
1187 )));
1188 }
1189 let dim_u = dim_normalized as usize;
1190 let self_shape_ref = self.shape();
1191 let self_shape = self_shape_ref.dims();
1192 let index_normalized = if index < 0 {
1193 (self_shape[dim_u] as isize) + index
1194 } else {
1195 index
1196 };
1197 if index_normalized < 0 || index_normalized >= self_shape[dim_u] as isize {
1198 return Err(TorshError::InvalidArgument(format!(
1199 "Index {} out of bounds for dimension {} with size {}",
1200 index, dim_u, self_shape[dim_u]
1201 )));
1202 }
1203 let index_u = index_normalized as usize;
1204 let expected_src_shape: Vec<usize> = self_shape
1205 .iter()
1206 .enumerate()
1207 .filter(|(i, _)| *i != dim_u)
1208 .map(|(_, &s)| s)
1209 .collect();
1210 let src_shape_ref = src.shape();
1211 let src_shape = src_shape_ref.dims();
1212 if src_shape != expected_src_shape.as_slice() {
1213 return Err(TorshError::ShapeMismatch {
1214 expected: expected_src_shape,
1215 got: src_shape.to_vec(),
1216 });
1217 }
1218 let mut result_data = self.to_vec()?;
1219 let src_data = src.to_vec()?;
1220 let mut self_strides = vec![1; self_shape.len()];
1221 for i in (0..self_shape.len() - 1).rev() {
1222 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
1223 }
1224 let outer_size: usize = self_shape[..dim_u].iter().product();
1225 let inner_size: usize = self_shape[dim_u + 1..].iter().product();
1226 for outer in 0..outer_size {
1227 for inner in 0..inner_size {
1228 let self_idx =
1229 outer * (self_shape[dim_u] * inner_size) + index_u * inner_size + inner;
1230 let src_idx = outer * inner_size + inner;
1231 result_data[self_idx] = src_data[src_idx];
1232 }
1233 }
1234 Self::from_data(result_data, self_shape.to_vec(), self.device)
1235 }
1236 pub fn slice_scatter(
1259 &self,
1260 src: &Tensor<T>,
1261 dim: isize,
1262 start: Option<isize>,
1263 end: Option<isize>,
1264 step: usize,
1265 ) -> Result<Self> {
1266 if step == 0 {
1267 return Err(TorshError::InvalidArgument(
1268 "Step must be greater than 0".to_string(),
1269 ));
1270 }
1271 let ndim = self.ndim() as isize;
1272 let dim_normalized = if dim < 0 { ndim + dim } else { dim };
1273 if dim_normalized < 0 || dim_normalized >= ndim {
1274 return Err(TorshError::InvalidArgument(format!(
1275 "Dimension {} out of range for {}-dimensional tensor",
1276 dim,
1277 self.ndim()
1278 )));
1279 }
1280 let dim_u = dim_normalized as usize;
1281 let self_shape_ref = self.shape();
1282 let self_shape = self_shape_ref.dims();
1283 let dim_size = self_shape[dim_u] as isize;
1284 let start_normalized = start.unwrap_or(0);
1285 let start_normalized = if start_normalized < 0 {
1286 dim_size + start_normalized
1287 } else {
1288 start_normalized
1289 };
1290 let start_normalized = std::cmp::max(0, std::cmp::min(start_normalized, dim_size)) as usize;
1291 let end_normalized = end.unwrap_or(dim_size);
1292 let end_normalized = if end_normalized < 0 {
1293 dim_size + end_normalized
1294 } else {
1295 end_normalized
1296 };
1297 let end_normalized = std::cmp::max(0, std::cmp::min(end_normalized, dim_size)) as usize;
1298 let slice_len = if end_normalized > start_normalized {
1299 (end_normalized - start_normalized + step - 1) / step
1300 } else {
1301 0
1302 };
1303 let mut expected_src_shape = self_shape.to_vec();
1304 expected_src_shape[dim_u] = slice_len;
1305 let src_shape_ref = src.shape();
1306 let src_shape = src_shape_ref.dims();
1307 if src_shape != expected_src_shape.as_slice() {
1308 return Err(TorshError::ShapeMismatch {
1309 expected: expected_src_shape,
1310 got: src_shape.to_vec(),
1311 });
1312 }
1313 let mut result_data = self.to_vec()?;
1314 let src_data = src.to_vec()?;
1315 let mut self_strides = vec![1; self_shape.len()];
1316 for i in (0..self_shape.len() - 1).rev() {
1317 self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
1318 }
1319 let outer_size: usize = self_shape[..dim_u].iter().product();
1320 let inner_size: usize = self_shape[dim_u + 1..].iter().product();
1321 for outer in 0..outer_size {
1322 for slice_idx in 0..slice_len {
1323 let self_dim_idx = start_normalized + slice_idx * step;
1324 for inner in 0..inner_size {
1325 let self_idx = outer * (self_shape[dim_u] * inner_size)
1326 + self_dim_idx * inner_size
1327 + inner;
1328 let src_idx = outer * (slice_len * inner_size) + slice_idx * inner_size + inner;
1329 result_data[self_idx] = src_data[src_idx];
1330 }
1331 }
1332 }
1333 Self::from_data(result_data, self_shape.to_vec(), self.device)
1334 }
1335}