1use anyhow::Result;
2
3use crate::{DType, Shape, StorageTrait, Tensor, TensorBase, UninitVec};
4
5impl<S: StorageTrait> TensorBase<S> {
6 pub fn tile<D: Into<Shape>>(&self, repeats: D) -> Result<Tensor> {
74 let repeats = repeats.into();
75 let self_rank = self.rank();
76 let repeats_len = repeats.len();
77
78 let target_rank = self_rank.max(repeats_len);
80
81 let mut expanded_shape = Shape::empty().with_len(target_rank);
83 let mut expanded_repeats = Shape::empty().with_len(target_rank);
84
85 let rep_pad = target_rank - repeats_len;
87 for i in 0..rep_pad {
88 expanded_repeats[i] = 1;
89 }
90 for i in 0..repeats_len {
91 expanded_repeats[rep_pad + i] = repeats[i];
92 }
93
94 let shp_pad = target_rank - self_rank;
96 for i in 0..shp_pad {
97 expanded_shape[i] = 1;
98 }
99 for i in 0..self_rank {
100 expanded_shape[shp_pad + i] = self.shape[i];
101 }
102
103 let mut final_shape = Shape::empty().with_len(target_rank);
105 for i in 0..target_rank {
106 final_shape[i] = expanded_shape[i] * expanded_repeats[i];
107 }
108
109 let total_elements = final_shape.numel();
111
112 match self.dtype {
115 DType::Bool => {
116 let out = UninitVec::<bool>::new(total_elements).init_with(|dst| {
117 Self::_tile_fill::<bool>(self, &expanded_shape, &final_shape, target_rank, dst)
118 });
119 Tensor::from_vec(out, final_shape)
120 }
121 DType::Int8 => {
122 let out = UninitVec::<i8>::new(total_elements).init_with(|dst| {
123 Self::_tile_fill::<i8>(self, &expanded_shape, &final_shape, target_rank, dst)
124 });
125 Tensor::from_vec(out, final_shape)
126 }
127 DType::Int16 => {
128 let out = UninitVec::<i16>::new(total_elements).init_with(|dst| {
129 Self::_tile_fill::<i16>(self, &expanded_shape, &final_shape, target_rank, dst)
130 });
131 Tensor::from_vec(out, final_shape)
132 }
133 DType::Int32 => {
134 let out = UninitVec::<i32>::new(total_elements).init_with(|dst| {
135 Self::_tile_fill::<i32>(self, &expanded_shape, &final_shape, target_rank, dst)
136 });
137 Tensor::from_vec(out, final_shape)
138 }
139 DType::Int64 => {
140 let out = UninitVec::<i64>::new(total_elements).init_with(|dst| {
141 Self::_tile_fill::<i64>(self, &expanded_shape, &final_shape, target_rank, dst)
142 });
143 Tensor::from_vec(out, final_shape)
144 }
145 DType::Uint8 => {
146 let out = UninitVec::<u8>::new(total_elements).init_with(|dst| {
147 Self::_tile_fill::<u8>(self, &expanded_shape, &final_shape, target_rank, dst)
148 });
149 Tensor::from_vec(out, final_shape)
150 }
151 DType::Uint16 => {
152 let out = UninitVec::<u16>::new(total_elements).init_with(|dst| {
153 Self::_tile_fill::<u16>(self, &expanded_shape, &final_shape, target_rank, dst)
154 });
155 Tensor::from_vec(out, final_shape)
156 }
157 DType::Uint32 => {
158 let out = UninitVec::<u32>::new(total_elements).init_with(|dst| {
159 Self::_tile_fill::<u32>(self, &expanded_shape, &final_shape, target_rank, dst)
160 });
161 Tensor::from_vec(out, final_shape)
162 }
163 DType::Uint64 => {
164 let out = UninitVec::<u64>::new(total_elements).init_with(|dst| {
165 Self::_tile_fill::<u64>(self, &expanded_shape, &final_shape, target_rank, dst)
166 });
167 Tensor::from_vec(out, final_shape)
168 }
169 DType::Fp16 => {
170 let out = UninitVec::<half::f16>::new(total_elements).init_with(|dst| {
171 Self::_tile_fill::<half::f16>(
172 self,
173 &expanded_shape,
174 &final_shape,
175 target_rank,
176 dst,
177 )
178 });
179 Tensor::from_vec(out, final_shape)
180 }
181 DType::Fp32 => {
182 let out = UninitVec::<f32>::new(total_elements).init_with(|dst| {
183 Self::_tile_fill::<f32>(self, &expanded_shape, &final_shape, target_rank, dst)
184 });
185 Tensor::from_vec(out, final_shape)
186 }
187 DType::Fp64 => {
188 let out = UninitVec::<f64>::new(total_elements).init_with(|dst| {
189 Self::_tile_fill::<f64>(self, &expanded_shape, &final_shape, target_rank, dst)
190 });
191 Tensor::from_vec(out, final_shape)
192 }
193 DType::Bf16 => {
194 let out = UninitVec::<half::bf16>::new(total_elements).init_with(|dst| {
195 Self::_tile_fill::<half::bf16>(
196 self,
197 &expanded_shape,
198 &final_shape,
199 target_rank,
200 dst,
201 )
202 });
203 Tensor::from_vec(out, final_shape)
204 }
205 _ => {
206 anyhow::bail!("tile function not supported for Auto dtype")
207 }
208 }
209 }
210
211 fn _tile_fill<T: crate::TensorElement>(
213 &self,
214 expanded_shape: &Shape,
215 final_shape: &Shape,
216 target_rank: usize,
217 dst: &mut [T],
218 ) {
219 let self_rank = self.rank();
220 let shp_pad = target_rank - self_rank;
221
222 let mut exp_strides = Shape::empty().with_len(target_rank);
224 for i in 0..target_rank {
225 if i < shp_pad {
226 exp_strides[i] = 0;
227 } else {
228 exp_strides[i] = if (i - shp_pad) < self.strides.len() {
229 self.strides[i - shp_pad]
230 } else {
231 0
232 };
233 }
234 }
235
236 let base = self.as_ptr() as *const T;
237 let total = final_shape.numel();
238
239 let mut idx = Shape::empty().with_len(target_rank);
240
241 for (pos, slot) in dst.iter_mut().enumerate().take(total) {
242 let mut rem = pos;
243 for i in (0..target_rank).rev() {
244 let dim = final_shape[i];
245 idx[i] = rem % dim;
246 rem /= dim;
247 }
248 let mut offset_elems = 0usize;
250 for i in 0..target_rank {
251 let dim_sz = expanded_shape[i];
252 let orig = if dim_sz == 0 { 0 } else { idx[i] % dim_sz };
253 offset_elems += orig * exp_strides[i];
254 }
255 let val = unsafe { core::ptr::read(base.add(offset_elems)) };
256 *slot = val;
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use crate::Tensor;
264
265 #[test]
266 fn tile_1d_basic() {
267 let t = Tensor::from_vec(vec![1i32, 2, 3], [3]).unwrap();
268 let out = t.tile([2]).unwrap();
269 assert_eq!(out.dims(), [6]);
270 assert_eq!(out.to_flat_vec::<i32>().unwrap(), vec![1, 2, 3, 1, 2, 3]);
271 }
272
273 #[test]
274 fn tile_2d_symmetric() {
275 let t = Tensor::from_vec(vec![1i32, 2, 3, 4], [2, 2]).unwrap();
276 let out = t.tile([2, 2]).unwrap();
277 assert_eq!(out.dims(), [4, 4]);
278 }
279
280 #[test]
281 fn tile_prepended_repeats() {
282 let t = Tensor::from_vec(vec![1i32, 2, 3, 4], [2, 2]).unwrap();
283 let out = t.tile([3]).unwrap();
284 assert_eq!(out.dims(), [2, 6]);
285 }
286
287 #[test]
288 fn tile_expand_rank() {
289 let t = Tensor::from_vec(vec![1i32, 2, 3], [3]).unwrap();
290 let out = t.tile([2, 2]).unwrap();
291 assert_eq!(out.dims(), [2, 6]);
292 }
293}