swh_graph/utils/sort/
strings.rs

1/*
2 * Copyright (C) 2024  The Software Heritage developers
3 * See the AUTHORS file at the top-level directory of this distribution
4 * License: GNU General Public License version 3, or any later version
5 * See top-level LICENSE file for more information
6 */
7
8//! Parallel sorting and deduplication for lists of SWHIDS that don't fit in RAM
9// Adapted from https://archive.softwareheritage.org/swh:1:cnt:d5129fef934309da995a8895ba9509a6faae0bba;origin=https://github.com/vigna/webgraph-rs;visit=swh:1:snp:76b76a6b68240ad1ec27aed81f7cc30441b69d7c;anchor=swh:1:rel:ef30092122d472899fdfa361e784fc1e04495dab;path=/src/utils/sort_pairs.rs;lines=410-512
10
11use std::io::{Read, Write};
12use std::path::PathBuf;
13
14use anyhow::{ensure, Context, Result};
15use dsi_progress_logger::ProgressLog;
16use rayon::prelude::*;
17
18use super::ParallelDeduplicatingExternalSorter;
19
20/// arbitrary value to pick a buffer capacity
21const AVERAGE_STRING_LENGTH: usize = 64;
22
23type Bytestring = Box<[u8]>;
24
25#[derive(Copy, Clone)]
26struct BytestringExternalSorter {
27    buffer_size: usize,
28}
29
30impl ParallelDeduplicatingExternalSorter<Bytestring> for BytestringExternalSorter {
31    fn buffer_capacity(&self) -> usize {
32        self.buffer_size
33            .div_ceil(AVERAGE_STRING_LENGTH)
34            .next_power_of_two()
35    }
36
37    #[allow(clippy::get_first)]
38    fn sort_vec(&self, vec: &mut Vec<Bytestring>) -> Result<()> {
39        // Perform a one-level radix sort before handing off to a generic sort.
40
41        // Note: bucket distribution is uniform when called from ExtractPersons because
42        // we manipulate sha256 digests; but is very heterogeneous when called from
43        // ExtractLabels because labels are mostly ASCII text.
44
45        let mut partitions: Vec<_> = (0..65536)
46            .map(|_| Vec::with_capacity(vec.len().div_ceil(65536)))
47            .collect();
48
49        // Split into partitions
50        for string in vec.drain(0..) {
51            let partition_id = ((string.get(0).copied().unwrap_or(0u8) as usize) << 8)
52                | string.get(1).copied().unwrap_or(0u8) as usize;
53            partitions[partition_id].push(string);
54        }
55
56        // Sort each partition. We use a single-threaded sort for each partition
57        // because we are already within a thread, and it would needlessly add churn to
58        // Rayon's scheduler.
59        partitions
60            .par_iter_mut()
61            .for_each(|partition| partition.sort_unstable());
62
63        for partition in partitions {
64            vec.extend(partition);
65        }
66        Ok(())
67    }
68
69    fn serialize(path: PathBuf, strings: impl Iterator<Item = Bytestring>) -> Result<()> {
70        let file = std::fs::File::create_new(&path)
71            .with_context(|| format!("Could not create {}", path.display()))?;
72        let compression_level = 3;
73        let mut encoder = zstd::stream::write::Encoder::new(file, compression_level)
74            .with_context(|| format!("Could not create ZSTD encoder for {}", path.display()))?;
75        for string in strings {
76            let len: u32 = string
77                .len()
78                .try_into()
79                .context("String is 2^32 bytes or longer")?;
80            ensure!(len != u32::MAX, "String is 2^32 -1 bytes long");
81            encoder
82                .write_all(&len.to_ne_bytes())
83                .with_context(|| format!("Could not write string to {}", path.display()))?;
84            encoder
85                .write_all(&string)
86                .with_context(|| format!("Could not write string to {}", path.display()))?;
87        }
88        // mark end of file
89        encoder
90            .write_all(&u32::MAX.to_ne_bytes())
91            .with_context(|| format!("Could not write string to {}", path.display()))?;
92
93        encoder
94            .finish()
95            .with_context(|| format!("Could not flush to {}", path.display()))?;
96        Ok(())
97    }
98
99    fn deserialize(path: PathBuf) -> Result<impl Iterator<Item = Bytestring>> {
100        let file = std::fs::File::open(&path)
101            .with_context(|| format!("Could not open {}", path.display()))?;
102        let mut decoder =
103            zstd::stream::read::Decoder::new(file).context("Could not decompress sorted file")?;
104        Ok(std::iter::repeat(()).map_while(move |()| {
105            let mut buf = [0u8; 4];
106            decoder
107                .read_exact(&mut buf)
108                .expect("Could not read string size");
109            let size = u32::from_ne_bytes(buf);
110            if size == u32::MAX {
111                // end of file marker
112                return None;
113            }
114            let mut line = vec![0; size.try_into().unwrap()].into_boxed_slice();
115            decoder
116                .read_exact(&mut line)
117                .expect("Could not read string");
118            Some(line)
119        }))
120    }
121}
122
123pub fn par_sort_strings<Iter: ParallelIterator<Item = Bytestring>>(
124    iter: Iter,
125    pl: impl ProgressLog + Send,
126    buffer_size: usize,
127) -> Result<impl Iterator<Item = Bytestring>> {
128    BytestringExternalSorter { buffer_size }.par_sort_dedup(iter, pl)
129}