swh_graph/compress/
transform.rs

1// Copyright (C) 2023  The Software Heritage developers
2// See the AUTHORS file at the top-level directory of this distribution
3// License: GNU General Public License version 3, or any later version
4// See top-level LICENSE file for more information
5
6use std::path::PathBuf;
7use std::sync::Mutex;
8
9use anyhow::{Context, Result};
10use dsi_bitstream::prelude::BE;
11use dsi_progress_logger::{progress_logger, ProgressLog};
12use lender::Lender;
13use rayon::prelude::*;
14use webgraph::graphs::arc_list_graph::ArcListGraph;
15use webgraph::prelude::*;
16
17use crate::utils::sort::par_sort_arcs;
18
19/// Writes a new graph on disk, obtained by applying the function to all arcs
20/// on the source graph.
21pub fn transform<F, G, Iter>(
22    input_batch_size: usize,
23    sort_batch_size: usize,
24    partitions_per_thread: usize,
25    graph: G,
26    transformation: F,
27    target_path: PathBuf,
28) -> Result<()>
29where
30    F: Fn(usize, usize) -> Iter + Send + Sync,
31    Iter: IntoIterator<Item = (usize, usize)>,
32    G: RandomAccessGraph + Sync,
33{
34    // Adapted from https://github.com/vigna/webgraph-rs/blob/08969fb1ac4ea59aafdbae976af8e026a99c9ac5/src/bin/perm.rs
35    let num_nodes = graph.num_nodes();
36
37    let num_batches = num_nodes.div_ceil(input_batch_size);
38
39    let temp_dir = tempfile::tempdir().context("Could not get temporary_directory")?;
40
41    let num_threads = num_cpus::get();
42    let num_partitions = num_threads * partitions_per_thread;
43    let nodes_per_partition = num_nodes.div_ceil(num_partitions);
44
45    // Avoid empty partitions at the end when there are very few nodes
46    let num_partitions = num_nodes.div_ceil(nodes_per_partition);
47
48    log::info!(
49        "Transforming {} nodes with {} threads, {} partitions, {} nodes per partition, {} batches of size {}",
50        num_nodes,
51        num_threads,
52        num_partitions,
53        nodes_per_partition,
54        num_batches,
55        input_batch_size
56    );
57
58    let mut pl = progress_logger!(
59        display_memory = true,
60        item_name = "node",
61        expected_updates = Some(num_nodes),
62        local_speed = true,
63    );
64    pl.start("Reading and sorting...");
65    let pl = Mutex::new(pl);
66
67    // Merge sorted arc lists into a single sorted arc list
68    let sorted_arcs = par_sort_arcs(
69        temp_dir.path(),
70        sort_batch_size,
71        (0..num_batches).into_par_iter(),
72        num_partitions,
73        (),
74        (),
75        |buffer, batch_id| -> Result<()> {
76            let start = batch_id * input_batch_size;
77            let end = (batch_id + 1) * input_batch_size;
78            graph // Not using PermutedGraph in order to avoid blanket iter_nodes_from
79                .iter_from(start)
80                .take_while(|(node_id, _successors)| *node_id < end)
81                .try_for_each(|(x, succ)| -> Result<()> {
82                    succ.into_iter().try_for_each(|s| -> Result<()> {
83                        for (x, s) in transformation(x, s).into_iter() {
84                            let partition_id = x / nodes_per_partition;
85                            buffer.insert(partition_id, x, s)?;
86                        }
87                        Ok(())
88                    })
89                })?;
90            pl.lock().unwrap().update_with_count(end - start);
91            Ok(())
92        },
93    )
94    .context("Could not sort arcs")?;
95    pl.lock().unwrap().done();
96
97    let arc_list_graphs =
98        sorted_arcs
99            .into_iter()
100            .enumerate()
101            .map(|(partition_id, sorted_arcs_partition)| {
102                webgraph::prelude::Left(ArcListGraph::new_labeled(num_nodes, sorted_arcs_partition))
103                    .iter_from(partition_id * nodes_per_partition)
104                    .take(nodes_per_partition)
105            });
106
107    let compression_flags = CompFlags {
108        compression_window: 1,
109        min_interval_length: 4,
110        max_ref_count: 3,
111        ..CompFlags::default()
112    };
113
114    let temp_bv_dir = temp_dir.path().join("transform-bv");
115    std::fs::create_dir(&temp_bv_dir)
116        .with_context(|| format!("Could not create {}", temp_bv_dir.display()))?;
117    BvComp::parallel_iter::<BE, _>(
118        target_path,
119        arc_list_graphs,
120        num_nodes,
121        compression_flags,
122        &rayon::ThreadPoolBuilder::default()
123            .build()
124            .expect("Could not create BvComp thread pool"),
125        &temp_bv_dir,
126    )
127    .context("Could not build BVGraph from arcs")?;
128
129    drop(temp_dir); // Prevent early deletion
130
131    Ok(())
132}