Skip to main content

tract_api/
transform.rs

1use std::collections::HashMap;
2
3use crate::DatumType;
4
5/// A serialized transform specification passed to `ModelInterface::transform`.
6///
7/// Wraps the string representation expected by the transform registry.
8/// Constructed from raw strings or typed config structs implementing [`TransformConfig`].
9#[derive(Debug, Clone)]
10pub struct TransformSpec(String);
11
12impl TransformSpec {
13    /// Produce the string the transform registry expects.
14    pub fn to_transform_string(&self) -> String {
15        self.0.clone()
16    }
17}
18
19impl From<&str> for TransformSpec {
20    fn from(s: &str) -> Self {
21        TransformSpec(s.to_string())
22    }
23}
24
25impl From<String> for TransformSpec {
26    fn from(s: String) -> Self {
27        TransformSpec(s)
28    }
29}
30
31impl From<&String> for TransformSpec {
32    fn from(s: &String) -> Self {
33        TransformSpec(s.clone())
34    }
35}
36
37/// Trait for typed transform configurations.
38///
39/// Implementors derive [`serde::Serialize`] and provide a transform [`name()`](TransformConfig::name).
40/// The default [`to_transform_string()`](TransformConfig::to_transform_string) serializes the
41/// struct as a JSON object and injects the `"name"` key.
42pub trait TransformConfig: serde::Serialize {
43    /// The transform registry name (e.g. `"pulse"`, `"float_precision"`).
44    fn name(&self) -> &'static str;
45
46    /// Produce the string the transform registry expects.
47    ///
48    /// The default implementation serializes `self` to a JSON object and inserts `"name"`.
49    fn to_transform_string(&self) -> String {
50        let mut obj: serde_json::Map<String, serde_json::Value> = serde_json::to_value(self)
51            .expect("TransformConfig serialization cannot fail")
52            .as_object()
53            .expect("TransformConfig must serialize to a JSON object")
54            .clone();
55        obj.insert("name".into(), serde_json::Value::String(self.name().to_string()));
56        serde_json::to_string(&obj).expect("serialization cannot fail")
57    }
58}
59
60/// Implements [`TransformConfig`] and `From<$ty> for TransformSpec`.
61macro_rules! transform_config {
62    ($ty:ty, $name:expr) => {
63        impl TransformConfig for $ty {
64            fn name(&self) -> &'static str {
65                $name
66            }
67        }
68
69        impl From<$ty> for TransformSpec {
70            fn from(config: $ty) -> Self {
71                TransformSpec(config.to_transform_string())
72            }
73        }
74    };
75}
76
77/// Typed config for the `set_symbols` transform.
78///
79/// Binds symbolic dimensions to concrete integers (or `TDim` expressions
80/// via [`Self::expr`]).
81///
82/// # Example
83/// ```ignore
84/// model.transform(SetSymbols::new().value("B", 1).value("T", 16))?;
85/// ```
86#[derive(Debug, Clone, Default, serde::Serialize)]
87pub struct SetSymbols {
88    #[serde(serialize_with = "serialize_values")]
89    values: HashMap<String, SetSymbolValue>,
90}
91
92#[derive(Debug, Clone, serde::Serialize)]
93#[serde(untagged)]
94enum SetSymbolValue {
95    Int(i64),
96    Expr(String),
97}
98
99fn serialize_values<S: serde::Serializer>(
100    values: &HashMap<String, SetSymbolValue>,
101    s: S,
102) -> Result<S::Ok, S::Error> {
103    use serde::ser::SerializeMap;
104    let mut map = s.serialize_map(Some(values.len()))?;
105    for (k, v) in values {
106        map.serialize_entry(k, v)?;
107    }
108    map.end()
109}
110
111impl SetSymbols {
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Bind a symbol to a concrete integer value.
117    pub fn value(mut self, symbol: impl Into<String>, val: i64) -> Self {
118        self.values.insert(symbol.into(), SetSymbolValue::Int(val));
119        self
120    }
121
122    /// Bind a symbol to a `TDim` expression (e.g. `"2*S"`) parsed against
123    /// the model's symbol scope at transform time.
124    pub fn expr(mut self, symbol: impl Into<String>, expr: impl Into<String>) -> Self {
125        self.values.insert(symbol.into(), SetSymbolValue::Expr(expr.into()));
126        self
127    }
128}
129
130transform_config!(SetSymbols, "set_symbols");
131
132/// Typed config for the `pulse` transform.
133///
134/// Converts a model to a pulsed (streaming) model.
135///
136/// # Example
137/// ```ignore
138/// model.transform(Pulse::new("5").symbol("B"))?;
139/// ```
140#[derive(Debug, Clone, serde::Serialize)]
141pub struct Pulse {
142    pulse: String,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    symbol: Option<String>,
145}
146
147impl Pulse {
148    /// Create a new Pulse config with the given pulse dimension.
149    pub fn new(pulse: impl Into<String>) -> Self {
150        Self { pulse: pulse.into(), symbol: None }
151    }
152
153    /// Set the symbol to pulse over (defaults to "S" if not set).
154    pub fn symbol(mut self, symbol: impl Into<String>) -> Self {
155        self.symbol = Some(symbol.into());
156        self
157    }
158}
159
160transform_config!(Pulse, "pulse");
161
162/// Typed config for the `float_precision` transform.
163///
164/// Changes the float precision of a model (e.g. F32 to F16).
165///
166/// # Example
167/// ```ignore
168/// use tract_api::DatumType;
169/// model.transform(FloatPrecision::new(DatumType::F32, DatumType::F16))?;
170/// ```
171#[derive(Debug, Clone, serde::Serialize)]
172pub struct FloatPrecision {
173    from: String,
174    to: String,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    include: Option<Vec<String>>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    exclude: Option<Vec<String>>,
179}
180
181fn datum_type_to_str(dt: DatumType) -> &'static str {
182    match dt {
183        DatumType::F16 => "f16",
184        DatumType::F32 => "f32",
185        DatumType::F64 => "f64",
186        _ => panic!("FloatPrecision only supports float datum types (F16, F32, F64)"),
187    }
188}
189
190impl FloatPrecision {
191    pub fn new(from: DatumType, to: DatumType) -> Self {
192        Self {
193            from: datum_type_to_str(from).to_string(),
194            to: datum_type_to_str(to).to_string(),
195            include: None,
196            exclude: None,
197        }
198    }
199
200    /// Set include patterns — only nodes matching at least one pattern are translated.
201    pub fn include(mut self, patterns: Vec<String>) -> Self {
202        self.include = Some(patterns);
203        self
204    }
205
206    /// Set exclude patterns — matching nodes are excluded from translation.
207    pub fn exclude(mut self, patterns: Vec<String>) -> Self {
208        self.exclude = Some(patterns);
209        self
210    }
211}
212
213transform_config!(FloatPrecision, "float_precision");