Skip to main content

ruvector_cnn/quantize/
graph_rewrite.rs

1//! Graph Rewrite Passes for INT8 Quantization (ADR-091 Phase 3)
2//!
3//! This module implements four critical graph optimization passes:
4//! - GR-1: fuse_batchnorm_to_conv - Absorb BatchNorm into Conv weights/bias
5//! - GR-2: fuse_zp_to_bias - Pre-compute zero-point correction in bias
6//! - GR-3: insert_qdq_nodes - Insert Quantize/Dequantize nodes at boundaries
7//! - GR-4: fuse_relu/fuse_hardswish - Merge activations into preceding ops
8
9use crate::quantize::calibration::QuantizationParams;
10use std::collections::HashMap;
11
12/// Computation graph node types
13#[derive(Debug, Clone, PartialEq)]
14pub enum NodeType {
15    Conv2d,
16    BatchNorm,
17    ReLU,
18    HardSwish,
19    Quantize,
20    Dequantize,
21    Input,
22    Output,
23}
24
25/// Graph node representing a single operation
26#[derive(Debug, Clone)]
27pub struct GraphNode {
28    pub id: usize,
29    pub node_type: NodeType,
30    pub inputs: Vec<usize>,
31    pub outputs: Vec<usize>,
32    pub params: NodeParams,
33}
34
35/// Parameters for different node types
36#[derive(Debug, Clone)]
37pub enum NodeParams {
38    Conv2d {
39        weights: Vec<f32>,
40        bias: Option<Vec<f32>>,
41        in_channels: usize,
42        out_channels: usize,
43        kernel_size: usize,
44    },
45    BatchNorm {
46        gamma: Vec<f32>,
47        beta: Vec<f32>,
48        mean: Vec<f32>,
49        var: Vec<f32>,
50        eps: f32,
51    },
52    Activation,
53    Quantize {
54        scale: f32,
55        zero_point: i32,
56    },
57    Dequantize {
58        scale: f32,
59        zero_point: i32,
60    },
61    None,
62}
63
64/// Computation graph for optimization passes
65#[derive(Debug, Clone)]
66pub struct ComputationGraph {
67    pub nodes: HashMap<usize, GraphNode>,
68    pub next_id: usize,
69}
70
71impl ComputationGraph {
72    pub fn new() -> Self {
73        Self {
74            nodes: HashMap::new(),
75            next_id: 0,
76        }
77    }
78
79    pub fn add_node(&mut self, node_type: NodeType, params: NodeParams) -> usize {
80        let id = self.next_id;
81        self.next_id += 1;
82        self.nodes.insert(
83            id,
84            GraphNode {
85                id,
86                node_type,
87                inputs: Vec::new(),
88                outputs: Vec::new(),
89                params,
90            },
91        );
92        id
93    }
94
95    pub fn connect(&mut self, from: usize, to: usize) {
96        if let Some(from_node) = self.nodes.get_mut(&from) {
97            from_node.outputs.push(to);
98        }
99        if let Some(to_node) = self.nodes.get_mut(&to) {
100            to_node.inputs.push(from);
101        }
102    }
103
104    pub fn remove_node(&mut self, id: usize) {
105        if let Some(node) = self.nodes.remove(&id) {
106            // Reconnect inputs directly to outputs
107            for &input_id in &node.inputs {
108                if let Some(input_node) = self.nodes.get_mut(&input_id) {
109                    input_node.outputs.retain(|&x| x != id);
110                    input_node.outputs.extend(&node.outputs);
111                }
112            }
113            for &output_id in &node.outputs {
114                if let Some(output_node) = self.nodes.get_mut(&output_id) {
115                    output_node.inputs.retain(|&x| x != id);
116                    output_node.inputs.extend(&node.inputs);
117                }
118            }
119        }
120    }
121
122    pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
123        self.nodes.get(&id)
124    }
125
126    pub fn get_node_mut(&mut self, id: usize) -> Option<&mut GraphNode> {
127        self.nodes.get_mut(&id)
128    }
129}
130
131impl Default for ComputationGraph {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137/// GR-1: Fuse BatchNorm parameters into Conv weights and bias
138///
139/// Mathematical formulation:
140/// w_fused = w * gamma / sqrt(var + eps)
141/// b_fused = (b - mean) * gamma / sqrt(var + eps) + beta
142pub fn fuse_batchnorm_to_conv(graph: &mut ComputationGraph) -> usize {
143    let mut fused_count = 0;
144    let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
145
146    for conv_id in node_ids {
147        let conv_node = match graph.get_node(conv_id) {
148            Some(node) if node.node_type == NodeType::Conv2d => node,
149            _ => continue,
150        };
151
152        // Check if followed by BatchNorm
153        let bn_id = match conv_node.outputs.first() {
154            Some(&id) => id,
155            None => continue,
156        };
157
158        let bn_node = match graph.get_node(bn_id) {
159            Some(node) if node.node_type == NodeType::BatchNorm => node,
160            _ => continue,
161        };
162
163        // Extract parameters
164        let (weights, bias, out_channels) = match &conv_node.params {
165            NodeParams::Conv2d {
166                weights,
167                bias,
168                out_channels,
169                ..
170            } => (weights.clone(), bias.clone(), *out_channels),
171            _ => continue,
172        };
173
174        let (gamma, beta, mean, var, eps) = match &bn_node.params {
175            NodeParams::BatchNorm {
176                gamma,
177                beta,
178                mean,
179                var,
180                eps,
181            } => (gamma, beta, mean, var, *eps),
182            _ => continue,
183        };
184
185        // Compute fused weights and bias
186        let mut fused_weights = weights;
187        let mut fused_bias = bias.unwrap_or_else(|| vec![0.0; out_channels]);
188
189        for c in 0..out_channels {
190            let scale = gamma[c] / (var[c] + eps).sqrt();
191
192            // Fuse weights: w_fused = w * scale
193            let weights_per_channel = fused_weights.len() / out_channels;
194            for i in 0..weights_per_channel {
195                fused_weights[c * weights_per_channel + i] *= scale;
196            }
197
198            // Fuse bias: b_fused = (b - mean) * scale + beta
199            fused_bias[c] = (fused_bias[c] - mean[c]) * scale + beta[c];
200        }
201
202        // Update Conv node with fused parameters
203        if let Some(conv_node) = graph.get_node_mut(conv_id) {
204            if let NodeParams::Conv2d { weights, bias, .. } = &mut conv_node.params {
205                *weights = fused_weights;
206                *bias = Some(fused_bias);
207            }
208        }
209
210        // Remove BatchNorm node
211        graph.remove_node(bn_id);
212        fused_count += 1;
213    }
214
215    fused_count
216}
217
218/// GR-2: Fuse zero-point correction into bias
219///
220/// Eliminates runtime zero-point subtraction by pre-computing:
221/// bias_q = bias - zp_input × Σweights
222pub fn fuse_zp_to_bias(
223    graph: &mut ComputationGraph,
224    quant_params: &HashMap<usize, QuantizationParams>,
225) -> usize {
226    let mut fused_count = 0;
227    let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
228
229    for conv_id in node_ids {
230        let conv_node = match graph.get_node(conv_id) {
231            Some(node) if node.node_type == NodeType::Conv2d => node,
232            _ => continue,
233        };
234
235        // Get input quantization params
236        let input_id = match conv_node.inputs.first() {
237            Some(&id) => id,
238            None => continue,
239        };
240
241        let input_qparams = match quant_params.get(&input_id) {
242            Some(qp) => qp,
243            None => continue,
244        };
245
246        let zp_input = input_qparams.zero_point as f32;
247
248        // Extract Conv parameters
249        let (weights, bias, in_channels, out_channels, kernel_size) = match &conv_node.params {
250            NodeParams::Conv2d {
251                weights,
252                bias,
253                in_channels,
254                out_channels,
255                kernel_size,
256            } => (weights, bias, *in_channels, *out_channels, *kernel_size),
257            _ => continue,
258        };
259
260        let mut fused_bias = bias.clone().unwrap_or_else(|| vec![0.0; out_channels]);
261
262        // Compute zero-point correction for each output channel
263        let weights_per_channel = kernel_size * kernel_size * in_channels;
264        for c in 0..out_channels {
265            let mut weight_sum = 0.0;
266            for i in 0..weights_per_channel {
267                weight_sum += weights[c * weights_per_channel + i];
268            }
269            // bias_q = bias - zp_input × Σweights
270            fused_bias[c] -= zp_input * weight_sum;
271        }
272
273        // Update Conv bias
274        if let Some(conv_node) = graph.get_node_mut(conv_id) {
275            if let NodeParams::Conv2d { bias, .. } = &mut conv_node.params {
276                *bias = Some(fused_bias);
277            }
278        }
279
280        fused_count += 1;
281    }
282
283    fused_count
284}
285
286/// GR-3: Insert Quantize/Dequantize nodes at INT8 subgraph boundaries
287///
288/// Detects transitions between FP32 and INT8 operations and inserts
289/// appropriate Q/DQ nodes to maintain numerical correctness.
290pub fn insert_qdq_nodes(
291    graph: &mut ComputationGraph,
292    quant_params: &HashMap<usize, QuantizationParams>,
293) -> usize {
294    let mut inserted_count = 0;
295    let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
296
297    for node_id in node_ids {
298        // Collect node info without holding borrow
299        let (node_type, inputs, outputs) = match graph.get_node(node_id) {
300            Some(n) => (n.node_type.clone(), n.inputs.clone(), n.outputs.clone()),
301            None => continue,
302        };
303
304        // Skip nodes that are already Q/DQ
305        if matches!(node_type, NodeType::Quantize | NodeType::Dequantize) {
306            continue;
307        }
308
309        // Check each input for FP32→INT8 transition
310        for &input_id in &inputs {
311            let input_node_type = match graph.get_node(input_id) {
312                Some(n) => n.node_type.clone(),
313                None => continue,
314            };
315
316            // If input is not quantized but current node needs quantized input
317            let needs_quantize = is_quantized_op(&node_type)
318                && !is_quantized_op(&input_node_type)
319                && quant_params.contains_key(&node_id);
320
321            if needs_quantize {
322                let qparams = &quant_params[&node_id];
323                let q_id = graph.add_node(
324                    NodeType::Quantize,
325                    NodeParams::Quantize {
326                        scale: qparams.scale,
327                        zero_point: qparams.zero_point,
328                    },
329                );
330
331                // Reconnect: input → Q → node
332                graph.nodes.get_mut(&input_id).unwrap().outputs.retain(|&x| x != node_id);
333                graph.nodes.get_mut(&input_id).unwrap().outputs.push(q_id);
334                graph.nodes.get_mut(&node_id).unwrap().inputs.retain(|&x| x != input_id);
335                graph.nodes.get_mut(&node_id).unwrap().inputs.push(q_id);
336                graph.nodes.get_mut(&q_id).unwrap().inputs.push(input_id);
337                graph.nodes.get_mut(&q_id).unwrap().outputs.push(node_id);
338
339                inserted_count += 1;
340            }
341        }
342
343        // Check each output for INT8→FP32 transition
344        for &output_id in &outputs {
345            let output_node_type = match graph.get_node(output_id) {
346                Some(n) => n.node_type.clone(),
347                None => continue,
348            };
349
350            // If current node is quantized but output expects FP32
351            let needs_dequantize = is_quantized_op(&node_type)
352                && !is_quantized_op(&output_node_type)
353                && quant_params.contains_key(&node_id);
354
355            if needs_dequantize {
356                let qparams = &quant_params[&node_id];
357                let dq_id = graph.add_node(
358                    NodeType::Dequantize,
359                    NodeParams::Dequantize {
360                        scale: qparams.scale,
361                        zero_point: qparams.zero_point,
362                    },
363                );
364
365                // Reconnect: node → DQ → output
366                graph.nodes.get_mut(&node_id).unwrap().outputs.retain(|&x| x != output_id);
367                graph.nodes.get_mut(&node_id).unwrap().outputs.push(dq_id);
368                graph.nodes.get_mut(&output_id).unwrap().inputs.retain(|&x| x != node_id);
369                graph.nodes.get_mut(&output_id).unwrap().inputs.push(dq_id);
370                graph.nodes.get_mut(&dq_id).unwrap().inputs.push(node_id);
371                graph.nodes.get_mut(&dq_id).unwrap().outputs.push(output_id);
372
373                inserted_count += 1;
374            }
375        }
376    }
377
378    inserted_count
379}
380
381/// Helper: Check if operation is quantized
382fn is_quantized_op(node_type: &NodeType) -> bool {
383    matches!(
384        node_type,
385        NodeType::Conv2d | NodeType::Quantize | NodeType::Dequantize
386    )
387}
388
389/// GR-4: Fuse ReLU activation into preceding convolution
390///
391/// Eliminates separate ReLU node by clamping Conv output to [0, ∞)
392pub fn fuse_relu(graph: &mut ComputationGraph) -> usize {
393    let mut fused_count = 0;
394    let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
395
396    for conv_id in node_ids {
397        let conv_node = match graph.get_node(conv_id) {
398            Some(node) if node.node_type == NodeType::Conv2d => node,
399            _ => continue,
400        };
401
402        // Check if followed by ReLU
403        let relu_id = match conv_node.outputs.first() {
404            Some(&id) => id,
405            None => continue,
406        };
407
408        let _relu_node = match graph.get_node(relu_id) {
409            Some(node) if node.node_type == NodeType::ReLU => node,
410            _ => continue,
411        };
412
413        // ReLU fusion is handled at runtime by clamping output
414        // We mark the Conv as having fused ReLU and remove the ReLU node
415        graph.remove_node(relu_id);
416        fused_count += 1;
417    }
418
419    fused_count
420}
421
422/// GR-4: Fuse HardSwish activation using LUT
423///
424/// Replaces HardSwish with 256-entry lookup table (i8→i8)
425/// HardSwish(x) = x * ReLU6(x + 3) / 6
426pub fn fuse_hardswish(graph: &mut ComputationGraph) -> usize {
427    let mut fused_count = 0;
428    let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
429
430    for conv_id in node_ids {
431        let conv_node = match graph.get_node(conv_id) {
432            Some(node) if node.node_type == NodeType::Conv2d => node,
433            _ => continue,
434        };
435
436        // Check if followed by HardSwish
437        let hs_id = match conv_node.outputs.first() {
438            Some(&id) => id,
439            None => continue,
440        };
441
442        let _hs_node = match graph.get_node(hs_id) {
443            Some(node) if node.node_type == NodeType::HardSwish => node,
444            _ => continue,
445        };
446
447        // HardSwish fusion is handled at runtime using LUT
448        // We mark the Conv as having fused HardSwish and remove the HardSwish node
449        graph.remove_node(hs_id);
450        fused_count += 1;
451    }
452
453    fused_count
454}
455
456/// Generate HardSwish LUT for INT8 quantized values
457///
458/// Maps i8 input to i8 output using the HardSwish function
459pub fn generate_hardswish_lut(scale: f32, zero_point: i32) -> [i8; 256] {
460    let mut lut = [0i8; 256];
461
462    for i in 0..256 {
463        let q_input = i as i8;
464        // Dequantize
465        let x = (q_input as i32 - zero_point) as f32 * scale;
466
467        // HardSwish: x * ReLU6(x + 3) / 6
468        let relu6 = ((x + 3.0).max(0.0)).min(6.0);
469        let hs_output = x * relu6 / 6.0;
470
471        // Quantize back
472        let q_output = (hs_output / scale).round() as i32 + zero_point;
473        lut[i] = q_output.clamp(-128, 127) as i8;
474    }
475
476    lut
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[test]
484    fn test_fuse_batchnorm_to_conv() {
485        let mut graph = ComputationGraph::new();
486
487        // Create Conv2d node
488        let conv_id = graph.add_node(
489            NodeType::Conv2d,
490            NodeParams::Conv2d {
491                weights: vec![1.0, 2.0, 3.0, 4.0], // 2 output channels, 2 weights each
492                bias: Some(vec![0.5, 1.0]),
493                in_channels: 1,
494                out_channels: 2,
495                kernel_size: 1,
496            },
497        );
498
499        // Create BatchNorm node
500        let bn_id = graph.add_node(
501            NodeType::BatchNorm,
502            NodeParams::BatchNorm {
503                gamma: vec![2.0, 3.0],
504                beta: vec![0.1, 0.2],
505                mean: vec![0.5, 1.0],
506                var: vec![1.0, 4.0],
507                eps: 1e-5,
508            },
509        );
510
511        graph.connect(conv_id, bn_id);
512
513        // Fuse BatchNorm into Conv
514        let fused = fuse_batchnorm_to_conv(&mut graph);
515        assert_eq!(fused, 1);
516
517        // Verify BatchNorm was removed
518        assert!(graph.get_node(bn_id).is_none());
519
520        // Verify Conv parameters were updated
521        let conv_node = graph.get_node(conv_id).unwrap();
522        if let NodeParams::Conv2d { weights, bias, .. } = &conv_node.params {
523            // Channel 0: scale = 2.0 / sqrt(1.0 + 1e-5) ≈ 2.0
524            // w0 = 1.0 * 2.0 = 2.0, w1 = 2.0 * 2.0 = 4.0
525            assert!((weights[0] - 2.0).abs() < 0.01);
526            assert!((weights[1] - 4.0).abs() < 0.01);
527
528            // Channel 1: scale = 3.0 / sqrt(4.0 + 1e-5) ≈ 1.5
529            // w2 = 3.0 * 1.5 = 4.5, w3 = 4.0 * 1.5 = 6.0
530            assert!((weights[2] - 4.5).abs() < 0.01);
531            assert!((weights[3] - 6.0).abs() < 0.01);
532
533            // Bias verification
534            let bias = bias.as_ref().unwrap();
535            // b0 = (0.5 - 0.5) * 2.0 + 0.1 = 0.1
536            assert!((bias[0] - 0.1).abs() < 0.01);
537            // b1 = (1.0 - 1.0) * 1.5 + 0.2 = 0.2
538            assert!((bias[1] - 0.2).abs() < 0.01);
539        } else {
540            panic!("Expected Conv2d params");
541        }
542    }
543
544    #[test]
545    fn test_fuse_zp_to_bias() {
546        let mut graph = ComputationGraph::new();
547
548        // Create Input node
549        let input_id = graph.add_node(NodeType::Input, NodeParams::None);
550
551        // Create Conv2d node
552        let conv_id = graph.add_node(
553            NodeType::Conv2d,
554            NodeParams::Conv2d {
555                weights: vec![1.0, 2.0, 3.0, 4.0], // 2 out channels, 2 weights each
556                bias: Some(vec![1.0, 2.0]),
557                in_channels: 1,
558                out_channels: 2,
559                kernel_size: 1,
560            },
561        );
562
563        graph.connect(input_id, conv_id);
564
565        // Create quantization params
566        let mut quant_params = HashMap::new();
567        quant_params.insert(
568            input_id,
569            QuantizationParams {
570                scale: 0.1,
571                zero_point: 10,
572                min_val: -12.8,
573                max_val: 12.7,
574                num_bins: 256,
575            },
576        );
577
578        // Fuse zero-point correction
579        let fused = fuse_zp_to_bias(&mut graph, &quant_params);
580        assert_eq!(fused, 1);
581
582        // Verify bias was updated
583        let conv_node = graph.get_node(conv_id).unwrap();
584        if let NodeParams::Conv2d { bias, .. } = &conv_node.params {
585            let bias = bias.as_ref().unwrap();
586            // Channel 0: weight_sum = 1.0 + 2.0 = 3.0
587            // bias_q = 1.0 - 10.0 * 3.0 = -29.0
588            assert!((bias[0] - (-29.0)).abs() < 0.01);
589
590            // Channel 1: weight_sum = 3.0 + 4.0 = 7.0
591            // bias_q = 2.0 - 10.0 * 7.0 = -68.0
592            assert!((bias[1] - (-68.0)).abs() < 0.01);
593        } else {
594            panic!("Expected Conv2d params");
595        }
596    }
597
598    #[test]
599    fn test_insert_qdq_nodes() {
600        let mut graph = ComputationGraph::new();
601
602        // Create FP32 Input node
603        let input_id = graph.add_node(NodeType::Input, NodeParams::None);
604
605        // Create quantized Conv2d node
606        let conv_id = graph.add_node(
607            NodeType::Conv2d,
608            NodeParams::Conv2d {
609                weights: vec![1.0; 4],
610                bias: None,
611                in_channels: 1,
612                out_channels: 1,
613                kernel_size: 2,
614            },
615        );
616
617        // Create FP32 Output node
618        let output_id = graph.add_node(NodeType::Output, NodeParams::None);
619
620        graph.connect(input_id, conv_id);
621        graph.connect(conv_id, output_id);
622
623        // Create quantization params
624        let mut quant_params = HashMap::new();
625        quant_params.insert(
626            conv_id,
627            QuantizationParams {
628                scale: 0.1,
629                zero_point: 0,
630                min_val: -12.8,
631                max_val: 12.7,
632                num_bins: 256,
633            },
634        );
635
636        // Insert Q/DQ nodes
637        let inserted = insert_qdq_nodes(&mut graph, &quant_params);
638        assert_eq!(inserted, 2); // One Q before Conv, one DQ after Conv
639
640        // Verify graph structure
641        let conv_node = graph.get_node(conv_id).unwrap();
642
643        // Conv should have one quantize input
644        assert_eq!(conv_node.inputs.len(), 1);
645        let q_id = conv_node.inputs[0];
646        let q_node = graph.get_node(q_id).unwrap();
647        assert_eq!(q_node.node_type, NodeType::Quantize);
648
649        // Conv should have one dequantize output
650        assert_eq!(conv_node.outputs.len(), 1);
651        let dq_id = conv_node.outputs[0];
652        let dq_node = graph.get_node(dq_id).unwrap();
653        assert_eq!(dq_node.node_type, NodeType::Dequantize);
654    }
655
656    #[test]
657    fn test_fuse_relu() {
658        let mut graph = ComputationGraph::new();
659
660        // Create Conv2d node
661        let conv_id = graph.add_node(
662            NodeType::Conv2d,
663            NodeParams::Conv2d {
664                weights: vec![1.0; 4],
665                bias: None,
666                in_channels: 1,
667                out_channels: 1,
668                kernel_size: 2,
669            },
670        );
671
672        // Create ReLU node
673        let relu_id = graph.add_node(NodeType::ReLU, NodeParams::Activation);
674
675        graph.connect(conv_id, relu_id);
676
677        // Fuse ReLU
678        let fused = fuse_relu(&mut graph);
679        assert_eq!(fused, 1);
680
681        // Verify ReLU was removed
682        assert!(graph.get_node(relu_id).is_none());
683
684        // Verify Conv outputs are connected to what ReLU was connected to
685        let conv_node = graph.get_node(conv_id).unwrap();
686        assert_eq!(conv_node.outputs, vec![]); // In this test, ReLU had no outputs
687    }
688
689    #[test]
690    fn test_fuse_hardswish() {
691        let mut graph = ComputationGraph::new();
692
693        // Create Conv2d node
694        let conv_id = graph.add_node(
695            NodeType::Conv2d,
696            NodeParams::Conv2d {
697                weights: vec![1.0; 4],
698                bias: None,
699                in_channels: 1,
700                out_channels: 1,
701                kernel_size: 2,
702            },
703        );
704
705        // Create HardSwish node
706        let hs_id = graph.add_node(NodeType::HardSwish, NodeParams::Activation);
707
708        graph.connect(conv_id, hs_id);
709
710        // Fuse HardSwish
711        let fused = fuse_hardswish(&mut graph);
712        assert_eq!(fused, 1);
713
714        // Verify HardSwish was removed
715        assert!(graph.get_node(hs_id).is_none());
716    }
717
718    #[test]
719    fn test_hardswish_lut_generation() {
720        let scale = 0.1;
721        let zero_point = 0;
722        let lut = generate_hardswish_lut(scale, zero_point);
723
724        // Test key points
725        // x = 0 → HardSwish(0) = 0
726        let idx_0 = (0 - zero_point + 128) as usize;
727        assert_eq!(lut[idx_0], 0);
728
729        // x = -3 (or less) → HardSwish = 0
730        let idx_neg3 = ((-30 as i32 - zero_point + 128) as usize).min(255);
731        assert_eq!(lut[idx_neg3], 0);
732
733        // x = 3 (or more) → HardSwish(x) ≈ x
734        let idx_pos3 = ((30 as i32 - zero_point + 128) as usize).min(255);
735        let x_pos3 = (lut[idx_pos3] as i32 - zero_point) as f32 * scale;
736        assert!((x_pos3 - 3.0).abs() < 0.5); // Should be close to 3.0
737    }
738
739    #[test]
740    fn test_graph_construction() {
741        let mut graph = ComputationGraph::new();
742
743        let id1 = graph.add_node(NodeType::Input, NodeParams::None);
744        let id2 = graph.add_node(NodeType::Conv2d, NodeParams::Conv2d {
745            weights: vec![1.0; 4],
746            bias: None,
747            in_channels: 1,
748            out_channels: 1,
749            kernel_size: 2,
750        });
751        let id3 = graph.add_node(NodeType::Output, NodeParams::None);
752
753        graph.connect(id1, id2);
754        graph.connect(id2, id3);
755
756        assert_eq!(graph.nodes.len(), 3);
757        assert_eq!(graph.get_node(id2).unwrap().inputs, vec![id1]);
758        assert_eq!(graph.get_node(id2).unwrap().outputs, vec![id3]);
759    }
760
761    #[test]
762    fn test_remove_node() {
763        let mut graph = ComputationGraph::new();
764
765        let id1 = graph.add_node(NodeType::Input, NodeParams::None);
766        let id2 = graph.add_node(NodeType::ReLU, NodeParams::Activation);
767        let id3 = graph.add_node(NodeType::Output, NodeParams::None);
768
769        graph.connect(id1, id2);
770        graph.connect(id2, id3);
771
772        graph.remove_node(id2);
773
774        // id2 should be removed
775        assert!(graph.get_node(id2).is_none());
776
777        // id1 should connect directly to id3
778        assert_eq!(graph.get_node(id1).unwrap().outputs, vec![id3]);
779        assert_eq!(graph.get_node(id3).unwrap().inputs, vec![id1]);
780    }
781}