1extern crate alloc;
2use crate::scalar::Scalar;
3use crate::{axes::Axes, shape::Shape};
4use alloc::boxed::Box;
5use alloc::string::String;
6use alloc::{vec, vec::Vec};
7
8pub enum ViewType {
10 Contiguous,
12 Strided,
14 Reshaped,
16 Padded,
18}
19
20#[derive(Clone, Debug)]
22pub enum Index {
23 Normal(String),
25 Padded(String, String),
27}
28
29#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
33pub struct View {
34 views: Vec<InnerView>,
36}
37
38#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
39struct InnerView {
40 shape: Shape,
41 strides: Shape,
42 padding: Box<[(i64, i64)]>,
43}
44
45impl InnerView {
46 #[must_use]
47 fn is_contiguous(&self) -> bool {
48 self.shape.strides() == self.strides && !self.is_padded()
49 }
50
51 #[must_use]
52 fn is_padded(&self) -> bool {
53 self.padding.iter().any(|(lp, rp)| *lp != 0 || *rp != 0)
54 }
55}
56
57pub struct CPUPaddedIter<'a, T> {
59 data: &'a [T],
60 view: &'a View,
61 idx: usize,
62 num_iters: usize,
63}
64
65impl<'a, T: Scalar> Iterator for CPUPaddedIter<'a, T> {
66 type Item = T;
67
68 fn next(&mut self) -> Option<Self::Item> {
69 if self.idx > self.num_iters {
70 return None;
71 }
72 let mut idx = self.idx;
73 self.idx += 1;
74 for InnerView {
75 shape,
76 strides,
77 padding,
78 } in &self.view.views
79 {
80 let mut res = 0;
81 for ((d, st), (lp, rp)) in shape.into_iter().zip(strides).zip(padding.iter()).rev() {
82 let mut dim_idx = idx % d;
83 if *lp > 0 {
84 let lpu = *lp as usize;
85 if dim_idx < lpu {
86 return Some(T::zero());
87 }
88 dim_idx -= lpu;
89 } else if *lp < 0 {
90 dim_idx += (-*lp) as usize;
91 }
92 if *rp > 0 {
93 if dim_idx > *rp as usize {
94 return Some(T::zero());
95 }
96 }
97 res += dim_idx * st;
98 idx /= d;
99 }
100 idx = res;
101 }
102 Some(self.data[idx].clone())
103 }
104}
105
106pub struct CPUReshapedIter<'a, T> {
108 data: &'a [T],
109 view: &'a View,
110 idx: usize,
111 num_iters: usize,
112}
113
114impl<'a, T: Scalar> Iterator for CPUReshapedIter<'a, T> {
115 type Item = T;
116
117 fn next(&mut self) -> Option<Self::Item> {
118 if self.idx > self.num_iters {
119 return None;
120 }
121 let mut idx = self.idx;
122 self.idx += 1;
123 for InnerView {
124 shape,
125 strides,
126 padding: _,
127 } in &self.view.views
128 {
129 let mut res = 0;
130 for (d, st) in shape.into_iter().zip(strides).rev() {
131 let dim_idx = idx % d;
132 res += dim_idx * st;
133 idx /= d;
134 }
135 idx = res;
136 }
137 Some(self.data[idx].clone())
138 }
139}
140
141pub struct CPUStridedIter<'a, T> {
143 data: &'a [T],
144 shape: &'a [usize],
145 strides: &'a [usize],
146 idx: usize,
147 num_iters: usize,
148}
149
150impl<'a, T: Scalar> Iterator for CPUStridedIter<'a, T> {
151 type Item = T;
152
153 fn next(&mut self) -> Option<Self::Item> {
154 if self.idx > self.num_iters {
155 return None;
156 }
157 let mut idx = self.idx;
158 self.idx += 1;
159 let mut res = 0;
160 for (d, st) in self
161 .shape
162 .into_iter()
163 .copied()
164 .zip(self.strides.into_iter().copied())
165 .rev()
166 {
167 res += idx % d * st;
168 idx /= d;
169 }
170 Some(self.data[res].clone())
171 }
172}
173
174impl View {
175 #[must_use]
177 pub fn new(shape: Shape) -> Self {
178 Self {
179 views: vec![InnerView {
180 strides: shape.strides(),
181 padding: core::iter::repeat((0, 0)).take(shape.rank()).collect(),
182 shape,
183 }],
184 }
185 }
186
187 #[must_use]
190 pub fn is_contiguous(&self) -> bool {
191 self.views.iter().all(InnerView::is_contiguous)
192 }
193
194 #[must_use]
196 pub fn is_padded(&self) -> bool {
197 self.views.iter().any(InnerView::is_padded)
198 }
199
200 #[must_use]
202 pub fn view_type(&self) -> ViewType {
203 if self.is_contiguous() {
204 ViewType::Contiguous
205 } else if self.is_padded() {
206 ViewType::Padded
207 } else if self.views.len() > 1 {
208 ViewType::Reshaped
209 } else {
210 ViewType::Strided
211 }
212 }
213
214 #[must_use]
216 pub fn iterate_contiguous<'a, T: Scalar>(
217 &'a self,
218 data: &'a [T],
219 ) -> impl Iterator<Item = T> + 'a {
220 data.iter().cloned()
221 }
222
223 #[must_use]
225 pub fn iterate_strided<'a, T: Scalar>(&'a self, data: &'a [T]) -> impl Iterator<Item = T> + 'a {
226 let InnerView {
227 shape,
228 strides,
229 padding: _,
230 } = self.views.first().unwrap();
231 CPUStridedIter {
232 data,
233 num_iters: shape.numel() - 1,
234 shape: shape.as_ref(),
235 strides: strides.as_ref(),
236 idx: 0,
237 }
238 }
239
240 #[must_use]
242 pub fn iterate_reshaped<'a, T: Scalar>(
243 &'a self,
244 data: &'a [T],
245 ) -> impl Iterator<Item = T> + 'a {
246 CPUReshapedIter {
247 data,
248 view: self,
249 idx: 0,
250 num_iters: self.numel() - 1,
251 }
252 }
253
254 #[must_use]
256 pub fn iterate_padded<'a, T: Scalar>(&'a self, data: &'a [T]) -> impl Iterator<Item = T> + 'a {
257 CPUPaddedIter {
258 data,
259 view: self,
260 idx: 0,
261 num_iters: self.numel() - 1,
262 }
263 }
264
265 #[must_use]
271 pub fn cidx(&self) -> Index {
272 use alloc::format as f;
276 let mut idx = String::new();
277 let mut padding_condition = String::new();
278 if self.is_contiguous() {
279 let numel = self.numel();
280 for (i, st) in self.views[0].strides.iter().enumerate() {
281 if *st == 1 {
282 idx += &f!("+idx{i}");
283 } else if *st != numel {
284 idx += &f!("+idx{i}*{st}");
285 }
286 }
287 idx.remove(0);
288 return Index::Normal(idx);
289 }
290 if let Some(InnerView {
291 shape,
292 strides,
293 padding,
294 }) = self.views.first()
295 {
296 for (i, ((d, st), (left_p, right_p))) in shape
297 .iter()
298 .zip(strides.iter())
299 .zip(padding.iter())
300 .enumerate()
301 {
302 match *st {
304 0 => idx += "",
305 1 => idx += &f!("idx{i}+"),
306 _ => idx += &f!("idx{i}*{st}+"),
307 }
308 if *left_p < 0 {
309 idx += &f!("{}+", (-left_p) as usize * st);
310 } else if *left_p > 0 {
311 padding_condition = f!("{padding_condition} && (idx{i}>{})", left_p - 1);
312 }
313 if *right_p > 0 {
314 padding_condition =
315 f!("{padding_condition} && (idx{i}<{})", d - *right_p as usize);
316 }
317 if *left_p > 0 {
318 idx += &f!("-{}+", *left_p as usize * st);
319 }
320 }
321 if idx.is_empty() {
322 idx = f!("0+");
323 }
324 } else {
325 return Index::Normal("0".into());
326 }
327 idx.remove(idx.len() - 1);
328 if self.views.len() == 1 {
329 if padding_condition.is_empty() {
330 return Index::Normal(idx);
331 } else {
332 padding_condition = f!("{}", &padding_condition[4..]);
333 return Index::Padded(padding_condition, idx);
334 }
335 }
336 for InnerView {
337 shape,
338 strides,
339 padding,
340 } in &self.views[1..]
341 {
342 let n = shape.numel();
343 idx.insert(0, '(');
344 idx.push(')');
345 let mut res = String::new();
346 let mut ost = 1;
347 for ((d, st), (left_p, right_p)) in
348 shape.into_iter().zip(strides).zip(padding.iter()).rev()
349 {
350 let mut temp = f!("{idx}");
354 match ost {
355 0 => panic!(),
356 1 => {}
357 _ => temp += &f!("/{ost}"),
358 }
359 ost *= d;
360 match *d {
361 0 => panic!(),
362 1 => temp = f!("0"),
363 _ => {
364 if ost < n {
365 temp += &f!("%{d}");
366 }
367 }
368 }
369 if *left_p < 0 {
370 temp = f!("{temp}+{}", -left_p);
371 } else if *left_p > 0 {
372 padding_condition = f!("{padding_condition} && ({temp}>{})", left_p - 1);
373 }
374 if *right_p > 0 {
375 padding_condition =
376 f!("{padding_condition} && ({temp}<{})", d - *right_p as usize);
377 }
378 if *left_p > 0 {
379 temp = f!("({temp}-{left_p})");
380 }
381 match *st {
382 0 => temp = f!("0"),
383 1 => {}
384 _ => temp += &f!("*{st}"),
385 }
386 res += &f!("{temp}+");
387 }
388 idx = res;
389 if !idx.is_empty() {
390 idx.remove(idx.len() - 1);
391 }
392 }
393 if padding_condition.is_empty() {
394 Index::Normal(idx)
395 } else {
396 padding_condition = f!("{}", &padding_condition[4..]);
397 Index::Padded(padding_condition, idx)
398 }
399 }
400
401 #[must_use]
403 pub fn numel(&self) -> usize {
404 self.shape().numel()
405 }
406
407 #[must_use]
409 pub fn shape(&self) -> &Shape {
410 &self.views.first().unwrap().shape
411 }
412
413 #[must_use]
415 pub fn strides(&self) -> &Shape {
416 &self.views.first().unwrap().strides
417 }
418
419 #[must_use]
421 pub fn original_shape(&self) -> &Shape {
422 &self.views.last().unwrap().shape
423 }
424
425 #[must_use]
427 pub fn original_numel(&self) -> usize {
428 let InnerView {
429 shape,
430 strides,
431 padding,
432 } = self.views.last().unwrap();
433 shape
434 .iter()
435 .zip(strides.iter())
436 .zip(padding.iter())
437 .filter_map(|((d, s), (lp, rp))| {
438 if *s != 0 {
439 Some((*d as i64 - lp - rp) as usize)
440 } else {
441 None
442 }
443 })
444 .product()
445 }
446
447 #[must_use]
449 pub fn expand(&self, shape: &Shape) -> Self {
450 let mut views = self.views.clone();
451 views[0].strides = views[0]
453 .shape
454 .expand_strides(shape, views[0].strides.clone());
455 views[0].shape = shape.clone();
456 let n = shape.rank() - views[0].padding.len();
457 views[0].padding = core::iter::repeat((0, 0))
458 .take(n)
459 .chain(views[0].padding.iter().copied())
460 .collect();
461 Self { views }
463 }
464
465 #[must_use]
467 pub fn pad(&self, new_padding: &[(i64, i64)]) -> Self {
468 let mut views = self.views.clone();
470 if let Some(InnerView {
471 shape,
472 strides: _,
473 padding,
474 }) = views.first_mut()
475 {
476 for (i, d) in shape.iter_mut().rev().enumerate() {
478 if let Some((left, right)) = new_padding.get(i) {
479 *d = (*d as i64 + left + right) as usize;
480 } else {
481 break;
482 }
483 }
484 let n = padding.len() - new_padding.len();
485 *padding = core::iter::repeat(&(0, 0))
486 .take(n)
487 .chain(new_padding.iter().rev())
488 .zip(padding.iter())
489 .map(|(x, y)| (x.0 + y.0, x.1 + y.1))
490 .collect();
491 }
493 Self { views }
494 }
495
496 #[must_use]
498 pub fn reshape(&self, n_shape: &Shape) -> Self {
499 if n_shape == self.shape() {
501 return self.clone();
502 }
503 debug_assert_eq!(
504 n_shape.numel(),
505 self.numel(),
506 "Can't reshape {} to {}",
507 self.shape(),
508 n_shape
509 );
510 let mut views = self.views.clone();
511 if views.first().unwrap().is_contiguous() {
513 views[0] = InnerView {
514 shape: n_shape.clone(),
515 strides: n_shape.strides(),
516 padding: core::iter::repeat((0, 0)).take(n_shape.rank()).collect(),
517 };
518 } else {
519 let shape = self.shape();
520 if n_shape.rank() > shape.rank()
521 && n_shape
522 .iter()
523 .filter(|d| **d != 1)
524 .zip(shape.iter())
525 .all(|(nd, d)| nd == d)
526 {
527 if let Some(InnerView {
530 shape,
531 strides,
532 padding,
533 }) = views.first_mut()
534 {
535 *shape = n_shape.clone();
537 let mut n_strides: Vec<usize> = strides.clone().into();
538 let mut n_padding = padding.to_vec();
539 for (i, d) in n_shape.iter().rev().enumerate() {
540 if *d == 1 {
541 n_strides.insert(
543 n_strides.len() - i,
544 if i == 0 {
545 1
546 } else {
547 n_strides[n_strides.len() - i]
548 },
549 );
550 n_padding.insert(n_padding.len() - i, (0, 0));
551 }
552 }
553 *strides = n_strides.into();
555 *padding = n_padding.into_boxed_slice();
556 }
557 } else {
558 views.insert(
560 0,
561 InnerView {
562 shape: n_shape.clone(),
563 strides: n_shape.strides(),
564 padding: core::iter::repeat((0, 0)).take(n_shape.rank()).collect(),
565 },
566 );
567 }
568 }
569 Self { views }
571 }
572
573 #[must_use]
575 pub fn permute(&self, axes: &Axes) -> Self {
576 let mut views = self.views.clone();
578 views[0].shape = views[0].shape.permute(axes);
579 views[0].strides = views[0].strides.permute(axes);
580 let padding = &views[0].padding;
581 let padding = axes.iter().map(|axis| padding[*axis]).collect();
582 views[0].padding = padding;
583 Self { views }
584 }
585}
586
587