1use alloc::boxed::Box;
20use alloc::vec;
21use alloc::vec::Vec;
22use core::any::TypeId;
23use core::fmt::Display;
24use std::sync::Mutex;
25
26use wasefire_error::{Code, Error};
27
28use crate::internal::{Builtin, Rule, RuleEnum, RuleStruct, Rules, Wire};
29use crate::reader::Reader;
30use crate::{helper, internal};
31
32#[derive(Debug, Clone, PartialEq, Eq, wasefire_wire_derive::Wire)]
33#[wire(crate = crate)]
34pub enum View<'a> {
35 Builtin(Builtin),
36 Array(Box<View<'a>>, usize),
37 Slice(Box<View<'a>>),
38 Struct(ViewStruct<'a>),
39 Enum(ViewEnum<'a>),
40 RecUse(usize),
41 RecNew(usize, Box<View<'a>>),
42}
43
44pub type ViewStruct<'a> = Vec<(Option<&'a str>, View<'a>)>;
45pub type ViewEnum<'a> = Vec<(&'a str, u32, ViewStruct<'a>)>;
46
47impl View<'static> {
48 pub fn new<'a, T: Wire<'a>>() -> View<'static> {
49 let mut rules = Rules::default();
50 T::schema(&mut rules);
51 Traverse::new(&rules).extract_or_empty(TypeId::of::<T::Type<'static>>())
52 }
53}
54
55struct Traverse<'a> {
56 rules: &'a Rules,
57 next: usize,
58 path: Vec<(TypeId, Option<usize>)>,
59}
60
61impl<'a> Traverse<'a> {
62 fn new(rules: &'a Rules) -> Self {
63 Traverse { rules, next: 0, path: Vec::new() }
64 }
65
66 fn extract_or_empty(&mut self, id: TypeId) -> View<'static> {
67 match self.extract(id) {
68 Some(x) => x,
69 None => View::Enum(Vec::new()),
70 }
71 }
72
73 fn extract(&mut self, id: TypeId) -> Option<View<'static>> {
74 if let Some((_, rec)) = self.path.iter_mut().find(|(x, _)| *x == id) {
75 let rec = rec.get_or_insert_with(|| {
76 self.next += 1;
77 self.next
78 });
79 return Some(View::RecUse(*rec));
80 }
81 self.path.push((id, None));
82 let result: Option<_> = try {
83 match self.rules.get(id) {
84 Rule::Builtin(x) => View::Builtin(*x),
85 Rule::Alias(_) => unreachable!(),
86 Rule::Array(x, n) => View::Array(Box::new(self.extract(*x)?), *n),
87 Rule::Slice(x) => View::Slice(Box::new(self.extract_or_empty(*x))),
88 Rule::Struct(xs) => View::Struct(self.extract_struct(xs)?),
89 Rule::Enum(xs) => View::Enum(self.extract_enum(xs)),
90 }
91 };
92 let (id_, rec) = self.path.pop().unwrap();
93 assert_eq!(id_, id);
94 let result = result?;
95 Some(match rec {
96 Some(rec) => View::RecNew(rec, Box::new(result)),
97 None => result,
98 })
99 }
100
101 fn extract_struct(&mut self, xs: &RuleStruct) -> Option<ViewStruct<'static>> {
102 xs.iter()
103 .map(|(n, x)| Some((*n, self.extract(*x)?)))
104 .filter(|x| !matches!(x, Some((None, View::Struct(xs))) if xs.is_empty()))
105 .collect()
106 }
107
108 fn extract_enum(&mut self, xs: &RuleEnum) -> ViewEnum<'static> {
109 xs.iter().filter_map(|(n, i, xs)| Some((*n, *i, self.extract_struct(xs)?))).collect()
110 }
111}
112
113impl core::fmt::Display for Builtin {
114 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
115 match self {
116 Builtin::Bool => write!(f, "bool"),
117 Builtin::U8 => write!(f, "u8"),
118 Builtin::I8 => write!(f, "i8"),
119 Builtin::U16 => write!(f, "u16"),
120 Builtin::I16 => write!(f, "i16"),
121 Builtin::U32 => write!(f, "u32"),
122 Builtin::I32 => write!(f, "i32"),
123 Builtin::U64 => write!(f, "u64"),
124 Builtin::I64 => write!(f, "i64"),
125 Builtin::Usize => write!(f, "usize"),
126 Builtin::Isize => write!(f, "isize"),
127 }
128 }
129}
130
131impl core::fmt::Display for View<'_> {
132 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
133 match self {
134 View::Builtin(x) => write!(f, "{x}"),
135 View::Array(x, n) => write!(f, "[{x}; {n}]"),
136 View::Slice(x) => write!(f, "[{x}]"),
137 View::Struct(xs) => write_fields(f, xs),
138 View::Enum(xs) => write_list(f, xs),
139 View::RecUse(n) => write!(f, "<{n}>"),
140 View::RecNew(n, x) => write!(f, "<{n}>:{x}"),
141 }
142 }
143}
144
145#[derive(Debug, Copy, Clone, PartialEq, Eq)]
146pub struct ViewFields<'a>(pub &'a ViewStruct<'a>);
147
148impl core::fmt::Display for ViewFields<'_> {
149 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
150 write_fields(f, self.0)
151 }
152}
153
154trait List {
155 const BEG: char;
156 const END: char;
157 fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
158 fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result;
159}
160
161impl<'a> List for (Option<&'a str>, View<'a>) {
162 const BEG: char = '(';
163 const END: char = ')';
164 fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
165 match self.0 {
166 Some(x) => write!(f, "{x}:"),
167 None => Ok(()),
168 }
169 }
170 fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
171 self.1.fmt(f)
172 }
173}
174
175impl<'a> List for (&'a str, u32, ViewStruct<'a>) {
176 const BEG: char = '{';
177 const END: char = '}';
178 fn fmt_name(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
179 match self.0 {
180 "" => write!(f, "{}:", self.1),
181 _ => write!(f, "{}={}:", self.0, self.1),
182 }
183 }
184 fn fmt_item(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
185 write_fields(f, &self.2)
186 }
187}
188
189fn write_fields(f: &mut core::fmt::Formatter, xs: &[(Option<&str>, View)]) -> core::fmt::Result {
190 if xs.len() == 1 && xs[0].0.is_none() { xs[0].1.fmt(f) } else { write_list(f, xs) }
191}
192
193fn write_list<T: List>(f: &mut core::fmt::Formatter, xs: &[T]) -> core::fmt::Result {
194 write!(f, "{}", T::BEG)?;
195 let mut first = true;
196 for x in xs.iter() {
197 if !first {
198 write!(f, " ")?;
199 }
200 first = false;
201 x.fmt_name(f)?;
202 x.fmt_item(f)?;
203 }
204 write!(f, "{}", T::END)
205}
206
207impl View<'_> {
208 pub fn simplify(&self) -> View<'static> {
225 self.simplify_(RecStack::Root)
226 }
227
228 pub fn simplify_struct(xs: &ViewStruct) -> View<'static> {
229 View::simplify_struct_(xs, RecStack::Root)
230 }
231
232 fn simplify_(&self, rec: RecStack) -> View<'static> {
233 match self {
234 View::Builtin(x) => View::Builtin(*x),
235 View::Array(_, 0) => View::Struct(Vec::new()),
236 View::Array(x, 1) => x.simplify_(rec),
237 View::Array(x, n) => match x.simplify_(rec) {
238 View::Array(x, m) => View::Array(x, n * m),
239 View::Enum(xs) if xs.is_empty() => View::Enum(xs),
240 x => View::Array(Box::new(x), *n),
241 },
242 View::Slice(x) => View::Slice(Box::new(x.simplify_(rec))),
243 View::Struct(xs) => View::simplify_struct_(xs, rec),
244 View::Enum(xs) => {
245 let mut ys = Vec::new();
246 for (_, t, xs) in xs {
247 let xs = match View::simplify_struct_(xs, rec) {
248 View::Struct(xs) => xs,
249 View::Enum(xs) if xs.is_empty() => continue,
250 x => vec![(None, x)],
251 };
252 ys.push(("", *t, xs));
253 }
254 ys.sort_by_key(|(_, t, _)| *t);
255 View::Enum(ys)
256 }
257 View::RecUse(n) => View::RecUse(rec.use_(*n)),
258 View::RecNew(n, x) => View::RecNew(rec.len(), Box::new(x.simplify_(rec.new(*n)))),
259 }
260 }
261
262 fn simplify_struct_(xs: &ViewStruct, rec: RecStack) -> View<'static> {
263 let mut ys = Vec::new();
264 for (_, x) in xs {
265 match x.simplify_(rec) {
266 View::Struct(mut xs) => ys.append(&mut xs),
267 View::Enum(xs) if xs.is_empty() => return View::Enum(xs),
268 y => ys.push((None, y)),
269 }
270 }
271 let mut zs = Vec::new();
272 for (_, y) in ys {
273 let z = match zs.last_mut() {
274 Some((_, z)) => z,
275 None => {
276 zs.push((None, y));
277 continue;
278 }
279 };
280 match (z, y) {
281 (View::Array(x, n), View::Array(y, m)) if *x == y => *n += m,
282 (View::Array(x, n), y) if **x == y => *n += 1,
283 (x, View::Array(y, m)) if *x == *y => *x = View::Array(y, m + 1),
284 (x, y) if *x == y => *x = View::Array(Box::new(y), 2),
285 (_, y) => zs.push((None, y)),
286 }
287 }
288 match zs.len() {
289 1 => zs.pop().unwrap().1,
290 _ => View::Struct(zs),
291 }
292 }
293
294 pub fn validate(&self, data: &[u8]) -> Result<(), Error> {
298 static GLOBAL_LOCK: Mutex<()> = Mutex::new(());
299 let _global_lock = GLOBAL_LOCK.lock().unwrap();
300 let _lock = ViewFrameLock::new(None, self);
301 let _ = crate::decode::<ViewDecoder>(data)?;
302 Ok(())
303 }
304
305 fn decode(&self, reader: &mut Reader) -> Result<(), Error> {
306 match self {
307 View::Builtin(Builtin::Bool) => drop(bool::decode(reader)?),
308 View::Builtin(Builtin::U8) => drop(u8::decode(reader)?),
309 View::Builtin(Builtin::I8) => drop(i8::decode(reader)?),
310 View::Builtin(Builtin::U16) => drop(u16::decode(reader)?),
311 View::Builtin(Builtin::I16) => drop(i16::decode(reader)?),
312 View::Builtin(Builtin::U32) => drop(u32::decode(reader)?),
313 View::Builtin(Builtin::I32) => drop(i32::decode(reader)?),
314 View::Builtin(Builtin::U64) => drop(u64::decode(reader)?),
315 View::Builtin(Builtin::I64) => drop(i64::decode(reader)?),
316 View::Builtin(Builtin::Usize) => drop(usize::decode(reader)?),
317 View::Builtin(Builtin::Isize) => drop(isize::decode(reader)?),
318 View::Array(x, n) => {
319 let _lock = ViewFrameLock::new(None, x);
320 let _ = helper::decode_array_dyn(*n, reader, decode_view)?;
321 }
322 View::Slice(x) => {
323 let _lock = ViewFrameLock::new(None, x);
324 let _ = helper::decode_slice(reader, decode_view)?;
325 }
326 View::Struct(xs) => {
327 for (_, x) in xs {
328 x.decode(reader)?;
329 }
330 }
331 View::Enum(xs) => {
332 let tag = internal::decode_tag(reader)?;
333 let mut found = false;
334 for (_, i, xs) in xs {
335 if tag == *i {
336 assert!(!std::mem::replace(&mut found, true));
337 for (_, x) in xs {
338 x.decode(reader)?;
339 }
340 }
341 }
342 if !found {
343 return Err(Error::user(Code::InvalidArgument));
344 }
345 }
346 View::RecUse(rec) => {
347 let view = VIEW_STACK
348 .lock()
349 .unwrap()
350 .iter()
351 .find(|x| x.rec == Some(*rec))
352 .ok_or(Error::user(Code::InvalidArgument))?
353 .view;
354 view.decode(reader)?;
355 }
356 View::RecNew(rec, x) => {
357 let _lock = ViewFrameLock::new(Some(*rec), x);
358 x.decode(reader)?;
359 }
360 }
361 Ok(())
362 }
363}
364
365static VIEW_STACK: Mutex<Vec<ViewFrame>> = Mutex::new(Vec::new());
366
367#[derive(Copy, Clone)]
368struct ViewFrame {
369 rec: Option<usize>,
370 view: &'static View<'static>,
371}
372
373struct ViewFrameLock(ViewFrame);
374
375impl ViewFrame {
376 fn key(self) -> (Option<usize>, *const View<'static>) {
377 (self.rec, self.view as *const _)
378 }
379}
380
381impl ViewFrameLock {
382 fn new(rec: Option<usize>, view: &View) -> Self {
383 #[expect(clippy::unnecessary_cast)]
385 let view = unsafe { &*(view as *const _ as *const View<'static>) };
387 let frame = ViewFrame { rec, view };
388 VIEW_STACK.lock().unwrap().push(frame);
389 ViewFrameLock(frame)
390 }
391}
392
393impl Drop for ViewFrameLock {
394 fn drop(&mut self) {
395 assert_eq!(VIEW_STACK.lock().unwrap().pop().unwrap().key(), self.0.key());
396 }
397}
398
399struct ViewDecoder;
400
401impl<'a> internal::Wire<'a> for ViewDecoder {
402 type Type<'b> = ViewDecoder;
403 fn schema(_rules: &mut Rules) {
404 unreachable!()
405 }
406 fn encode(&self, _: &mut internal::Writer<'a>) -> internal::Result<()> {
407 unreachable!()
408 }
409 fn decode(reader: &mut Reader<'a>) -> internal::Result<Self> {
410 decode_view(reader)?;
411 Ok(ViewDecoder)
412 }
413}
414
415fn decode_view(reader: &mut Reader) -> Result<(), Error> {
416 let view = VIEW_STACK.lock().unwrap().last().unwrap().view;
417 view.decode(reader)
418}
419
420#[derive(Copy, Clone)]
421enum RecStack<'a> {
422 Root,
423 Binder(usize, &'a RecStack<'a>),
424}
425
426impl<'a> RecStack<'a> {
427 fn use_(&self, x: usize) -> usize {
428 match self {
429 RecStack::Root => unreachable!(),
430 RecStack::Binder(y, r) if x == *y => r.len(),
431 RecStack::Binder(_, r) => r.use_(x),
432 }
433 }
434
435 fn len(&self) -> usize {
436 match self {
437 RecStack::Root => 0,
438 RecStack::Binder(_, r) => 1 + r.len(),
439 }
440 }
441
442 #[allow(clippy::wrong_self_convention)]
443 fn new(&'a self, x: usize) -> RecStack<'a> {
444 RecStack::Binder(x, self)
445 }
446}