1use crate::quantize::calibration::QuantizationParams;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, PartialEq)]
14pub enum NodeType {
15 Conv2d,
16 BatchNorm,
17 ReLU,
18 HardSwish,
19 Quantize,
20 Dequantize,
21 Input,
22 Output,
23}
24
25#[derive(Debug, Clone)]
27pub struct GraphNode {
28 pub id: usize,
29 pub node_type: NodeType,
30 pub inputs: Vec<usize>,
31 pub outputs: Vec<usize>,
32 pub params: NodeParams,
33}
34
35#[derive(Debug, Clone)]
37pub enum NodeParams {
38 Conv2d {
39 weights: Vec<f32>,
40 bias: Option<Vec<f32>>,
41 in_channels: usize,
42 out_channels: usize,
43 kernel_size: usize,
44 },
45 BatchNorm {
46 gamma: Vec<f32>,
47 beta: Vec<f32>,
48 mean: Vec<f32>,
49 var: Vec<f32>,
50 eps: f32,
51 },
52 Activation,
53 Quantize {
54 scale: f32,
55 zero_point: i32,
56 },
57 Dequantize {
58 scale: f32,
59 zero_point: i32,
60 },
61 None,
62}
63
64#[derive(Debug, Clone)]
66pub struct ComputationGraph {
67 pub nodes: HashMap<usize, GraphNode>,
68 pub next_id: usize,
69}
70
71impl ComputationGraph {
72 pub fn new() -> Self {
73 Self {
74 nodes: HashMap::new(),
75 next_id: 0,
76 }
77 }
78
79 pub fn add_node(&mut self, node_type: NodeType, params: NodeParams) -> usize {
80 let id = self.next_id;
81 self.next_id += 1;
82 self.nodes.insert(
83 id,
84 GraphNode {
85 id,
86 node_type,
87 inputs: Vec::new(),
88 outputs: Vec::new(),
89 params,
90 },
91 );
92 id
93 }
94
95 pub fn connect(&mut self, from: usize, to: usize) {
96 if let Some(from_node) = self.nodes.get_mut(&from) {
97 from_node.outputs.push(to);
98 }
99 if let Some(to_node) = self.nodes.get_mut(&to) {
100 to_node.inputs.push(from);
101 }
102 }
103
104 pub fn remove_node(&mut self, id: usize) {
105 if let Some(node) = self.nodes.remove(&id) {
106 for &input_id in &node.inputs {
108 if let Some(input_node) = self.nodes.get_mut(&input_id) {
109 input_node.outputs.retain(|&x| x != id);
110 input_node.outputs.extend(&node.outputs);
111 }
112 }
113 for &output_id in &node.outputs {
114 if let Some(output_node) = self.nodes.get_mut(&output_id) {
115 output_node.inputs.retain(|&x| x != id);
116 output_node.inputs.extend(&node.inputs);
117 }
118 }
119 }
120 }
121
122 pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
123 self.nodes.get(&id)
124 }
125
126 pub fn get_node_mut(&mut self, id: usize) -> Option<&mut GraphNode> {
127 self.nodes.get_mut(&id)
128 }
129}
130
131impl Default for ComputationGraph {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137pub fn fuse_batchnorm_to_conv(graph: &mut ComputationGraph) -> usize {
143 let mut fused_count = 0;
144 let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
145
146 for conv_id in node_ids {
147 let conv_node = match graph.get_node(conv_id) {
148 Some(node) if node.node_type == NodeType::Conv2d => node,
149 _ => continue,
150 };
151
152 let bn_id = match conv_node.outputs.first() {
154 Some(&id) => id,
155 None => continue,
156 };
157
158 let bn_node = match graph.get_node(bn_id) {
159 Some(node) if node.node_type == NodeType::BatchNorm => node,
160 _ => continue,
161 };
162
163 let (weights, bias, out_channels) = match &conv_node.params {
165 NodeParams::Conv2d {
166 weights,
167 bias,
168 out_channels,
169 ..
170 } => (weights.clone(), bias.clone(), *out_channels),
171 _ => continue,
172 };
173
174 let (gamma, beta, mean, var, eps) = match &bn_node.params {
175 NodeParams::BatchNorm {
176 gamma,
177 beta,
178 mean,
179 var,
180 eps,
181 } => (gamma, beta, mean, var, *eps),
182 _ => continue,
183 };
184
185 let mut fused_weights = weights;
187 let mut fused_bias = bias.unwrap_or_else(|| vec![0.0; out_channels]);
188
189 for c in 0..out_channels {
190 let scale = gamma[c] / (var[c] + eps).sqrt();
191
192 let weights_per_channel = fused_weights.len() / out_channels;
194 for i in 0..weights_per_channel {
195 fused_weights[c * weights_per_channel + i] *= scale;
196 }
197
198 fused_bias[c] = (fused_bias[c] - mean[c]) * scale + beta[c];
200 }
201
202 if let Some(conv_node) = graph.get_node_mut(conv_id) {
204 if let NodeParams::Conv2d { weights, bias, .. } = &mut conv_node.params {
205 *weights = fused_weights;
206 *bias = Some(fused_bias);
207 }
208 }
209
210 graph.remove_node(bn_id);
212 fused_count += 1;
213 }
214
215 fused_count
216}
217
218pub fn fuse_zp_to_bias(
223 graph: &mut ComputationGraph,
224 quant_params: &HashMap<usize, QuantizationParams>,
225) -> usize {
226 let mut fused_count = 0;
227 let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
228
229 for conv_id in node_ids {
230 let conv_node = match graph.get_node(conv_id) {
231 Some(node) if node.node_type == NodeType::Conv2d => node,
232 _ => continue,
233 };
234
235 let input_id = match conv_node.inputs.first() {
237 Some(&id) => id,
238 None => continue,
239 };
240
241 let input_qparams = match quant_params.get(&input_id) {
242 Some(qp) => qp,
243 None => continue,
244 };
245
246 let zp_input = input_qparams.zero_point as f32;
247
248 let (weights, bias, in_channels, out_channels, kernel_size) = match &conv_node.params {
250 NodeParams::Conv2d {
251 weights,
252 bias,
253 in_channels,
254 out_channels,
255 kernel_size,
256 } => (weights, bias, *in_channels, *out_channels, *kernel_size),
257 _ => continue,
258 };
259
260 let mut fused_bias = bias.clone().unwrap_or_else(|| vec![0.0; out_channels]);
261
262 let weights_per_channel = kernel_size * kernel_size * in_channels;
264 for c in 0..out_channels {
265 let mut weight_sum = 0.0;
266 for i in 0..weights_per_channel {
267 weight_sum += weights[c * weights_per_channel + i];
268 }
269 fused_bias[c] -= zp_input * weight_sum;
271 }
272
273 if let Some(conv_node) = graph.get_node_mut(conv_id) {
275 if let NodeParams::Conv2d { bias, .. } = &mut conv_node.params {
276 *bias = Some(fused_bias);
277 }
278 }
279
280 fused_count += 1;
281 }
282
283 fused_count
284}
285
286pub fn insert_qdq_nodes(
291 graph: &mut ComputationGraph,
292 quant_params: &HashMap<usize, QuantizationParams>,
293) -> usize {
294 let mut inserted_count = 0;
295 let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
296
297 for node_id in node_ids {
298 let (node_type, inputs, outputs) = match graph.get_node(node_id) {
300 Some(n) => (n.node_type.clone(), n.inputs.clone(), n.outputs.clone()),
301 None => continue,
302 };
303
304 if matches!(node_type, NodeType::Quantize | NodeType::Dequantize) {
306 continue;
307 }
308
309 for &input_id in &inputs {
311 let input_node_type = match graph.get_node(input_id) {
312 Some(n) => n.node_type.clone(),
313 None => continue,
314 };
315
316 let needs_quantize = is_quantized_op(&node_type)
318 && !is_quantized_op(&input_node_type)
319 && quant_params.contains_key(&node_id);
320
321 if needs_quantize {
322 let qparams = &quant_params[&node_id];
323 let q_id = graph.add_node(
324 NodeType::Quantize,
325 NodeParams::Quantize {
326 scale: qparams.scale,
327 zero_point: qparams.zero_point,
328 },
329 );
330
331 graph.nodes.get_mut(&input_id).unwrap().outputs.retain(|&x| x != node_id);
333 graph.nodes.get_mut(&input_id).unwrap().outputs.push(q_id);
334 graph.nodes.get_mut(&node_id).unwrap().inputs.retain(|&x| x != input_id);
335 graph.nodes.get_mut(&node_id).unwrap().inputs.push(q_id);
336 graph.nodes.get_mut(&q_id).unwrap().inputs.push(input_id);
337 graph.nodes.get_mut(&q_id).unwrap().outputs.push(node_id);
338
339 inserted_count += 1;
340 }
341 }
342
343 for &output_id in &outputs {
345 let output_node_type = match graph.get_node(output_id) {
346 Some(n) => n.node_type.clone(),
347 None => continue,
348 };
349
350 let needs_dequantize = is_quantized_op(&node_type)
352 && !is_quantized_op(&output_node_type)
353 && quant_params.contains_key(&node_id);
354
355 if needs_dequantize {
356 let qparams = &quant_params[&node_id];
357 let dq_id = graph.add_node(
358 NodeType::Dequantize,
359 NodeParams::Dequantize {
360 scale: qparams.scale,
361 zero_point: qparams.zero_point,
362 },
363 );
364
365 graph.nodes.get_mut(&node_id).unwrap().outputs.retain(|&x| x != output_id);
367 graph.nodes.get_mut(&node_id).unwrap().outputs.push(dq_id);
368 graph.nodes.get_mut(&output_id).unwrap().inputs.retain(|&x| x != node_id);
369 graph.nodes.get_mut(&output_id).unwrap().inputs.push(dq_id);
370 graph.nodes.get_mut(&dq_id).unwrap().inputs.push(node_id);
371 graph.nodes.get_mut(&dq_id).unwrap().outputs.push(output_id);
372
373 inserted_count += 1;
374 }
375 }
376 }
377
378 inserted_count
379}
380
381fn is_quantized_op(node_type: &NodeType) -> bool {
383 matches!(
384 node_type,
385 NodeType::Conv2d | NodeType::Quantize | NodeType::Dequantize
386 )
387}
388
389pub fn fuse_relu(graph: &mut ComputationGraph) -> usize {
393 let mut fused_count = 0;
394 let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
395
396 for conv_id in node_ids {
397 let conv_node = match graph.get_node(conv_id) {
398 Some(node) if node.node_type == NodeType::Conv2d => node,
399 _ => continue,
400 };
401
402 let relu_id = match conv_node.outputs.first() {
404 Some(&id) => id,
405 None => continue,
406 };
407
408 let _relu_node = match graph.get_node(relu_id) {
409 Some(node) if node.node_type == NodeType::ReLU => node,
410 _ => continue,
411 };
412
413 graph.remove_node(relu_id);
416 fused_count += 1;
417 }
418
419 fused_count
420}
421
422pub fn fuse_hardswish(graph: &mut ComputationGraph) -> usize {
427 let mut fused_count = 0;
428 let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
429
430 for conv_id in node_ids {
431 let conv_node = match graph.get_node(conv_id) {
432 Some(node) if node.node_type == NodeType::Conv2d => node,
433 _ => continue,
434 };
435
436 let hs_id = match conv_node.outputs.first() {
438 Some(&id) => id,
439 None => continue,
440 };
441
442 let _hs_node = match graph.get_node(hs_id) {
443 Some(node) if node.node_type == NodeType::HardSwish => node,
444 _ => continue,
445 };
446
447 graph.remove_node(hs_id);
450 fused_count += 1;
451 }
452
453 fused_count
454}
455
456pub fn generate_hardswish_lut(scale: f32, zero_point: i32) -> [i8; 256] {
460 let mut lut = [0i8; 256];
461
462 for i in 0..256 {
463 let q_input = i as i8;
464 let x = (q_input as i32 - zero_point) as f32 * scale;
466
467 let relu6 = ((x + 3.0).max(0.0)).min(6.0);
469 let hs_output = x * relu6 / 6.0;
470
471 let q_output = (hs_output / scale).round() as i32 + zero_point;
473 lut[i] = q_output.clamp(-128, 127) as i8;
474 }
475
476 lut
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_fuse_batchnorm_to_conv() {
485 let mut graph = ComputationGraph::new();
486
487 let conv_id = graph.add_node(
489 NodeType::Conv2d,
490 NodeParams::Conv2d {
491 weights: vec![1.0, 2.0, 3.0, 4.0], bias: Some(vec![0.5, 1.0]),
493 in_channels: 1,
494 out_channels: 2,
495 kernel_size: 1,
496 },
497 );
498
499 let bn_id = graph.add_node(
501 NodeType::BatchNorm,
502 NodeParams::BatchNorm {
503 gamma: vec![2.0, 3.0],
504 beta: vec![0.1, 0.2],
505 mean: vec![0.5, 1.0],
506 var: vec![1.0, 4.0],
507 eps: 1e-5,
508 },
509 );
510
511 graph.connect(conv_id, bn_id);
512
513 let fused = fuse_batchnorm_to_conv(&mut graph);
515 assert_eq!(fused, 1);
516
517 assert!(graph.get_node(bn_id).is_none());
519
520 let conv_node = graph.get_node(conv_id).unwrap();
522 if let NodeParams::Conv2d { weights, bias, .. } = &conv_node.params {
523 assert!((weights[0] - 2.0).abs() < 0.01);
526 assert!((weights[1] - 4.0).abs() < 0.01);
527
528 assert!((weights[2] - 4.5).abs() < 0.01);
531 assert!((weights[3] - 6.0).abs() < 0.01);
532
533 let bias = bias.as_ref().unwrap();
535 assert!((bias[0] - 0.1).abs() < 0.01);
537 assert!((bias[1] - 0.2).abs() < 0.01);
539 } else {
540 panic!("Expected Conv2d params");
541 }
542 }
543
544 #[test]
545 fn test_fuse_zp_to_bias() {
546 let mut graph = ComputationGraph::new();
547
548 let input_id = graph.add_node(NodeType::Input, NodeParams::None);
550
551 let conv_id = graph.add_node(
553 NodeType::Conv2d,
554 NodeParams::Conv2d {
555 weights: vec![1.0, 2.0, 3.0, 4.0], bias: Some(vec![1.0, 2.0]),
557 in_channels: 1,
558 out_channels: 2,
559 kernel_size: 1,
560 },
561 );
562
563 graph.connect(input_id, conv_id);
564
565 let mut quant_params = HashMap::new();
567 quant_params.insert(
568 input_id,
569 QuantizationParams {
570 scale: 0.1,
571 zero_point: 10,
572 min_val: -12.8,
573 max_val: 12.7,
574 num_bins: 256,
575 },
576 );
577
578 let fused = fuse_zp_to_bias(&mut graph, &quant_params);
580 assert_eq!(fused, 1);
581
582 let conv_node = graph.get_node(conv_id).unwrap();
584 if let NodeParams::Conv2d { bias, .. } = &conv_node.params {
585 let bias = bias.as_ref().unwrap();
586 assert!((bias[0] - (-29.0)).abs() < 0.01);
589
590 assert!((bias[1] - (-68.0)).abs() < 0.01);
593 } else {
594 panic!("Expected Conv2d params");
595 }
596 }
597
598 #[test]
599 fn test_insert_qdq_nodes() {
600 let mut graph = ComputationGraph::new();
601
602 let input_id = graph.add_node(NodeType::Input, NodeParams::None);
604
605 let conv_id = graph.add_node(
607 NodeType::Conv2d,
608 NodeParams::Conv2d {
609 weights: vec![1.0; 4],
610 bias: None,
611 in_channels: 1,
612 out_channels: 1,
613 kernel_size: 2,
614 },
615 );
616
617 let output_id = graph.add_node(NodeType::Output, NodeParams::None);
619
620 graph.connect(input_id, conv_id);
621 graph.connect(conv_id, output_id);
622
623 let mut quant_params = HashMap::new();
625 quant_params.insert(
626 conv_id,
627 QuantizationParams {
628 scale: 0.1,
629 zero_point: 0,
630 min_val: -12.8,
631 max_val: 12.7,
632 num_bins: 256,
633 },
634 );
635
636 let inserted = insert_qdq_nodes(&mut graph, &quant_params);
638 assert_eq!(inserted, 2); let conv_node = graph.get_node(conv_id).unwrap();
642
643 assert_eq!(conv_node.inputs.len(), 1);
645 let q_id = conv_node.inputs[0];
646 let q_node = graph.get_node(q_id).unwrap();
647 assert_eq!(q_node.node_type, NodeType::Quantize);
648
649 assert_eq!(conv_node.outputs.len(), 1);
651 let dq_id = conv_node.outputs[0];
652 let dq_node = graph.get_node(dq_id).unwrap();
653 assert_eq!(dq_node.node_type, NodeType::Dequantize);
654 }
655
656 #[test]
657 fn test_fuse_relu() {
658 let mut graph = ComputationGraph::new();
659
660 let conv_id = graph.add_node(
662 NodeType::Conv2d,
663 NodeParams::Conv2d {
664 weights: vec![1.0; 4],
665 bias: None,
666 in_channels: 1,
667 out_channels: 1,
668 kernel_size: 2,
669 },
670 );
671
672 let relu_id = graph.add_node(NodeType::ReLU, NodeParams::Activation);
674
675 graph.connect(conv_id, relu_id);
676
677 let fused = fuse_relu(&mut graph);
679 assert_eq!(fused, 1);
680
681 assert!(graph.get_node(relu_id).is_none());
683
684 let conv_node = graph.get_node(conv_id).unwrap();
686 assert_eq!(conv_node.outputs, vec![]); }
688
689 #[test]
690 fn test_fuse_hardswish() {
691 let mut graph = ComputationGraph::new();
692
693 let conv_id = graph.add_node(
695 NodeType::Conv2d,
696 NodeParams::Conv2d {
697 weights: vec![1.0; 4],
698 bias: None,
699 in_channels: 1,
700 out_channels: 1,
701 kernel_size: 2,
702 },
703 );
704
705 let hs_id = graph.add_node(NodeType::HardSwish, NodeParams::Activation);
707
708 graph.connect(conv_id, hs_id);
709
710 let fused = fuse_hardswish(&mut graph);
712 assert_eq!(fused, 1);
713
714 assert!(graph.get_node(hs_id).is_none());
716 }
717
718 #[test]
719 fn test_hardswish_lut_generation() {
720 let scale = 0.1;
721 let zero_point = 0;
722 let lut = generate_hardswish_lut(scale, zero_point);
723
724 let idx_0 = (0 - zero_point + 128) as usize;
727 assert_eq!(lut[idx_0], 0);
728
729 let idx_neg3 = ((-30 as i32 - zero_point + 128) as usize).min(255);
731 assert_eq!(lut[idx_neg3], 0);
732
733 let idx_pos3 = ((30 as i32 - zero_point + 128) as usize).min(255);
735 let x_pos3 = (lut[idx_pos3] as i32 - zero_point) as f32 * scale;
736 assert!((x_pos3 - 3.0).abs() < 0.5); }
738
739 #[test]
740 fn test_graph_construction() {
741 let mut graph = ComputationGraph::new();
742
743 let id1 = graph.add_node(NodeType::Input, NodeParams::None);
744 let id2 = graph.add_node(NodeType::Conv2d, NodeParams::Conv2d {
745 weights: vec![1.0; 4],
746 bias: None,
747 in_channels: 1,
748 out_channels: 1,
749 kernel_size: 2,
750 });
751 let id3 = graph.add_node(NodeType::Output, NodeParams::None);
752
753 graph.connect(id1, id2);
754 graph.connect(id2, id3);
755
756 assert_eq!(graph.nodes.len(), 3);
757 assert_eq!(graph.get_node(id2).unwrap().inputs, vec![id1]);
758 assert_eq!(graph.get_node(id2).unwrap().outputs, vec![id3]);
759 }
760
761 #[test]
762 fn test_remove_node() {
763 let mut graph = ComputationGraph::new();
764
765 let id1 = graph.add_node(NodeType::Input, NodeParams::None);
766 let id2 = graph.add_node(NodeType::ReLU, NodeParams::Activation);
767 let id3 = graph.add_node(NodeType::Output, NodeParams::None);
768
769 graph.connect(id1, id2);
770 graph.connect(id2, id3);
771
772 graph.remove_node(id2);
773
774 assert!(graph.get_node(id2).is_none());
776
777 assert_eq!(graph.get_node(id1).unwrap().outputs, vec![id3]);
779 assert_eq!(graph.get_node(id3).unwrap().inputs, vec![id1]);
780 }
781}