Skip to main content

webnn_graph/onnx/
constant_folding.rs

1// Constant folding for ONNX models
2// Eliminates nodes with all-constant inputs by evaluating them at conversion time
3
4pub mod evaluators;
5
6use crate::onnx::convert::OnnxError;
7use crate::protos::onnx::{ModelProto, NodeProto, TensorProto, TensorProto_DataType};
8use std::collections::{HashMap, HashSet};
9
10/// Represents constant tensor data with various types
11#[derive(Debug, Clone)]
12pub enum TensorData {
13    Int64(Vec<i64>),
14    Int32(Vec<i32>),
15    Float32(Vec<f32>),
16    Float64(Vec<f64>),
17    UInt8(Vec<u8>),
18    Int8(Vec<i8>),
19}
20
21impl TensorData {
22    /// Get the number of elements in this tensor
23    pub fn len(&self) -> usize {
24        match self {
25            TensorData::Int64(v) => v.len(),
26            TensorData::Int32(v) => v.len(),
27            TensorData::Float32(v) => v.len(),
28            TensorData::Float64(v) => v.len(),
29            TensorData::UInt8(v) => v.len(),
30            TensorData::Int8(v) => v.len(),
31        }
32    }
33
34    /// Check if the tensor is empty
35    pub fn is_empty(&self) -> bool {
36        self.len() == 0
37    }
38
39    /// Get the data type
40    pub fn data_type(&self) -> TensorProto_DataType {
41        match self {
42            TensorData::Int64(_) => TensorProto_DataType::Int64,
43            TensorData::Int32(_) => TensorProto_DataType::Int32,
44            TensorData::Float32(_) => TensorProto_DataType::Float,
45            TensorData::Float64(_) => TensorProto_DataType::Double,
46            TensorData::UInt8(_) => TensorProto_DataType::Uint8,
47            TensorData::Int8(_) => TensorProto_DataType::Int8,
48        }
49    }
50
51    /// Convert to bytes (little-endian)
52    pub fn to_bytes(&self) -> Vec<u8> {
53        match self {
54            TensorData::Int64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
55            TensorData::Int32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
56            TensorData::Float32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
57            TensorData::Float64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
58            TensorData::UInt8(v) => v.clone(),
59            TensorData::Int8(v) => v.iter().map(|&x| x as u8).collect(),
60        }
61    }
62
63    /// Create from TensorProto
64    pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
65        let raw_data = tensor.raw_data.as_slice();
66        let data_type = tensor.data_type;
67
68        if !raw_data.is_empty() {
69            // Parse from raw bytes
70            match data_type {
71                x if x == TensorProto_DataType::Int64 as i32 => {
72                    let values = raw_data
73                        .chunks_exact(8)
74                        .map(|c| {
75                            i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
76                        })
77                        .collect();
78                    Ok(TensorData::Int64(values))
79                }
80                x if x == TensorProto_DataType::Int32 as i32 => {
81                    let values = raw_data
82                        .chunks_exact(4)
83                        .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
84                        .collect();
85                    Ok(TensorData::Int32(values))
86                }
87                x if x == TensorProto_DataType::Float as i32 => {
88                    let values = raw_data
89                        .chunks_exact(4)
90                        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
91                        .collect();
92                    Ok(TensorData::Float32(values))
93                }
94                x if x == TensorProto_DataType::Double as i32 => {
95                    let values = raw_data
96                        .chunks_exact(8)
97                        .map(|c| {
98                            f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
99                        })
100                        .collect();
101                    Ok(TensorData::Float64(values))
102                }
103                x if x == TensorProto_DataType::Uint8 as i32 => {
104                    Ok(TensorData::UInt8(raw_data.to_vec()))
105                }
106                x if x == TensorProto_DataType::Int8 as i32 => Ok(TensorData::Int8(
107                    raw_data.iter().map(|&x| x as i8).collect(),
108                )),
109                _ => Err(OnnxError::TypeConversion(
110                    webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
111                )),
112            }
113        } else {
114            // Parse from typed data fields
115            match data_type {
116                x if x == TensorProto_DataType::Int64 as i32 => {
117                    Ok(TensorData::Int64(tensor.int64_data.as_slice().to_vec()))
118                }
119                x if x == TensorProto_DataType::Int32 as i32 => {
120                    Ok(TensorData::Int32(tensor.int32_data.as_slice().to_vec()))
121                }
122                x if x == TensorProto_DataType::Float as i32 => {
123                    Ok(TensorData::Float32(tensor.float_data.as_slice().to_vec()))
124                }
125                x if x == TensorProto_DataType::Double as i32 => {
126                    Ok(TensorData::Float64(tensor.double_data.as_slice().to_vec()))
127                }
128                _ => Err(OnnxError::TypeConversion(
129                    webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
130                )),
131            }
132        }
133    }
134}
135
136/// Represents a constant tensor with its shape and type
137#[derive(Debug, Clone)]
138pub struct ConstantTensor {
139    pub data: TensorData,
140    pub shape: Vec<i64>,
141    pub data_type: i32,
142}
143
144impl ConstantTensor {
145    /// Create a ConstantTensor from a TensorProto
146    pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
147        let data = TensorData::from_tensor_proto(tensor)?;
148        let shape = tensor.dims.as_slice().to_vec();
149        let data_type = tensor.data_type;
150
151        Ok(ConstantTensor {
152            data,
153            shape,
154            data_type,
155        })
156    }
157
158    /// Convert to TensorProto
159    pub fn to_tensor_proto(&self, name: &str) -> TensorProto {
160        TensorProto {
161            name: name.to_string(),
162            data_type: self.data_type,
163            dims: self.shape.clone(),
164            raw_data: self.data.to_bytes(),
165            ..Default::default()
166        }
167    }
168
169    /// Get the total number of elements
170    pub fn numel(&self) -> i64 {
171        if self.shape.is_empty() {
172            1
173        } else {
174            self.shape.iter().product()
175        }
176    }
177}
178
179/// Context for constant folding operations
180#[derive(Debug)]
181pub struct ConstantFoldingContext<'a> {
182    /// Map from tensor name to constant value
183    pub constants: HashMap<String, ConstantTensor>,
184    /// Original ONNX initializers (for reference)
185    pub initializers: &'a HashMap<String, &'a TensorProto>,
186}
187
188impl<'a> ConstantFoldingContext<'a> {
189    /// Create a new context from initializers
190    pub fn new(initializers: &'a HashMap<String, &'a TensorProto>) -> Result<Self, OnnxError> {
191        let mut constants = HashMap::new();
192
193        for (name, tensor) in initializers.iter() {
194            // Only add tensors with data
195            if !tensor.raw_data.as_slice().is_empty()
196                || !tensor.int64_data.as_slice().is_empty()
197                || !tensor.int32_data.as_slice().is_empty()
198                || !tensor.float_data.as_slice().is_empty()
199                || !tensor.double_data.as_slice().is_empty()
200            {
201                match ConstantTensor::from_tensor_proto(tensor) {
202                    Ok(ct) => {
203                        constants.insert((*name).clone(), ct);
204                    }
205                    Err(e) => {
206                        crate::debug_println!(
207                            "Warning: Failed to parse initializer '{}': {}",
208                            name,
209                            e
210                        );
211                    }
212                }
213            }
214        }
215
216        Ok(ConstantFoldingContext {
217            constants,
218            initializers,
219        })
220    }
221
222    /// Check if a value is a constant
223    pub fn is_constant(&self, name: &str) -> bool {
224        self.constants.contains_key(name)
225    }
226
227    /// Get a constant by name
228    pub fn get_constant(&self, name: &str) -> Option<&ConstantTensor> {
229        self.constants.get(name)
230    }
231
232    /// Add a new constant
233    pub fn add_constant(&mut self, name: String, tensor: ConstantTensor) {
234        self.constants.insert(name, tensor);
235    }
236}
237
238/// Result of a constant folding pass
239#[derive(Debug, Default)]
240pub struct FoldingResult {
241    /// New initializers to add to the model
242    pub new_initializers: Vec<TensorProto>,
243    /// Node indices to remove from the graph
244    pub nodes_to_remove: HashSet<usize>,
245    /// Number of nodes folded in this pass
246    pub nodes_folded: usize,
247}
248
249/// Trait for operations that support constant evaluation
250pub trait ConstantEvaluator {
251    /// Get the operation type this evaluator handles
252    fn op_type(&self) -> &str;
253
254    /// Check if this evaluator can handle the given node
255    fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool;
256
257    /// Evaluate the node with constant inputs, returning output tensors
258    fn evaluate(
259        &self,
260        node: &NodeProto,
261        ctx: &ConstantFoldingContext,
262    ) -> Result<Vec<ConstantTensor>, OnnxError>;
263}
264
265/// Build the initial context from model initializers
266fn build_context<'a>(
267    _model: &'a ModelProto,
268    initializers_map: &'a HashMap<String, &'a TensorProto>,
269) -> Result<ConstantFoldingContext<'a>, OnnxError> {
270    ConstantFoldingContext::new(initializers_map)
271}
272
273/// Identify nodes that have all constant inputs
274fn identify_constant_nodes(
275    model: &ModelProto,
276    ctx: &ConstantFoldingContext,
277    evaluators: &[Box<dyn ConstantEvaluator>],
278) -> Result<Vec<usize>, OnnxError> {
279    let graph = model.graph.as_ref().unwrap();
280    let mut constant_nodes = Vec::new();
281
282    for (idx, node) in graph.node.as_slice().iter().enumerate() {
283        // Check if any evaluator can handle this node
284        let can_evaluate = evaluators.iter().any(|e| e.can_evaluate(node, ctx));
285
286        if can_evaluate {
287            constant_nodes.push(idx);
288        }
289    }
290
291    Ok(constant_nodes)
292}
293
294/// Evaluate constant nodes and return the folding result
295fn evaluate_constant_nodes(
296    model: &ModelProto,
297    constant_node_indices: &[usize],
298    ctx: &mut ConstantFoldingContext,
299    evaluators: &[Box<dyn ConstantEvaluator>],
300) -> Result<FoldingResult, OnnxError> {
301    let graph = model.graph.as_ref().unwrap();
302    let mut result = FoldingResult::default();
303
304    for &idx in constant_node_indices {
305        let node = &graph.node.as_slice()[idx];
306
307        // Find an evaluator that can handle this node
308        let evaluator = evaluators.iter().find(|e| e.can_evaluate(node, ctx));
309
310        if let Some(evaluator) = evaluator {
311            match evaluator.evaluate(node, ctx) {
312                Ok(output_tensors) => {
313                    // Add outputs as new initializers
314                    for (i, tensor) in output_tensors.iter().enumerate() {
315                        if i < node.output.as_slice().len() {
316                            let output_name = &node.output.as_slice()[i];
317                            let proto = tensor.to_tensor_proto(output_name);
318                            result.new_initializers.push(proto.clone());
319
320                            // Add to context for subsequent evaluations
321                            ctx.add_constant(output_name.to_string(), tensor.clone());
322                        }
323                    }
324
325                    result.nodes_to_remove.insert(idx);
326                    result.nodes_folded += 1;
327                }
328                Err(e) => {
329                    crate::debug_println!(
330                        "Warning: Failed to evaluate constant node '{}' ({}): {}",
331                        node.name.as_str(),
332                        node.op_type.as_str(),
333                        e
334                    );
335                }
336            }
337        }
338    }
339
340    Ok(result)
341}
342
343/// Main entry point: fold constants in an ONNX model
344pub fn fold_constants_in_model(
345    model: &mut ModelProto,
346    evaluators: &[Box<dyn ConstantEvaluator>],
347) -> Result<usize, OnnxError> {
348    let mut total_folded = 0;
349    let max_iterations = 10;
350
351    // Build initializers map
352    let graph = model.graph.as_ref().unwrap();
353    let mut initializers_map: HashMap<String, &TensorProto> = HashMap::new();
354    for init in graph.initializer.as_slice() {
355        initializers_map.insert(init.name.as_str().to_string(), init);
356    }
357
358    for iteration in 0..max_iterations {
359        // 1. Build context from current initializers
360        let initializers_map_ref: HashMap<String, &TensorProto> = model
361            .graph
362            .as_ref()
363            .unwrap()
364            .initializer
365            .as_slice()
366            .iter()
367            .map(|init| (init.name.as_str().to_string(), init))
368            .collect();
369
370        let mut ctx = build_context(model, &initializers_map_ref)?;
371
372        // 2. Identify constant nodes
373        let constant_nodes = identify_constant_nodes(model, &ctx, evaluators)?;
374
375        if constant_nodes.is_empty() {
376            break;
377        }
378
379        // 3. Evaluate constant nodes
380        let result = evaluate_constant_nodes(model, &constant_nodes, &mut ctx, evaluators)?;
381
382        if result.nodes_folded == 0 {
383            break;
384        }
385
386        // 4. Add new initializers to the model
387        let graph_mut = model.graph.as_mut().unwrap();
388        for init in result.new_initializers {
389            graph_mut.initializer.push(init);
390        }
391
392        // 5. Remove evaluated nodes
393        let nodes = graph_mut.node.as_slice().to_vec();
394        graph_mut.node.clear();
395        for (idx, node) in nodes.into_iter().enumerate() {
396            if !result.nodes_to_remove.contains(&idx) {
397                graph_mut.node.push(node);
398            }
399        }
400
401        total_folded += result.nodes_folded;
402
403        crate::debug_println!(
404            "Constant folding iteration {}: {} nodes folded",
405            iteration + 1,
406            result.nodes_folded
407        );
408    }
409
410    Ok(total_folded)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_tensor_data_len() {
419        let data = TensorData::Int64(vec![1, 2, 3]);
420        assert_eq!(data.len(), 3);
421
422        let data = TensorData::Float32(vec![1.0, 2.0]);
423        assert_eq!(data.len(), 2);
424    }
425
426    #[test]
427    fn test_tensor_data_to_bytes() {
428        let data = TensorData::Int32(vec![1, 2, 3]);
429        let bytes = data.to_bytes();
430        assert_eq!(bytes.len(), 12); // 3 * 4 bytes
431
432        let data = TensorData::Int64(vec![1, 2]);
433        let bytes = data.to_bytes();
434        assert_eq!(bytes.len(), 16); // 2 * 8 bytes
435    }
436
437    #[test]
438    fn test_constant_tensor_numel() {
439        let ct = ConstantTensor {
440            data: TensorData::Int64(vec![1, 2, 3, 4, 5, 6]),
441            shape: vec![2, 3],
442            data_type: TensorProto_DataType::Int64 as i32,
443        };
444        assert_eq!(ct.numel(), 6);
445
446        let ct = ConstantTensor {
447            data: TensorData::Int64(vec![42]),
448            shape: vec![],
449            data_type: TensorProto_DataType::Int64 as i32,
450        };
451        assert_eq!(ct.numel(), 1);
452    }
453}