Skip to main content

safety_net/
graph.rs

1/*!
2
3  Graph utils for the `graph` module.
4
5*/
6
7use crate::circuit::{Instantiable, Net};
8use crate::error::Error;
9#[cfg(feature = "graph")]
10use crate::netlist::Connection;
11use crate::netlist::{NetRef, Netlist};
12#[cfg(feature = "graph")]
13use petgraph::graph::DiGraph;
14use std::collections::hash_map::Entry;
15use std::collections::{HashMap, HashSet};
16
17/// A common trait of analyses than can be performed on a netlist.
18/// An analysis becomes stale when the netlist is modified.
19pub trait Analysis<'a, I: Instantiable>
20where
21    Self: Sized + 'a,
22{
23    /// Construct the analysis to the current state of the netlist.
24    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
25}
26
27/// A table that maps nets to the circuit nodes they drive
28pub struct FanOutTable<'a, I: Instantiable> {
29    /// A reference to the underlying netlist
30    _netlist: &'a Netlist<I>,
31    /// Maps a net to the list of nodes it drives
32    net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
33    /// Maps a node to the list of nodes it drives
34    node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
35    /// Contains nets which are outputs
36    is_an_output: HashSet<Net>,
37}
38
39impl<I> FanOutTable<'_, I>
40where
41    I: Instantiable,
42{
43    /// Returns an iterator to the circuit nodes that use `net`.
44    pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
45        self.net_fan_out
46            .get(net)
47            .into_iter()
48            .flat_map(|users| users.iter().cloned())
49    }
50
51    /// Returns an iterator to the circuit nodes that use `node`.
52    pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
53        self.node_fan_out
54            .get(node)
55            .into_iter()
56            .flat_map(|users| users.iter().cloned())
57    }
58
59    /// Returns `true` if the net has any used by any cells in the circuit
60    /// This does incude nets that are only used as outputs.
61    pub fn net_has_uses(&self, net: &Net) -> bool {
62        (self.net_fan_out.contains_key(net) && !self.net_fan_out.get(net).unwrap().is_empty())
63            || self.is_an_output.contains(net)
64    }
65}
66
67impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
68where
69    I: Instantiable,
70{
71    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
72        let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
73        let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
74        let mut is_an_output: HashSet<Net> = HashSet::new();
75
76        // This can only be fully-correct on a verified netlist.
77        netlist.verify()?;
78
79        for c in netlist.connections() {
80            if let Entry::Vacant(e) = net_fan_out.entry(c.net()) {
81                e.insert(vec![c.target().unwrap()]);
82            } else {
83                net_fan_out
84                    .get_mut(&c.net())
85                    .unwrap()
86                    .push(c.target().unwrap());
87            }
88
89            if let Entry::Vacant(e) = node_fan_out.entry(c.src().unwrap()) {
90                e.insert(vec![c.target().unwrap()]);
91            } else {
92                node_fan_out
93                    .get_mut(&c.src().unwrap())
94                    .unwrap()
95                    .push(c.target().unwrap());
96            }
97        }
98
99        for (o, n) in netlist.outputs() {
100            is_an_output.insert(o.as_net().clone());
101            is_an_output.insert(n);
102        }
103
104        Ok(FanOutTable {
105            _netlist: netlist,
106            net_fan_out,
107            node_fan_out,
108            is_an_output,
109        })
110    }
111}
112
113/// A simple example to analyze the logic levels of a netlist.
114/// This analysis checks for cycles, but it doesn't check for registers.
115/// Result of combinational depth analysis for a single net.
116#[derive(Debug, Copy, Clone, PartialEq, Eq)]
117pub enum CombDepthResult {
118    /// Signal has no driver
119    Undefined,
120    /// Signal is along a cycle
121    CombCycle,
122    /// Integer logic level
123    Depth(usize),
124}
125
126/// Computes the combinational depth of each net in a netlist.
127///
128/// Each net is classified as having a defined depth, being undefined,
129/// or participating in a combinational cycle.
130pub struct SimpleCombDepth<'a, I: Instantiable> {
131    _netlist: &'a Netlist<I>,
132    results: HashMap<NetRef<I>, CombDepthResult>,
133    /// Max will be None whenever no outputs in the whole netlist have a well defined combinational depth
134    /// for example if they are all undefined or they all partake in a cycle
135    max_depth: Option<usize>,
136}
137
138impl<I> SimpleCombDepth<'_, I>
139where
140    I: Instantiable,
141{
142    /// Returns the logic level of a node in the circuit.
143    pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<CombDepthResult> {
144        self.results.get(node).copied()
145    }
146
147    /// Returns the maximum logic level of the circuit.
148    pub fn get_max_depth(&self) -> Option<usize> {
149        self.max_depth
150    }
151}
152impl<'a, I> Analysis<'a, I> for SimpleCombDepth<'a, I>
153where
154    I: Instantiable,
155{
156    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
157        let mut results: HashMap<NetRef<I>, CombDepthResult> = HashMap::new();
158        let mut visiting: HashSet<NetRef<I>> = HashSet::new();
159        let mut max_depth: Option<usize> = None;
160
161        fn compute<I: Instantiable>(
162            node: NetRef<I>,
163            netlist: &Netlist<I>,
164            results: &mut HashMap<NetRef<I>, CombDepthResult>,
165            visiting: &mut HashSet<NetRef<I>>,
166        ) -> CombDepthResult {
167            // Memoized result
168            if let Some(&r) = results.get(&node) {
169                return r;
170            }
171
172            // Cycle detection
173            if visiting.contains(&node) {
174                for n in visiting.iter() {
175                    results.insert(n.clone(), CombDepthResult::CombCycle);
176                }
177                return CombDepthResult::CombCycle;
178            }
179
180            // Input nodes have depth 0
181            if node.is_an_input() {
182                let r = CombDepthResult::Depth(0);
183                results.insert(node.clone(), r);
184                return r;
185            }
186
187            visiting.insert(node.clone());
188
189            let mut max_depth = 0;
190            let mut is_undefined = false;
191
192            for i in 0..node.get_num_input_ports() {
193                let driver = match netlist.get_driver(node.clone(), i) {
194                    Some(d) => d,
195                    None => {
196                        is_undefined = true;
197                        continue;
198                    }
199                };
200
201                if let Some(inst) = driver.get_instance_type()
202                    && inst.is_seq()
203                {
204                    continue;
205                }
206
207                match compute(driver, netlist, results, visiting) {
208                    CombDepthResult::Depth(d) => {
209                        max_depth = max_depth.max(d);
210                    }
211                    CombDepthResult::Undefined => {
212                        is_undefined = true;
213                    }
214                    CombDepthResult::CombCycle => {
215                        let r = CombDepthResult::CombCycle;
216                        results.insert(node.clone(), r);
217                        visiting.remove(&node);
218                        return r;
219                    }
220                }
221            }
222
223            visiting.remove(&node);
224            let r = if is_undefined {
225                CombDepthResult::Undefined
226            } else {
227                CombDepthResult::Depth(max_depth + 1)
228            };
229            results.insert(node.clone(), r);
230            r
231        }
232
233        for (driven, _) in netlist.outputs() {
234            let node = driven.unwrap();
235            let r = compute(node, netlist, &mut results, &mut visiting);
236
237            if let CombDepthResult::Depth(d) = r {
238                max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
239            }
240        }
241
242        for node in netlist.matches(|inst| inst.is_seq()) {
243            results.insert(node.clone(), CombDepthResult::Depth(0));
244            for i in 0..node.get_num_input_ports() {
245                if let Some(driver) = netlist.get_driver(node.clone(), i) {
246                    if driver.get_instance_type().is_some_and(|inst| inst.is_seq()) {
247                        continue;
248                    }
249
250                    let r = compute(driver, netlist, &mut results, &mut visiting);
251                    if let CombDepthResult::Depth(d) = r {
252                        max_depth = Some(max_depth.map_or(d, |m| m.max(d)));
253                    }
254                }
255            }
256        }
257
258        Ok(SimpleCombDepth {
259            _netlist: netlist,
260            results,
261            max_depth,
262        })
263    }
264}
265
266/// An enum to provide pseudo-nodes for any misc user-programmable behavior.
267#[cfg(feature = "graph")]
268#[derive(Debug, Clone)]
269pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
270    /// A 'real' circuit node
271    NetRef(NetRef<I>),
272    /// Any other user-programmable node
273    Pseudo(T),
274}
275
276#[cfg(feature = "graph")]
277impl<I, T> std::fmt::Display for Node<I, T>
278where
279    I: Instantiable,
280    T: Clone + std::fmt::Debug + std::fmt::Display,
281{
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        match self {
284            Node::NetRef(nr) => nr.fmt(f),
285            Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
286        }
287    }
288}
289
290/// An enum to provide pseudo-edges for any misc user-programmable behavior.
291#[cfg(feature = "graph")]
292#[derive(Debug, Clone)]
293pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
294    /// A 'real' circuit connection
295    Connection(Connection<I>),
296    /// Any other user-programmable node
297    Pseudo(T),
298}
299
300#[cfg(feature = "graph")]
301impl<I, T> std::fmt::Display for Edge<I, T>
302where
303    I: Instantiable,
304    T: Clone + std::fmt::Debug + std::fmt::Display,
305{
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        match self {
308            Edge::Connection(c) => c.fmt(f),
309            Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
310        }
311    }
312}
313
314/// Returns a petgraph representation of the netlist as a directed multi-graph with type [DiGraph<Object, NetLabel>].
315#[cfg(feature = "graph")]
316pub struct MultiDiGraph<'a, I: Instantiable> {
317    _netlist: &'a Netlist<I>,
318    graph: DiGraph<Node<I, String>, Edge<I, Net>>,
319}
320
321#[cfg(feature = "graph")]
322impl<I> MultiDiGraph<'_, I>
323where
324    I: Instantiable,
325{
326    /// Return a reference to the graph constructed by this analysis
327    pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
328        &self.graph
329    }
330
331    /// Iterates through a [greedy feedback arc set](https://doi.org/10.1016/0020-0190(93)90079-O) for the graph.
332    pub fn greedy_feedback_arcs(&self) -> impl Iterator<Item = Connection<I>> {
333        petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&self.graph)
334            .map(|e| match e.weight() {
335                Edge::Connection(c) => c,
336                _ => unreachable!("Outputs should be sinks"),
337            })
338            .cloned()
339    }
340}
341
342#[cfg(feature = "graph")]
343impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
344where
345    I: Instantiable,
346{
347    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
348        // If we verify, we can hash by name
349        netlist.verify()?;
350        let mut mapping = HashMap::new();
351        let mut graph = DiGraph::new();
352
353        for obj in netlist.objects() {
354            let id = graph.add_node(Node::NetRef(obj.clone()));
355            mapping.insert(obj.to_string(), id);
356        }
357
358        for connection in netlist.connections() {
359            let source = connection.src().unwrap().get_obj().to_string();
360            let target = connection.target().unwrap().get_obj().to_string();
361            let s_id = mapping[&source];
362            let t_id = mapping[&target];
363            graph.add_edge(s_id, t_id, Edge::Connection(connection));
364        }
365
366        // Finally, add the output connections
367        for (o, n) in netlist.outputs() {
368            let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
369            let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
370            graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
371        }
372
373        Ok(Self {
374            _netlist: netlist,
375            graph,
376        })
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use crate::{format_id, netlist::*};
384
385    fn full_adder() -> Gate {
386        Gate::new_logical_multi(
387            "FA".into(),
388            vec!["CIN".into(), "A".into(), "B".into()],
389            vec!["S".into(), "COUT".into()],
390        )
391    }
392
393    fn ripple_adder() -> GateNetlist {
394        let netlist = Netlist::new("ripple_adder".to_string());
395        let bitwidth = 4;
396
397        // Add the the inputs
398        let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
399        let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
400        let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
401
402        for (i, (a, b)) in a.into_iter().zip(b.into_iter()).enumerate() {
403            // Instantiate a full adder for each bit
404            let fa = netlist
405                .insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
406                .unwrap();
407
408            // Expose the sum
409            fa.expose_net(&fa.get_net(0)).unwrap();
410
411            carry = fa.find_output(&"COUT".into()).unwrap();
412
413            if i == bitwidth - 1 {
414                // Last full adder, expose the carry out
415                fa.get_output(1).expose_with_name("cout".into()).unwrap();
416            }
417        }
418
419        netlist.reclaim().unwrap()
420    }
421
422    #[test]
423    fn fanout_table() {
424        let netlist = ripple_adder();
425        let analysis = FanOutTable::build(&netlist);
426        assert!(analysis.is_ok());
427        let analysis = analysis.unwrap();
428        assert!(netlist.verify().is_ok());
429
430        for item in netlist.objects().filter(|o| !o.is_an_input()) {
431            // Sum bit has no users (it is a direct output)
432            assert!(
433                analysis
434                    .get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
435                    .next()
436                    .is_none(),
437                "Sum bit should not have users"
438            );
439
440            assert!(
441                item.get_instance_name().is_some(),
442                "Item should have a name. Filtered inputs"
443            );
444
445            let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
446            let mut cout_users = analysis.get_net_users(&net);
447            if item.get_instance_name().unwrap().to_string() != "fa_3" {
448                assert!(cout_users.next().is_some(), "Carry bit should have users");
449            }
450
451            assert!(
452                cout_users.next().is_none(),
453                "Carry bit should have 1 or 0 user"
454            );
455        }
456    }
457}