swh_graph/views/
subgraph.rs

1// Copyright (C) 2023-2025  The Software Heritage developers
2// See the AUTHORS file at the top-level directory of this distribution
3// License: GNU General Public License version 3, or any later version
4// See top-level LICENSE file for more information
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use anyhow::{anyhow, Result};
10use webgraph::traits::labels::SortedIterator;
11
12use crate::arc_iterators::FlattenedSuccessorsIterator;
13use crate::graph::*;
14use crate::properties;
15use crate::{NodeConstraint, NodeType};
16
17macro_rules! make_filtered_arcs_iterator {
18    ($name:ident, $inner:ident, $( $next:tt )*) => {
19        pub struct $name<
20            'a,
21            $inner: Iterator<Item = NodeId> + 'a,
22            NodeFilter: Fn(NodeId) -> bool,
23            ArcFilter: Fn(NodeId, NodeId) -> bool,
24        > {
25            inner: $inner,
26            node: NodeId,
27            node_filter: &'a NodeFilter,
28            arc_filter: &'a ArcFilter,
29        }
30
31        impl<
32            'a,
33            $inner: Iterator<Item = NodeId> + 'a,
34            NodeFilter: Fn(NodeId) -> bool,
35            ArcFilter: Fn(NodeId, NodeId) -> bool,
36        > Iterator for $name<'a, $inner, NodeFilter, ArcFilter> {
37            type Item = $inner::Item;
38
39            $( $next )*
40        }
41
42        // SAFETY: filtering out elements out of an iterator preserves sortedness
43        unsafe impl<
44            'a,
45            $inner: SortedIterator<Item = NodeId> + 'a,
46            NodeFilter: Fn(NodeId) -> bool,
47            ArcFilter: Fn(NodeId, NodeId) -> bool,
48        > SortedIterator for $name<'a, $inner, NodeFilter, ArcFilter> {
49        }
50    }
51}
52
53make_filtered_arcs_iterator! {
54    FilteredSuccessors,
55    Successors,
56    fn next(&mut self) -> Option<Self::Item> {
57        if !(self.node_filter)(self.node) {
58            return None;
59        }
60        for dst in self.inner.by_ref() {
61            if (self.node_filter)(dst) && (self.arc_filter)(self.node, dst) {
62                return Some(dst)
63            }
64        }
65        None
66    }
67}
68make_filtered_arcs_iterator! {
69    FilteredPredecessors,
70    Predecessors,
71    fn next(&mut self) -> Option<Self::Item> {
72        if !(self.node_filter)(self.node) {
73            return None;
74        }
75        for src in self.inner.by_ref() {
76            if (self.node_filter)(src) && (self.arc_filter)(src, self.node) {
77                return Some(src)
78            }
79        }
80        None
81    }
82}
83
84macro_rules! make_filtered_labeled_arcs_iterator {
85    ($name:ident, $inner:ident, $( $next:tt )*) => {
86        pub struct $name<
87            'a,
88            Labels,
89            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
90            NodeFilter: Fn(NodeId) -> bool,
91            ArcFilter: Fn(NodeId, NodeId) -> bool,
92        > {
93            inner: $inner,
94            node: NodeId,
95            node_filter: &'a NodeFilter,
96            arc_filter: &'a ArcFilter,
97        }
98
99        impl<
100            'a,
101            Labels,
102            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
103            NodeFilter: Fn(NodeId) -> bool,
104            ArcFilter: Fn(NodeId, NodeId) -> bool,
105        > Iterator for $name<'a, Labels, $inner, NodeFilter, ArcFilter> {
106            type Item = $inner::Item;
107
108            $( $next )*
109        }
110
111        // SAFETY: filtering out elements out of an iterator preserves sortedness
112        // 'Labels' itself does not need to be sorted because we only implement
113        // SortedIterator on the outer iterator, not in the inner one.
114        unsafe impl<
115            'a,
116            Labels,
117            $inner: SortedIterator<Item = (NodeId, Labels)> + 'a,
118            NodeFilter: Fn(NodeId) -> bool,
119            ArcFilter: Fn(NodeId, NodeId) -> bool,
120        > SortedIterator for $name<'a, Labels, $inner, NodeFilter, ArcFilter>
121        {
122        }
123
124        impl<
125            'a,
126            Labels: IntoIterator,
127            $inner: Iterator<Item = (NodeId, Labels)> + 'a,
128            NodeFilter: Fn(NodeId) -> bool,
129            ArcFilter: Fn(NodeId, NodeId) -> bool,
130        > IntoFlattenedLabeledArcsIterator<<Labels as IntoIterator>::Item> for $name<'a, Labels, $inner, NodeFilter, ArcFilter> {
131            type Flattened = FlattenedSuccessorsIterator<Self>;
132
133            fn flatten_labels(self) -> Self::Flattened {
134                FlattenedSuccessorsIterator::new(self)
135            }
136        }
137    }
138}
139
140make_filtered_labeled_arcs_iterator! {
141    FilteredLabeledSuccessors,
142    LabeledSuccessors,
143    fn next(&mut self) -> Option<Self::Item> {
144        if !(self.node_filter)(self.node) {
145            return None;
146        }
147        for (dst, label) in self.inner.by_ref() {
148            if (self.node_filter)(dst) && (self.arc_filter)(self.node, dst) {
149                return Some((dst, label))
150            }
151        }
152        None
153    }
154}
155make_filtered_labeled_arcs_iterator! {
156    FilteredLabeledPredecessors,
157    LabeledPredecessors,
158    fn next(&mut self) -> Option<Self::Item> {
159        if !(self.node_filter)(self.node) {
160            return None;
161        }
162        for (src, label) in self.inner.by_ref() {
163            if (self.node_filter)(src) && (self.arc_filter)(src, self.node) {
164                return Some((src, label))
165            }
166        }
167        None
168    }
169}
170
171/// A view over [`SwhGraph`] and related traits, that filters out some nodes and arcs
172/// based on arbitrary closures.
173pub struct Subgraph<G: SwhGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
174{
175    pub graph: G,
176    pub node_filter: NodeFilter,
177    pub arc_filter: ArcFilter,
178    pub num_nodes_by_type: Option<HashMap<NodeType, usize>>,
179    pub num_arcs_by_type: Option<HashMap<(NodeType, NodeType), usize>>,
180}
181
182impl<G: SwhGraph, NodeFilter: Fn(usize) -> bool> Subgraph<G, NodeFilter, fn(usize, usize) -> bool> {
183    /// Create a [Subgraph] keeping only nodes matching a given node filter function.
184    ///
185    /// Shorthand for `Subgraph { graph, node_filter, arc_filter: |_src, _dst| true }`
186    pub fn with_node_filter(
187        graph: G,
188        node_filter: NodeFilter,
189    ) -> Subgraph<G, NodeFilter, fn(usize, usize) -> bool> {
190        Subgraph {
191            graph,
192            node_filter,
193            arc_filter: |_src, _dst| true,
194            num_nodes_by_type: None,
195            num_arcs_by_type: None,
196        }
197    }
198}
199
200impl<G: SwhGraph, ArcFilter: Fn(usize, usize) -> bool> Subgraph<G, fn(usize) -> bool, ArcFilter> {
201    /// Create a [Subgraph] keeping only arcs matching a arc filter function.
202    ///
203    /// Shorthand for `Subgraph { graph, node_filter: |_node| true, arc_filter }`
204    pub fn with_arc_filter(
205        graph: G,
206        arc_filter: ArcFilter,
207    ) -> Subgraph<G, fn(usize) -> bool, ArcFilter> {
208        Subgraph {
209            graph,
210            node_filter: |_node| true,
211            arc_filter,
212            num_nodes_by_type: None,
213            num_arcs_by_type: None,
214        }
215    }
216}
217
218impl<G> Subgraph<G, fn(usize) -> bool, fn(usize, usize) -> bool>
219where
220    G: SwhGraphWithProperties + Clone,
221    <G as SwhGraphWithProperties>::Maps: properties::Maps,
222{
223    /// Create a [Subgraph] keeping only nodes matching a given node constraint.
224    #[allow(clippy::type_complexity)]
225    pub fn with_node_constraint(
226        graph: G,
227        node_constraint: NodeConstraint,
228    ) -> Subgraph<G, impl Fn(NodeId) -> bool, fn(usize, usize) -> bool> {
229        Subgraph {
230            graph: graph.clone(),
231            num_nodes_by_type: graph.num_nodes_by_type().ok().map(|counts| {
232                counts
233                    .into_iter()
234                    .filter(|&(type_, _count)| node_constraint.matches(type_))
235                    .collect()
236            }),
237            num_arcs_by_type: graph.num_arcs_by_type().ok().map(|counts| {
238                counts
239                    .into_iter()
240                    .filter(|&((src_type, dst_type), _count)| {
241                        node_constraint.matches(src_type) && node_constraint.matches(dst_type)
242                    })
243                    .collect()
244            }),
245            node_filter: move |node| node_constraint.matches(graph.properties().node_type(node)),
246            arc_filter: |_src, _dst| true,
247        }
248    }
249}
250
251impl<G: SwhGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool> SwhGraph
252    for Subgraph<G, NodeFilter, ArcFilter>
253{
254    fn path(&self) -> &Path {
255        self.graph.path()
256    }
257    fn is_transposed(&self) -> bool {
258        self.graph.is_transposed()
259    }
260    // Note: this return the number or nodes in the original graph, before
261    // subgraph filtering.
262    fn num_nodes(&self) -> usize {
263        self.graph.num_nodes()
264    }
265    fn has_node(&self, node_id: NodeId) -> bool {
266        (self.node_filter)(node_id)
267    }
268    // Note: this return the number or arcs in the original graph, before
269    // subgraph filtering.
270    fn num_arcs(&self) -> u64 {
271        self.graph.num_arcs()
272    }
273    fn num_nodes_by_type(&self) -> Result<HashMap<NodeType, usize>> {
274        self.num_nodes_by_type.clone().ok_or(anyhow!(
275            "num_nodes_by_type is not supported by this Subgraph (if possible, use Subgraph::with_node_constraint to build it)"
276        ))
277    }
278    fn num_arcs_by_type(&self) -> Result<HashMap<(NodeType, NodeType), usize>> {
279        self.num_arcs_by_type.clone().ok_or(anyhow!(
280            "num_arcs_by_type is not supported by this Subgraph (if possible, use Subgraph::with_node_constraint to build it)"
281        ))
282    }
283    fn has_arc(&self, src_node_id: NodeId, dst_node_id: NodeId) -> bool {
284        (self.node_filter)(src_node_id)
285            && (self.node_filter)(dst_node_id)
286            && (self.arc_filter)(src_node_id, dst_node_id)
287            && self.graph.has_arc(src_node_id, dst_node_id)
288    }
289}
290
291impl<G: SwhForwardGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
292    SwhForwardGraph for Subgraph<G, NodeFilter, ArcFilter>
293{
294    type Successors<'succ>
295        = FilteredSuccessors<
296        'succ,
297        <<G as SwhForwardGraph>::Successors<'succ> as IntoIterator>::IntoIter,
298        NodeFilter,
299        ArcFilter,
300    >
301    where
302        Self: 'succ;
303
304    fn successors(&self, node_id: NodeId) -> Self::Successors<'_> {
305        FilteredSuccessors {
306            inner: self.graph.successors(node_id).into_iter(),
307            node: node_id,
308            node_filter: &self.node_filter,
309            arc_filter: &self.arc_filter,
310        }
311    }
312    fn outdegree(&self, node_id: NodeId) -> usize {
313        self.successors(node_id).count()
314    }
315}
316
317impl<G: SwhBackwardGraph, NodeFilter: Fn(usize) -> bool, ArcFilter: Fn(usize, usize) -> bool>
318    SwhBackwardGraph for Subgraph<G, NodeFilter, ArcFilter>
319{
320    type Predecessors<'succ>
321        = FilteredPredecessors<
322        'succ,
323        <<G as SwhBackwardGraph>::Predecessors<'succ> as IntoIterator>::IntoIter,
324        NodeFilter,
325        ArcFilter,
326    >
327    where
328        Self: 'succ;
329
330    fn predecessors(&self, node_id: NodeId) -> Self::Predecessors<'_> {
331        FilteredPredecessors {
332            inner: self.graph.predecessors(node_id).into_iter(),
333            node: node_id,
334            node_filter: &self.node_filter,
335            arc_filter: &self.arc_filter,
336        }
337    }
338    fn indegree(&self, node_id: NodeId) -> usize {
339        self.predecessors(node_id).count()
340    }
341}
342
343impl<
344        G: SwhLabeledForwardGraph,
345        NodeFilter: Fn(usize) -> bool,
346        ArcFilter: Fn(usize, usize) -> bool,
347    > SwhLabeledForwardGraph for Subgraph<G, NodeFilter, ArcFilter>
348{
349    type LabeledArcs<'arc>
350        = <G as SwhLabeledForwardGraph>::LabeledArcs<'arc>
351    where
352        Self: 'arc;
353    type LabeledSuccessors<'node>
354        = FilteredLabeledSuccessors<
355        'node,
356        Self::LabeledArcs<'node>,
357        <<G as SwhLabeledForwardGraph>::LabeledSuccessors<'node> as IntoIterator>::IntoIter,
358        NodeFilter,
359        ArcFilter,
360    >
361    where
362        Self: 'node;
363
364    fn untyped_labeled_successors(&self, node_id: NodeId) -> Self::LabeledSuccessors<'_> {
365        FilteredLabeledSuccessors {
366            inner: self.graph.untyped_labeled_successors(node_id).into_iter(),
367            node: node_id,
368            node_filter: &self.node_filter,
369            arc_filter: &self.arc_filter,
370        }
371    }
372}
373
374impl<
375        G: SwhLabeledBackwardGraph,
376        NodeFilter: Fn(usize) -> bool,
377        ArcFilter: Fn(usize, usize) -> bool,
378    > SwhLabeledBackwardGraph for Subgraph<G, NodeFilter, ArcFilter>
379{
380    type LabeledArcs<'arc>
381        = <G as SwhLabeledBackwardGraph>::LabeledArcs<'arc>
382    where
383        Self: 'arc;
384    type LabeledPredecessors<'node>
385        = FilteredLabeledPredecessors<
386        'node,
387        Self::LabeledArcs<'node>,
388        <<G as SwhLabeledBackwardGraph>::LabeledPredecessors<'node> as IntoIterator>::IntoIter,
389        NodeFilter,
390        ArcFilter,
391    >
392    where
393        Self: 'node;
394
395    fn untyped_labeled_predecessors(&self, node_id: NodeId) -> Self::LabeledPredecessors<'_> {
396        FilteredLabeledPredecessors {
397            inner: self.graph.untyped_labeled_predecessors(node_id).into_iter(),
398            node: node_id,
399            node_filter: &self.node_filter,
400            arc_filter: &self.arc_filter,
401        }
402    }
403}
404
405impl<
406        G: SwhGraphWithProperties,
407        NodeFilter: Fn(usize) -> bool,
408        ArcFilter: Fn(usize, usize) -> bool,
409    > SwhGraphWithProperties for Subgraph<G, NodeFilter, ArcFilter>
410{
411    type Maps = <G as SwhGraphWithProperties>::Maps;
412    type Timestamps = <G as SwhGraphWithProperties>::Timestamps;
413    type Persons = <G as SwhGraphWithProperties>::Persons;
414    type Contents = <G as SwhGraphWithProperties>::Contents;
415    type Strings = <G as SwhGraphWithProperties>::Strings;
416    type LabelNames = <G as SwhGraphWithProperties>::LabelNames;
417
418    fn properties(
419        &self,
420    ) -> &properties::SwhGraphProperties<
421        Self::Maps,
422        Self::Timestamps,
423        Self::Persons,
424        Self::Contents,
425        Self::Strings,
426        Self::LabelNames,
427    > {
428        self.graph.properties()
429    }
430}