Skip to main content

webgraph_cli/analyze/
codes.rs

1/*
2 * SPDX-FileCopyrightText: 2023 Inria
3 * SPDX-FileCopyrightText: 2023 Tommaso Fontana
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use crate::{GlobalArgs, GranularityArgs, NumThreadsArg};
9use anyhow::Result;
10use clap::Parser;
11use dsi_bitstream::{dispatch::factory::CodesReaderFactoryHelper, prelude::*};
12use dsi_progress_logger::prelude::*;
13use std::path::PathBuf;
14use webgraph::prelude::*;
15
16#[derive(Parser, Debug)]
17#[command(name = "codes", about = "Reads a graph and suggests the best codes to use.", long_about = None)]
18pub struct CliArgs {
19    /// The basename of the graph.
20    pub basename: PathBuf,
21
22    #[clap(flatten)]
23    pub num_threads: NumThreadsArg,
24
25    #[clap(flatten)]
26    pub granularity: GranularityArgs,
27
28    #[clap(short = 'k', long, default_value_t = 3)]
29    /// How many codes to show for each type, if k is bigger than the number of codes available
30    /// all codes will be shown.
31    pub top_k: usize,
32}
33
34pub fn main(global_args: GlobalArgs, args: CliArgs) -> Result<()> {
35    match get_endianness(&args.basename)?.as_str() {
36        #[cfg(feature = "be_bins")]
37        BE::NAME => optimize_codes::<BE>(global_args, args),
38        #[cfg(feature = "le_bins")]
39        LE::NAME => optimize_codes::<LE>(global_args, args),
40        e => panic!("Unknown endianness: {}", e),
41    }
42}
43
44/// Returns ranges of nodes to process in parallel of size `chunk_size` each,
45/// with the last chunk possibly being smaller.
46/// The equivalent of `std::iter::Chunks` but with a `Range` instead of a `Slice`.
47pub struct ChunksIter {
48    total: core::ops::Range<usize>,
49    chunk_size: usize,
50}
51
52impl ChunksIter {
53    pub fn new(total: core::ops::Range<usize>, chunk_size: usize) -> Self {
54        Self { total, chunk_size }
55    }
56}
57
58impl Iterator for ChunksIter {
59    type Item = core::ops::Range<usize>;
60
61    fn next(&mut self) -> Option<Self::Item> {
62        if self.total.start < self.total.end {
63            let end = (self.total.start + self.chunk_size).min(self.total.end);
64            let range = self.total.start..end;
65            self.total.start = end;
66            Some(range)
67        } else {
68            None
69        }
70    }
71}
72
73pub fn optimize_codes<E: Endianness>(global_args: GlobalArgs, args: CliArgs) -> Result<()>
74where
75    MmapHelper<u32>: CodesReaderFactoryHelper<E>,
76    for<'a> LoadModeCodesReader<'a, E, Mmap>: BitSeek,
77{
78    let mut stats = Default::default();
79    let has_ef = std::fs::metadata(args.basename.with_extension("ef")).is_ok_and(|x| x.is_file());
80
81    // Load the compression flags from the properties file so we can compare them
82    let (_, _, comp_flags) =
83        parse_properties::<E>(args.basename.with_extension(PROPERTIES_EXTENSION))?;
84
85    if has_ef {
86        log::info!(
87            "Analyzing codes in parallel using {} threads",
88            args.num_threads.num_threads
89        );
90        let graph = BvGraph::with_basename(&args.basename)
91            .endianness::<E>()
92            .load()?;
93
94        let mut pl = concurrent_progress_logger![item_name = "node"];
95        pl.display_memory(true)
96            .expected_updates(Some(graph.num_nodes()));
97        pl.start("Scanning...");
98
99        if let Some(duration) = global_args.log_interval {
100            pl.log_interval(duration);
101        }
102
103        let thread_pool = rayon::ThreadPoolBuilder::new()
104            .num_threads(args.num_threads.num_threads)
105            .build()?;
106
107        let node_granularity = args
108            .granularity
109            .into_granularity()
110            .node_granularity(graph.num_nodes(), Some(graph.num_arcs()));
111
112        // TODO!: use FairChunks with the offsets EF to distribute the
113        // work based on number of bits used, not nodes
114        stats = thread_pool.install(|| {
115            ChunksIter::new(0..graph.num_nodes(), node_granularity).par_map_fold_with(
116                pl.clone(),
117                |pl, range| {
118                    let mut iter = graph
119                        .offset_deg_iter_from(range.start)
120                        .map_decoder(|d| StatsDecoder::new(d, Default::default()));
121
122                    for _ in (&mut iter).take(range.len()) {
123                        pl.light_update();
124                    }
125
126                    let mut stats = Default::default();
127                    iter.map_decoder(|d| {
128                        stats = d.stats;
129                        d.codes_reader // not important but we need to return something
130                    });
131                    stats
132                },
133                |mut acc1, acc2| {
134                    acc1 += &acc2;
135                    acc1
136                },
137            )
138        });
139
140        pl.done();
141    } else {
142        if args.num_threads.num_threads != 1 {
143            log::info!(SEQ_PROC_WARN![], args.basename.display());
144        }
145
146        let graph = BvGraphSeq::with_basename(args.basename)
147            .endianness::<E>()
148            .load()?;
149
150        let mut pl = ProgressLogger::default();
151        pl.display_memory(true)
152            .item_name("node")
153            .expected_updates(Some(graph.num_nodes()));
154
155        pl.start("Scanning...");
156
157        // add the stats wrapper to the decoder
158        let mut iter = graph
159            .offset_deg_iter()
160            .map_decoder(|d| StatsDecoder::new(d, Default::default()));
161        // iterate over the graph
162        for _ in iter.by_ref() {
163            pl.light_update();
164        }
165        pl.done();
166        // extract the stats
167        iter.map_decoder(|d| {
168            stats = d.stats;
169            d.codes_reader // not important but we need to return something
170        });
171    }
172
173    println!("Default codes");
174    compare_codes(&stats, CompFlags::default(), args.top_k);
175
176    print!("\n\n\n");
177
178    println!("Current codes");
179    compare_codes(&stats, comp_flags, args.top_k);
180
181    Ok(())
182}
183
184/// Gets the size in bits used by a given code.
185/// This should go in dsi-bitstream eventually.
186fn get_size_by_code(stats: &CodesStats, code: Codes) -> Option<u64> {
187    match code {
188        Codes::Unary => Some(stats.unary),
189        Codes::Gamma => Some(stats.gamma),
190        Codes::Delta => Some(stats.delta),
191        Codes::Omega => Some(stats.omega),
192        Codes::VByteBe | Codes::VByteLe => Some(stats.vbyte),
193        Codes::Zeta(k) => stats.zeta.get(k - 1).copied(),
194        Codes::Golomb(b) => stats.golomb.get(b as usize - 1).copied(),
195        Codes::ExpGolomb(k) => stats.exp_golomb.get(k).copied(),
196        Codes::Rice(k) => stats.rice.get(k).copied(),
197        Codes::Pi(0) => Some(stats.gamma),   // Pi(0) is Gamma
198        Codes::Pi(1) => Some(stats.zeta[1]), // Pi(1) is Zeta(2)
199        Codes::Pi(k) => stats.pi.get(k - 2).copied(),
200        _ => unreachable!("Code {:?} not supported", code),
201    }
202}
203
204/// Prints the statistics of how much the optimal codes improve over the reference ones.
205pub fn compare_codes(stats: &DecoderStats, reference: CompFlags, top_k: usize) {
206    macro_rules! impl_best_code {
207        ($new_bits:expr, $old_bits:expr, $stats:expr, $($code:ident -> $old:expr),*) => {
208            println!("{:>17} {:>20} {:>12} {:>10} {:>10} {:>16}",
209                "Type", "Code", "Improvement", "Weight", "Bytes", "Bits",
210            );
211            $(
212                let (_, new) = $stats.$code.best_code();
213                $new_bits += new;
214                $old_bits += $old;
215            )*
216
217            $(
218                let codes = $stats.$code.get_codes();
219                let (best_code, best_size) = codes[0];
220
221                let improvement = 100.0 * ($old - best_size) as f64 / $old as f64;
222                let weight = 100.0 * ($old as f64 - best_size as f64) / ($old_bits as f64 - $new_bits as f64);
223
224                println!("{:>17} {:>20} {:>12.3}% {:>9.3}% {:>10} {:>16}",
225                    stringify!($code),
226                    format!("{:?}", best_code),
227                    improvement,
228                    weight,
229                    normalize(best_size as f64 / 8.0),
230                    best_size,
231                );
232                for i in 1..top_k.min(codes.len()).max(1) {
233                    let (code, size) = codes[i];
234                    let improvement = 100.0 * ($old as f64 - size as f64) / $old as f64;
235                    println!("{:>17} {:>20} {:>12.3}% {:>10.3} {:>10} {:>16}",
236                        stringify!($code),
237                        format!("{:?}", code),
238                        improvement,
239                        "",
240                        normalize(size as f64 / 8.0),
241                        size,
242                    );
243                }
244                print!("\n");
245            )*
246        };
247    }
248
249    println!("Code optimization results against:");
250    for (name, code) in [
251        ("outdegrees", reference.outdegrees),
252        ("reference offsets", reference.references),
253        ("block counts", reference.blocks),
254        ("blocks", reference.blocks),
255        ("interval counts", reference.intervals),
256        ("interval starts", reference.intervals),
257        ("interval lengths", reference.intervals),
258        ("first residuals", reference.residuals),
259        ("residuals", reference.residuals),
260    ] {
261        println!("\t{:>18} : {:?}", name, code);
262    }
263
264    let mut new_bits = 0;
265    let mut old_bits = 0;
266    impl_best_code!(
267        new_bits,
268        old_bits,
269        stats,
270        outdegrees -> get_size_by_code(&stats.outdegrees, reference.outdegrees).unwrap(),
271        reference_offsets -> get_size_by_code(&stats.reference_offsets, reference.references).unwrap(),
272        block_counts -> get_size_by_code(&stats.block_counts, reference.blocks).unwrap(),
273        blocks -> get_size_by_code(&stats.blocks, reference.blocks).unwrap(),
274        interval_counts -> get_size_by_code(&stats.interval_counts, reference.intervals).unwrap(),
275        interval_starts -> get_size_by_code(&stats.interval_starts, reference.intervals).unwrap(),
276        interval_lens -> get_size_by_code(&stats.interval_lens, reference.intervals).unwrap(),
277        first_residuals -> get_size_by_code(&stats.first_residuals, reference.residuals).unwrap(),
278        residuals -> get_size_by_code(&stats.residuals, reference.residuals).unwrap()
279    );
280
281    println!();
282    println!(" Old bit size: {:>16}", old_bits);
283    println!(" New bit size: {:>16}", new_bits);
284    println!("   Saved bits: {:>16}", old_bits - new_bits);
285
286    println!("Old byte size: {:>16}", normalize(old_bits as f64 / 8.0));
287    println!("New byte size: {:>16}", normalize(new_bits as f64 / 8.0));
288    println!(
289        "  Saved bytes: {:>16}",
290        normalize((old_bits - new_bits) as f64 / 8.0)
291    );
292
293    println!(
294        "  Improvement: {:>15.3}%",
295        100.0 * (old_bits - new_bits) as f64 / old_bits as f64
296    );
297}
298
299fn normalize(mut value: f64) -> String {
300    let mut uom = ' ';
301    if value > 1000.0 {
302        value /= 1000.0;
303        uom = 'K';
304    }
305    if value > 1000.0 {
306        value /= 1000.0;
307        uom = 'M';
308    }
309    if value > 1000.0 {
310        value /= 1000.0;
311        uom = 'G';
312    }
313    if value > 1000.0 {
314        value /= 1000.0;
315        uom = 'T';
316    }
317    format!("{:.3}{}", value, uom)
318}