tract_core/ops/scan/
mod.rs1use crate::internal::*;
2use std::fmt;
3
4mod decluttered;
5mod optimized;
6
7pub use decluttered::Scan;
8pub use optimized::{OptScan, State};
9
10#[derive(Clone, new, Hash, Eq, PartialEq, Copy, Debug)]
11pub struct ScanInfo {
12 pub axis: usize,
13 pub chunk: isize,
14}
15
16#[derive(Clone, new, Hash, Debug, PartialEq, Eq)]
17pub enum InputMapping {
18 Full,
19 State,
20 Scan(ScanInfo),
21}
22
23impl InputMapping {
24 pub fn is_state(&self) -> bool {
25 matches!(self, InputMapping::State)
26 }
27
28 pub fn is_scan(&self) -> bool {
29 self.as_scan().is_some()
30 }
31
32 pub fn as_scan(&self) -> Option<&ScanInfo> {
33 match self {
34 InputMapping::Scan(s) => Some(s),
35 _ => None,
36 }
37 }
38}
39
40#[derive(Clone, new, Hash, Default, PartialEq, Eq)]
41pub struct OutputMapping<F: Clone> {
42 pub scan: Option<(usize, ScanInfo)>,
43 pub full_dim_hint: Option<F>,
44 pub last_value_slot: Option<usize>,
45 pub state: bool,
46}
47
48impl<F: Clone> OutputMapping<F> {
49 pub fn invisible(&self) -> bool {
50 self.scan.is_none() && self.last_value_slot.is_none()
51 }
52}
53
54impl<F: Clone + DimLike> OutputMapping<F> {
55 pub fn set_symbols(
56 &self,
57 subs: &std::collections::HashMap<Symbol, F>,
58 ) -> TractResult<OutputMapping<F>> {
59 Ok(Self {
60 full_dim_hint: self
61 .full_dim_hint
62 .as_ref()
63 .map(|h| h.substitute_all(subs))
64 .transpose()?,
65 ..self.clone()
66 })
67 }
68}
69
70impl<F: Clone + fmt::Display> fmt::Debug for OutputMapping<F> {
71 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
72 if self.state {
73 write!(fmt, "State. ")?;
74 }
75 if let Some(last_value_slot) = self.last_value_slot {
76 write!(fmt, "Last value to outlet {last_value_slot}. ")?;
77 }
78 if let Some((slot, info)) = self.scan {
79 write!(fmt, "Full value to outlet {} (axis: {}). ", slot, info.axis)?;
80 }
81 if let Some(full_dim_hint) = &self.full_dim_hint {
82 write!(fmt, "Full len {full_dim_hint}. ")?;
83 }
84 Ok(())
85 }
86}
87
88pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> Option<TDim> {
89 let (slot, info) = input_mapping
90 .iter()
91 .enumerate()
92 .find_map(|(slot, im)| im.as_scan().map(|scan| (slot, scan)))?;
93 let outside_dim = inputs[slot].shape[info.axis].clone();
94 Some(outside_dim.div_ceil(info.chunk.unsigned_abs() as u64))
95}