1pub mod evaluators;
5
6use crate::onnx::convert::OnnxError;
7use crate::protos::onnx::{ModelProto, NodeProto, TensorProto, TensorProto_DataType};
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone)]
12pub enum TensorData {
13 Int64(Vec<i64>),
14 Int32(Vec<i32>),
15 Float32(Vec<f32>),
16 Float64(Vec<f64>),
17 UInt8(Vec<u8>),
18 Int8(Vec<i8>),
19}
20
21impl TensorData {
22 pub fn len(&self) -> usize {
24 match self {
25 TensorData::Int64(v) => v.len(),
26 TensorData::Int32(v) => v.len(),
27 TensorData::Float32(v) => v.len(),
28 TensorData::Float64(v) => v.len(),
29 TensorData::UInt8(v) => v.len(),
30 TensorData::Int8(v) => v.len(),
31 }
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.len() == 0
37 }
38
39 pub fn data_type(&self) -> TensorProto_DataType {
41 match self {
42 TensorData::Int64(_) => TensorProto_DataType::Int64,
43 TensorData::Int32(_) => TensorProto_DataType::Int32,
44 TensorData::Float32(_) => TensorProto_DataType::Float,
45 TensorData::Float64(_) => TensorProto_DataType::Double,
46 TensorData::UInt8(_) => TensorProto_DataType::Uint8,
47 TensorData::Int8(_) => TensorProto_DataType::Int8,
48 }
49 }
50
51 pub fn to_bytes(&self) -> Vec<u8> {
53 match self {
54 TensorData::Int64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
55 TensorData::Int32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
56 TensorData::Float32(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
57 TensorData::Float64(v) => v.iter().flat_map(|&x| x.to_le_bytes()).collect(),
58 TensorData::UInt8(v) => v.clone(),
59 TensorData::Int8(v) => v.iter().map(|&x| x as u8).collect(),
60 }
61 }
62
63 pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
65 let raw_data = tensor.raw_data.as_slice();
66 let data_type = tensor.data_type;
67
68 if !raw_data.is_empty() {
69 match data_type {
71 x if x == TensorProto_DataType::Int64 as i32 => {
72 let values = raw_data
73 .chunks_exact(8)
74 .map(|c| {
75 i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
76 })
77 .collect();
78 Ok(TensorData::Int64(values))
79 }
80 x if x == TensorProto_DataType::Int32 as i32 => {
81 let values = raw_data
82 .chunks_exact(4)
83 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
84 .collect();
85 Ok(TensorData::Int32(values))
86 }
87 x if x == TensorProto_DataType::Float as i32 => {
88 let values = raw_data
89 .chunks_exact(4)
90 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
91 .collect();
92 Ok(TensorData::Float32(values))
93 }
94 x if x == TensorProto_DataType::Double as i32 => {
95 let values = raw_data
96 .chunks_exact(8)
97 .map(|c| {
98 f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
99 })
100 .collect();
101 Ok(TensorData::Float64(values))
102 }
103 x if x == TensorProto_DataType::Uint8 as i32 => {
104 Ok(TensorData::UInt8(raw_data.to_vec()))
105 }
106 x if x == TensorProto_DataType::Int8 as i32 => Ok(TensorData::Int8(
107 raw_data.iter().map(|&x| x as i8).collect(),
108 )),
109 _ => Err(OnnxError::TypeConversion(
110 webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
111 )),
112 }
113 } else {
114 match data_type {
116 x if x == TensorProto_DataType::Int64 as i32 => {
117 Ok(TensorData::Int64(tensor.int64_data.as_slice().to_vec()))
118 }
119 x if x == TensorProto_DataType::Int32 as i32 => {
120 Ok(TensorData::Int32(tensor.int32_data.as_slice().to_vec()))
121 }
122 x if x == TensorProto_DataType::Float as i32 => {
123 Ok(TensorData::Float32(tensor.float_data.as_slice().to_vec()))
124 }
125 x if x == TensorProto_DataType::Double as i32 => {
126 Ok(TensorData::Float64(tensor.double_data.as_slice().to_vec()))
127 }
128 _ => Err(OnnxError::TypeConversion(
129 webnn_onnx_utils::error::ConversionError::UnsupportedOnnxDataType(data_type),
130 )),
131 }
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct ConstantTensor {
139 pub data: TensorData,
140 pub shape: Vec<i64>,
141 pub data_type: i32,
142}
143
144impl ConstantTensor {
145 pub fn from_tensor_proto(tensor: &TensorProto) -> Result<Self, OnnxError> {
147 let data = TensorData::from_tensor_proto(tensor)?;
148 let shape = tensor.dims.as_slice().to_vec();
149 let data_type = tensor.data_type;
150
151 Ok(ConstantTensor {
152 data,
153 shape,
154 data_type,
155 })
156 }
157
158 pub fn to_tensor_proto(&self, name: &str) -> TensorProto {
160 TensorProto {
161 name: name.to_string(),
162 data_type: self.data_type,
163 dims: self.shape.clone(),
164 raw_data: self.data.to_bytes(),
165 ..Default::default()
166 }
167 }
168
169 pub fn numel(&self) -> i64 {
171 if self.shape.is_empty() {
172 1
173 } else {
174 self.shape.iter().product()
175 }
176 }
177}
178
179#[derive(Debug)]
181pub struct ConstantFoldingContext<'a> {
182 pub constants: HashMap<String, ConstantTensor>,
184 pub initializers: &'a HashMap<String, &'a TensorProto>,
186}
187
188impl<'a> ConstantFoldingContext<'a> {
189 pub fn new(initializers: &'a HashMap<String, &'a TensorProto>) -> Result<Self, OnnxError> {
191 let mut constants = HashMap::new();
192
193 for (name, tensor) in initializers.iter() {
194 if !tensor.raw_data.as_slice().is_empty()
196 || !tensor.int64_data.as_slice().is_empty()
197 || !tensor.int32_data.as_slice().is_empty()
198 || !tensor.float_data.as_slice().is_empty()
199 || !tensor.double_data.as_slice().is_empty()
200 {
201 match ConstantTensor::from_tensor_proto(tensor) {
202 Ok(ct) => {
203 constants.insert((*name).clone(), ct);
204 }
205 Err(e) => {
206 crate::debug_println!(
207 "Warning: Failed to parse initializer '{}': {}",
208 name,
209 e
210 );
211 }
212 }
213 }
214 }
215
216 Ok(ConstantFoldingContext {
217 constants,
218 initializers,
219 })
220 }
221
222 pub fn is_constant(&self, name: &str) -> bool {
224 self.constants.contains_key(name)
225 }
226
227 pub fn get_constant(&self, name: &str) -> Option<&ConstantTensor> {
229 self.constants.get(name)
230 }
231
232 pub fn add_constant(&mut self, name: String, tensor: ConstantTensor) {
234 self.constants.insert(name, tensor);
235 }
236}
237
238#[derive(Debug, Default)]
240pub struct FoldingResult {
241 pub new_initializers: Vec<TensorProto>,
243 pub nodes_to_remove: HashSet<usize>,
245 pub nodes_folded: usize,
247}
248
249pub trait ConstantEvaluator {
251 fn op_type(&self) -> &str;
253
254 fn can_evaluate(&self, node: &NodeProto, ctx: &ConstantFoldingContext) -> bool;
256
257 fn evaluate(
259 &self,
260 node: &NodeProto,
261 ctx: &ConstantFoldingContext,
262 ) -> Result<Vec<ConstantTensor>, OnnxError>;
263}
264
265fn build_context<'a>(
267 _model: &'a ModelProto,
268 initializers_map: &'a HashMap<String, &'a TensorProto>,
269) -> Result<ConstantFoldingContext<'a>, OnnxError> {
270 ConstantFoldingContext::new(initializers_map)
271}
272
273fn identify_constant_nodes(
275 model: &ModelProto,
276 ctx: &ConstantFoldingContext,
277 evaluators: &[Box<dyn ConstantEvaluator>],
278) -> Result<Vec<usize>, OnnxError> {
279 let graph = model.graph.as_ref().unwrap();
280 let mut constant_nodes = Vec::new();
281
282 for (idx, node) in graph.node.as_slice().iter().enumerate() {
283 let can_evaluate = evaluators.iter().any(|e| e.can_evaluate(node, ctx));
285
286 if can_evaluate {
287 constant_nodes.push(idx);
288 }
289 }
290
291 Ok(constant_nodes)
292}
293
294fn evaluate_constant_nodes(
296 model: &ModelProto,
297 constant_node_indices: &[usize],
298 ctx: &mut ConstantFoldingContext,
299 evaluators: &[Box<dyn ConstantEvaluator>],
300) -> Result<FoldingResult, OnnxError> {
301 let graph = model.graph.as_ref().unwrap();
302 let mut result = FoldingResult::default();
303
304 for &idx in constant_node_indices {
305 let node = &graph.node.as_slice()[idx];
306
307 let evaluator = evaluators.iter().find(|e| e.can_evaluate(node, ctx));
309
310 if let Some(evaluator) = evaluator {
311 match evaluator.evaluate(node, ctx) {
312 Ok(output_tensors) => {
313 for (i, tensor) in output_tensors.iter().enumerate() {
315 if i < node.output.as_slice().len() {
316 let output_name = &node.output.as_slice()[i];
317 let proto = tensor.to_tensor_proto(output_name);
318 result.new_initializers.push(proto.clone());
319
320 ctx.add_constant(output_name.to_string(), tensor.clone());
322 }
323 }
324
325 result.nodes_to_remove.insert(idx);
326 result.nodes_folded += 1;
327 }
328 Err(e) => {
329 crate::debug_println!(
330 "Warning: Failed to evaluate constant node '{}' ({}): {}",
331 node.name.as_str(),
332 node.op_type.as_str(),
333 e
334 );
335 }
336 }
337 }
338 }
339
340 Ok(result)
341}
342
343pub fn fold_constants_in_model(
345 model: &mut ModelProto,
346 evaluators: &[Box<dyn ConstantEvaluator>],
347) -> Result<usize, OnnxError> {
348 let mut total_folded = 0;
349 let max_iterations = 10;
350
351 let graph = model.graph.as_ref().unwrap();
353 let mut initializers_map: HashMap<String, &TensorProto> = HashMap::new();
354 for init in graph.initializer.as_slice() {
355 initializers_map.insert(init.name.as_str().to_string(), init);
356 }
357
358 for iteration in 0..max_iterations {
359 let initializers_map_ref: HashMap<String, &TensorProto> = model
361 .graph
362 .as_ref()
363 .unwrap()
364 .initializer
365 .as_slice()
366 .iter()
367 .map(|init| (init.name.as_str().to_string(), init))
368 .collect();
369
370 let mut ctx = build_context(model, &initializers_map_ref)?;
371
372 let constant_nodes = identify_constant_nodes(model, &ctx, evaluators)?;
374
375 if constant_nodes.is_empty() {
376 break;
377 }
378
379 let result = evaluate_constant_nodes(model, &constant_nodes, &mut ctx, evaluators)?;
381
382 if result.nodes_folded == 0 {
383 break;
384 }
385
386 let graph_mut = model.graph.as_mut().unwrap();
388 for init in result.new_initializers {
389 graph_mut.initializer.push(init);
390 }
391
392 let nodes = graph_mut.node.as_slice().to_vec();
394 graph_mut.node.clear();
395 for (idx, node) in nodes.into_iter().enumerate() {
396 if !result.nodes_to_remove.contains(&idx) {
397 graph_mut.node.push(node);
398 }
399 }
400
401 total_folded += result.nodes_folded;
402
403 crate::debug_println!(
404 "Constant folding iteration {}: {} nodes folded",
405 iteration + 1,
406 result.nodes_folded
407 );
408 }
409
410 Ok(total_folded)
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_tensor_data_len() {
419 let data = TensorData::Int64(vec![1, 2, 3]);
420 assert_eq!(data.len(), 3);
421
422 let data = TensorData::Float32(vec![1.0, 2.0]);
423 assert_eq!(data.len(), 2);
424 }
425
426 #[test]
427 fn test_tensor_data_to_bytes() {
428 let data = TensorData::Int32(vec![1, 2, 3]);
429 let bytes = data.to_bytes();
430 assert_eq!(bytes.len(), 12); let data = TensorData::Int64(vec![1, 2]);
433 let bytes = data.to_bytes();
434 assert_eq!(bytes.len(), 16); }
436
437 #[test]
438 fn test_constant_tensor_numel() {
439 let ct = ConstantTensor {
440 data: TensorData::Int64(vec![1, 2, 3, 4, 5, 6]),
441 shape: vec![2, 3],
442 data_type: TensorProto_DataType::Int64 as i32,
443 };
444 assert_eq!(ct.numel(), 6);
445
446 let ct = ConstantTensor {
447 data: TensorData::Int64(vec![42]),
448 shape: vec![],
449 data_type: TensorProto_DataType::Int64 as i32,
450 };
451 assert_eq!(ct.numel(), 1);
452 }
453}