slsl/ops/
tile.rs

1use anyhow::Result;
2
3use crate::{DType, Shape, StorageTrait, Tensor, TensorBase, UninitVec};
4
5impl<S: StorageTrait> TensorBase<S> {
6    /// Constructs a tensor by repeating the elements of `self` according to the specified pattern.
7    ///
8    /// This function creates a new tensor where each element of the input tensor is repeated
9    /// according to the `repeats` argument. The `repeats` argument specifies the number of
10    /// repetitions in each dimension.
11    ///
12    /// # Arguments
13    ///
14    /// * `repeats` - An array specifying the number of repetitions for each dimension.
15    ///   Can be converted into `Shape` (e.g., `[2, 3]`, `vec![2, 3]`, etc.)
16    ///
17    /// # Behavior
18    ///
19    /// - **Fewer dimensions in repeats**: If `repeats` has fewer dimensions than `self`,
20    ///   ones are prepended to `repeats` until all dimensions are specified.
21    ///   - Example: `self` shape `(8, 6, 4, 2)`, `repeats` `[2, 2]` → treated as `[1, 1, 2, 2]`
22    ///
23    /// - **More dimensions in repeats**: If `self` has fewer dimensions than `repeats`,
24    ///   `self` is treated as if it were unsqueezed at dimension zero until it has as many
25    ///   dimensions as `repeats` specifies.
26    ///   - Example: `self` shape `(4, 2)`, `repeats` `[3, 3, 2, 2]` → `self` treated as `(1, 1, 4, 2)`
27    ///
28    /// # Returns
29    ///
30    /// A `Result<Tensor>` containing:
31    /// - `Ok(Tensor)`: A new tensor with the repeated data
32    /// - `Err`: If there's an error during the tiling operation
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use slsl::Tensor;
38    ///
39    /// // 1D tensor
40    /// let tensor = Tensor::from_vec(vec![1, 2, 3], [3])?;
41    ///
42    /// // Repeat 2 times
43    /// let tiled = tensor.tile([2])?;
44    /// assert_eq!(tiled.dims(), [6]);
45    /// assert_eq!(tiled.to_flat_vec::<i32>()?, vec![1, 2, 3, 1, 2, 3]);
46    ///
47    /// // 2D tensor
48    /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
49    ///
50    /// // Repeat 2 times in each dimension
51    /// let tiled = tensor_2d.tile([2, 2])?;
52    /// assert_eq!(tiled.dims(), [4, 4]);
53    ///
54    /// // Repeat with different counts
55    /// let tiled = tensor_2d.tile([3, 1])?;
56    /// assert_eq!(tiled.dims(), [6, 2]);
57    /// # Ok::<(), Box<dyn std::error::Error>>(())
58    /// ```
59    ///
60    /// # Notes
61    ///
62    /// - This operation creates a new tensor with copied data (not a view)
63    /// - The resulting tensor size is the element-wise product of `self.shape` and `repeats`
64    /// - The function follows PyTorch's `tile` behavior
65    /// - All supported data types are handled automatically
66    ///
67    /// # See Also
68    ///
69    /// - Note: `repeat` function is not yet implemented in this library
70    /// - [PyTorch tile][]: <https://pytorch.org/docs/stable/generated/torch.tile.html>
71    ///
72    /// [PyTorch tile]: https://pytorch.org/docs/stable/generated/torch.tile.html
73    pub fn tile<D: Into<Shape>>(&self, repeats: D) -> Result<Tensor> {
74        let repeats = repeats.into();
75        let self_rank = self.rank();
76        let repeats_len = repeats.len();
77
78        // Determine the target rank (maximum of self rank and repeats length)
79        let target_rank = self_rank.max(repeats_len);
80
81        // Create expanded shape and repeats arrays using Shape
82        let mut expanded_shape = Shape::empty().with_len(target_rank);
83        let mut expanded_repeats = Shape::empty().with_len(target_rank);
84
85        // repeats (prepend ones if needed)
86        let rep_pad = target_rank - repeats_len;
87        for i in 0..rep_pad {
88            expanded_repeats[i] = 1;
89        }
90        for i in 0..repeats_len {
91            expanded_repeats[rep_pad + i] = repeats[i];
92        }
93
94        // shape (prepend ones to self if needed)
95        let shp_pad = target_rank - self_rank;
96        for i in 0..shp_pad {
97            expanded_shape[i] = 1;
98        }
99        for i in 0..self_rank {
100            expanded_shape[shp_pad + i] = self.shape[i];
101        }
102
103        // Calculate the final shape after tiling
104        let mut final_shape = Shape::empty().with_len(target_rank);
105        for i in 0..target_rank {
106            final_shape[i] = expanded_shape[i] * expanded_repeats[i];
107        }
108
109        // Calculate the total number of elements in the tiled tensor
110        let total_elements = final_shape.numel();
111
112        // We need to handle different data types generically
113        // Let's use a match statement to handle all supported DType cases
114        match self.dtype {
115            DType::Bool => {
116                let out = UninitVec::<bool>::new(total_elements).init_with(|dst| {
117                    Self::_tile_fill::<bool>(self, &expanded_shape, &final_shape, target_rank, dst)
118                });
119                Tensor::from_vec(out, final_shape)
120            }
121            DType::Int8 => {
122                let out = UninitVec::<i8>::new(total_elements).init_with(|dst| {
123                    Self::_tile_fill::<i8>(self, &expanded_shape, &final_shape, target_rank, dst)
124                });
125                Tensor::from_vec(out, final_shape)
126            }
127            DType::Int16 => {
128                let out = UninitVec::<i16>::new(total_elements).init_with(|dst| {
129                    Self::_tile_fill::<i16>(self, &expanded_shape, &final_shape, target_rank, dst)
130                });
131                Tensor::from_vec(out, final_shape)
132            }
133            DType::Int32 => {
134                let out = UninitVec::<i32>::new(total_elements).init_with(|dst| {
135                    Self::_tile_fill::<i32>(self, &expanded_shape, &final_shape, target_rank, dst)
136                });
137                Tensor::from_vec(out, final_shape)
138            }
139            DType::Int64 => {
140                let out = UninitVec::<i64>::new(total_elements).init_with(|dst| {
141                    Self::_tile_fill::<i64>(self, &expanded_shape, &final_shape, target_rank, dst)
142                });
143                Tensor::from_vec(out, final_shape)
144            }
145            DType::Uint8 => {
146                let out = UninitVec::<u8>::new(total_elements).init_with(|dst| {
147                    Self::_tile_fill::<u8>(self, &expanded_shape, &final_shape, target_rank, dst)
148                });
149                Tensor::from_vec(out, final_shape)
150            }
151            DType::Uint16 => {
152                let out = UninitVec::<u16>::new(total_elements).init_with(|dst| {
153                    Self::_tile_fill::<u16>(self, &expanded_shape, &final_shape, target_rank, dst)
154                });
155                Tensor::from_vec(out, final_shape)
156            }
157            DType::Uint32 => {
158                let out = UninitVec::<u32>::new(total_elements).init_with(|dst| {
159                    Self::_tile_fill::<u32>(self, &expanded_shape, &final_shape, target_rank, dst)
160                });
161                Tensor::from_vec(out, final_shape)
162            }
163            DType::Uint64 => {
164                let out = UninitVec::<u64>::new(total_elements).init_with(|dst| {
165                    Self::_tile_fill::<u64>(self, &expanded_shape, &final_shape, target_rank, dst)
166                });
167                Tensor::from_vec(out, final_shape)
168            }
169            DType::Fp16 => {
170                let out = UninitVec::<half::f16>::new(total_elements).init_with(|dst| {
171                    Self::_tile_fill::<half::f16>(
172                        self,
173                        &expanded_shape,
174                        &final_shape,
175                        target_rank,
176                        dst,
177                    )
178                });
179                Tensor::from_vec(out, final_shape)
180            }
181            DType::Fp32 => {
182                let out = UninitVec::<f32>::new(total_elements).init_with(|dst| {
183                    Self::_tile_fill::<f32>(self, &expanded_shape, &final_shape, target_rank, dst)
184                });
185                Tensor::from_vec(out, final_shape)
186            }
187            DType::Fp64 => {
188                let out = UninitVec::<f64>::new(total_elements).init_with(|dst| {
189                    Self::_tile_fill::<f64>(self, &expanded_shape, &final_shape, target_rank, dst)
190                });
191                Tensor::from_vec(out, final_shape)
192            }
193            DType::Bf16 => {
194                let out = UninitVec::<half::bf16>::new(total_elements).init_with(|dst| {
195                    Self::_tile_fill::<half::bf16>(
196                        self,
197                        &expanded_shape,
198                        &final_shape,
199                        target_rank,
200                        dst,
201                    )
202                });
203                Tensor::from_vec(out, final_shape)
204            }
205            _ => {
206                anyhow::bail!("tile function not supported for Auto dtype")
207            }
208        }
209    }
210
211    /// Fast filler for tile using pointer reads and expanded strides
212    fn _tile_fill<T: crate::TensorElement>(
213        &self,
214        expanded_shape: &Shape,
215        final_shape: &Shape,
216        target_rank: usize,
217        dst: &mut [T],
218    ) {
219        let self_rank = self.rank();
220        let shp_pad = target_rank - self_rank;
221
222        // Build expanded strides that correspond to expanded_shape
223        let mut exp_strides = Shape::empty().with_len(target_rank);
224        for i in 0..target_rank {
225            if i < shp_pad {
226                exp_strides[i] = 0;
227            } else {
228                exp_strides[i] = if (i - shp_pad) < self.strides.len() {
229                    self.strides[i - shp_pad]
230                } else {
231                    0
232                };
233            }
234        }
235
236        let base = self.as_ptr() as *const T;
237        let total = final_shape.numel();
238
239        let mut idx = Shape::empty().with_len(target_rank);
240
241        for (pos, slot) in dst.iter_mut().enumerate().take(total) {
242            let mut rem = pos;
243            for i in (0..target_rank).rev() {
244                let dim = final_shape[i];
245                idx[i] = rem % dim;
246                rem /= dim;
247            }
248            // Reduce indices by repeats to original indices
249            let mut offset_elems = 0usize;
250            for i in 0..target_rank {
251                let dim_sz = expanded_shape[i];
252                let orig = if dim_sz == 0 { 0 } else { idx[i] % dim_sz };
253                offset_elems += orig * exp_strides[i];
254            }
255            let val = unsafe { core::ptr::read(base.add(offset_elems)) };
256            *slot = val;
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use crate::Tensor;
264
265    #[test]
266    fn tile_1d_basic() {
267        let t = Tensor::from_vec(vec![1i32, 2, 3], [3]).unwrap();
268        let out = t.tile([2]).unwrap();
269        assert_eq!(out.dims(), [6]);
270        assert_eq!(out.to_flat_vec::<i32>().unwrap(), vec![1, 2, 3, 1, 2, 3]);
271    }
272
273    #[test]
274    fn tile_2d_symmetric() {
275        let t = Tensor::from_vec(vec![1i32, 2, 3, 4], [2, 2]).unwrap();
276        let out = t.tile([2, 2]).unwrap();
277        assert_eq!(out.dims(), [4, 4]);
278    }
279
280    #[test]
281    fn tile_prepended_repeats() {
282        let t = Tensor::from_vec(vec![1i32, 2, 3, 4], [2, 2]).unwrap();
283        let out = t.tile([3]).unwrap();
284        assert_eq!(out.dims(), [2, 6]);
285    }
286
287    #[test]
288    fn tile_expand_rank() {
289        let t = Tensor::from_vec(vec![1i32, 2, 3], [3]).unwrap();
290        let out = t.tile([2, 2]).unwrap();
291        assert_eq!(out.dims(), [2, 6]);
292    }
293}