Skip to main content

tract_core/ops/scan/
mod.rs

1use crate::internal::*;
2use std::fmt;
3
4mod decluttered;
5mod optimized;
6
7pub use decluttered::Scan;
8pub use optimized::{OptScan, State};
9
10#[derive(Clone, new, Hash, Eq, PartialEq, Copy, Debug)]
11pub struct ScanInfo {
12    pub axis: usize,
13    pub chunk: isize,
14}
15
16#[derive(Clone, new, Hash, Debug, PartialEq, Eq)]
17pub enum InputMapping {
18    Full,
19    State,
20    Scan(ScanInfo),
21}
22
23impl InputMapping {
24    pub fn is_state(&self) -> bool {
25        matches!(self, InputMapping::State)
26    }
27
28    pub fn is_scan(&self) -> bool {
29        self.as_scan().is_some()
30    }
31
32    pub fn as_scan(&self) -> Option<&ScanInfo> {
33        match self {
34            InputMapping::Scan(s) => Some(s),
35            _ => None,
36        }
37    }
38}
39
40#[derive(Clone, new, Hash, Default, PartialEq, Eq)]
41pub struct OutputMapping<F: Clone> {
42    pub scan: Option<(usize, ScanInfo)>,
43    pub full_dim_hint: Option<F>,
44    pub last_value_slot: Option<usize>,
45    pub state: bool,
46}
47
48impl<F: Clone> OutputMapping<F> {
49    pub fn invisible(&self) -> bool {
50        self.scan.is_none() && self.last_value_slot.is_none()
51    }
52}
53
54impl<F: Clone + DimLike> OutputMapping<F> {
55    pub fn set_symbols(
56        &self,
57        subs: &std::collections::HashMap<Symbol, F>,
58    ) -> TractResult<OutputMapping<F>> {
59        Ok(Self {
60            full_dim_hint: self
61                .full_dim_hint
62                .as_ref()
63                .map(|h| h.substitute_all(subs))
64                .transpose()?,
65            ..self.clone()
66        })
67    }
68}
69
70impl<F: Clone + fmt::Display> fmt::Debug for OutputMapping<F> {
71    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
72        if self.state {
73            write!(fmt, "State. ")?;
74        }
75        if let Some(last_value_slot) = self.last_value_slot {
76            write!(fmt, "Last value to outlet {last_value_slot}. ")?;
77        }
78        if let Some((slot, info)) = self.scan {
79            write!(fmt, "Full value to outlet {} (axis: {}). ", slot, info.axis)?;
80        }
81        if let Some(full_dim_hint) = &self.full_dim_hint {
82            write!(fmt, "Full len {full_dim_hint}. ")?;
83        }
84        Ok(())
85    }
86}
87
88pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> Option<TDim> {
89    let (slot, info) = input_mapping
90        .iter()
91        .enumerate()
92        .find_map(|(slot, im)| im.as_scan().map(|scan| (slot, scan)))?;
93    let outside_dim = inputs[slot].shape[info.axis].clone();
94    Some(outside_dim.div_ceil(info.chunk.unsigned_abs() as u64))
95}