Skip to main content

vyre_wgpu/engine/decode/
mod.rs

1// Bulk byte decoding engine.
2//
3// NOTE: This is a host-side workflow dispatcher, not an IR op domain. It
4// accepts runtime bytes and TOML-configured decode rules, owns GPU buffers,
5// dispatches format-specific kernels, and returns decoded byte regions. The
6// IR-side decode operations live under `vyre::ops::decode`; those produce
7// `Program` values that go through validate and lower.
8
9/// The `codec` module.
10pub mod codec;
11/// The `dispatch` module.
12pub mod dispatch;
13
14// Entropy helpers used by decode region discovery.
15//
16// NOTE: This is a host-side CPU helper, not part of the vyre IR. The
17// IR-side entropy operation lives in `vyre::ops::hash::entropy`.
18
19/// `MAX_WINDOW_SIZE` constant.
20pub const MAX_WINDOW_SIZE: usize = 256;
21
22/// `DEFAULT_REGION_EXPANSION` constant.
23pub const DEFAULT_REGION_EXPANSION: usize = 256;
24
25/// `MAX_INPUT_BYTES` constant.
26pub const MAX_INPUT_BYTES: usize = 64 * 1024 * 1024;
27
28/// Error type for entropy computation.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum EntropyError {
31    /// Input length exceeds the maximum allowed size.
32    InputTooLarge,
33}
34
35impl core::fmt::Display for EntropyError {
36    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
37        match self {
38            Self::InputTooLarge => write!(
39                f,
40                "input length exceeds 64 MiB. Fix: split the input into smaller chunks."
41            ),
42        }
43    }
44}
45
46impl std::error::Error for EntropyError {}
47
48/// Uses f64 for the size division: f32 mantissa loses precision on files > 16 MB and the entropy estimate drifts.
49pub fn shannon_entropy(bytes: &[u8]) -> f32 {
50    if bytes.is_empty() {
51        return 0.0;
52    }
53    let mut counts = [0u32; 256];
54    for &b in bytes {
55        counts[b as usize] = counts[b as usize].saturating_add(1);
56    }
57    let total = bytes.len() as f64;
58    let mut entropy = 0.0_f64;
59    for &count in &counts {
60        if count == 0 {
61            continue;
62        }
63        let p = count as f64 / total;
64        entropy -= p * p.log2();
65    }
66    entropy as f32
67}
68
69/// Compute Shannon entropy for each sliding window on CPU.
70pub fn entropy_map_cpu(
71    data: &[u8],
72    window_size: usize,
73) -> std::result::Result<Vec<f32>, EntropyError> {
74    if data.len() > MAX_INPUT_BYTES {
75        return Err(EntropyError::InputTooLarge);
76    }
77    if data.is_empty()
78        || window_size == 0
79        || window_size > data.len()
80        || window_size > MAX_WINDOW_SIZE
81    {
82        return Ok(Vec::new());
83    }
84    let windows: Vec<f32> = (0..=data.len() - window_size)
85        .map(|start| shannon_entropy(&data[start..start + window_size]))
86        .collect();
87    Ok(windows)
88}
89
90/// Convert entropy values to contiguous high-entropy regions.
91pub fn find_high_entropy_regions(entropy: &[f32], threshold: f32) -> Vec<(usize, usize)> {
92    find_high_entropy_regions_with_window(entropy, threshold, DEFAULT_REGION_EXPANSION)
93}
94
95/// `find_high_entropy_regions_with_window` function.
96pub fn find_high_entropy_regions_with_window(
97    entropy: &[f32],
98    threshold: f32,
99    window_size: usize,
100) -> Vec<(usize, usize)> {
101    let mut regions = Vec::new();
102    let mut run_start = None;
103    for (offset, value) in entropy.iter().enumerate() {
104        match (*value > threshold, run_start) {
105            (true, None) => run_start = Some(offset),
106            (false, Some(start)) => {
107                regions.push((start, offset.saturating_add(window_size)));
108                run_start = None;
109            }
110            _ => {}
111        }
112    }
113    if let Some(start) = run_start {
114        regions.push((start, entropy.len().saturating_add(window_size)));
115    }
116    regions
117}
118
119// Host-side recursive decode frontier management.
120//
121// NOTE: This is NOT part of the vyre IR. It runs on the CPU around decode
122// GPU dispatches, tracks decoded-region frontiers, rejects malformed region
123// bounds, and deduplicates recursive decode work. It does not produce a
124// Program, does not go through validate or lower, and is not registered in
125// the op registry.
126
127use std::collections::{HashSet, VecDeque};
128use std::hash::{Hash, Hasher};
129use vyre::{Error, Result};
130
131/// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
132/// restricted visibility audit blind spots.
133pub(crate) fn recursive_decode<F>(
134    file_bytes: &[u8],
135    rules: &DecodeRules,
136    mut decode_one: F,
137) -> Result<Vec<DecodedRegion>>
138where
139    F: FnMut(DecodeFormat, &[u8], &DecodeRules) -> Result<Vec<DecodedRegion>>,
140{
141    if rules.max_passes == 0 {
142        return Err(Error::Decode {
143            message: "max_passes must be at least 1. Fix: call DecodeRules::validate before dispatch or set max_passes to a positive value.".to_string(),
144        });
145    }
146    let mut visited_hashes = HashSet::<u64>::from([stable_hash(file_bytes)]);
147    let mut seen_regions = HashSet::<(usize, usize)>::new();
148    let mut frontier = VecDeque::from([(0usize, file_bytes.to_vec())]);
149    let mut all_regions = Vec::<DecodedRegion>::new();
150
151    for _ in 0..rules.max_passes {
152        let mut next_frontier = VecDeque::new();
153        let mut progress = false;
154        while let Some((base_offset, bytes)) = frontier.pop_front() {
155            let mut state = FrontierState {
156                seen_regions: &mut seen_regions,
157                visited_hashes: &mut visited_hashes,
158                next_frontier: &mut next_frontier,
159                all_regions: &mut all_regions,
160                progress: &mut progress,
161            };
162            decode_frontier(base_offset, &bytes, rules, &mut decode_one, &mut state)?;
163        }
164        if !progress {
165            break;
166        }
167        frontier = next_frontier;
168    }
169    all_regions.sort_by(|left, right| {
170        left.offset
171            .cmp(&right.offset)
172            .then(left.length.cmp(&right.length))
173            .then(left.decoded_bytes.cmp(&right.decoded_bytes))
174    });
175    Ok(all_regions)
176}
177
178/// `decode_frontier` function.
179pub fn decode_frontier<F>(
180    base_offset: usize,
181    bytes: &[u8],
182    rules: &DecodeRules,
183    decode_one: &mut F,
184    state: &mut FrontierState<'_>,
185) -> Result<()>
186where
187    F: FnMut(DecodeFormat, &[u8], &DecodeRules) -> Result<Vec<DecodedRegion>>,
188{
189    for format in [
190        DecodeFormat::Base64,
191        DecodeFormat::Hex,
192        DecodeFormat::Url,
193        DecodeFormat::Unicode,
194    ] {
195        for region in decode_one(format, bytes, rules)? {
196            push_region(base_offset, bytes, region, state)?;
197        }
198    }
199    Ok(())
200}
201
202/// `FrontierState` struct.
203pub struct FrontierState<'a> {
204    seen_regions: &'a mut HashSet<(usize, usize)>,
205    visited_hashes: &'a mut HashSet<u64>,
206    next_frontier: &'a mut VecDeque<(usize, Vec<u8>)>,
207    all_regions: &'a mut Vec<DecodedRegion>,
208    progress: &'a mut bool,
209}
210
211/// `push_region` function.
212pub fn push_region(
213    base_offset: usize,
214    bytes: &[u8],
215    region: DecodedRegion,
216    state: &mut FrontierState<'_>,
217) -> Result<()> {
218    let source_end = region
219        .offset
220        .checked_add(region.length)
221        .ok_or_else(|| Error::Decode {
222            message: "region overflow while validating source bounds. Fix: ensure the GPU decoder returns offset + length within usize bounds.".to_string(),
223        })?;
224    if source_end > bytes.len() {
225        return Err(Error::Decode {
226            message: "decoder returned a region beyond input bounds. Fix: report the decoder shader output and reject this malformed region.".to_string(),
227        });
228    }
229    if region.decoded_bytes == bytes[region.offset..source_end] {
230        return Ok(());
231    }
232    let normalized = DecodedRegion {
233        offset: base_offset + region.offset,
234        length: region.length,
235        decoded_bytes: region.decoded_bytes,
236    };
237    if state
238        .seen_regions
239        .insert((normalized.offset, normalized.length))
240    {
241        *state.progress = true;
242        let hash = stable_hash(&normalized.decoded_bytes);
243        if state.visited_hashes.insert(hash) {
244            state
245                .next_frontier
246                .push_back((normalized.offset, normalized.decoded_bytes.clone()));
247        }
248        state.all_regions.push(normalized);
249    }
250    Ok(())
251}
252
253/// `stable_hash` function.
254pub fn stable_hash(bytes: &[u8]) -> u64 {
255    let mut hasher = std::collections::hash_map::DefaultHasher::new();
256    bytes.hash(&mut hasher);
257    hasher.finish()
258}
259
260/// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
261/// restricted visibility audit blind spots.
262pub(crate) fn flatten_regions(regions: Vec<DecodedRegion>) -> Vec<u8> {
263    regions
264        .into_iter()
265        .flat_map(|region| region.decoded_bytes)
266        .collect()
267}
268// Decoded region metadata.
269
270/// A decoded region produced by one decode pass.
271///
272/// This struct is `#[non_exhaustive]` to allow adding new region metadata
273/// (like character encoding or confidence scores) without breaking consumers.
274///
275/// # Examples
276///
277/// ```
278/// use vyre_wgpu::engine::decode::DecodedRegion;
279///
280/// let region = DecodedRegion::new(0, 4, vec![1, 2, 3]);
281/// assert_eq!(region.decoded_bytes.len(), 3);
282/// ```
283#[derive(Debug, Clone, PartialEq, Eq, Hash)]
284#[non_exhaustive]
285pub struct DecodedRegion {
286    /// Source offset of the encoded region.
287    pub offset: usize,
288    /// Source length of the encoded region.
289    pub length: usize,
290    /// Decoded bytes emitted for the region.
291    pub decoded_bytes: Vec<u8>,
292}
293
294impl DecodedRegion {
295    /// Create a decoded region.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// use vyre_wgpu::engine::decode::DecodedRegion;
301    ///
302    /// let region = DecodedRegion::new(5, 3, vec![0x20]);
303    /// assert_eq!(region.offset, 5);
304    /// ```
305    #[must_use]
306    pub fn new(offset: usize, length: usize, decoded_bytes: Vec<u8>) -> Self {
307        Self {
308            offset,
309            length,
310            decoded_bytes,
311        }
312    }
313}
314// TOML decode rule validation.
315
316use serde::Deserialize;
317
318impl DecodeRules {
319    /// Create decode rules with explicit values.
320    ///
321    /// Call [`validate`](Self::validate) to check that the values are acceptable
322    /// before using the rules for decode work.
323    ///
324    /// # Examples
325    ///
326    /// ```
327    /// use vyre_wgpu::engine::decode::DecodeRules;
328    ///
329    /// let rules = DecodeRules::with_values(12, 16, 4);
330    /// assert_eq!(rules.min_base64_run, 12);
331    /// ```
332    #[must_use]
333    pub fn with_values(min_base64_run: u32, min_hex_run: u32, max_passes: u32) -> Self {
334        Self {
335            min_base64_run,
336            min_hex_run,
337            max_passes,
338        }
339    }
340
341    /// Parse decode rules from a TOML document.
342    ///
343    /// # Errors
344    ///
345    /// Returns `Error::DecodeConfig` if the TOML is unparsable or the rules fail validation.
346    pub fn from_toml(toml_source: &str) -> Result<Self> {
347        let rules = toml::from_str::<Self>(toml_source).map_err(|error| {
348            Error::DecodeConfig {
349                message: format!("failed to parse decode rules TOML: {error}. Fix: correct the TOML syntax and provide min_base64_run, min_hex_run, and max_passes values."),
350            }
351        })?;
352        rules.validate().map_err(|error| Error::DecodeConfig {
353            message: error.to_string(),
354        })?;
355        Ok(rules)
356    }
357
358    /// Validate thresholds before CPU or GPU work starts.
359    ///
360    /// # Errors
361    ///
362    /// Returns `Error::DecodeConfig` if any threshold is out of range.
363    pub fn validate(&self) -> std::result::Result<(), DecodeError> {
364        if self.min_base64_run < 4 {
365            return Err(DecodeError::MinBase64RunTooSmall);
366        }
367        if self.min_hex_run < 2 {
368            return Err(DecodeError::MinHexRunTooSmall);
369        }
370        if self.max_passes == 0 {
371            return Err(DecodeError::MaxPassesZero);
372        }
373        if self.max_passes > 64 {
374            return Err(DecodeError::MaxPassesOutOfRange);
375        }
376        Ok(())
377    }
378}
379
380/// Error type for decode rule validation.
381#[derive(Debug, Clone, Copy, PartialEq, Eq)]
382pub enum DecodeError {
383    /// `min_base64_run` is below the minimum threshold.
384    MinBase64RunTooSmall,
385    /// `min_hex_run` is below the minimum threshold.
386    MinHexRunTooSmall,
387    /// `max_passes` is zero.
388    MaxPassesZero,
389    /// `max_passes` exceeds the allowed upper bound.
390    MaxPassesOutOfRange,
391}
392
393impl core::fmt::Display for DecodeError {
394    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
395        match self {
396            Self::MinBase64RunTooSmall => write!(
397                f,
398                "min_base64_run must be at least 4 to preserve base64 quartets. Fix: set min_base64_run to 4 or greater."
399            ),
400            Self::MinHexRunTooSmall => write!(
401                f,
402                "min_hex_run must be at least 2 to preserve full bytes. Fix: set min_hex_run to 2 or greater."
403            ),
404            Self::MaxPassesZero => write!(
405                f,
406                "max_passes must be greater than zero. Fix: set max_passes to at least 1."
407            ),
408            Self::MaxPassesOutOfRange => write!(
409                f,
410                "max_passes must be at most 64. Fix: set max_passes to 64 or lower."
411            ),
412        }
413    }
414}
415
416impl std::error::Error for DecodeError {}
417
418/// TOML-configurable decode thresholds and recursion limits.
419///
420/// This struct is `#[non_exhaustive]` to allow adding new configuration fields
421/// (like per-format recursion caps) without breaking downstream consumers.
422///
423/// # Examples
424///
425/// ```
426/// use vyre_wgpu::engine::decode::DecodeRules;
427///
428/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
429/// let rules = DecodeRules::from_toml("min_base64_run = 12\nmin_hex_run = 16\nmax_passes = 4")?;
430/// assert_eq!(rules.min_base64_run, 12);
431/// # Ok(())
432/// # }
433/// ```
434#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
435#[non_exhaustive]
436pub struct DecodeRules {
437    /// Minimum contiguous base64 run length to attempt decoding.
438    pub min_base64_run: u32,
439    /// Minimum contiguous hex run length to attempt decoding.
440    pub min_hex_run: u32,
441    /// Maximum recursive decode passes.
442    pub max_passes: u32,
443}
444
445impl Default for DecodeRules {
446    fn default() -> Self {
447        Self {
448            min_base64_run: 8,
449            min_hex_run: 8,
450            max_passes: 8,
451        }
452    }
453}
454
455pub use codec::decoder::{
456    decode_base64, decode_bytes, decode_file, decode_file_with_rules, decode_hex, decode_regions,
457    decode_unicode, decode_url, GpuDecoder,
458};
459pub use codec::format::DecodeFormat;