zyx_core/
view.rs

1extern crate alloc;
2use crate::scalar::Scalar;
3use crate::{axes::Axes, shape::Shape};
4use alloc::boxed::Box;
5use alloc::string::String;
6use alloc::{vec, vec::Vec};
7
8/// View type
9pub enum ViewType {
10    /// Contiguous
11    Contiguous,
12    /// Permuted or expanded
13    Strided,
14    /// Permuted, expanded or reshaped
15    Reshaped,
16    /// Permuted, expanded, reshaped or padded
17    Padded,
18}
19
20/// Compiler index
21#[derive(Clone, Debug)]
22pub enum Index {
23    /// Index without padding
24    Normal(String),
25    /// Padded index
26    Padded(String, String),
27}
28
29/// View holds shape of the tensor and allows for arbitrary number of movement ops
30/// (reshape, expand, pad, permute) to be executed as noops (without accessing the
31/// actual data).
32#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
33pub struct View {
34    // TODO only 2 shape and stride pairs are needed
35    views: Vec<InnerView>,
36}
37
38#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
39struct InnerView {
40    shape: Shape,
41    strides: Shape,
42    padding: Box<[(i64, i64)]>,
43}
44
45impl InnerView {
46    #[must_use]
47    fn is_contiguous(&self) -> bool {
48        self.shape.strides() == self.strides && !self.is_padded()
49    }
50
51    #[must_use]
52    fn is_padded(&self) -> bool {
53        self.padding.iter().any(|(lp, rp)| *lp != 0 || *rp != 0)
54    }
55}
56
57/// CPU iterator
58pub struct CPUPaddedIter<'a, T> {
59    data: &'a [T],
60    view: &'a View,
61    idx: usize,
62    num_iters: usize,
63}
64
65impl<'a, T: Scalar> Iterator for CPUPaddedIter<'a, T> {
66    type Item = T;
67
68    fn next(&mut self) -> Option<Self::Item> {
69        if self.idx > self.num_iters {
70            return None;
71        }
72        let mut idx = self.idx;
73        self.idx += 1;
74        for InnerView {
75            shape,
76            strides,
77            padding,
78        } in &self.view.views
79        {
80            let mut res = 0;
81            for ((d, st), (lp, rp)) in shape.into_iter().zip(strides).zip(padding.iter()).rev() {
82                let mut dim_idx = idx % d;
83                if *lp > 0 {
84                    let lpu = *lp as usize;
85                    if dim_idx < lpu {
86                        return Some(T::zero());
87                    }
88                    dim_idx -= lpu;
89                } else if *lp < 0 {
90                    dim_idx += (-*lp) as usize;
91                }
92                if *rp > 0 {
93                    if dim_idx > *rp as usize {
94                        return Some(T::zero());
95                    }
96                }
97                res += dim_idx * st;
98                idx /= d;
99            }
100            idx = res;
101        }
102        Some(self.data[idx].clone())
103    }
104}
105
106/// CPU iterator
107pub struct CPUReshapedIter<'a, T> {
108    data: &'a [T],
109    view: &'a View,
110    idx: usize,
111    num_iters: usize,
112}
113
114impl<'a, T: Scalar> Iterator for CPUReshapedIter<'a, T> {
115    type Item = T;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        if self.idx > self.num_iters {
119            return None;
120        }
121        let mut idx = self.idx;
122        self.idx += 1;
123        for InnerView {
124            shape,
125            strides,
126            padding: _,
127        } in &self.view.views
128        {
129            let mut res = 0;
130            for (d, st) in shape.into_iter().zip(strides).rev() {
131                let dim_idx = idx % d;
132                res += dim_idx * st;
133                idx /= d;
134            }
135            idx = res;
136        }
137        Some(self.data[idx].clone())
138    }
139}
140
141/// Strided iterator, only expand and permute
142pub struct CPUStridedIter<'a, T> {
143    data: &'a [T],
144    shape: &'a [usize],
145    strides: &'a [usize],
146    idx: usize,
147    num_iters: usize,
148}
149
150impl<'a, T: Scalar> Iterator for CPUStridedIter<'a, T> {
151    type Item = T;
152
153    fn next(&mut self) -> Option<Self::Item> {
154        if self.idx > self.num_iters {
155            return None;
156        }
157        let mut idx = self.idx;
158        self.idx += 1;
159        let mut res = 0;
160        for (d, st) in self
161            .shape
162            .into_iter()
163            .copied()
164            .zip(self.strides.into_iter().copied())
165            .rev()
166        {
167            res += idx % d * st;
168            idx /= d;
169        }
170        Some(self.data[res].clone())
171    }
172}
173
174impl View {
175    /// Create new view from shape
176    #[must_use]
177    pub fn new(shape: Shape) -> Self {
178        Self {
179            views: vec![InnerView {
180                strides: shape.strides(),
181                padding: core::iter::repeat((0, 0)).take(shape.rank()).collect(),
182                shape,
183            }],
184        }
185    }
186
187    /// Is this view contiguous?
188    /// i. e. no padding, expands or permutes, only reshapes are allowed
189    #[must_use]
190    pub fn is_contiguous(&self) -> bool {
191        self.views.iter().all(InnerView::is_contiguous)
192    }
193
194    /// Is this view padded?
195    #[must_use]
196    pub fn is_padded(&self) -> bool {
197        self.views.iter().any(InnerView::is_padded)
198    }
199
200    /// For cpu backend
201    #[must_use]
202    pub fn view_type(&self) -> ViewType {
203        if self.is_contiguous() {
204            ViewType::Contiguous
205        } else if self.is_padded() {
206            ViewType::Padded
207        } else if self.views.len() > 1 {
208            ViewType::Reshaped
209        } else {
210            ViewType::Strided
211        }
212    }
213
214    /// Simple iteration
215    #[must_use]
216    pub fn iterate_contiguous<'a, T: Scalar>(
217        &'a self,
218        data: &'a [T],
219    ) -> impl Iterator<Item = T> + 'a {
220        data.iter().cloned()
221    }
222
223    /// Iteration with expands and permutes
224    #[must_use]
225    pub fn iterate_strided<'a, T: Scalar>(&'a self, data: &'a [T]) -> impl Iterator<Item = T> + 'a {
226        let InnerView {
227            shape,
228            strides,
229            padding: _,
230        } = self.views.first().unwrap();
231        CPUStridedIter {
232            data,
233            num_iters: shape.numel() - 1,
234            shape: shape.as_ref(),
235            strides: strides.as_ref(),
236            idx: 0,
237        }
238    }
239
240    /// Iteration with expands, permutes and reshapes, but without padding
241    #[must_use]
242    pub fn iterate_reshaped<'a, T: Scalar>(
243        &'a self,
244        data: &'a [T],
245    ) -> impl Iterator<Item = T> + 'a {
246        CPUReshapedIter {
247            data,
248            view: self,
249            idx: 0,
250            num_iters: self.numel() - 1,
251        }
252    }
253
254    /// Iteration with expands, permutes, reshapes and padding
255    #[must_use]
256    pub fn iterate_padded<'a, T: Scalar>(&'a self, data: &'a [T]) -> impl Iterator<Item = T> + 'a {
257        CPUPaddedIter {
258            data,
259            view: self,
260            idx: 0,
261            num_iters: self.numel() - 1,
262        }
263    }
264
265    /// Access data called name with idx0-idx{rank} converted into self view.
266    /// This is used by compiled backends.
267    /// Returns padding condition and index.
268    /// If padding condition == 0, padding value is applied, if padding condition
269    /// is one, value is drawn from data.
270    #[must_use]
271    pub fn cidx(&self) -> Index {
272        // TODO simplify this as much as possible, not for performance (it is cached),
273        // just for clarity, because currently it is a mess.
274        //std::println!("View: {self:?}");
275        use alloc::format as f;
276        let mut idx = String::new();
277        let mut padding_condition = String::new();
278        if self.is_contiguous() {
279            let numel = self.numel();
280            for (i, st) in self.views[0].strides.iter().enumerate() {
281                if *st == 1 {
282                    idx += &f!("+idx{i}");
283                } else if *st != numel {
284                    idx += &f!("+idx{i}*{st}");
285                }
286            }
287            idx.remove(0);
288            return Index::Normal(idx);
289        }
290        if let Some(InnerView {
291            shape,
292            strides,
293            padding,
294        }) = self.views.first()
295        {
296            for (i, ((d, st), (left_p, right_p))) in shape
297                .iter()
298                .zip(strides.iter())
299                .zip(padding.iter())
300                .enumerate()
301            {
302                //std::println!("i: {i}, d: {d}, st: {st}, lp: {left_p}, rp: {right_p}");
303                match *st {
304                    0 => idx += "",
305                    1 => idx += &f!("idx{i}+"),
306                    _ => idx += &f!("idx{i}*{st}+"),
307                }
308                if *left_p < 0 {
309                    idx += &f!("{}+", (-left_p) as usize * st);
310                } else if *left_p > 0 {
311                    padding_condition = f!("{padding_condition} && (idx{i}>{})", left_p - 1);
312                }
313                if *right_p > 0 {
314                    padding_condition =
315                        f!("{padding_condition} && (idx{i}<{})", d - *right_p as usize);
316                }
317                if *left_p > 0 {
318                    idx += &f!("-{}+", *left_p as usize * st);
319                }
320            }
321            if idx.is_empty() {
322                idx = f!("0+");
323            }
324        } else {
325            return Index::Normal("0".into());
326        }
327        idx.remove(idx.len() - 1);
328        if self.views.len() == 1 {
329            if padding_condition.is_empty() {
330                return Index::Normal(idx);
331            } else {
332                padding_condition = f!("{}", &padding_condition[4..]);
333                return Index::Padded(padding_condition, idx);
334            }
335        }
336        for InnerView {
337            shape,
338            strides,
339            padding,
340        } in &self.views[1..]
341        {
342            let n = shape.numel();
343            idx.insert(0, '(');
344            idx.push(')');
345            let mut res = String::new();
346            let mut ost = 1;
347            for ((d, st), (left_p, right_p)) in
348                shape.into_iter().zip(strides).zip(padding.iter()).rev()
349            {
350                //println!("d: {d}, st: {st}, lp: {left_p}, rp: {right_p}");
351                //res += &f!("{idx}/{ost}%{d}*{st}+");
352                //ost *= d;
353                let mut temp = f!("{idx}");
354                match ost {
355                    0 => panic!(),
356                    1 => {}
357                    _ => temp += &f!("/{ost}"),
358                }
359                ost *= d;
360                match *d {
361                    0 => panic!(),
362                    1 => temp = f!("0"),
363                    _ => {
364                        if ost < n {
365                            temp += &f!("%{d}");
366                        }
367                    }
368                }
369                if *left_p < 0 {
370                    temp = f!("{temp}+{}", -left_p);
371                } else if *left_p > 0 {
372                    padding_condition = f!("{padding_condition} && ({temp}>{})", left_p - 1);
373                }
374                if *right_p > 0 {
375                    padding_condition =
376                        f!("{padding_condition} && ({temp}<{})", d - *right_p as usize);
377                }
378                if *left_p > 0 {
379                    temp = f!("({temp}-{left_p})");
380                }
381                match *st {
382                    0 => temp = f!("0"),
383                    1 => {}
384                    _ => temp += &f!("*{st}"),
385                }
386                res += &f!("{temp}+");
387            }
388            idx = res;
389            if !idx.is_empty() {
390                idx.remove(idx.len() - 1);
391            }
392        }
393        if padding_condition.is_empty() {
394            Index::Normal(idx)
395        } else {
396            padding_condition = f!("{}", &padding_condition[4..]);
397            Index::Padded(padding_condition, idx)
398        }
399    }
400
401    /// Number of elements in view with self.shape()
402    #[must_use]
403    pub fn numel(&self) -> usize {
404        self.shape().numel()
405    }
406
407    /// Last shape of self.
408    #[must_use]
409    pub fn shape(&self) -> &Shape {
410        &self.views.first().unwrap().shape
411    }
412
413    /// Last strides of self.
414    #[must_use]
415    pub fn strides(&self) -> &Shape {
416        &self.views.first().unwrap().strides
417    }
418
419    /// Original (first) shape of self.
420    #[must_use]
421    pub fn original_shape(&self) -> &Shape {
422        &self.views.last().unwrap().shape
423    }
424
425    /// Original number of elements of self.
426    #[must_use]
427    pub fn original_numel(&self) -> usize {
428        let InnerView {
429            shape,
430            strides,
431            padding,
432        } = self.views.last().unwrap();
433        shape
434            .iter()
435            .zip(strides.iter())
436            .zip(padding.iter())
437            .filter_map(|((d, s), (lp, rp))| {
438                if *s != 0 {
439                    Some((*d as i64 - lp - rp) as usize)
440                } else {
441                    None
442                }
443            })
444            .product()
445    }
446
447    /// Expand self into shape
448    #[must_use]
449    pub fn expand(&self, shape: &Shape) -> Self {
450        let mut views = self.views.clone();
451        //std::println!("Expanding {views:?}");
452        views[0].strides = views[0]
453            .shape
454            .expand_strides(shape, views[0].strides.clone());
455        views[0].shape = shape.clone();
456        let n = shape.rank() - views[0].padding.len();
457        views[0].padding = core::iter::repeat((0, 0))
458            .take(n)
459            .chain(views[0].padding.iter().copied())
460            .collect();
461        //std::println!("To {views:?}");
462        Self { views }
463    }
464
465    /// Pad self by padding
466    #[must_use]
467    pub fn pad(&self, new_padding: &[(i64, i64)]) -> Self {
468        //std::println!("{:?}\n{new_padding:?}", self);
469        let mut views = self.views.clone();
470        if let Some(InnerView {
471            shape,
472            strides: _,
473            padding,
474        }) = views.first_mut()
475        {
476            // Invert padding order
477            for (i, d) in shape.iter_mut().rev().enumerate() {
478                if let Some((left, right)) = new_padding.get(i) {
479                    *d = (*d as i64 + left + right) as usize;
480                } else {
481                    break;
482                }
483            }
484            let n = padding.len() - new_padding.len();
485            *padding = core::iter::repeat(&(0, 0))
486                .take(n)
487                .chain(new_padding.iter().rev())
488                .zip(padding.iter())
489                .map(|(x, y)| (x.0 + y.0, x.1 + y.1))
490                .collect();
491            //std::println!("new_padding: {:?}", padding);
492        }
493        Self { views }
494    }
495
496    /// Reshape self into shape
497    #[must_use]
498    pub fn reshape(&self, n_shape: &Shape) -> Self {
499        //std::println!("Reshaping {self:?} into {n_shape}");
500        if n_shape == self.shape() {
501            return self.clone();
502        }
503        debug_assert_eq!(
504            n_shape.numel(),
505            self.numel(),
506            "Can't reshape {} to {}",
507            self.shape(),
508            n_shape
509        );
510        let mut views = self.views.clone();
511        // If we are reshaping InnerView that is contiguous, we just delete the last reshape
512        if views.first().unwrap().is_contiguous() {
513            views[0] = InnerView {
514                shape: n_shape.clone(),
515                strides: n_shape.strides(),
516                padding: core::iter::repeat((0, 0)).take(n_shape.rank()).collect(),
517            };
518        } else {
519            let shape = self.shape();
520            if n_shape.rank() > shape.rank()
521                && n_shape
522                    .iter()
523                    .filter(|d| **d != 1)
524                    .zip(shape.iter())
525                    .all(|(nd, d)| nd == d)
526            {
527                // If not  contiguous, then merge, this merges if reshape is unsqueeze
528                //std::println!("Ok to merge {n_shape} with {}", self.shape());
529                if let Some(InnerView {
530                    shape,
531                    strides,
532                    padding,
533                }) = views.first_mut()
534                {
535                    //std::println!("Merging");
536                    *shape = n_shape.clone();
537                    let mut n_strides: Vec<usize> = strides.clone().into();
538                    let mut n_padding = padding.to_vec();
539                    for (i, d) in n_shape.iter().rev().enumerate() {
540                        if *d == 1 {
541                            //std::println!("Inserting");
542                            n_strides.insert(
543                                n_strides.len() - i,
544                                if i == 0 {
545                                    1
546                                } else {
547                                    n_strides[n_strides.len() - i]
548                                },
549                            );
550                            n_padding.insert(n_padding.len() - i, (0, 0));
551                        }
552                    }
553                    //std::println!("n_strides: {n_strides:?}, n_padding: {n_padding:?}");
554                    *strides = n_strides.into();
555                    *padding = n_padding.into_boxed_slice();
556                }
557            } else {
558                // If there is no merge.
559                views.insert(
560                    0,
561                    InnerView {
562                        shape: n_shape.clone(),
563                        strides: n_shape.strides(),
564                        padding: core::iter::repeat((0, 0)).take(n_shape.rank()).collect(),
565                    },
566                );
567            }
568        }
569        //std::println!("Merged into: {:?}", views);
570        Self { views }
571    }
572
573    /// Permute self by axes
574    #[must_use]
575    pub fn permute(&self, axes: &Axes) -> Self {
576        //std::println!("{:?}\n{:?}", self, axes);
577        let mut views = self.views.clone();
578        views[0].shape = views[0].shape.permute(axes);
579        views[0].strides = views[0].strides.permute(axes);
580        let padding = &views[0].padding;
581        let padding = axes.iter().map(|axis| padding[*axis]).collect();
582        views[0].padding = padding;
583        Self { views }
584    }
585}
586
587/*#[test]
588fn view() {
589    use crate::axes::IntoAxes;
590
591    let s0 = View::new([10, 15].into());
592    let s5 = s0.permute(&[1, 0].into_axes(2)).reshape(&[10, 15].into());
593    /*let s0 = View::new(Shape::from([4, 5, 2]));
594    let s1 = s0.permute(&[2, 0, 1].into_axes(3));
595    let s2 = s1.reshape(&[4, 1, 5, 2, 1].into());
596    let s3 = s2.expand(&[4, 3, 5, 2, 2].into());
597    let s4 = s3.permute(&[3, 0, 4, 2, 1].into_axes(5));
598    let s5 = s4.reshape(&[12, 20].into());*/
599    for InnerView { shape, strides, padding } in s5.views {
600        std::println!("{shape:?}, {strides:?}, {padding:?}");
601    }
602    panic!();
603}*/