swh_graph/compress/
transform.rs

1// Copyright (C) 2023-2025  The Software Heritage developers
2// See the AUTHORS file at the top-level directory of this distribution
3// License: GNU General Public License version 3, or any later version
4// See top-level LICENSE file for more information
5
6use std::num::NonZeroUsize;
7use std::path::PathBuf;
8
9use anyhow::{Context, Result};
10use dsi_bitstream::prelude::BE;
11use dsi_progress_logger::{concurrent_progress_logger, ProgressLog};
12use itertools::Itertools;
13use lender::{IntoIteratorExt, IntoLender, Lender};
14use rayon::prelude::*;
15use webgraph::graphs::arc_list_graph::ArcListGraph;
16use webgraph::prelude::*;
17use webgraph::utils::ParSortPairs;
18
19/// Writes a new graph on disk, obtained by applying the function to all arcs
20/// on the source graph.
21pub fn transform<F, G, Iter>(
22    input_batch_size: usize,
23    partitions_per_thread: usize,
24    graph: G,
25    transformation: F,
26    target_path: PathBuf,
27) -> Result<()>
28where
29    F: Fn(usize, usize) -> Iter + Send + Sync,
30    Iter: IntoIterator<Item = (usize, usize), IntoIter: Send + Sync>,
31    G: SplitLabeling<Label=usize>,
32    for<'a> <<G as SplitLabeling>::IntoIterator<'a> as IntoIterator>::IntoIter: Send + Sync,
33    for<'a, 'b> <<<G as SplitLabeling>::SplitLender<'a> as NodeLabelsLender<'b>>::IntoIterator as IntoIterator>::IntoIter: Send + Sync,
34{
35    // Adapted from https://github.com/vigna/webgraph-rs/blob/08969fb1ac4ea59aafdbae976af8e026a99c9ac5/src/bin/perm.rs
36    let num_nodes = graph.num_nodes();
37
38    let num_batches = num_nodes.div_ceil(input_batch_size);
39
40    let temp_dir = tempfile::tempdir().context("Could not get temporary_directory")?;
41
42    let num_threads = num_cpus::get();
43    let num_partitions = num_threads * partitions_per_thread;
44    let nodes_per_partition = num_nodes.div_ceil(num_partitions);
45
46    // Avoid empty partitions at the end when there are very few nodes
47    let num_partitions = num_nodes.div_ceil(nodes_per_partition);
48
49    log::info!(
50        "Transforming {} nodes with {} threads, {} partitions, {} nodes per partition, {} batches of size {}",
51        num_nodes,
52        num_threads,
53        num_partitions,
54        nodes_per_partition,
55        num_batches,
56        input_batch_size
57    );
58
59    let mut pl = concurrent_progress_logger!(
60        display_memory = true,
61        item_name = "node",
62        expected_updates = Some(num_nodes),
63        local_speed = true,
64    );
65    pl.start("Reading and sorting...");
66
67    // Merge sorted arc lists into a single sorted arc list
68    let pair_sorter =
69        ParSortPairs::new(num_nodes)?.num_partitions(NonZeroUsize::new(num_partitions).unwrap());
70    let transformation = &transformation;
71    let sorted_arcs = {
72        let pl = pl.clone();
73        pair_sorter
74            .sort(
75                graph
76                    .split_iter(num_partitions)
77                    .into_iter()
78                    .collect::<Vec<_>>()
79                    .into_par_iter()
80                    .flat_map_iter(move |partition| {
81                        let mut pl = pl.clone();
82                        partition
83                            .flat_map(move |(src, succ)| {
84                                let transformed_succ: Vec<_> = succ
85                                    .into_iter()
86                                    .flat_map(move |dst| {
87                                        let res: Vec<_> =
88                                            transformation(src, dst).into_iter().collect();
89                                        println!("{src}->{dst}   ->    {res:?}");
90                                        res.into_iter()
91                                    })
92                                    .collect();
93                                pl.light_update();
94                                transformed_succ.into_into_lender().into_lender()
95                            })
96                            .iter()
97                    }),
98            )
99            .context("Could not sort arcs")?
100    };
101    pl.done();
102
103    let arc_list_graphs = Vec::from(sorted_arcs.iters).into_iter().enumerate().map(
104        |(partition_id, sorted_arcs_partition)| {
105            ArcListGraph::new(num_nodes, sorted_arcs_partition.into_iter().dedup())
106                .iter_from(sorted_arcs.boundaries[partition_id])
107                .take(
108                    sorted_arcs.boundaries[partition_id + 1]
109                        .checked_sub(sorted_arcs.boundaries[partition_id])
110                        .expect("sorted_arcs.boundaries is not sorted"),
111                )
112        },
113    );
114
115    let compression_flags = CompFlags {
116        compression_window: 1,
117        min_interval_length: 4,
118        max_ref_count: 3,
119        ..CompFlags::default()
120    };
121
122    let temp_bv_dir = temp_dir.path().join("transform-bv");
123    std::fs::create_dir(&temp_bv_dir)
124        .with_context(|| format!("Could not create {}", temp_bv_dir.display()))?;
125    BvComp::parallel_iter::<BE, _>(
126        target_path,
127        arc_list_graphs.into_iter(),
128        num_nodes,
129        compression_flags,
130        &rayon::ThreadPoolBuilder::default()
131            .build()
132            .expect("Could not create BvComp thread pool"),
133        &temp_bv_dir,
134    )
135    .context("Could not build BVGraph from arcs")?;
136
137    drop(temp_dir); // Prevent early deletion
138
139    Ok(())
140}