tract_core/ops/scan/
mod.rs

1use crate::internal::*;
2use std::fmt;
3
4mod decluttered;
5mod optimized;
6
7pub use optimized::{OptScan, State};
8pub use decluttered::Scan;
9
10#[derive(Clone, new, Hash, Eq, PartialEq, Copy, Debug)]
11pub struct ScanInfo {
12    pub axis: usize,
13    pub chunk: isize,
14}
15
16#[derive(Clone, new, Hash, Debug)]
17pub enum InputMapping {
18    Full,
19    State,
20    Scan(ScanInfo),
21}
22
23impl InputMapping {
24    pub fn is_state(&self) -> bool {
25        matches!(self, InputMapping::State)
26    }
27
28    pub fn is_scan(&self) -> bool {
29        self.as_scan().is_some()
30    }
31
32    pub fn as_scan(&self) -> Option<&ScanInfo> {
33        match self {
34            InputMapping::Scan(s) => Some(s),
35            _ => None,
36        }
37    }
38}
39
40#[derive(Clone, new, Hash, Default)]
41pub struct OutputMapping<F: Clone> {
42    pub scan: Option<(usize, ScanInfo)>,
43    pub full_dim_hint: Option<F>,
44    pub last_value_slot: Option<usize>,
45    pub state: bool,
46}
47
48impl<F: Clone> OutputMapping<F> {
49    pub fn invisible(&self) -> bool {
50        self.scan.is_none() && self.last_value_slot.is_none()
51    }
52}
53
54impl<F: Clone + DimLike> OutputMapping<F> {
55    pub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<OutputMapping<F>> {
56        Ok(Self {
57            full_dim_hint: self.full_dim_hint.as_ref().map(|h| h.eval(values)),
58            ..self.clone()
59        })
60    }
61}
62
63impl<F: Clone + fmt::Display> fmt::Debug for OutputMapping<F> {
64    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
65        if self.state {
66            write!(fmt, "State. ")?;
67        }
68        if let Some(last_value_slot) = self.last_value_slot {
69            write!(fmt, "Last value to outlet {last_value_slot}. ")?;
70        }
71        if let Some((slot, info)) = self.scan {
72            write!(fmt, "Full value to outlet {} (axis: {}). ", slot, info.axis)?;
73        }
74        if let Some(full_dim_hint) = &self.full_dim_hint {
75            write!(fmt, "Full len {full_dim_hint}. ")?;
76        }
77        Ok(())
78    }
79}
80
81pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> Option<TDim> {
82    let (slot, info) = input_mapping
83        .iter()
84        .enumerate()
85        .find_map(|(slot, im)| im.as_scan().map(|scan| (slot, scan)))?;
86    let outside_dim = inputs[slot].shape[info.axis].clone();
87    Some(outside_dim.div_ceil(info.chunk.unsigned_abs() as u64))
88}