Skip to main content

torsh_tensor/data_ops/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::core_ops::Tensor;
6use torsh_core::{
7    device::DeviceType,
8    dtype::TensorElement,
9    error::{Result, TorshError},
10};
11
12impl<T: TensorElement + Copy> Tensor<T> {
13    /// Create tensor from a scalar value repeated to fill the shape
14    pub fn from_scalar(value: T, shape: &[usize], device: DeviceType) -> Result<Self>
15    where
16        T: Copy,
17    {
18        let numel = shape.iter().product::<usize>();
19        let data = vec![value; numel];
20        Self::from_data(data, shape.to_vec(), device)
21    }
22    /// Fill tensor with a single value (in-place)
23    pub fn fill_(&mut self, value: T) -> Result<()>
24    where
25        T: Copy,
26    {
27        for i in 0..self.numel() {
28            self.storage.set(i, value)?;
29        }
30        Ok(())
31    }
32    /// Zero out the tensor (in-place)
33    pub fn zero_(&mut self) -> Result<()>
34    where
35        T: Copy,
36    {
37        self.fill_(T::zero())
38    }
39    /// Fill with ones (in-place)
40    pub fn ones_(&mut self) -> Result<()>
41    where
42        T: Copy,
43    {
44        self.fill_(T::one())
45    }
46    /// Copy data from another tensor (in-place)
47    pub fn copy_(&mut self, other: &Self) -> Result<()>
48    where
49        T: Copy,
50    {
51        if self.shape() != other.shape() {
52            return Err(TorshError::ShapeMismatch {
53                expected: self.shape().dims().to_vec(),
54                got: other.shape().dims().to_vec(),
55            });
56        }
57        let other_data = other.to_vec()?;
58        for (i, &value) in other_data.iter().enumerate() {
59            self.storage.set(i, value)?;
60        }
61        Ok(())
62    }
63    /// Get an element by multi-dimensional index
64    pub fn get_item(&self, indices: &[usize]) -> Result<T>
65    where
66        T: Copy,
67    {
68        if indices.len() != self.ndim() {
69            return Err(TorshError::InvalidArgument(format!(
70                "Expected {} indices, got {}",
71                self.ndim(),
72                indices.len()
73            )));
74        }
75        let binding = self.shape();
76        let shape = binding.dims();
77        for (i, &idx) in indices.iter().enumerate() {
78            if idx >= shape[i] {
79                return Err(TorshError::IndexOutOfBounds {
80                    index: idx,
81                    size: shape[i],
82                });
83            }
84        }
85        let flat_index = self.multi_to_flat_index(indices)?;
86        self.get_item_flat(flat_index)
87    }
88    /// Set an element by multi-dimensional index
89    pub fn set_item(&mut self, indices: &[usize], value: T) -> Result<()>
90    where
91        T: Copy,
92    {
93        if indices.len() != self.ndim() {
94            return Err(TorshError::InvalidArgument(format!(
95                "Expected {} indices, got {}",
96                self.ndim(),
97                indices.len()
98            )));
99        }
100        let binding = self.shape();
101        let shape = binding.dims();
102        for (i, &idx) in indices.iter().enumerate() {
103            if idx >= shape[i] {
104                return Err(TorshError::IndexOutOfBounds {
105                    index: idx,
106                    size: shape[i],
107                });
108            }
109        }
110        let flat_index = self.multi_to_flat_index(indices)?;
111        self.set_item_flat(flat_index, value)
112    }
113    /// Get element by flat index
114    pub fn get_item_flat(&self, index: usize) -> Result<T>
115    where
116        T: Copy,
117    {
118        if index >= self.numel() {
119            return Err(TorshError::IndexOutOfBounds {
120                index,
121                size: self.numel(),
122            });
123        }
124        self.storage.get(index)
125    }
126    /// Set element by flat index
127    pub fn set_item_flat(&mut self, index: usize, value: T) -> Result<()>
128    where
129        T: Copy,
130    {
131        if index >= self.numel() {
132            return Err(TorshError::IndexOutOfBounds {
133                index,
134                size: self.numel(),
135            });
136        }
137        self.storage.set(index, value)
138    }
139    /// Convert multi-dimensional indices to flat index
140    pub fn multi_to_flat_index(&self, indices: &[usize]) -> Result<usize> {
141        let binding = self.shape();
142        let shape = binding.dims();
143        if indices.len() != shape.len() {
144            return Err(TorshError::InvalidArgument(format!(
145                "Expected {} indices, got {}",
146                shape.len(),
147                indices.len()
148            )));
149        }
150        let mut flat_index = 0;
151        let mut stride = 1;
152        for i in (0..indices.len()).rev() {
153            flat_index += indices[i] * stride;
154            stride *= shape[i];
155        }
156        Ok(flat_index)
157    }
158    /// Gather values along an axis using indices
159    pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
160        if dim >= self.ndim() {
161            return Err(TorshError::InvalidArgument(format!(
162                "Dimension {} out of range for tensor with {} dimensions",
163                dim,
164                self.ndim()
165            )));
166        }
167        let self_data = self.to_vec()?;
168        let indices_data = indices.to_vec()?;
169        let mut result_data = Vec::new();
170        let result_shape = indices.shape().dims().to_vec();
171        if self.ndim() == 1 {
172            for &index in &indices_data {
173                let idx = if index < 0 {
174                    (self.shape().dims()[0] as i64 + index) as usize
175                } else {
176                    index as usize
177                };
178                if idx >= self.shape().dims()[0] {
179                    return Err(TorshError::InvalidArgument(format!(
180                        "Index {} out of range for tensor with size {}",
181                        index,
182                        self.shape().dims()[0]
183                    )));
184                }
185                result_data.push(self_data[idx]);
186            }
187        } else {
188            let self_shape_ref = self.shape();
189            let self_shape = self_shape_ref.dims();
190            let indices_shape_ref = indices.shape();
191            let indices_shape = indices_shape_ref.dims();
192            let dim_size = self_shape[dim];
193            let mut self_strides = vec![1; self_shape.len()];
194            let mut indices_strides = vec![1; indices_shape.len()];
195            for i in (0..self_shape.len() - 1).rev() {
196                self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
197            }
198            for i in (0..indices_shape.len() - 1).rev() {
199                indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
200            }
201            let total_elements = indices_data.len();
202            for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
203                let mut indices_coords = vec![0; indices_shape.len()];
204                let mut temp_i = i;
205                for j in 0..indices_shape.len() {
206                    indices_coords[j] = temp_i / indices_strides[j];
207                    temp_i %= indices_strides[j];
208                }
209                let idx = if index_value < 0 {
210                    (dim_size as i64 + index_value) as usize
211                } else {
212                    index_value as usize
213                };
214                if idx >= dim_size {
215                    return Err(TorshError::InvalidArgument(format!(
216                        "Index {index_value} out of range for dimension {dim} with size {dim_size}"
217                    )));
218                }
219                let mut self_coords = indices_coords.clone();
220                if dim < self_coords.len() {
221                    self_coords[dim] = idx;
222                }
223                let mut flat_idx = 0;
224                for j in 0..self_coords.len() {
225                    flat_idx += self_coords[j] * self_strides[j];
226                }
227                result_data.push(self_data[flat_idx]);
228            }
229        }
230        Self::from_data(result_data, result_shape, self.device)
231    }
232    /// Scatter values along an axis using indices
233    pub fn scatter(&self, dim: usize, indices: &Tensor<i64>, src: &Tensor<T>) -> Result<Self> {
234        if dim >= self.ndim() {
235            return Err(TorshError::InvalidArgument(format!(
236                "Dimension {} out of range for tensor with {} dimensions",
237                dim,
238                self.ndim()
239            )));
240        }
241        let mut result_data = self.to_vec()?;
242        let indices_data = indices.to_vec()?;
243        let src_data = src.to_vec()?;
244        if indices_data.len() != src_data.len() {
245            return Err(TorshError::InvalidArgument(
246                "Indices and source tensor must have the same number of elements".to_string(),
247            ));
248        }
249        if self.ndim() == 1 {
250            for (i, &index) in indices_data.iter().enumerate() {
251                let idx = if index < 0 {
252                    (self.shape().dims()[0] as i64 + index) as usize
253                } else {
254                    index as usize
255                };
256                if idx >= self.shape().dims()[0] {
257                    return Err(TorshError::InvalidArgument(format!(
258                        "Index {} out of range for tensor with size {}",
259                        index,
260                        self.shape().dims()[0]
261                    )));
262                }
263                result_data[idx] = src_data[i];
264            }
265        } else {
266            let self_shape_ref = self.shape();
267            let self_shape = self_shape_ref.dims();
268            let indices_shape_ref = indices.shape();
269            let indices_shape = indices_shape_ref.dims();
270            let dim_size = self_shape[dim];
271            let mut self_strides = vec![1; self_shape.len()];
272            let mut indices_strides = vec![1; indices_shape.len()];
273            for i in (0..self_shape.len() - 1).rev() {
274                self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
275            }
276            for i in (0..indices_shape.len() - 1).rev() {
277                indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
278            }
279            let total_elements = indices_data.len();
280            for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
281                let mut indices_coords = vec![0; indices_shape.len()];
282                let mut temp_i = i;
283                for j in 0..indices_shape.len() {
284                    indices_coords[j] = temp_i / indices_strides[j];
285                    temp_i %= indices_strides[j];
286                }
287                let idx = if index_value < 0 {
288                    (dim_size as i64 + index_value) as usize
289                } else {
290                    index_value as usize
291                };
292                if idx >= dim_size {
293                    return Err(TorshError::InvalidArgument(format!(
294                        "Index {index_value} out of range for dimension {dim} with size {dim_size}"
295                    )));
296                }
297                let mut self_coords = indices_coords.clone();
298                if dim < self_coords.len() {
299                    self_coords[dim] = idx;
300                }
301                let mut flat_idx = 0;
302                for j in 0..self_coords.len() {
303                    flat_idx += self_coords[j] * self_strides[j];
304                }
305                result_data[flat_idx] = src_data[i];
306            }
307        }
308        Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
309    }
310    /// Scatter values along an axis using indices and add to existing values
311    ///
312    /// # PyTorch Compatibility
313    /// Equivalent to `torch.scatter_add(tensor, dim, index, src)`
314    ///
315    /// # Arguments
316    /// * `dim` - Dimension along which to index
317    /// * `indices` - Index tensor (same shape as src)
318    /// * `src` - Source tensor containing values to add
319    ///
320    /// # Examples
321    /// ```ignore
322    /// let tensor = Tensor::zeros(&[5], DeviceType::Cpu)?;
323    /// let indices = Tensor::from_data(vec![0i64, 1, 2, 0, 1], vec![5], DeviceType::Cpu)?;
324    /// let src = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)?;
325    /// let result = tensor.scatter_add(0, &indices, &src)?;
326    /// // result[0] += 1.0 + 4.0 = 5.0
327    /// // result[1] += 2.0 + 5.0 = 7.0
328    /// // result[2] += 3.0 = 3.0
329    /// ```
330    pub fn scatter_add(&self, dim: usize, indices: &Tensor<i64>, src: &Tensor<T>) -> Result<Self>
331    where
332        T: std::ops::Add<Output = T>,
333    {
334        if dim >= self.ndim() {
335            return Err(TorshError::InvalidArgument(format!(
336                "Dimension {} out of range for tensor with {} dimensions",
337                dim,
338                self.ndim()
339            )));
340        }
341        let mut result_data = self.to_vec()?;
342        let indices_data = indices.to_vec()?;
343        let src_data = src.to_vec()?;
344        if indices_data.len() != src_data.len() {
345            return Err(TorshError::InvalidArgument(
346                "Indices and source tensor must have the same number of elements".to_string(),
347            ));
348        }
349        if self.ndim() == 1 {
350            for (i, &index) in indices_data.iter().enumerate() {
351                let idx = if index < 0 {
352                    (self.shape().dims()[0] as i64 + index) as usize
353                } else {
354                    index as usize
355                };
356                if idx >= self.shape().dims()[0] {
357                    return Err(TorshError::InvalidArgument(format!(
358                        "Index {} out of range for tensor with size {}",
359                        index,
360                        self.shape().dims()[0]
361                    )));
362                }
363                result_data[idx] = result_data[idx] + src_data[i];
364            }
365        } else {
366            let self_shape_ref = self.shape();
367            let self_shape = self_shape_ref.dims();
368            let indices_shape_ref = indices.shape();
369            let indices_shape = indices_shape_ref.dims();
370            let dim_size = self_shape[dim];
371            let mut self_strides = vec![1; self_shape.len()];
372            let mut indices_strides = vec![1; indices_shape.len()];
373            for i in (0..self_shape.len() - 1).rev() {
374                self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
375            }
376            for i in (0..indices_shape.len() - 1).rev() {
377                indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
378            }
379            let total_elements = indices_data.len();
380            for (i, &index_value) in indices_data.iter().enumerate().take(total_elements) {
381                let mut indices_coords = vec![0; indices_shape.len()];
382                let mut temp_i = i;
383                for j in 0..indices_shape.len() {
384                    indices_coords[j] = temp_i / indices_strides[j];
385                    temp_i %= indices_strides[j];
386                }
387                let idx = if index_value < 0 {
388                    (dim_size as i64 + index_value) as usize
389                } else {
390                    index_value as usize
391                };
392                if idx >= dim_size {
393                    return Err(TorshError::InvalidArgument(format!(
394                        "Index {index_value} out of range for dimension {dim} with size {dim_size}"
395                    )));
396                }
397                let mut self_coords = indices_coords.clone();
398                if dim < self_coords.len() {
399                    self_coords[dim] = idx;
400                }
401                let mut flat_idx = 0;
402                for j in 0..self_coords.len() {
403                    flat_idx += self_coords[j] * self_strides[j];
404                }
405                result_data[flat_idx] = result_data[flat_idx] + src_data[i];
406            }
407        }
408        Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
409    }
410    /// Repeat tensor along specified dimensions
411    pub fn repeat(&self, repeats: &[usize]) -> Result<Self> {
412        if repeats.len() != self.ndim() {
413            return Err(TorshError::InvalidArgument(format!(
414                "Number of repeats {} must match tensor dimensions {}",
415                repeats.len(),
416                self.ndim()
417            )));
418        }
419        let self_data = self.to_vec()?;
420        let shape_binding = self.shape();
421        let self_shape = shape_binding.dims();
422        let new_shape: Vec<usize> = self_shape
423            .iter()
424            .zip(repeats.iter())
425            .map(|(&dim, &repeat)| dim * repeat)
426            .collect();
427        let new_numel = new_shape.iter().product();
428        let mut result_data = Vec::with_capacity(new_numel);
429        for result_idx in 0..new_numel {
430            let mut result_coords = vec![0; new_shape.len()];
431            let mut temp_idx = result_idx;
432            for i in (0..new_shape.len()).rev() {
433                result_coords[i] = temp_idx % new_shape[i];
434                temp_idx /= new_shape[i];
435            }
436            let source_coords: Vec<usize> = result_coords
437                .iter()
438                .zip(self_shape.iter())
439                .map(|(&result_coord, &dim_size)| result_coord % dim_size)
440                .collect();
441            let mut source_idx = 0;
442            let mut stride = 1;
443            for i in (0..self_shape.len()).rev() {
444                source_idx += source_coords[i] * stride;
445                stride *= self_shape[i];
446            }
447            result_data.push(self_data[source_idx]);
448        }
449        Self::from_data(result_data, new_shape, self.device)
450    }
451    /// Add values to tensor at specified indices along a dimension
452    ///
453    /// # PyTorch Compatibility
454    /// Equivalent to `torch.index_add(tensor, dim, index, source, alpha=1.0)`
455    ///
456    /// # Arguments
457    /// * `dim` - Dimension along which to index
458    /// * `index` - 1D tensor containing indices
459    /// * `source` - Source tensor to add
460    ///
461    /// # Examples
462    /// ```ignore
463    /// let tensor = Tensor::zeros(&[3, 5], DeviceType::Cpu)?;
464    /// let index = Tensor::from_data(vec![0i64, 2], vec![2], DeviceType::Cpu)?;
465    /// let source = Tensor::ones(&[2, 5], DeviceType::Cpu)?;
466    /// let result = tensor.index_add(0, &index, &source)?;
467    /// ```
468    pub fn index_add(&self, dim: isize, index: &Tensor<i64>, source: &Self) -> Result<Self>
469    where
470        T: std::ops::Add<Output = T>,
471    {
472        let ndim = self.ndim();
473        let dim = if dim < 0 {
474            (ndim as isize + dim) as usize
475        } else {
476            dim as usize
477        };
478        if dim >= ndim {
479            return Err(TorshError::InvalidArgument(format!(
480                "Dimension {} out of range for {}-D tensor",
481                dim, ndim
482            )));
483        }
484        if index.ndim() != 1 {
485            return Err(TorshError::InvalidArgument(
486                "index must be 1D tensor".to_string(),
487            ));
488        }
489        let index_size = index.shape().dims()[0];
490        let self_shape = self.shape().to_vec();
491        let source_shape = source.shape().to_vec();
492        if source_shape.len() != self_shape.len() {
493            return Err(TorshError::ShapeMismatch {
494                expected: self_shape.clone(),
495                got: source_shape.clone(),
496            });
497        }
498        for (i, (&s, &src_s)) in self_shape.iter().zip(source_shape.iter()).enumerate() {
499            if i == dim {
500                if src_s != index_size {
501                    return Err(TorshError::InvalidArgument(format!(
502                        "source dimension {} size {} must match index size {}",
503                        i, src_s, index_size
504                    )));
505                }
506            } else if s != src_s {
507                return Err(TorshError::ShapeMismatch {
508                    expected: self_shape.clone(),
509                    got: source_shape.clone(),
510                });
511            }
512        }
513        let mut result_data = self.to_vec()?;
514        let source_data = source.to_vec()?;
515        let index_data = index.to_vec()?;
516        let dim_size = self_shape[dim];
517        let outer_size: usize = self_shape[..dim].iter().product();
518        let inner_size: usize = self_shape[dim + 1..].iter().product();
519        for (src_idx_in_dim, &target_idx) in index_data.iter().enumerate() {
520            if target_idx < 0 || target_idx as usize >= dim_size {
521                return Err(TorshError::InvalidArgument(format!(
522                    "Index {} out of range for dimension size {}",
523                    target_idx, dim_size
524                )));
525            }
526            let target_idx = target_idx as usize;
527            for outer in 0..outer_size {
528                for inner in 0..inner_size {
529                    let result_idx =
530                        outer * dim_size * inner_size + target_idx * inner_size + inner;
531                    let source_idx =
532                        outer * index_size * inner_size + src_idx_in_dim * inner_size + inner;
533                    result_data[result_idx] = result_data[result_idx] + source_data[source_idx];
534                }
535            }
536        }
537        Self::from_data(result_data, self_shape, self.device)
538    }
539    /// Copy values from source to tensor at specified indices along a dimension
540    ///
541    /// # PyTorch Compatibility
542    /// Equivalent to `torch.index_copy(tensor, dim, index, source)`
543    ///
544    /// # Arguments
545    /// * `dim` - Dimension along which to index
546    /// * `index` - 1D tensor containing indices
547    /// * `source` - Source tensor to copy from
548    ///
549    /// # Examples
550    /// ```ignore
551    /// let tensor = Tensor::zeros(&[3, 5], DeviceType::Cpu)?;
552    /// let index = Tensor::from_data(vec![0i64, 2], vec![2], DeviceType::Cpu)?;
553    /// let source = Tensor::ones(&[2, 5], DeviceType::Cpu)?;
554    /// let result = tensor.index_copy(0, &index, &source)?;
555    /// ```
556    pub fn index_copy(&self, dim: isize, index: &Tensor<i64>, source: &Self) -> Result<Self> {
557        let ndim = self.ndim();
558        let dim = if dim < 0 {
559            (ndim as isize + dim) as usize
560        } else {
561            dim as usize
562        };
563        if dim >= ndim {
564            return Err(TorshError::InvalidArgument(format!(
565                "Dimension {} out of range for {}-D tensor",
566                dim, ndim
567            )));
568        }
569        if index.ndim() != 1 {
570            return Err(TorshError::InvalidArgument(
571                "index must be 1D tensor".to_string(),
572            ));
573        }
574        let index_size = index.shape().dims()[0];
575        let self_shape = self.shape().to_vec();
576        let source_shape = source.shape().to_vec();
577        if source_shape.len() != self_shape.len() {
578            return Err(TorshError::ShapeMismatch {
579                expected: self_shape.clone(),
580                got: source_shape.clone(),
581            });
582        }
583        for (i, (&s, &src_s)) in self_shape.iter().zip(source_shape.iter()).enumerate() {
584            if i == dim {
585                if src_s != index_size {
586                    return Err(TorshError::InvalidArgument(format!(
587                        "source dimension {} size {} must match index size {}",
588                        i, src_s, index_size
589                    )));
590                }
591            } else if s != src_s {
592                return Err(TorshError::ShapeMismatch {
593                    expected: self_shape.clone(),
594                    got: source_shape.clone(),
595                });
596            }
597        }
598        let mut result_data = self.to_vec()?;
599        let source_data = source.to_vec()?;
600        let index_data = index.to_vec()?;
601        let dim_size = self_shape[dim];
602        let outer_size: usize = self_shape[..dim].iter().product();
603        let inner_size: usize = self_shape[dim + 1..].iter().product();
604        for (src_idx_in_dim, &target_idx) in index_data.iter().enumerate() {
605            if target_idx < 0 || target_idx as usize >= dim_size {
606                return Err(TorshError::InvalidArgument(format!(
607                    "Index {} out of range for dimension size {}",
608                    target_idx, dim_size
609                )));
610            }
611            let target_idx = target_idx as usize;
612            for outer in 0..outer_size {
613                for inner in 0..inner_size {
614                    let result_idx =
615                        outer * dim_size * inner_size + target_idx * inner_size + inner;
616                    let source_idx =
617                        outer * index_size * inner_size + src_idx_in_dim * inner_size + inner;
618                    result_data[result_idx] = source_data[source_idx];
619                }
620            }
621        }
622        Self::from_data(result_data, self_shape, self.device)
623    }
624    /// Fill values in tensor at specified indices along a dimension
625    ///
626    /// # PyTorch Compatibility
627    /// Equivalent to `torch.index_fill(tensor, dim, index, value)`
628    ///
629    /// # Arguments
630    /// * `dim` - Dimension along which to index
631    /// * `index` - 1D tensor containing indices
632    /// * `value` - Scalar value to fill
633    ///
634    /// # Examples
635    /// ```ignore
636    /// let tensor = Tensor::zeros(&[3, 5], DeviceType::Cpu)?;
637    /// let index = Tensor::from_data(vec![0i64, 2], vec![2], DeviceType::Cpu)?;
638    /// let result = tensor.index_fill(0, &index, 3.14)?;
639    /// ```
640    pub fn index_fill(&self, dim: isize, index: &Tensor<i64>, value: T) -> Result<Self> {
641        let ndim = self.ndim();
642        let dim = if dim < 0 {
643            (ndim as isize + dim) as usize
644        } else {
645            dim as usize
646        };
647        if dim >= ndim {
648            return Err(TorshError::InvalidArgument(format!(
649                "Dimension {} out of range for {}-D tensor",
650                dim, ndim
651            )));
652        }
653        if index.ndim() != 1 {
654            return Err(TorshError::InvalidArgument(
655                "index must be 1D tensor".to_string(),
656            ));
657        }
658        let mut result_data = self.to_vec()?;
659        let index_data = index.to_vec()?;
660        let self_shape = self.shape().to_vec();
661        let dim_size = self_shape[dim];
662        let outer_size: usize = self_shape[..dim].iter().product();
663        let inner_size: usize = self_shape[dim + 1..].iter().product();
664        for &target_idx in index_data.iter() {
665            if target_idx < 0 || target_idx as usize >= dim_size {
666                return Err(TorshError::InvalidArgument(format!(
667                    "Index {} out of range for dimension size {}",
668                    target_idx, dim_size
669                )));
670            }
671            let target_idx = target_idx as usize;
672            for outer in 0..outer_size {
673                for inner in 0..inner_size {
674                    let result_idx =
675                        outer * dim_size * inner_size + target_idx * inner_size + inner;
676                    result_data[result_idx] = value;
677                }
678            }
679        }
680        Self::from_data(result_data, self_shape, self.device)
681    }
682    /// Place values at specified flat indices (in-place-like operation, returns new tensor)
683    ///
684    /// # PyTorch Compatibility
685    /// Equivalent to `torch.put_(tensor, indices, values)` but returns new tensor
686    ///
687    /// # Arguments
688    /// * `indices` - 1D tensor of flat indices
689    /// * `values` - 1D tensor of values (must match indices length or be broadcastable)
690    ///
691    /// # Examples
692    /// ```ignore
693    /// let tensor = Tensor::zeros(&[3, 3], DeviceType::Cpu)?;  // [[0,0,0],[0,0,0],[0,0,0]]
694    /// let indices = Tensor::from_data(vec![0i64, 4, 8], vec![3], DeviceType::Cpu)?;
695    /// let values = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
696    /// let result = tensor.put_(&indices, &values)?;  // [[1,0,0],[0,2,0],[0,0,3]]
697    /// ```
698    pub fn put_(&self, indices: &Tensor<i64>, values: &Tensor<T>) -> Result<Self> {
699        if indices.ndim() != 1 {
700            return Err(TorshError::InvalidArgument(
701                "indices must be 1D tensor".to_string(),
702            ));
703        }
704        if values.ndim() != 1 {
705            return Err(TorshError::InvalidArgument(
706                "values must be 1D tensor".to_string(),
707            ));
708        }
709        let indices_data = indices.to_vec()?;
710        let values_data = values.to_vec()?;
711        if indices_data.len() != values_data.len() {
712            return Err(TorshError::InvalidArgument(format!(
713                "Number of values {} must match number of indices {}",
714                values_data.len(),
715                indices_data.len()
716            )));
717        }
718        let mut result_data = self.to_vec()?;
719        let numel = self.numel();
720        for (i, &index) in indices_data.iter().enumerate() {
721            let idx = if index < 0 {
722                ((numel as i64) + index) as usize
723            } else {
724                index as usize
725            };
726            if idx >= numel {
727                return Err(TorshError::InvalidArgument(format!(
728                    "Index {} out of range for tensor with {} elements",
729                    index, numel
730                )));
731            }
732            result_data[idx] = values_data[i];
733        }
734        Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
735    }
736    /// Scatter values from source tensor where mask is true (PyTorch-compatible)
737    ///
738    /// Copies values from the source tensor to positions where the mask is true.
739    /// The mask must have the same shape as self. Source values are taken sequentially
740    /// and placed at positions where mask is true.
741    ///
742    /// # PyTorch Compatibility
743    /// Equivalent to `torch.masked_scatter(tensor, mask, source)`
744    ///
745    /// # Arguments
746    /// * `mask` - Boolean tensor with same shape as self
747    /// * `source` - Tensor containing values to scatter (must have at least as many elements as true values in mask)
748    ///
749    /// # Examples
750    /// ```ignore
751    /// let tensor = Tensor::zeros(&[3, 3], DeviceType::Cpu)?;
752    /// let mask = Tensor::from_data(
753    ///     vec![true, false, false, false, true, false, false, false, true],
754    ///     vec![3, 3],
755    ///     DeviceType::Cpu
756    /// )?;
757    /// let source = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
758    /// let result = tensor.masked_scatter(&mask, &source)?;  // [[1,0,0],[0,2,0],[0,0,3]]
759    /// ```
760    pub fn masked_scatter(&self, mask: &Tensor<bool>, source: &Tensor<T>) -> Result<Self> {
761        if self.shape() != mask.shape() {
762            return Err(TorshError::ShapeMismatch {
763                expected: self.shape().dims().to_vec(),
764                got: mask.shape().dims().to_vec(),
765            });
766        }
767        let mask_data = mask.to_vec()?;
768        let true_count = mask_data.iter().filter(|&&x| x).count();
769        if source.numel() < true_count {
770            return Err(TorshError::InvalidArgument(format!(
771                "Source tensor has {} elements but need {} for scatter (mask has {} true values)",
772                source.numel(),
773                true_count,
774                true_count
775            )));
776        }
777        let self_data = self.to_vec()?;
778        let source_data = source.to_vec()?;
779        let mut result_data = Vec::with_capacity(self_data.len());
780        let mut source_idx = 0;
781        for (i, &self_val) in self_data.iter().enumerate() {
782            if i < mask_data.len() && mask_data[i] {
783                result_data.push(source_data[source_idx]);
784                source_idx += 1;
785            } else {
786                result_data.push(self_val);
787            }
788        }
789        Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
790    }
791    /// Multi-dimensional indexed put operation (PyTorch-compatible)
792    ///
793    /// Places values from source tensor at positions specified by index tensors.
794    /// Each index tensor specifies indices along one dimension. Index tensors must
795    /// be broadcastable to the same shape.
796    ///
797    /// # PyTorch Compatibility
798    /// Equivalent to `torch.index_put(tensor, indices, values)` where indices is a tuple of index tensors
799    ///
800    /// # Arguments
801    /// * `indices` - Slice of index tensors, one per dimension to index
802    /// * `values` - Tensor of values to place (must broadcast to indexed positions)
803    ///
804    /// # Examples
805    /// ```ignore
806    /// // 2D example: index_put a 3x3 matrix with row=[0,1] col=[1,2]
807    /// let tensor = Tensor::zeros(&[3, 3], DeviceType::Cpu)?;
808    /// let row_idx = Tensor::from_data(vec![0i64, 1], vec![2], DeviceType::Cpu)?;
809    /// let col_idx = Tensor::from_data(vec![1i64, 2], vec![2], DeviceType::Cpu)?;
810    /// let values = Tensor::from_data(vec![10.0f32, 20.0], vec![2], DeviceType::Cpu)?;
811    /// let result = tensor.index_put(&[row_idx, col_idx], &values)?;
812    /// // result[0,1] = 10.0, result[1,2] = 20.0
813    /// ```
814    pub fn index_put(&self, indices: &[Tensor<i64>], values: &Tensor<T>) -> Result<Self> {
815        if indices.is_empty() {
816            return Err(TorshError::InvalidArgument(
817                "indices cannot be empty".to_string(),
818            ));
819        }
820        if indices.len() > self.ndim() {
821            return Err(TorshError::InvalidArgument(format!(
822                "Too many indices ({}) for tensor with {} dimensions",
823                indices.len(),
824                self.ndim()
825            )));
826        }
827        let index_shape_ref = indices[0].shape();
828        let index_shape = index_shape_ref.dims();
829        let num_indices = indices[0].numel();
830        for idx_tensor in indices.iter() {
831            if idx_tensor.shape().dims() != index_shape {
832                return Err(TorshError::ShapeMismatch {
833                    expected: index_shape.to_vec(),
834                    got: idx_tensor.shape().dims().to_vec(),
835                });
836            }
837        }
838        if values.numel() != num_indices && values.numel() != 1 {
839            return Err(TorshError::InvalidArgument(format!(
840                "Values tensor has {} elements but need {} (or 1 for broadcasting)",
841                values.numel(),
842                num_indices
843            )));
844        }
845        let mut result_data = self.to_vec()?;
846        let self_shape_ref = self.shape();
847        let self_shape = self_shape_ref.dims();
848        let values_data = values.to_vec()?;
849        let index_data: Result<Vec<Vec<i64>>> = indices.iter().map(|idx| idx.to_vec()).collect();
850        let index_data = index_data?;
851        let mut strides = vec![1; self_shape.len()];
852        for i in (0..self_shape.len() - 1).rev() {
853            strides[i] = strides[i + 1] * self_shape[i + 1];
854        }
855        for i in 0..num_indices {
856            let value = if values_data.len() == 1 {
857                values_data[0]
858            } else {
859                values_data[i]
860            };
861            let mut flat_idx = 0;
862            for (dim, idx_vec) in index_data.iter().enumerate() {
863                let mut idx = idx_vec[i];
864                if idx < 0 {
865                    idx += self_shape[dim] as i64;
866                }
867                if idx < 0 || idx >= self_shape[dim] as i64 {
868                    return Err(TorshError::InvalidArgument(format!(
869                        "Index {} out of bounds for dimension {} with size {}",
870                        idx_vec[i], dim, self_shape[dim]
871                    )));
872                }
873                flat_idx += (idx as usize) * strides[dim];
874            }
875            result_data[flat_idx] = value;
876        }
877        Self::from_data(result_data, self_shape.to_vec(), self.device)
878    }
879    /// Scatter with reduction operation (PyTorch-compatible)
880    ///
881    /// Generalized scatter operation that applies a reduction operation (sum, prod, mean, etc.)
882    /// when scattering values to the same index position.
883    ///
884    /// # PyTorch Compatibility
885    /// Equivalent to `torch.scatter_reduce(tensor, dim, index, src, reduce)`
886    ///
887    /// # Arguments
888    /// * `dim` - Dimension along which to scatter
889    /// * `indices` - Index tensor specifying where to scatter values
890    /// * `src` - Source tensor containing values to scatter
891    /// * `reduce` - Reduction operation ("sum", "prod", "mean", "amax", "amin")
892    ///
893    /// # Examples
894    /// ```ignore
895    /// let tensor = Tensor::zeros(&[5], DeviceType::Cpu)?;
896    /// let indices = Tensor::from_data(vec![0i64, 1, 2, 0, 1], vec![5], DeviceType::Cpu)?;
897    /// let src = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)?;
898    /// let result = tensor.scatter_reduce(0, &indices, &src, "sum")?;
899    /// // result[0] = 1.0 + 4.0 = 5.0 (sum reduction)
900    /// // result[1] = 2.0 + 5.0 = 7.0
901    /// ```
902    pub fn scatter_reduce(
903        &self,
904        dim: usize,
905        indices: &Tensor<i64>,
906        src: &Tensor<T>,
907        reduce: &str,
908    ) -> Result<Self>
909    where
910        T: std::ops::Add<Output = T>
911            + std::ops::Mul<Output = T>
912            + std::ops::Div<Output = T>
913            + PartialOrd
914            + num_traits::FromPrimitive,
915    {
916        if dim >= self.ndim() {
917            return Err(TorshError::InvalidArgument(format!(
918                "Dimension {} out of range for {}-dimensional tensor",
919                dim,
920                self.ndim()
921            )));
922        }
923        if indices.shape() != src.shape() {
924            return Err(TorshError::ShapeMismatch {
925                expected: indices.shape().dims().to_vec(),
926                got: src.shape().dims().to_vec(),
927            });
928        }
929        let indices_data = indices.to_vec()?;
930        let src_data = src.to_vec()?;
931        let mut result_data = self.to_vec()?;
932        let self_shape_ref = self.shape();
933        let self_shape = self_shape_ref.dims();
934        let mut counts = if reduce == "mean" {
935            vec![0usize; result_data.len()]
936        } else {
937            vec![]
938        };
939        if self.ndim() == 1 {
940            for (i, &index) in indices_data.iter().enumerate() {
941                let idx = if index < 0 {
942                    (self_shape[0] as i64 + index) as usize
943                } else {
944                    index as usize
945                };
946                if idx >= self_shape[0] {
947                    return Err(TorshError::InvalidArgument(format!(
948                        "Index {} out of bounds for dimension size {}",
949                        index, self_shape[0]
950                    )));
951                }
952                result_data[idx] = match reduce {
953                    "sum" => result_data[idx] + src_data[i],
954                    "prod" => result_data[idx] * src_data[i],
955                    "mean" => {
956                        counts[idx] += 1;
957                        result_data[idx] + src_data[i]
958                    }
959                    "amax" => {
960                        if src_data[i] > result_data[idx] {
961                            src_data[i]
962                        } else {
963                            result_data[idx]
964                        }
965                    }
966                    "amin" => {
967                        if src_data[i] < result_data[idx] {
968                            src_data[i]
969                        } else {
970                            result_data[idx]
971                        }
972                    }
973                    _ => {
974                        return Err(TorshError::InvalidArgument(format!(
975                            "Unknown reduce operation: {}. Supported: sum, prod, mean, amax, amin",
976                            reduce
977                        )));
978                    }
979                };
980            }
981            if reduce == "mean" {
982                for (i, count) in counts.iter().enumerate() {
983                    if *count > 0 {
984                        result_data[i] = T::from_usize(*count)
985                            .and_then(|c| Some(result_data[i] / c))
986                            .unwrap_or(result_data[i]);
987                    }
988                }
989            }
990        } else {
991            let dim_size = self_shape[dim];
992            let _outer_size: usize = self_shape[..dim].iter().product();
993            let _inner_size: usize = self_shape[dim + 1..].iter().product();
994            let mut self_strides = vec![1; self_shape.len()];
995            for i in (0..self_shape.len() - 1).rev() {
996                self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
997            }
998            let src_shape_ref = src.shape();
999            let src_shape = src_shape_ref.dims();
1000            let mut src_strides = vec![1; src_shape.len()];
1001            for i in (0..src_shape.len() - 1).rev() {
1002                src_strides[i] = src_strides[i + 1] * src_shape[i + 1];
1003            }
1004            for i in 0..indices_data.len() {
1005                let index = indices_data[i];
1006                let idx = if index < 0 {
1007                    (dim_size as i64 + index) as usize
1008                } else {
1009                    index as usize
1010                };
1011                if idx >= dim_size {
1012                    return Err(TorshError::InvalidArgument(format!(
1013                        "Index {} out of bounds for dimension {} size {}",
1014                        index, dim, dim_size
1015                    )));
1016                }
1017                let mut coords = vec![0; self_shape.len()];
1018                let mut remainder = i;
1019                for (d, &stride) in src_strides.iter().enumerate() {
1020                    coords[d] = remainder / stride;
1021                    remainder %= stride;
1022                }
1023                coords[dim] = idx;
1024                let flat_idx = coords
1025                    .iter()
1026                    .zip(self_strides.iter())
1027                    .map(|(c, s)| c * s)
1028                    .sum::<usize>();
1029                result_data[flat_idx] = match reduce {
1030                    "sum" => result_data[flat_idx] + src_data[i],
1031                    "prod" => result_data[flat_idx] * src_data[i],
1032                    "mean" => {
1033                        counts[flat_idx] += 1;
1034                        result_data[flat_idx] + src_data[i]
1035                    }
1036                    "amax" => {
1037                        if src_data[i] > result_data[flat_idx] {
1038                            src_data[i]
1039                        } else {
1040                            result_data[flat_idx]
1041                        }
1042                    }
1043                    "amin" => {
1044                        if src_data[i] < result_data[flat_idx] {
1045                            src_data[i]
1046                        } else {
1047                            result_data[flat_idx]
1048                        }
1049                    }
1050                    _ => {
1051                        return Err(TorshError::InvalidArgument(format!(
1052                            "Unknown reduce operation: {}",
1053                            reduce
1054                        )));
1055                    }
1056                };
1057            }
1058            if reduce == "mean" {
1059                for (i, count) in counts.iter().enumerate() {
1060                    if *count > 0 {
1061                        result_data[i] = T::from_usize(*count)
1062                            .and_then(|c| Some(result_data[i] / c))
1063                            .unwrap_or(result_data[i]);
1064                    }
1065                }
1066            }
1067        }
1068        Self::from_data(result_data, self_shape.to_vec(), self.device)
1069    }
1070    /// Scatter values to the diagonal (PyTorch-compatible)
1071    ///
1072    /// Embeds the values of src tensor into self along the diagonal elements,
1073    /// with respect to dim1 and dim2. The offset determines which diagonal to use.
1074    ///
1075    /// # PyTorch Compatibility
1076    /// Equivalent to `torch.diagonal_scatter(tensor, src, offset, dim1, dim2)`
1077    ///
1078    /// # Arguments
1079    /// * `src` - Source tensor containing values for the diagonal
1080    /// * `offset` - Diagonal offset (0=main diagonal, >0=above, <0=below)
1081    /// * `dim1` - First dimension (default: 0)
1082    /// * `dim2` - Second dimension (default: 1)
1083    ///
1084    /// # Examples
1085    /// ```ignore
1086    /// let tensor = Tensor::zeros(&[3, 3], DeviceType::Cpu)?;
1087    /// let src = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
1088    /// let result = tensor.diagonal_scatter(&src, 0, 0, 1)?;
1089    /// // result = [[1, 0, 0], [0, 2, 0], [0, 0, 3]]
1090    /// ```
1091    pub fn diagonal_scatter(
1092        &self,
1093        src: &Tensor<T>,
1094        offset: isize,
1095        dim1: usize,
1096        dim2: usize,
1097    ) -> Result<Self> {
1098        if dim1 >= self.ndim() || dim2 >= self.ndim() {
1099            return Err(TorshError::InvalidArgument(format!(
1100                "Dimensions ({}, {}) out of range for {}-dimensional tensor",
1101                dim1,
1102                dim2,
1103                self.ndim()
1104            )));
1105        }
1106        if dim1 == dim2 {
1107            return Err(TorshError::InvalidArgument(
1108                "dim1 and dim2 must be different".to_string(),
1109            ));
1110        }
1111        let self_shape_ref = self.shape();
1112        let self_shape = self_shape_ref.dims();
1113        let dim1_size = self_shape[dim1];
1114        let dim2_size = self_shape[dim2];
1115        let diag_len = if offset >= 0 {
1116            let offset_u = offset as usize;
1117            if offset_u >= dim2_size {
1118                0
1119            } else {
1120                std::cmp::min(dim1_size, dim2_size - offset_u)
1121            }
1122        } else {
1123            let offset_u = (-offset) as usize;
1124            if offset_u >= dim1_size {
1125                0
1126            } else {
1127                std::cmp::min(dim1_size - offset_u, dim2_size)
1128            }
1129        };
1130        if src.numel() != diag_len {
1131            return Err(TorshError::ShapeMismatch {
1132                expected: vec![diag_len],
1133                got: vec![src.numel()],
1134            });
1135        }
1136        let mut result_data = self.to_vec()?;
1137        let src_data = src.to_vec()?;
1138        let mut strides = vec![1; self_shape.len()];
1139        for i in (0..self_shape.len() - 1).rev() {
1140            strides[i] = strides[i + 1] * self_shape[i + 1];
1141        }
1142        for i in 0..diag_len {
1143            let mut indices = vec![0; self_shape.len()];
1144            if offset >= 0 {
1145                indices[dim1] = i;
1146                indices[dim2] = i + offset as usize;
1147            } else {
1148                indices[dim1] = i + (-offset) as usize;
1149                indices[dim2] = i;
1150            }
1151            let mut flat_idx = 0;
1152            for (d, &idx) in indices.iter().enumerate() {
1153                flat_idx += idx * strides[d];
1154            }
1155            result_data[flat_idx] = src_data[i];
1156        }
1157        Self::from_data(result_data, self_shape.to_vec(), self.device)
1158    }
1159    /// Scatter values to a selected slice along dimension (PyTorch-compatible)
1160    ///
1161    /// Embeds the values of src tensor into self at the given index along dimension dim.
1162    /// This is the inverse of `select()` operation.
1163    ///
1164    /// # PyTorch Compatibility
1165    /// Equivalent to `torch.select_scatter(tensor, src, dim, index)`
1166    ///
1167    /// # Arguments
1168    /// * `src` - Source tensor to scatter (shape should match self with dim removed)
1169    /// * `dim` - Dimension along which to select
1170    /// * `index` - Index position to scatter to
1171    ///
1172    /// # Examples
1173    /// ```ignore
1174    /// let tensor = Tensor::zeros(&[3, 4, 5], DeviceType::Cpu)?;
1175    /// let src = Tensor::ones(&[3, 5], DeviceType::Cpu)?; // dim=1 removed
1176    /// let result = tensor.select_scatter(&src, 1, 2)?;
1177    /// // result[:, 2, :] = src
1178    /// ```
1179    pub fn select_scatter(&self, src: &Tensor<T>, dim: isize, index: isize) -> Result<Self> {
1180        let ndim = self.ndim() as isize;
1181        let dim_normalized = if dim < 0 { ndim + dim } else { dim };
1182        if dim_normalized < 0 || dim_normalized >= ndim {
1183            return Err(TorshError::InvalidArgument(format!(
1184                "Dimension {} out of range for {}-dimensional tensor",
1185                dim,
1186                self.ndim()
1187            )));
1188        }
1189        let dim_u = dim_normalized as usize;
1190        let self_shape_ref = self.shape();
1191        let self_shape = self_shape_ref.dims();
1192        let index_normalized = if index < 0 {
1193            (self_shape[dim_u] as isize) + index
1194        } else {
1195            index
1196        };
1197        if index_normalized < 0 || index_normalized >= self_shape[dim_u] as isize {
1198            return Err(TorshError::InvalidArgument(format!(
1199                "Index {} out of bounds for dimension {} with size {}",
1200                index, dim_u, self_shape[dim_u]
1201            )));
1202        }
1203        let index_u = index_normalized as usize;
1204        let expected_src_shape: Vec<usize> = self_shape
1205            .iter()
1206            .enumerate()
1207            .filter(|(i, _)| *i != dim_u)
1208            .map(|(_, &s)| s)
1209            .collect();
1210        let src_shape_ref = src.shape();
1211        let src_shape = src_shape_ref.dims();
1212        if src_shape != expected_src_shape.as_slice() {
1213            return Err(TorshError::ShapeMismatch {
1214                expected: expected_src_shape,
1215                got: src_shape.to_vec(),
1216            });
1217        }
1218        let mut result_data = self.to_vec()?;
1219        let src_data = src.to_vec()?;
1220        let mut self_strides = vec![1; self_shape.len()];
1221        for i in (0..self_shape.len() - 1).rev() {
1222            self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
1223        }
1224        let outer_size: usize = self_shape[..dim_u].iter().product();
1225        let inner_size: usize = self_shape[dim_u + 1..].iter().product();
1226        for outer in 0..outer_size {
1227            for inner in 0..inner_size {
1228                let self_idx =
1229                    outer * (self_shape[dim_u] * inner_size) + index_u * inner_size + inner;
1230                let src_idx = outer * inner_size + inner;
1231                result_data[self_idx] = src_data[src_idx];
1232            }
1233        }
1234        Self::from_data(result_data, self_shape.to_vec(), self.device)
1235    }
1236    /// Scatter values to a slice along dimension (PyTorch-compatible)
1237    ///
1238    /// Embeds the values of src tensor into self along dimension dim, starting at
1239    /// start index, ending at end index, with the given step.
1240    ///
1241    /// # PyTorch Compatibility
1242    /// Equivalent to `torch.slice_scatter(tensor, src, dim, start, end, step)`
1243    ///
1244    /// # Arguments
1245    /// * `src` - Source tensor to scatter
1246    /// * `dim` - Dimension along which to slice
1247    /// * `start` - Starting index (None means 0)
1248    /// * `end` - Ending index (None means size of dim)
1249    /// * `step` - Step size (default: 1)
1250    ///
1251    /// # Examples
1252    /// ```ignore
1253    /// let tensor = Tensor::zeros(&[5, 5], DeviceType::Cpu)?;
1254    /// let src = Tensor::ones(&[2, 5], DeviceType::Cpu)?;
1255    /// let result = tensor.slice_scatter(&src, 0, Some(1), Some(3), 1)?;
1256    /// // result[1:3, :] = src
1257    /// ```
1258    pub fn slice_scatter(
1259        &self,
1260        src: &Tensor<T>,
1261        dim: isize,
1262        start: Option<isize>,
1263        end: Option<isize>,
1264        step: usize,
1265    ) -> Result<Self> {
1266        if step == 0 {
1267            return Err(TorshError::InvalidArgument(
1268                "Step must be greater than 0".to_string(),
1269            ));
1270        }
1271        let ndim = self.ndim() as isize;
1272        let dim_normalized = if dim < 0 { ndim + dim } else { dim };
1273        if dim_normalized < 0 || dim_normalized >= ndim {
1274            return Err(TorshError::InvalidArgument(format!(
1275                "Dimension {} out of range for {}-dimensional tensor",
1276                dim,
1277                self.ndim()
1278            )));
1279        }
1280        let dim_u = dim_normalized as usize;
1281        let self_shape_ref = self.shape();
1282        let self_shape = self_shape_ref.dims();
1283        let dim_size = self_shape[dim_u] as isize;
1284        let start_normalized = start.unwrap_or(0);
1285        let start_normalized = if start_normalized < 0 {
1286            dim_size + start_normalized
1287        } else {
1288            start_normalized
1289        };
1290        let start_normalized = std::cmp::max(0, std::cmp::min(start_normalized, dim_size)) as usize;
1291        let end_normalized = end.unwrap_or(dim_size);
1292        let end_normalized = if end_normalized < 0 {
1293            dim_size + end_normalized
1294        } else {
1295            end_normalized
1296        };
1297        let end_normalized = std::cmp::max(0, std::cmp::min(end_normalized, dim_size)) as usize;
1298        let slice_len = if end_normalized > start_normalized {
1299            (end_normalized - start_normalized + step - 1) / step
1300        } else {
1301            0
1302        };
1303        let mut expected_src_shape = self_shape.to_vec();
1304        expected_src_shape[dim_u] = slice_len;
1305        let src_shape_ref = src.shape();
1306        let src_shape = src_shape_ref.dims();
1307        if src_shape != expected_src_shape.as_slice() {
1308            return Err(TorshError::ShapeMismatch {
1309                expected: expected_src_shape,
1310                got: src_shape.to_vec(),
1311            });
1312        }
1313        let mut result_data = self.to_vec()?;
1314        let src_data = src.to_vec()?;
1315        let mut self_strides = vec![1; self_shape.len()];
1316        for i in (0..self_shape.len() - 1).rev() {
1317            self_strides[i] = self_strides[i + 1] * self_shape[i + 1];
1318        }
1319        let outer_size: usize = self_shape[..dim_u].iter().product();
1320        let inner_size: usize = self_shape[dim_u + 1..].iter().product();
1321        for outer in 0..outer_size {
1322            for slice_idx in 0..slice_len {
1323                let self_dim_idx = start_normalized + slice_idx * step;
1324                for inner in 0..inner_size {
1325                    let self_idx = outer * (self_shape[dim_u] * inner_size)
1326                        + self_dim_idx * inner_size
1327                        + inner;
1328                    let src_idx = outer * (slice_len * inner_size) + slice_idx * inner_size + inner;
1329                    result_data[self_idx] = src_data[src_idx];
1330                }
1331            }
1332        }
1333        Self::from_data(result_data, self_shape.to_vec(), self.device)
1334    }
1335}