1use std::cell::RefCell;
10use std::fs::File;
11use std::io::BufWriter;
12use std::path::{Path, PathBuf};
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::{Arc, Mutex};
15
16use anyhow::{Context, Result};
17use dsi_bitstream::prelude::*;
18use dsi_progress_logger::{concurrent_progress_logger, ProgressLog};
19use mmap_rs::MmapFlags;
20use rayon::prelude::*;
21use webgraph::prelude::{ArcMmapHelper, BitDeserializer, BitSerializer, MmapHelper};
22use webgraph::utils::sort_pairs::{BatchIterator, BitReader, BitWriter, KMergeIters, Triple};
23
24pub struct PartitionedBuffer<
25 L: Ord + Copy + Send + Sync,
26 S: BitSerializer<NE, BitWriter, SerType = L>,
27 D: BitDeserializer<NE, BitReader>,
28> {
29 partitions: Vec<Vec<Triple<L>>>,
30 capacity: usize,
31 sorted_iterators: Arc<Mutex<Vec<Vec<BatchIterator<D>>>>>,
32 temp_dir: PathBuf,
33 label_serializer: S,
34 label_deserializer: D,
35 total_flushed: Arc<AtomicUsize>,
37}
38
39impl<
40 L: Ord + Copy + Send + Sync,
41 S: BitSerializer<NE, BitWriter, SerType = L> + Copy,
42 D: BitDeserializer<NE, BitReader, DeserType = L> + Copy,
43 > PartitionedBuffer<L, S, D>
44{
45 fn new(
46 sorted_iterators: Arc<Mutex<Vec<Vec<BatchIterator<D>>>>>,
47 temp_dir: &Path,
48 batch_size: usize,
49 num_partitions: usize,
50 label_serializer: S,
51 label_deserializer: D,
52 total_flushed: Arc<AtomicUsize>,
53 ) -> Self {
54 let capacity = batch_size / num_partitions;
55 PartitionedBuffer {
56 partitions: vec![Vec::with_capacity(capacity); num_partitions],
57 sorted_iterators,
58 temp_dir: temp_dir.to_owned(),
59 capacity,
60 label_serializer,
61 label_deserializer,
62 total_flushed,
63 }
64 }
65
66 pub fn insert_labeled(
67 &mut self,
68 partition_id: usize,
69 src: usize,
70 dst: usize,
71 label: L,
72 ) -> Result<()> {
73 let partition_buffer = self
74 .partitions
75 .get_mut(partition_id)
76 .expect("Partition sorter out of bound");
77 partition_buffer.push(Triple {
78 pair: [src, dst],
79 label,
80 });
81 if partition_buffer.len() + 1 >= self.capacity {
82 self.flush(partition_id)?;
83 }
84 Ok(())
85 }
86
87 fn flush_all(&mut self) -> Result<()> {
88 for partition_id in 0..self.partitions.len() {
89 self.flush(partition_id)?;
90 }
91 Ok(())
92 }
93
94 fn flush(&mut self, partition_id: usize) -> Result<()> {
95 let partition_buffer = self
96 .partitions
97 .get_mut(partition_id)
98 .expect("Partition buffer out of bound");
99 let batch = flush(
100 &self.temp_dir,
101 &mut partition_buffer[..],
102 self.label_serializer,
103 self.label_deserializer,
104 )?;
105 self.sorted_iterators
106 .lock()
107 .unwrap()
108 .get_mut(partition_id)
109 .expect("Partition sorters out of bound")
110 .push(batch);
111 self.total_flushed
112 .fetch_add(partition_buffer.len(), Ordering::Relaxed);
113 partition_buffer.clear();
114 Ok(())
115 }
116}
117
118impl PartitionedBuffer<(), (), ()> {
119 pub fn insert(&mut self, partition_id: usize, src: usize, dst: usize) -> Result<()> {
120 self.insert_labeled(partition_id, src, dst, ())
121 }
122}
123
124pub fn par_sort_arcs<Item, Iter, F, L, S, D>(
135 temp_dir: &Path,
136 batch_size: usize,
137 iter: Iter,
138 num_partitions: usize,
139 label_serializer: S,
140 label_deserializer: D,
141 f: F,
142) -> Result<Vec<impl Iterator<Item = (usize, usize, L)> + Clone + Send + Sync>>
143where
144 F: Fn(&mut PartitionedBuffer<L, S, D>, Item) -> Result<()> + Send + Sync,
145 Iter: ParallelIterator<Item = Item>,
146 L: Ord + Copy + Send + Sync,
147 S: BitSerializer<NE, BitWriter, SerType = L> + Send + Sync + Copy,
148 D: BitDeserializer<NE, BitReader, DeserType = L> + Send + Sync + Copy,
149{
150 let buffers = thread_local::ThreadLocal::new();
153
154 let sorted_iterators = Arc::new(Mutex::new(vec![Vec::new(); num_partitions]));
157
158 let unmerged_sorted_dir = temp_dir.join("unmerged");
159 std::fs::create_dir(&unmerged_sorted_dir)
160 .with_context(|| format!("Could not create {}", unmerged_sorted_dir.display()))?;
161
162 let num_arcs = Arc::new(AtomicUsize::new(0));
163
164 iter.try_for_each_init(
165 || -> std::cell::RefMut<PartitionedBuffer<L, S, D>> {
166 buffers
167 .get_or(|| {
168 RefCell::new(PartitionedBuffer::new(
169 sorted_iterators.clone(),
170 &unmerged_sorted_dir,
171 batch_size,
172 num_partitions,
173 label_serializer,
174 label_deserializer,
175 num_arcs.clone(),
176 ))
177 })
178 .borrow_mut()
179 },
180 |thread_buffers, item| -> Result<()> {
181 let thread_buffers = &mut *thread_buffers;
182 f(thread_buffers, item)
183 },
184 )?;
185
186 log::info!("Flushing remaining buffers to BatchIterator...");
187
188 buffers.into_iter().par_bridge().try_for_each(
190 |thread_buffer: RefCell<PartitionedBuffer<L, S, D>>| -> Result<()> {
191 thread_buffer.into_inner().flush_all()
192 },
193 )?;
194 log::info!("Done sorting all buffers.");
195
196 let sorted_iterators = Arc::into_inner(sorted_iterators)
197 .expect("Dangling references to sorted_iterators Arc")
198 .into_inner()
199 .unwrap();
200
201 let num_arcs = Arc::into_inner(num_arcs)
202 .expect("Could not take ownership of num_arcs")
203 .into_inner();
204
205 let merged_sorted_dir = temp_dir.join("merged");
206 std::fs::create_dir(&merged_sorted_dir)
207 .with_context(|| format!("Could not create {}", merged_sorted_dir.display()))?;
208
209 let mut pl = concurrent_progress_logger!(
210 display_memory = true,
211 item_name = "arc",
212 local_speed = true,
213 expected_updates = Some(num_arcs),
214 );
215 pl.start("Merging sorted arcs");
216
217 let merged_sorted_iterators = sorted_iterators
218 .into_par_iter()
219 .enumerate()
220 .map_with(
222 pl.clone(),
223 |thread_pl, (partition_id, partition_sorted_iterators)| {
224 let path = merged_sorted_dir.join(format!("part_{partition_id}"));
231 let num_arcs_in_partition = serialize(
232 &path,
233 thread_pl,
234 label_serializer,
235 KMergeIters::new(partition_sorted_iterators),
236 )?;
237
238 deserialize(&path, label_deserializer, num_arcs_in_partition)
239 },
240 )
241 .collect::<Result<Vec<_>>>()?;
242
243 pl.done();
244
245 log::info!("Deleted unmerged sorted files");
246 std::fs::remove_dir_all(&unmerged_sorted_dir)
247 .with_context(|| format!("Could not remove {}", unmerged_sorted_dir.display()))?;
248 log::info!("Done");
249
250 Ok(merged_sorted_iterators)
251}
252
253fn serialize<L, S>(
254 path: &Path,
255 pl: &mut impl ProgressLog,
256 label_serializer: S,
257 arcs: impl Iterator<Item = (usize, usize, L)>,
258) -> Result<usize>
259where
260 S: BitSerializer<NE, BitWriter, SerType = L> + Send + Sync + Copy,
261{
262 let file =
263 File::create_new(path).with_context(|| format!("Could not create {}", path.display()))?;
264 let mut write_stream =
265 <BufBitWriter<NE, _>>::new(<WordAdapter<usize, _>>::new(BufWriter::new(file)));
266 let mut prev_src = 0;
267 let mut prev_dst = 0;
268 let mut num_arcs_in_partition: usize = 0;
269 for (src, dst, label) in arcs {
270 write_stream
271 .write_gamma((src - prev_src).try_into().expect("usize overflowed u64"))
272 .context("Could not write src gamma")?;
273 if src != prev_src {
274 prev_dst = 0;
275 }
276 write_stream
277 .write_gamma((dst - prev_dst).try_into().expect("usize overflowed u64"))
278 .context("Could not write dst gamma")?;
279 label_serializer
280 .serialize(&label, &mut write_stream)
281 .context("Could not serialize label")?;
282 prev_src = src;
283 prev_dst = dst;
284 pl.light_update();
285 num_arcs_in_partition += 1;
286 }
287 write_stream.flush().context("Could not flush stream")?;
288 Ok(num_arcs_in_partition)
289}
290
291fn deserialize<L, D>(
292 path: &Path,
293 label_deserializer: D,
294 num_arcs: usize,
295) -> Result<impl Iterator<Item = (usize, usize, L)> + Clone + Send + Sync>
296where
297 D: BitDeserializer<NE, BitReader, DeserType = L> + Send + Sync + Copy,
298{
299 let mut read_stream = <BufBitReader<NE, _>>::new(MemWordReader::new(ArcMmapHelper(Arc::new(
300 MmapHelper::mmap(
301 path,
302 MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::SEQUENTIAL,
303 )
304 .with_context(|| format!("Could not mmap {}", path.display()))?,
305 ))));
306
307 let mut prev_src = 0;
308 let mut prev_dst = 0;
309 let arcs = (0..num_arcs).map(move |_| {
310 let src = prev_src + read_stream.read_gamma().expect("Could not read src gamma");
311 if src != prev_src {
312 prev_dst = 0;
313 }
314 let dst = prev_dst + read_stream.read_gamma().expect("Could not read dst gamma");
315 let label = label_deserializer
316 .deserialize(&mut read_stream)
317 .expect("Could not deserialize label");
318 prev_src = src;
319 prev_dst = dst;
320 let src = usize::try_from(src).expect("deserialized usize overflows usize");
321 let dst = usize::try_from(dst).expect("deserialized usize overflows usize");
322 (src, dst, label)
323 });
324 Ok(arcs)
325}
326
327fn flush<
328 L: Ord + Copy + Send + Sync,
329 S: BitSerializer<NE, BitWriter, SerType = L>,
330 D: BitDeserializer<NE, BitReader, DeserType = L>,
331>(
332 temp_dir: &Path,
333 buffer: &mut [Triple<L>],
334 label_serializer: S,
335 label_deserializer: D,
336) -> Result<BatchIterator<D>> {
337 use rand::Rng;
338 let sorter_id = rand::thread_rng().r#gen::<u64>();
339 let mut sorter_temp_file = temp_dir.to_owned();
340 sorter_temp_file.push(format!("sort-arcs-permute-{sorter_id:#x}"));
341
342 buffer.sort_unstable_by_key(
346 |Triple {
347 pair: [src, dst],
348 label: _,
349 }| (*src, *dst), );
351 BatchIterator::new_from_vec_sorted_labeled(
352 &sorter_temp_file,
353 buffer,
354 &label_serializer,
355 label_deserializer,
356 )
357 .with_context(|| {
358 format!(
359 "Could not create BatchIterator in {}",
360 sorter_temp_file.display()
361 )
362 })
363}