safety_net/
graph.rs

1/*!
2
3  Graph utils for the `graph` module.
4
5*/
6
7use crate::circuit::{Instantiable, Net};
8use crate::error::Error;
9#[cfg(feature = "graph")]
10use crate::netlist::Connection;
11use crate::netlist::iter::DFSIterator;
12use crate::netlist::{NetRef, Netlist};
13#[cfg(feature = "graph")]
14use petgraph::graph::DiGraph;
15use std::collections::hash_map::Entry;
16use std::collections::{HashMap, HashSet};
17
18/// A common trait of analyses than can be performed on a netlist.
19/// An analysis becomes stale when the netlist is modified.
20pub trait Analysis<'a, I: Instantiable>
21where
22    Self: Sized + 'a,
23{
24    /// Construct the analysis to the current state of the netlist.
25    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error>;
26}
27
28/// A table that maps nets to the circuit nodes they drive
29pub struct FanOutTable<'a, I: Instantiable> {
30    /// A reference to the underlying netlist
31    _netlist: &'a Netlist<I>,
32    /// Maps a net to the list of nodes it drives
33    net_fan_out: HashMap<Net, Vec<NetRef<I>>>,
34    /// Maps a node to the list of nodes it drives
35    node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>>,
36    /// Contains nets which are outputs
37    is_an_output: HashSet<Net>,
38}
39
40impl<I> FanOutTable<'_, I>
41where
42    I: Instantiable,
43{
44    /// Returns an iterator to the circuit nodes that use `net`.
45    pub fn get_net_users(&self, net: &Net) -> impl Iterator<Item = NetRef<I>> {
46        self.net_fan_out
47            .get(net)
48            .into_iter()
49            .flat_map(|users| users.iter().cloned())
50    }
51
52    /// Returns an iterator to the circuit nodes that use `node`.
53    pub fn get_node_users(&self, node: &NetRef<I>) -> impl Iterator<Item = NetRef<I>> {
54        self.node_fan_out
55            .get(node)
56            .into_iter()
57            .flat_map(|users| users.iter().cloned())
58    }
59
60    /// Returns `true` if the net has any used by any cells in the circuit
61    /// This does incude nets that are only used as outputs.
62    pub fn net_has_uses(&self, net: &Net) -> bool {
63        (self.net_fan_out.contains_key(net) && !self.net_fan_out.get(net).unwrap().is_empty())
64            || self.is_an_output.contains(net)
65    }
66}
67
68impl<'a, I> Analysis<'a, I> for FanOutTable<'a, I>
69where
70    I: Instantiable,
71{
72    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
73        let mut net_fan_out: HashMap<Net, Vec<NetRef<I>>> = HashMap::new();
74        let mut node_fan_out: HashMap<NetRef<I>, Vec<NetRef<I>>> = HashMap::new();
75        let mut is_an_output: HashSet<Net> = HashSet::new();
76
77        // This can only be fully-correct on a verified netlist.
78        netlist.verify()?;
79
80        for c in netlist.connections() {
81            if let Entry::Vacant(e) = net_fan_out.entry(c.net()) {
82                e.insert(vec![c.target().unwrap()]);
83            } else {
84                net_fan_out
85                    .get_mut(&c.net())
86                    .unwrap()
87                    .push(c.target().unwrap());
88            }
89
90            if let Entry::Vacant(e) = node_fan_out.entry(c.src().unwrap()) {
91                e.insert(vec![c.target().unwrap()]);
92            } else {
93                node_fan_out
94                    .get_mut(&c.src().unwrap())
95                    .unwrap()
96                    .push(c.target().unwrap());
97            }
98        }
99
100        for (o, n) in netlist.outputs() {
101            is_an_output.insert(o.as_net().clone());
102            is_an_output.insert(n);
103        }
104
105        Ok(FanOutTable {
106            _netlist: netlist,
107            net_fan_out,
108            node_fan_out,
109            is_an_output,
110        })
111    }
112}
113
114/// An simple example to analyze the logic levels of a netlist.
115/// This analysis checks for cycles, but it doesn't check for registers.
116pub struct SimpleCombDepth<'a, I: Instantiable> {
117    /// A reference to the underlying netlist
118    _netlist: &'a Netlist<I>,
119    /// Maps a net to its logic level as a DAG
120    comb_depth: HashMap<NetRef<I>, usize>,
121    /// The maximum depth of the circuit
122    max_depth: usize,
123}
124
125impl<I> SimpleCombDepth<'_, I>
126where
127    I: Instantiable,
128{
129    /// Returns the logic level of a node in the circuit.
130    pub fn get_comb_depth(&self, node: &NetRef<I>) -> Option<usize> {
131        self.comb_depth.get(node).cloned()
132    }
133
134    /// Returns the maximum logic level of the circuit.
135    pub fn get_max_depth(&self) -> usize {
136        self.max_depth
137    }
138}
139
140impl<'a, I> Analysis<'a, I> for SimpleCombDepth<'a, I>
141where
142    I: Instantiable,
143{
144    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
145        let mut comb_depth: HashMap<NetRef<I>, usize> = HashMap::new();
146
147        let mut nodes = Vec::new();
148        for (driven, _) in netlist.outputs() {
149            let mut dfs = DFSIterator::new(netlist, driven.clone().unwrap());
150            while let Some(n) = dfs.next() {
151                if dfs.check_cycles() {
152                    return Err(Error::CycleDetected(vec![driven.as_net().clone()]));
153                }
154                nodes.push(n);
155            }
156        }
157        nodes.reverse();
158        nodes.dedup();
159
160        for node in nodes {
161            if node.is_an_input() {
162                comb_depth.insert(node.clone(), 0);
163            } else {
164                let max_depth: usize = (0..node.get_num_input_ports())
165                    .filter_map(|i| netlist.get_driver(node.clone(), i))
166                    .filter_map(|n| comb_depth.get(&n))
167                    .max()
168                    .cloned()
169                    .unwrap_or(usize::MAX);
170
171                comb_depth.insert(node, max_depth + 1);
172            }
173        }
174
175        let max_depth = comb_depth.values().max().cloned().unwrap_or(0);
176
177        Ok(SimpleCombDepth {
178            _netlist: netlist,
179            comb_depth,
180            max_depth,
181        })
182    }
183}
184
185/// An enum to provide pseudo-nodes for any misc user-programmable behavior.
186#[cfg(feature = "graph")]
187#[derive(Debug, Clone)]
188pub enum Node<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
189    /// A 'real' circuit node
190    NetRef(NetRef<I>),
191    /// Any other user-programmable node
192    Pseudo(T),
193}
194
195#[cfg(feature = "graph")]
196impl<I, T> std::fmt::Display for Node<I, T>
197where
198    I: Instantiable,
199    T: Clone + std::fmt::Debug + std::fmt::Display,
200{
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        match self {
203            Node::NetRef(nr) => nr.fmt(f),
204            Node::Pseudo(t) => std::fmt::Display::fmt(t, f),
205        }
206    }
207}
208
209/// An enum to provide pseudo-edges for any misc user-programmable behavior.
210#[cfg(feature = "graph")]
211#[derive(Debug, Clone)]
212pub enum Edge<I: Instantiable, T: Clone + std::fmt::Debug + std::fmt::Display> {
213    /// A 'real' circuit connection
214    Connection(Connection<I>),
215    /// Any other user-programmable node
216    Pseudo(T),
217}
218
219#[cfg(feature = "graph")]
220impl<I, T> std::fmt::Display for Edge<I, T>
221where
222    I: Instantiable,
223    T: Clone + std::fmt::Debug + std::fmt::Display,
224{
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Edge::Connection(c) => c.fmt(f),
228            Edge::Pseudo(t) => std::fmt::Display::fmt(t, f),
229        }
230    }
231}
232
233/// Returns a petgraph representation of the netlist as a directed multi-graph with type [DiGraph<Object, NetLabel>].
234#[cfg(feature = "graph")]
235pub struct MultiDiGraph<'a, I: Instantiable> {
236    _netlist: &'a Netlist<I>,
237    graph: DiGraph<Node<I, String>, Edge<I, Net>>,
238}
239
240#[cfg(feature = "graph")]
241impl<I> MultiDiGraph<'_, I>
242where
243    I: Instantiable,
244{
245    /// Return a reference to the graph constructed by this analysis
246    pub fn get_graph(&self) -> &DiGraph<Node<I, String>, Edge<I, Net>> {
247        &self.graph
248    }
249}
250
251#[cfg(feature = "graph")]
252impl<'a, I> Analysis<'a, I> for MultiDiGraph<'a, I>
253where
254    I: Instantiable,
255{
256    fn build(netlist: &'a Netlist<I>) -> Result<Self, Error> {
257        // If we verify, we can hash by name
258        netlist.verify()?;
259        let mut mapping = HashMap::new();
260        let mut graph = DiGraph::new();
261
262        for obj in netlist.objects() {
263            let id = graph.add_node(Node::NetRef(obj.clone()));
264            mapping.insert(obj.to_string(), id);
265        }
266
267        for connection in netlist.connections() {
268            let source = connection.src().unwrap().get_obj().to_string();
269            let target = connection.target().unwrap().get_obj().to_string();
270            let s_id = mapping[&source];
271            let t_id = mapping[&target];
272            graph.add_edge(s_id, t_id, Edge::Connection(connection));
273        }
274
275        // Finally, add the output connections
276        for (o, n) in netlist.outputs() {
277            let s_id = mapping[&o.clone().unwrap().get_obj().to_string()];
278            let t_id = graph.add_node(Node::Pseudo(format!("Output({n})")));
279            graph.add_edge(s_id, t_id, Edge::Pseudo(o.as_net().clone()));
280        }
281
282        Ok(Self {
283            _netlist: netlist,
284            graph,
285        })
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::{format_id, netlist::*};
293
294    fn full_adder() -> Gate {
295        Gate::new_logical_multi(
296            "FA".into(),
297            vec!["CIN".into(), "A".into(), "B".into()],
298            vec!["S".into(), "COUT".into()],
299        )
300    }
301
302    fn ripple_adder() -> GateNetlist {
303        let netlist = Netlist::new("ripple_adder".to_string());
304        let bitwidth = 4;
305
306        // Add the the inputs
307        let a = netlist.insert_input_escaped_logic_bus("a".to_string(), bitwidth);
308        let b = netlist.insert_input_escaped_logic_bus("b".to_string(), bitwidth);
309        let mut carry: DrivenNet<Gate> = netlist.insert_input("cin".into());
310
311        for (i, (a, b)) in a.into_iter().zip(b.into_iter()).enumerate() {
312            // Instantiate a full adder for each bit
313            let fa = netlist
314                .insert_gate(full_adder(), format_id!("fa_{i}"), &[carry, a, b])
315                .unwrap();
316
317            // Expose the sum
318            fa.expose_net(&fa.get_net(0)).unwrap();
319
320            carry = fa.find_output(&"COUT".into()).unwrap();
321
322            if i == bitwidth - 1 {
323                // Last full adder, expose the carry out
324                fa.get_output(1).expose_with_name("cout".into()).unwrap();
325            }
326        }
327
328        netlist.reclaim().unwrap()
329    }
330
331    #[test]
332    fn fanout_table() {
333        let netlist = ripple_adder();
334        let analysis = FanOutTable::build(&netlist);
335        assert!(analysis.is_ok());
336        let analysis = analysis.unwrap();
337        assert!(netlist.verify().is_ok());
338
339        for item in netlist.objects().filter(|o| !o.is_an_input()) {
340            // Sum bit has no users (it is a direct output)
341            assert!(
342                analysis
343                    .get_net_users(&item.find_output(&"S".into()).unwrap().as_net())
344                    .next()
345                    .is_none(),
346                "Sum bit should not have users"
347            );
348
349            assert!(
350                item.get_instance_name().is_some(),
351                "Item should have a name. Filtered inputs"
352            );
353
354            let net = item.find_output(&"COUT".into()).unwrap().as_net().clone();
355            let mut cout_users = analysis.get_net_users(&net);
356            if item.get_instance_name().unwrap().to_string() != "fa_3" {
357                assert!(cout_users.next().is_some(), "Carry bit should have users");
358            }
359
360            assert!(
361                cout_users.next().is_none(),
362                "Carry bit should have 1 or 0 user"
363            );
364        }
365    }
366}