tract_core/ops/scan/
mod.rs1use crate::internal::*;
2use std::fmt;
3
4mod decluttered;
5mod optimized;
6
7pub use optimized::{OptScan, State};
8pub use decluttered::Scan;
9
10#[derive(Clone, new, Hash, Eq, PartialEq, Copy, Debug)]
11pub struct ScanInfo {
12 pub axis: usize,
13 pub chunk: isize,
14}
15
16#[derive(Clone, new, Hash, Debug)]
17pub enum InputMapping {
18 Full,
19 State,
20 Scan(ScanInfo),
21}
22
23impl InputMapping {
24 pub fn is_state(&self) -> bool {
25 matches!(self, InputMapping::State)
26 }
27
28 pub fn is_scan(&self) -> bool {
29 self.as_scan().is_some()
30 }
31
32 pub fn as_scan(&self) -> Option<&ScanInfo> {
33 match self {
34 InputMapping::Scan(s) => Some(s),
35 _ => None,
36 }
37 }
38}
39
40#[derive(Clone, new, Hash, Default)]
41pub struct OutputMapping<F: Clone> {
42 pub scan: Option<(usize, ScanInfo)>,
43 pub full_dim_hint: Option<F>,
44 pub last_value_slot: Option<usize>,
45 pub state: bool,
46}
47
48impl<F: Clone> OutputMapping<F> {
49 pub fn invisible(&self) -> bool {
50 self.scan.is_none() && self.last_value_slot.is_none()
51 }
52}
53
54impl<F: Clone + DimLike> OutputMapping<F> {
55 pub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<OutputMapping<F>> {
56 Ok(Self {
57 full_dim_hint: self.full_dim_hint.as_ref().map(|h| h.eval(values)),
58 ..self.clone()
59 })
60 }
61}
62
63impl<F: Clone + fmt::Display> fmt::Debug for OutputMapping<F> {
64 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
65 if self.state {
66 write!(fmt, "State. ")?;
67 }
68 if let Some(last_value_slot) = self.last_value_slot {
69 write!(fmt, "Last value to outlet {last_value_slot}. ")?;
70 }
71 if let Some((slot, info)) = self.scan {
72 write!(fmt, "Full value to outlet {} (axis: {}). ", slot, info.axis)?;
73 }
74 if let Some(full_dim_hint) = &self.full_dim_hint {
75 write!(fmt, "Full len {full_dim_hint}. ")?;
76 }
77 Ok(())
78 }
79}
80
81pub fn iteration_count(input_mapping: &[InputMapping], inputs: &[&TypedFact]) -> Option<TDim> {
82 let (slot, info) = input_mapping
83 .iter()
84 .enumerate()
85 .find_map(|(slot, im)| im.as_scan().map(|scan| (slot, scan)))?;
86 let outside_dim = inputs[slot].shape[info.axis].clone();
87 Some(outside_dim.div_ceil(info.chunk.unsigned_abs() as u64))
88}