Skip to main content

vyre_wgpu/engine/decode/codec/
format.rs

1//! Supported decode formats and shader source selection.
2
3use crate::engine::decode::DecodeRules;
4
5
6
7
8/// Supported decode format.
9///
10/// This enum is `#[non_exhaustive]` to allow adding new formats (like
11/// Base32 or quoted-printable) without breaking consumers.
12///
13/// # Examples
14///
15/// ```
16/// use vyre_wgpu::engine::decode::DecodeFormat;
17///
18/// assert!(matches!(DecodeFormat::Base64, DecodeFormat::Base64));
19/// ```
20#[derive(Clone, Copy, Debug, Eq, PartialEq)]
21#[non_exhaustive]
22pub enum DecodeFormat {
23    /// RFC 4648 base64.
24    Base64,
25    /// Hexadecimal bytes.
26    Hex,
27    /// URL percent-encoded bytes.
28    Url,
29    /// `\xNN` and `\uNNNN` escape sequences.
30    Unicode,
31}
32
33
34impl DecodeFormat {
35    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
36    /// restricted visibility audit blind spots.
37    pub(crate) fn label(self) -> &'static str {
38        match self {
39            Self::Base64 => "vyre decode base64",
40            Self::Hex => "vyre decode hex",
41            Self::Url => "vyre decode url",
42            Self::Unicode => "vyre decode unicode",
43        }
44    }
45
46    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
47    /// restricted visibility audit blind spots.
48    pub(crate) fn min_run(self, rules: &DecodeRules) -> u32 {
49        match self {
50            Self::Base64 => rules.min_base64_run,
51            Self::Hex => rules.min_hex_run,
52            Self::Url | Self::Unicode => 0,
53        }
54    }
55
56    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
57    /// restricted visibility audit blind spots.
58    pub(crate) fn op_id(self) -> &'static str {
59        match self {
60            Self::Base64 => vyre::ops::decode::base64::Base64Decode::SPEC.id(),
61            Self::Hex => vyre::ops::decode::hex::HexDecode::SPEC.id(),
62            Self::Url => vyre::ops::decode::url::UrlDecode::SPEC.id(),
63            Self::Unicode => vyre::ops::decode::unicode::UnicodeDecode::SPEC.id(),
64        }
65    }
66
67    /// Fixes architecture_deep_audit.md#10/#13: crate-private visibility avoids
68    /// restricted visibility audit blind spots.
69    pub(crate) fn wgsl(self) -> String {
70        match self {
71            Self::Base64 => [DECODE_WGSL_HEADER, BASE64_WGSL_BODY].concat(),
72            Self::Hex => [DECODE_WGSL_HEADER, HEX_WGSL_BODY].concat(),
73            Self::Url => [DECODE_WGSL_HEADER, URL_WGSL_BODY].concat(),
74            Self::Unicode => [DECODE_WGSL_HEADER, UNICODE_WGSL_BODY].concat(),
75        }
76    }
77}
78
79
80/// `DECODE_WGSL_HEADER` constant.
81pub const DECODE_WGSL_HEADER: &str = r"
82pub struct Params {
83    input_len: u32,
84    min_run: u32,
85    max_regions: u32,
86    output_size: u32,
87};
88
89pub struct RegionMeta {
90    src_offset: u32,
91    src_len: u32,
92    dst_offset: u32,
93    dst_len: u32,
94};
95
96@group(0) @binding(0) var<storage, read> input_words: array<u32>;
97@group(0) @binding(1) var<storage, read_write> regions: array<RegionMeta>;
98@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;
99@group(0) @binding(3) var<storage, read_write> counters: array<atomic<u32>>;
100@group(0) @binding(4) var<uniform> params: Params;
101
102pub fn read_byte(offset: u32) -> u32 {
103    let word = input_words[offset / 4u];
104    let shift = (offset % 4u) * 8u;
105    return (word >> shift) & 0xffu;
106}
107
108pub fn hex_value(byte: u32) -> u32 {
109    if (byte >= 48u && byte <= 57u) { return byte - 48u; }
110    if (byte >= 65u && byte <= 70u) { return byte - 55u; }
111    if (byte >= 97u && byte <= 102u) { return byte - 87u; }
112    return 0xffffffffu;
113}
114
115pub fn emit_region(src_offset: u32, src_len: u32, dst_len: u32, b0: u32, b1: u32, b2: u32) {
116    let region_index = atomicAdd(&counters[0], 1u);
117    if (region_index >= params.max_regions) { return; }
118    let dst_offset = atomicAdd(&counters[1], dst_len);
119    if (dst_offset + dst_len > params.output_size) { return; }
120    regions[region_index] = RegionMeta(src_offset, src_len, dst_offset, dst_len);
121    if (dst_len > 0u) { output_words[dst_offset] = b0; }
122    if (dst_len > 1u) { output_words[dst_offset + 1u] = b1; }
123    if (dst_len > 2u) { output_words[dst_offset + 2u] = b2; }
124}
125";
126
127
128/// `BASE64_WGSL_BODY` constant.
129pub const BASE64_WGSL_BODY: &str = r"
130pub fn b64_value(byte: u32) -> u32 {
131    if (byte >= 65u && byte <= 90u) { return byte - 65u; }
132    if (byte >= 97u && byte <= 122u) { return byte - 71u; }
133    if (byte >= 48u && byte <= 57u) { return byte + 4u; }
134    if (byte == 43u) { return 62u; }
135    if (byte == 47u) { return 63u; }
136    return 0xffffffffu;
137}
138
139@compute @workgroup_size(64)
140pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
141    let offset = gid.x;
142    if (offset + 3u >= params.input_len) { return; }
143    let a = b64_value(read_byte(offset));
144    let b = b64_value(read_byte(offset + 1u));
145    let c = b64_value(read_byte(offset + 2u));
146    let d = b64_value(read_byte(offset + 3u));
147    if (a == 0xffffffffu || b == 0xffffffffu || c == 0xffffffffu || d == 0xffffffffu) { return; }
148    let out0 = ((a << 2u) | (b >> 4u)) & 0xffu;
149    let out1 = (((b & 15u) << 4u) | (c >> 2u)) & 0xffu;
150    let out2 = (((c & 3u) << 6u) | d) & 0xffu;
151    emit_region(offset, 4u, 3u, out0, out1, out2);
152}
153";
154
155
156/// `HEX_WGSL_BODY` constant.
157pub const HEX_WGSL_BODY: &str = r"
158@compute @workgroup_size(64)
159pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
160    let offset = gid.x;
161    if (offset + 1u >= params.input_len) { return; }
162    let hi = hex_value(read_byte(offset));
163    let lo = hex_value(read_byte(offset + 1u));
164    if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
165    emit_region(offset, 2u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
166}
167";
168
169
170/// `URL_WGSL_BODY` constant.
171pub const URL_WGSL_BODY: &str = r"
172@compute @workgroup_size(64)
173pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
174    let offset = gid.x;
175    if (offset + 2u >= params.input_len || read_byte(offset) != 37u) { return; }
176    let hi = hex_value(read_byte(offset + 1u));
177    let lo = hex_value(read_byte(offset + 2u));
178    if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
179    emit_region(offset, 3u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
180}
181";
182
183
184/// `UNICODE_WGSL_BODY` constant.
185pub const UNICODE_WGSL_BODY: &str = r"
186@compute @workgroup_size(64)
187pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
188    let offset = gid.x;
189    if (offset + 3u >= params.input_len || read_byte(offset) != 92u) { return; }
190    if (read_byte(offset + 1u) == 120u) {
191        let hi = hex_value(read_byte(offset + 2u));
192        let lo = hex_value(read_byte(offset + 3u));
193        if (hi != 0xffffffffu && lo != 0xffffffffu) {
194            emit_region(offset, 4u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
195        }
196        return;
197    }
198    if (offset + 5u >= params.input_len || read_byte(offset + 1u) != 117u) { return; }
199    let h0 = hex_value(read_byte(offset + 2u));
200    let h1 = hex_value(read_byte(offset + 3u));
201    let h2 = hex_value(read_byte(offset + 4u));
202    let h3 = hex_value(read_byte(offset + 5u));
203    if (h0 == 0xffffffffu || h1 == 0xffffffffu || h2 == 0xffffffffu || h3 == 0xffffffffu) { return; }
204    emit_region(offset, 6u, 1u, ((h2 << 4u) | h3) & 0xffu, 0u, 0u);
205}
206";