swh_graph/utils/sort/
arcs.rs

1/*
2 * Copyright (C) 2023-2025  The Software Heritage developers
3 * See the AUTHORS file at the top-level directory of this distribution
4 * License: GNU General Public License version 3, or any later version
5 * See top-level LICENSE file for more information
6 */
7
8//! Parallel string sorting and deduplication for data that doesn't fit in RAM
9use std::cell::RefCell;
10use std::fs::File;
11use std::io::BufWriter;
12use std::path::{Path, PathBuf};
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::{Arc, Mutex};
15
16use anyhow::{Context, Result};
17use dsi_bitstream::prelude::*;
18use dsi_progress_logger::{concurrent_progress_logger, ProgressLog};
19use mmap_rs::MmapFlags;
20use rayon::prelude::*;
21use webgraph::prelude::{ArcMmapHelper, BitDeserializer, BitSerializer, MmapHelper};
22use webgraph::utils::sort_pairs::{BatchIterator, BitReader, BitWriter, KMergeIters, Triple};
23
24pub struct PartitionedBuffer<
25    L: Ord + Copy + Send + Sync,
26    S: BitSerializer<NE, BitWriter, SerType = L>,
27    D: BitDeserializer<NE, BitReader>,
28> {
29    partitions: Vec<Vec<Triple<L>>>,
30    capacity: usize,
31    sorted_iterators: Arc<Mutex<Vec<Vec<BatchIterator<D>>>>>,
32    temp_dir: PathBuf,
33    label_serializer: S,
34    label_deserializer: D,
35    // total number of items flushed this the buffer was created
36    total_flushed: Arc<AtomicUsize>,
37}
38
39impl<
40        L: Ord + Copy + Send + Sync,
41        S: BitSerializer<NE, BitWriter, SerType = L> + Copy,
42        D: BitDeserializer<NE, BitReader, DeserType = L> + Copy,
43    > PartitionedBuffer<L, S, D>
44{
45    fn new(
46        sorted_iterators: Arc<Mutex<Vec<Vec<BatchIterator<D>>>>>,
47        temp_dir: &Path,
48        batch_size: usize,
49        num_partitions: usize,
50        label_serializer: S,
51        label_deserializer: D,
52        total_flushed: Arc<AtomicUsize>,
53    ) -> Self {
54        let capacity = batch_size / num_partitions;
55        PartitionedBuffer {
56            partitions: vec![Vec::with_capacity(capacity); num_partitions],
57            sorted_iterators,
58            temp_dir: temp_dir.to_owned(),
59            capacity,
60            label_serializer,
61            label_deserializer,
62            total_flushed,
63        }
64    }
65
66    pub fn insert_labeled(
67        &mut self,
68        partition_id: usize,
69        src: usize,
70        dst: usize,
71        label: L,
72    ) -> Result<()> {
73        let partition_buffer = self
74            .partitions
75            .get_mut(partition_id)
76            .expect("Partition sorter out of bound");
77        partition_buffer.push(Triple {
78            pair: [src, dst],
79            label,
80        });
81        if partition_buffer.len() + 1 >= self.capacity {
82            self.flush(partition_id)?;
83        }
84        Ok(())
85    }
86
87    fn flush_all(&mut self) -> Result<()> {
88        for partition_id in 0..self.partitions.len() {
89            self.flush(partition_id)?;
90        }
91        Ok(())
92    }
93
94    fn flush(&mut self, partition_id: usize) -> Result<()> {
95        let partition_buffer = self
96            .partitions
97            .get_mut(partition_id)
98            .expect("Partition buffer out of bound");
99        let batch = flush(
100            &self.temp_dir,
101            &mut partition_buffer[..],
102            self.label_serializer,
103            self.label_deserializer,
104        )?;
105        self.sorted_iterators
106            .lock()
107            .unwrap()
108            .get_mut(partition_id)
109            .expect("Partition sorters out of bound")
110            .push(batch);
111        self.total_flushed
112            .fetch_add(partition_buffer.len(), Ordering::Relaxed);
113        partition_buffer.clear();
114        Ok(())
115    }
116}
117
118impl PartitionedBuffer<(), (), ()> {
119    pub fn insert(&mut self, partition_id: usize, src: usize, dst: usize) -> Result<()> {
120        self.insert_labeled(partition_id, src, dst, ())
121    }
122}
123
124/// Given an iterator and a function to insert its items to [`BatchIterator`]s, returns an
125/// iterator of pairs.
126///
127/// `f` gets as parameters `num_partitions` `BatchIterator`s; and should place its pair
128/// in such a way that all `(src, dst, labels)` in partition `n` should be lexicographically
129/// lower than those in partition `n+1`.
130///
131/// In orther words, `f` writes in arbitrary order in each partition, but partitions
132/// should be sorted with respect to each other. This allows merging partitions in
133/// parallel after they are sorted.
134pub fn par_sort_arcs<Item, Iter, F, L, S, D>(
135    temp_dir: &Path,
136    batch_size: usize,
137    iter: Iter,
138    num_partitions: usize,
139    label_serializer: S,
140    label_deserializer: D,
141    f: F,
142) -> Result<Vec<impl Iterator<Item = (usize, usize, L)> + Clone + Send + Sync>>
143where
144    F: Fn(&mut PartitionedBuffer<L, S, D>, Item) -> Result<()> + Send + Sync,
145    Iter: ParallelIterator<Item = Item>,
146    L: Ord + Copy + Send + Sync,
147    S: BitSerializer<NE, BitWriter, SerType = L> + Send + Sync + Copy,
148    D: BitDeserializer<NE, BitReader, DeserType = L> + Send + Sync + Copy,
149{
150    // For each thread, stores a vector of `num_shards` BatchIterator. The n-th BatchIterator
151    // of each thread stores arcs for nodes [n*shard_size; (n+1)*shard_size)
152    let buffers = thread_local::ThreadLocal::new();
153
154    // Read the input to buffers, and flush buffer to disk (through BatchIterator)
155    // from time to time
156    let sorted_iterators = Arc::new(Mutex::new(vec![Vec::new(); num_partitions]));
157
158    let unmerged_sorted_dir = temp_dir.join("unmerged");
159    std::fs::create_dir(&unmerged_sorted_dir)
160        .with_context(|| format!("Could not create {}", unmerged_sorted_dir.display()))?;
161
162    let num_arcs = Arc::new(AtomicUsize::new(0));
163
164    iter.try_for_each_init(
165        || -> std::cell::RefMut<PartitionedBuffer<L, S, D>> {
166            buffers
167                .get_or(|| {
168                    RefCell::new(PartitionedBuffer::new(
169                        sorted_iterators.clone(),
170                        &unmerged_sorted_dir,
171                        batch_size,
172                        num_partitions,
173                        label_serializer,
174                        label_deserializer,
175                        num_arcs.clone(),
176                    ))
177                })
178                .borrow_mut()
179        },
180        |thread_buffers, item| -> Result<()> {
181            let thread_buffers = &mut *thread_buffers;
182            f(thread_buffers, item)
183        },
184    )?;
185
186    log::info!("Flushing remaining buffers to BatchIterator...");
187
188    // Flush all buffers even if not full
189    buffers.into_iter().par_bridge().try_for_each(
190        |thread_buffer: RefCell<PartitionedBuffer<L, S, D>>| -> Result<()> {
191            thread_buffer.into_inner().flush_all()
192        },
193    )?;
194    log::info!("Done sorting all buffers.");
195
196    let sorted_iterators = Arc::into_inner(sorted_iterators)
197        .expect("Dangling references to sorted_iterators Arc")
198        .into_inner()
199        .unwrap();
200
201    let num_arcs = Arc::into_inner(num_arcs)
202        .expect("Could not take ownership of num_arcs")
203        .into_inner();
204
205    let merged_sorted_dir = temp_dir.join("merged");
206    std::fs::create_dir(&merged_sorted_dir)
207        .with_context(|| format!("Could not create {}", merged_sorted_dir.display()))?;
208
209    let mut pl = concurrent_progress_logger!(
210        display_memory = true,
211        item_name = "arc",
212        local_speed = true,
213        expected_updates = Some(num_arcs),
214    );
215    pl.start("Merging sorted arcs");
216
217    let merged_sorted_iterators = sorted_iterators
218        .into_par_iter()
219        .enumerate()
220        // Concatenate partitions
221        .map_with(
222            pl.clone(),
223            |thread_pl, (partition_id, partition_sorted_iterators)| {
224                // In the previous step, each of the N threads generated M partitions,
225                // so N×M lists. (Unless Rayon did something funny, N=M.)
226                // We now transpose, by taking for each partition what each thread produced,
227                // and merge them together, to get only M lists.
228                // This is done *in parallel*, and saves work when *sequentially* consuming
229                // the final iterator
230                let path = merged_sorted_dir.join(format!("part_{partition_id}"));
231                let num_arcs_in_partition = serialize(
232                    &path,
233                    thread_pl,
234                    label_serializer,
235                    KMergeIters::new(partition_sorted_iterators),
236                )?;
237
238                deserialize(&path, label_deserializer, num_arcs_in_partition)
239            },
240        )
241        .collect::<Result<Vec<_>>>()?;
242
243    pl.done();
244
245    log::info!("Deleted unmerged sorted files");
246    std::fs::remove_dir_all(&unmerged_sorted_dir)
247        .with_context(|| format!("Could not remove {}", unmerged_sorted_dir.display()))?;
248    log::info!("Done");
249
250    Ok(merged_sorted_iterators)
251}
252
253fn serialize<L, S>(
254    path: &Path,
255    pl: &mut impl ProgressLog,
256    label_serializer: S,
257    arcs: impl Iterator<Item = (usize, usize, L)>,
258) -> Result<usize>
259where
260    S: BitSerializer<NE, BitWriter, SerType = L> + Send + Sync + Copy,
261{
262    let file =
263        File::create_new(path).with_context(|| format!("Could not create {}", path.display()))?;
264    let mut write_stream =
265        <BufBitWriter<NE, _>>::new(<WordAdapter<usize, _>>::new(BufWriter::new(file)));
266    let mut prev_src = 0;
267    let mut prev_dst = 0;
268    let mut num_arcs_in_partition: usize = 0;
269    for (src, dst, label) in arcs {
270        write_stream
271            .write_gamma((src - prev_src).try_into().expect("usize overflowed u64"))
272            .context("Could not write src gamma")?;
273        if src != prev_src {
274            prev_dst = 0;
275        }
276        write_stream
277            .write_gamma((dst - prev_dst).try_into().expect("usize overflowed u64"))
278            .context("Could not write dst gamma")?;
279        label_serializer
280            .serialize(&label, &mut write_stream)
281            .context("Could not serialize label")?;
282        prev_src = src;
283        prev_dst = dst;
284        pl.light_update();
285        num_arcs_in_partition += 1;
286    }
287    write_stream.flush().context("Could not flush stream")?;
288    Ok(num_arcs_in_partition)
289}
290
291fn deserialize<L, D>(
292    path: &Path,
293    label_deserializer: D,
294    num_arcs: usize,
295) -> Result<impl Iterator<Item = (usize, usize, L)> + Clone + Send + Sync>
296where
297    D: BitDeserializer<NE, BitReader, DeserType = L> + Send + Sync + Copy,
298{
299    let mut read_stream = <BufBitReader<NE, _>>::new(MemWordReader::new(ArcMmapHelper(Arc::new(
300        MmapHelper::mmap(
301            path,
302            MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::SEQUENTIAL,
303        )
304        .with_context(|| format!("Could not mmap {}", path.display()))?,
305    ))));
306
307    let mut prev_src = 0;
308    let mut prev_dst = 0;
309    let arcs = (0..num_arcs).map(move |_| {
310        let src = prev_src + read_stream.read_gamma().expect("Could not read src gamma");
311        if src != prev_src {
312            prev_dst = 0;
313        }
314        let dst = prev_dst + read_stream.read_gamma().expect("Could not read dst gamma");
315        let label = label_deserializer
316            .deserialize(&mut read_stream)
317            .expect("Could not deserialize label");
318        prev_src = src;
319        prev_dst = dst;
320        let src = usize::try_from(src).expect("deserialized usize overflows usize");
321        let dst = usize::try_from(dst).expect("deserialized usize overflows usize");
322        (src, dst, label)
323    });
324    Ok(arcs)
325}
326
327fn flush<
328    L: Ord + Copy + Send + Sync,
329    S: BitSerializer<NE, BitWriter, SerType = L>,
330    D: BitDeserializer<NE, BitReader, DeserType = L>,
331>(
332    temp_dir: &Path,
333    buffer: &mut [Triple<L>],
334    label_serializer: S,
335    label_deserializer: D,
336) -> Result<BatchIterator<D>> {
337    use rand::Rng;
338    let sorter_id = rand::thread_rng().r#gen::<u64>();
339    let mut sorter_temp_file = temp_dir.to_owned();
340    sorter_temp_file.push(format!("sort-arcs-permute-{sorter_id:#x}"));
341
342    // This is equivalent to BatchIterator::new_from_vec(&sorter_temp_file, buffer),
343    // but without parallelism, which would cause Rayon to re-enter
344    // par_sort_arcs and cause deadlocks: https://github.com/rayon-rs/rayon/issues/1083
345    buffer.sort_unstable_by_key(
346        |Triple {
347             pair: [src, dst],
348             label: _,
349         }| (*src, *dst), // not sorting by label, KMergeIters loses the order anyway
350    );
351    BatchIterator::new_from_vec_sorted_labeled(
352        &sorter_temp_file,
353        buffer,
354        &label_serializer,
355        label_deserializer,
356    )
357    .with_context(|| {
358        format!(
359            "Could not create BatchIterator in {}",
360            sorter_temp_file.display()
361        )
362    })
363}