webgraph_cli/analyze/
codes.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Inria
3 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use crate::{GlobalArgs, GranularityArgs, NumThreadsArg};
9use anyhow::Result;
10use clap::Parser;
11use dsi_bitstream::{dispatch::factory::CodesReaderFactoryHelper, prelude::*};
12use dsi_progress_logger::prelude::*;
13use std::path::PathBuf;
14use webgraph::prelude::*;
15
16#[derive(Parser, Debug)]
17#[command(name = "codes", about = "Reads a graph and suggests the best codes to use.", long_about = None)]
18pub struct CliArgs {
19    /// The basename of the graph.
20    pub src: PathBuf,
21
22    #[clap(flatten)]
23    pub num_threads: NumThreadsArg,
24
25    #[clap(flatten)]
26    pub granularity: GranularityArgs,
27}
28
29pub fn main(global_args: GlobalArgs, args: CliArgs) -> Result<()> {
30    match get_endianness(&args.src)?.as_str() {
31        #[cfg(feature = "be_bins")]
32        BE::NAME => optimize_codes::<BE>(global_args, args),
33        #[cfg(feature = "le_bins")]
34        LE::NAME => optimize_codes::<LE>(global_args, args),
35        e => panic!("Unknown endianness: {}", e),
36    }
37}
38
39/// Returns ranges of nodes to process in parallel of size `chunk_size` each,
40/// with the last chunk possibly being smaller.
41/// The equivalent of `std::iter::Chunks` but with a `Range` instead of a `Slice`.
42pub struct Chunks {
43    total: core::ops::Range<usize>,
44    chunk_size: usize,
45}
46
47impl Chunks {
48    pub fn new(total: core::ops::Range<usize>, chunk_size: usize) -> Self {
49        Self { total, chunk_size }
50    }
51}
52
53impl Iterator for Chunks {
54    type Item = core::ops::Range<usize>;
55
56    fn next(&mut self) -> Option<Self::Item> {
57        if self.total.start < self.total.end {
58            let end = (self.total.start + self.chunk_size).min(self.total.end);
59            let range = self.total.start..end;
60            self.total.start = end;
61            Some(range)
62        } else {
63            None
64        }
65    }
66}
67
68pub fn optimize_codes<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result<()>
69where
70    MmapHelper<u32>: CodesReaderFactoryHelper<E>,
71    for<'a> LoadModeCodesReader<'a, E, Mmap>: BitSeek,
72{
73    let mut stats = Default::default();
74    let has_ef = std::fs::metadata(args.src.with_extension("ef")).is_ok_and(|x| x.is_file());
75
76    if has_ef {
77        log::info!(
78            "Analyzing codes in parallel using {} threads",
79            args.num_threads.num_threads
80        );
81        let graph = BvGraph::with_basename(&args.src).endianness::<E>().load()?;
82
83        let mut pl = concurrent_progress_logger![item_name = "node"];
84        pl.display_memory(true)
85            .expected_updates(Some(graph.num_nodes()));
86        pl.start("Scanning...");
87
88        if let Some(duration) = global_args.log_interval {
89            pl.log_interval(duration);
90        }
91
92        let thread_pool = rayon::ThreadPoolBuilder::new()
93            .num_threads(args.num_threads.num_threads)
94            .build()?;
95
96        let node_granularity = args
97            .granularity
98            .into_granularity()
99            .node_granularity(graph.num_nodes(), Some(graph.num_arcs()));
100
101        // TODO!: use FairChunks with the offsets EF to distribute the
102        // work based on number of bits used, not nodes
103        stats = Chunks::new(0..graph.num_nodes(), node_granularity).par_map_fold_with(
104            pl.clone(),
105            |pl, range| {
106                let mut iter = graph
107                    .offset_deg_iter_from(range.start)
108                    .map_decoder(|d| StatsDecoder::new(d, Default::default()));
109
110                for _ in (&mut iter).take(range.len()) {
111                    pl.light_update();
112                }
113
114                let mut stats = Default::default();
115                iter.map_decoder(|d| {
116                    stats = d.stats;
117                    d.codes_reader // not important but we need to return something
118                });
119                stats
120            },
121            |mut acc1, acc2| {
122                acc1 += &acc2;
123                acc1
124            },
125            &thread_pool,
126        );
127
128        pl.done();
129    } else {
130        if args.num_threads.num_threads != 1 {
131            log::info!(
132                "Analyzing codes sequentially, this might be faster if you build the Elias-Fano index using `webgraph build ef {}` which will generate file {}",
133                args.src.display(),
134                args.src.with_extension("ef").display()
135            );
136        }
137
138        let graph = BvGraphSeq::with_basename(args.src)
139            .endianness::<E>()
140            .load()?;
141
142        let mut pl = ProgressLogger::default();
143        pl.display_memory(true)
144            .item_name("node")
145            .expected_updates(Some(graph.num_nodes()));
146
147        pl.start("Scanning...");
148
149        // add the stats wrapper to the decoder
150        let mut iter = graph
151            .offset_deg_iter()
152            .map_decoder(|d| StatsDecoder::new(d, Default::default()));
153        // iterate over the graph
154        for _ in iter.by_ref() {
155            pl.light_update();
156        }
157        pl.done();
158        // extract the stats
159        iter.map_decoder(|d| {
160            stats = d.stats;
161            d.codes_reader // not important but we need to return something
162        });
163    }
164
165    macro_rules! impl_best_code {
166        ($new_bits:expr, $old_bits:expr, $stats:expr, $($code:ident - $old:expr),*) => {
167            println!("{:>17} {:>16} {:>12} {:>8} {:>10} {:>16}",
168                "Type", "Code", "Improvement", "Weight", "Bytes", "Bits",
169            );
170            $(
171                let (_, new) = $stats.$code.best_code();
172                $new_bits += new;
173                $old_bits += $old;
174            )*
175
176            $(
177                let (code, new) = $stats.$code.best_code();
178                println!("{:>17} {:>16} {:>12} {:>8} {:>10} {:>16}",
179                    stringify!($code), format!("{:?}", code),
180                    format!("{:.3}%", 100.0 * ($old - new) as f64 / $old as f64),
181                    format!("{:.3}", (($old - new) as f64 / ($old_bits - $new_bits) as f64)),
182                    normalize(($old - new) as f64 / 8.0),
183                    $old - new,
184                );
185            )*
186        };
187    }
188
189    let mut new_bits = 0;
190    let mut old_bits = 0;
191    impl_best_code!(
192        new_bits,
193        old_bits,
194        stats,
195        outdegrees - stats.outdegrees.gamma,
196        reference_offsets - stats.reference_offsets.unary,
197        block_counts - stats.block_counts.gamma,
198        blocks - stats.blocks.gamma,
199        interval_counts - stats.interval_counts.gamma,
200        interval_starts - stats.interval_starts.gamma,
201        interval_lens - stats.interval_lens.gamma,
202        first_residuals - stats.first_residuals.zeta[2],
203        residuals - stats.residuals.zeta[2]
204    );
205
206    println!();
207    println!(" Old bit size: {:>16}", old_bits);
208    println!(" New bit size: {:>16}", new_bits);
209    println!("   Saved bits: {:>16}", old_bits - new_bits);
210
211    println!("Old byte size: {:>16}", normalize(old_bits as f64 / 8.0));
212    println!("New byte size: {:>16}", normalize(new_bits as f64 / 8.0));
213    println!(
214        "  Saved bytes: {:>16}",
215        normalize((old_bits - new_bits) as f64 / 8.0)
216    );
217
218    println!(
219        "  Improvement: {:>15.3}%",
220        100.0 * (old_bits - new_bits) as f64 / old_bits as f64
221    );
222    Ok(())
223}
224
225fn normalize(mut value: f64) -> String {
226    let mut uom = ' ';
227    if value > 1000.0 {
228        value /= 1000.0;
229        uom = 'K';
230    }
231    if value > 1000.0 {
232        value /= 1000.0;
233        uom = 'M';
234    }
235    if value > 1000.0 {
236        value /= 1000.0;
237        uom = 'G';
238    }
239    if value > 1000.0 {
240        value /= 1000.0;
241        uom = 'T';
242    }
243    format!("{:.3}{}", value, uom)
244}