vyre_wgpu/engine/decode/codec/
format.rs1use crate::engine::decode::DecodeRules;
4
5
6
7
8#[derive(Clone, Copy, Debug, Eq, PartialEq)]
21#[non_exhaustive]
22pub enum DecodeFormat {
23 Base64,
25 Hex,
27 Url,
29 Unicode,
31}
32
33
34impl DecodeFormat {
35 pub(crate) fn label(self) -> &'static str {
38 match self {
39 Self::Base64 => "vyre decode base64",
40 Self::Hex => "vyre decode hex",
41 Self::Url => "vyre decode url",
42 Self::Unicode => "vyre decode unicode",
43 }
44 }
45
46 pub(crate) fn min_run(self, rules: &DecodeRules) -> u32 {
49 match self {
50 Self::Base64 => rules.min_base64_run,
51 Self::Hex => rules.min_hex_run,
52 Self::Url | Self::Unicode => 0,
53 }
54 }
55
56 pub(crate) fn op_id(self) -> &'static str {
59 match self {
60 Self::Base64 => vyre::ops::decode::base64::Base64Decode::SPEC.id(),
61 Self::Hex => vyre::ops::decode::hex::HexDecode::SPEC.id(),
62 Self::Url => vyre::ops::decode::url::UrlDecode::SPEC.id(),
63 Self::Unicode => vyre::ops::decode::unicode::UnicodeDecode::SPEC.id(),
64 }
65 }
66
67 pub(crate) fn wgsl(self) -> String {
70 match self {
71 Self::Base64 => [DECODE_WGSL_HEADER, BASE64_WGSL_BODY].concat(),
72 Self::Hex => [DECODE_WGSL_HEADER, HEX_WGSL_BODY].concat(),
73 Self::Url => [DECODE_WGSL_HEADER, URL_WGSL_BODY].concat(),
74 Self::Unicode => [DECODE_WGSL_HEADER, UNICODE_WGSL_BODY].concat(),
75 }
76 }
77}
78
79
80pub const DECODE_WGSL_HEADER: &str = r"
82pub struct Params {
83 input_len: u32,
84 min_run: u32,
85 max_regions: u32,
86 output_size: u32,
87};
88
89pub struct RegionMeta {
90 src_offset: u32,
91 src_len: u32,
92 dst_offset: u32,
93 dst_len: u32,
94};
95
96@group(0) @binding(0) var<storage, read> input_words: array<u32>;
97@group(0) @binding(1) var<storage, read_write> regions: array<RegionMeta>;
98@group(0) @binding(2) var<storage, read_write> output_words: array<u32>;
99@group(0) @binding(3) var<storage, read_write> counters: array<atomic<u32>>;
100@group(0) @binding(4) var<uniform> params: Params;
101
102pub fn read_byte(offset: u32) -> u32 {
103 let word = input_words[offset / 4u];
104 let shift = (offset % 4u) * 8u;
105 return (word >> shift) & 0xffu;
106}
107
108pub fn hex_value(byte: u32) -> u32 {
109 if (byte >= 48u && byte <= 57u) { return byte - 48u; }
110 if (byte >= 65u && byte <= 70u) { return byte - 55u; }
111 if (byte >= 97u && byte <= 102u) { return byte - 87u; }
112 return 0xffffffffu;
113}
114
115pub fn emit_region(src_offset: u32, src_len: u32, dst_len: u32, b0: u32, b1: u32, b2: u32) {
116 let region_index = atomicAdd(&counters[0], 1u);
117 if (region_index >= params.max_regions) { return; }
118 let dst_offset = atomicAdd(&counters[1], dst_len);
119 if (dst_offset + dst_len > params.output_size) { return; }
120 regions[region_index] = RegionMeta(src_offset, src_len, dst_offset, dst_len);
121 if (dst_len > 0u) { output_words[dst_offset] = b0; }
122 if (dst_len > 1u) { output_words[dst_offset + 1u] = b1; }
123 if (dst_len > 2u) { output_words[dst_offset + 2u] = b2; }
124}
125";
126
127
128pub const BASE64_WGSL_BODY: &str = r"
130pub fn b64_value(byte: u32) -> u32 {
131 if (byte >= 65u && byte <= 90u) { return byte - 65u; }
132 if (byte >= 97u && byte <= 122u) { return byte - 71u; }
133 if (byte >= 48u && byte <= 57u) { return byte + 4u; }
134 if (byte == 43u) { return 62u; }
135 if (byte == 47u) { return 63u; }
136 return 0xffffffffu;
137}
138
139@compute @workgroup_size(64)
140pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
141 let offset = gid.x;
142 if (offset + 3u >= params.input_len) { return; }
143 let a = b64_value(read_byte(offset));
144 let b = b64_value(read_byte(offset + 1u));
145 let c = b64_value(read_byte(offset + 2u));
146 let d = b64_value(read_byte(offset + 3u));
147 if (a == 0xffffffffu || b == 0xffffffffu || c == 0xffffffffu || d == 0xffffffffu) { return; }
148 let out0 = ((a << 2u) | (b >> 4u)) & 0xffu;
149 let out1 = (((b & 15u) << 4u) | (c >> 2u)) & 0xffu;
150 let out2 = (((c & 3u) << 6u) | d) & 0xffu;
151 emit_region(offset, 4u, 3u, out0, out1, out2);
152}
153";
154
155
156pub const HEX_WGSL_BODY: &str = r"
158@compute @workgroup_size(64)
159pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
160 let offset = gid.x;
161 if (offset + 1u >= params.input_len) { return; }
162 let hi = hex_value(read_byte(offset));
163 let lo = hex_value(read_byte(offset + 1u));
164 if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
165 emit_region(offset, 2u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
166}
167";
168
169
170pub const URL_WGSL_BODY: &str = r"
172@compute @workgroup_size(64)
173pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
174 let offset = gid.x;
175 if (offset + 2u >= params.input_len || read_byte(offset) != 37u) { return; }
176 let hi = hex_value(read_byte(offset + 1u));
177 let lo = hex_value(read_byte(offset + 2u));
178 if (hi == 0xffffffffu || lo == 0xffffffffu) { return; }
179 emit_region(offset, 3u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
180}
181";
182
183
184pub const UNICODE_WGSL_BODY: &str = r"
186@compute @workgroup_size(64)
187pub fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
188 let offset = gid.x;
189 if (offset + 3u >= params.input_len || read_byte(offset) != 92u) { return; }
190 if (read_byte(offset + 1u) == 120u) {
191 let hi = hex_value(read_byte(offset + 2u));
192 let lo = hex_value(read_byte(offset + 3u));
193 if (hi != 0xffffffffu && lo != 0xffffffffu) {
194 emit_region(offset, 4u, 1u, ((hi << 4u) | lo) & 0xffu, 0u, 0u);
195 }
196 return;
197 }
198 if (offset + 5u >= params.input_len || read_byte(offset + 1u) != 117u) { return; }
199 let h0 = hex_value(read_byte(offset + 2u));
200 let h1 = hex_value(read_byte(offset + 3u));
201 let h2 = hex_value(read_byte(offset + 4u));
202 let h3 = hex_value(read_byte(offset + 5u));
203 if (h0 == 0xffffffffu || h1 == 0xffffffffu || h2 == 0xffffffffu || h3 == 0xffffffffu) { return; }
204 emit_region(offset, 6u, 1u, ((h2 << 4u) | h3) & 0xffu, 0u, 0u);
205}
206";