Skip to main content

tract_api/
transform.rs

1use std::collections::HashMap;
2
3use crate::DatumType;
4
5/// A serialized transform specification passed to `ModelInterface::transform`.
6///
7/// Wraps the string representation expected by the transform registry.
8/// Constructed from raw strings or typed config structs implementing [`TransformConfig`].
9#[derive(Debug, Clone)]
10pub struct TransformSpec(String);
11
12impl TransformSpec {
13    /// Produce the string the transform registry expects.
14    pub fn to_transform_string(&self) -> String {
15        self.0.clone()
16    }
17}
18
19impl From<&str> for TransformSpec {
20    fn from(s: &str) -> Self {
21        TransformSpec(s.to_string())
22    }
23}
24
25impl From<String> for TransformSpec {
26    fn from(s: String) -> Self {
27        TransformSpec(s)
28    }
29}
30
31impl From<&String> for TransformSpec {
32    fn from(s: &String) -> Self {
33        TransformSpec(s.clone())
34    }
35}
36
37/// Trait for typed transform configurations.
38///
39/// Implementors derive [`serde::Serialize`] and provide a transform [`name()`](TransformConfig::name).
40/// The default [`to_transform_string()`](TransformConfig::to_transform_string) serializes the
41/// struct as a JSON object and injects the `"name"` key.
42pub trait TransformConfig: serde::Serialize {
43    /// The transform registry name (e.g. `"pulse"`, `"float_precision"`).
44    fn name(&self) -> &'static str;
45
46    /// Produce the string the transform registry expects.
47    ///
48    /// The default implementation serializes `self` to a JSON object and inserts `"name"`.
49    fn to_transform_string(&self) -> String {
50        let mut obj: serde_json::Map<String, serde_json::Value> = serde_json::to_value(self)
51            .expect("TransformConfig serialization cannot fail")
52            .as_object()
53            .expect("TransformConfig must serialize to a JSON object")
54            .clone();
55        obj.insert("name".into(), serde_json::Value::String(self.name().to_string()));
56        serde_json::to_string(&obj).expect("serialization cannot fail")
57    }
58}
59
60/// Implements [`TransformConfig`] and `From<$ty> for TransformSpec`.
61macro_rules! transform_config {
62    ($ty:ty, $name:expr) => {
63        impl TransformConfig for $ty {
64            fn name(&self) -> &'static str {
65                $name
66            }
67        }
68
69        impl From<$ty> for TransformSpec {
70            fn from(config: $ty) -> Self {
71                TransformSpec(config.to_transform_string())
72            }
73        }
74    };
75}
76
77/// Typed config for the `concretize_symbols` transform.
78///
79/// Replaces symbolic dimensions with concrete integer values.
80///
81/// # Example
82/// ```ignore
83/// model.transform(ConcretizeSymbols::new().value("B", 1))?;
84/// ```
85#[derive(Debug, Clone, Default, serde::Serialize)]
86pub struct ConcretizeSymbols {
87    values: HashMap<String, i64>,
88}
89
90impl ConcretizeSymbols {
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Set a symbol to a concrete value.
96    pub fn value(mut self, symbol: impl Into<String>, val: i64) -> Self {
97        self.values.insert(symbol.into(), val);
98        self
99    }
100}
101
102transform_config!(ConcretizeSymbols, "concretize_symbols");
103
104/// Typed config for the `pulse` transform.
105///
106/// Converts a model to a pulsed (streaming) model.
107///
108/// # Example
109/// ```ignore
110/// model.transform(Pulse::new("5").symbol("B"))?;
111/// ```
112#[derive(Debug, Clone, serde::Serialize)]
113pub struct Pulse {
114    pulse: String,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    symbol: Option<String>,
117}
118
119impl Pulse {
120    /// Create a new Pulse config with the given pulse dimension.
121    pub fn new(pulse: impl Into<String>) -> Self {
122        Self { pulse: pulse.into(), symbol: None }
123    }
124
125    /// Set the symbol to pulse over (defaults to "S" if not set).
126    pub fn symbol(mut self, symbol: impl Into<String>) -> Self {
127        self.symbol = Some(symbol.into());
128        self
129    }
130}
131
132transform_config!(Pulse, "pulse");
133
134/// Typed config for the `float_precision` transform.
135///
136/// Changes the float precision of a model (e.g. F32 to F16).
137///
138/// # Example
139/// ```ignore
140/// use tract_api::DatumType;
141/// model.transform(FloatPrecision::new(DatumType::F32, DatumType::F16))?;
142/// ```
143#[derive(Debug, Clone, serde::Serialize)]
144pub struct FloatPrecision {
145    from: String,
146    to: String,
147    #[serde(skip_serializing_if = "Option::is_none")]
148    include: Option<Vec<String>>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    exclude: Option<Vec<String>>,
151}
152
153fn datum_type_to_str(dt: DatumType) -> &'static str {
154    match dt {
155        DatumType::F16 => "f16",
156        DatumType::F32 => "f32",
157        DatumType::F64 => "f64",
158        _ => panic!("FloatPrecision only supports float datum types (F16, F32, F64)"),
159    }
160}
161
162impl FloatPrecision {
163    pub fn new(from: DatumType, to: DatumType) -> Self {
164        Self {
165            from: datum_type_to_str(from).to_string(),
166            to: datum_type_to_str(to).to_string(),
167            include: None,
168            exclude: None,
169        }
170    }
171
172    /// Set include patterns — only nodes matching at least one pattern are translated.
173    pub fn include(mut self, patterns: Vec<String>) -> Self {
174        self.include = Some(patterns);
175        self
176    }
177
178    /// Set exclude patterns — matching nodes are excluded from translation.
179    pub fn exclude(mut self, patterns: Vec<String>) -> Self {
180        self.exclude = Some(patterns);
181        self
182    }
183}
184
185transform_config!(FloatPrecision, "float_precision");