1use std::collections::{BTreeMap, HashMap};
2
3use darling::util::{SpannedValue, parse_expr};
4use darling::{Error, FromMeta};
5use proc_macro::TokenStream;
6use quote::{ToTokens, quote};
7use syn::ItemImpl;
8
9#[derive(Debug, FromMeta)]
10struct VmInstrArgs {
11 code: SpannedValue<String>,
12
13 #[darling(with = parse_expr::preserve_str_literal)]
14 fmt: syn::Expr,
15
16 #[darling(default)]
17 args: HashMap<String, syn::Expr>,
18 #[darling(default)]
19 cond: Option<syn::Expr>,
20}
21
22#[derive(Debug, FromMeta)]
23struct VmExtInstrArgs {
24 code: syn::Expr,
25 code_bits: syn::Expr,
26 arg_bits: syn::Expr,
27 dump_with: syn::Expr,
28}
29
30#[derive(Debug, FromMeta)]
31struct VmExtRangeInstrArgs {
32 code_min: syn::Expr,
33 code_max: syn::Expr,
34 total_bits: syn::Expr,
35 dump_with: syn::Expr,
36}
37
38#[proc_macro_attribute]
39pub fn vm_module(_: TokenStream, input: TokenStream) -> TokenStream {
40 let mut input = syn::parse_macro_input!(input as ItemImpl);
41
42 let opcodes_arg = quote::format_ident!("__t");
43
44 let mut definitions = Vec::new();
45 let mut errors = Vec::new();
46
47 let mut init_function_names = Vec::new();
48 let mut init_functions = Vec::new();
49 let mut other_functions = Vec::new();
50
51 let mut opcodes = Opcodes::default();
52
53 for impl_item in input.items.drain(..) {
54 let syn::ImplItem::Fn(mut fun) = impl_item else {
55 other_functions.push(impl_item);
56 continue;
57 };
58
59 let mut has_init = false;
60
61 let mut instr_attrs = Vec::with_capacity(fun.attrs.len());
62 let mut ext_instr_attrs = Vec::new();
63 let mut ext_range_instr_attrs = Vec::new();
64 let mut remaining_attr = Vec::new();
65 for attr in fun.attrs.drain(..) {
66 if let Some(path) = attr.meta.path().get_ident() {
67 if path == "op" {
68 instr_attrs.push(attr);
69 continue;
70 } else if path == "op_ext" {
71 ext_instr_attrs.push(attr);
72 continue;
73 } else if path == "op_ext_range" {
74 ext_range_instr_attrs.push(attr);
75 continue;
76 } else if path == "init" {
77 has_init = true;
78 continue;
79 }
80 }
81
82 remaining_attr.push(attr);
83 }
84 fun.attrs = remaining_attr;
85
86 if has_init {
87 fun.sig.ident = quote::format_ident!("__{}", fun.sig.ident);
88 init_function_names.push(fun.sig.ident.clone());
89 init_functions.push(fun);
90 } else {
91 for attr in instr_attrs {
92 match process_instr_definition(&fun, &opcodes_arg, &attr, &mut opcodes) {
93 Ok(definition) => definitions.push(definition),
94 Err(e) => errors.push(e.with_span(&attr)),
95 }
96 }
97
98 for attr in ext_instr_attrs {
99 match process_ext_instr_definition(&fun, &opcodes_arg, &attr) {
100 Ok(definition) => definitions.push(definition),
101 Err(e) => errors.push(e.with_span(&attr)),
102 }
103 }
104
105 for attr in ext_range_instr_attrs {
106 match process_ext_range_instr_definition(&fun, &opcodes_arg, &attr) {
107 Ok(definition) => definitions.push(definition),
108 Err(e) => errors.push(e.with_span(&attr)),
109 }
110 }
111
112 other_functions.push(syn::ImplItem::Fn(fun));
113 }
114 }
115
116 if !errors.is_empty() {
117 return TokenStream::from(Error::multiple(errors).write_errors());
118 }
119
120 let ty = input.self_ty;
121 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
122
123 quote! {
124 impl #impl_generics #ty #ty_generics #where_clause {
125 #(#init_functions)*
126 }
127
128 #[automatically_derived]
129 impl #impl_generics ::tycho_vm::instr::Module for #ty #ty_generics #where_clause {
130 fn init(&self, #opcodes_arg: &mut ::tycho_vm::dispatch::Opcodes) -> ::anyhow::Result<()> {
131 #(self.#init_function_names(#opcodes_arg)?;)*
132 #(#definitions)*
133 Ok(())
134 }
135 }
136
137 #(#other_functions)*
138 }
139 .into()
140}
141
142struct ParsedCode<'a> {
143 code: &'a str,
144 range_from: Option<&'a str>,
145 range_to: Option<&'a str>,
146}
147
148impl<'a> ParsedCode<'a> {
149 fn from_str(s: &'a SpannedValue<String>) -> Result<Self, Error> {
150 match s.split_once('@') {
151 None => Ok(Self {
152 code: s.as_str(),
153 range_from: None,
154 range_to: None,
155 }),
156 Some((code, range)) => {
157 let Some((range_from, range_to)) = range.split_once("..") else {
158 return Err(
159 Error::custom("expected an opcode range after `@`").with_span(&s.span())
160 );
161 };
162 let range_from = range_from.trim();
163 let range_to = range_to.trim();
164
165 Ok(Self {
166 code: code.trim(),
167 range_from: (!range_from.is_empty()).then_some(range_from),
168 range_to: (!range_to.is_empty()).then_some(range_to),
169 })
170 }
171 }
172 }
173}
174
175fn process_instr_definition(
176 function: &syn::ImplItemFn,
177 opcodes_arg: &syn::Ident,
178 attr: &syn::Attribute,
179 opcodes: &mut Opcodes,
180) -> Result<syn::Expr, Error> {
181 let mut instr = VmInstrArgs::from_meta(&attr.meta)?;
182 let parsed = ParsedCode::from_str(&instr.code)?;
183
184 let mut opcode_bits = 0u16;
185 let mut opcode_base_min = 0;
186 let mut binary_mode = false;
187 let mut args = Vec::<(char, u16)>::new();
188 for c in parsed.code.chars() {
189 if c.is_whitespace() || c == '_' {
190 continue;
191 }
192 if c == '$' {
193 binary_mode = true;
194 }
195
196 match c {
197 '$' => {
198 binary_mode = true;
199 continue;
200 }
201 '#' => {
202 binary_mode = false;
203 continue;
204 }
205 c if c.is_whitespace() || c == '_' => {
206 continue;
207 }
208 c if c.is_ascii_alphanumeric() => {}
209 _ => {
210 return Err(
211 Error::custom("Invalid pattern for the opcode").with_span(&instr.code.span())
212 );
213 }
214 }
215
216 let (radix, symbol_bits) = if binary_mode { (2, 1) } else { (16, 4) };
217
218 opcode_base_min <<= symbol_bits;
219
220 if let Some(c) = c.to_digit(radix) {
221 if !args.is_empty() {
222 return Err(
223 Error::custom("Invalid pattern for the opcode").with_span(&instr.code.span())
224 );
225 }
226
227 opcode_bits += symbol_bits;
228 opcode_base_min |= c;
229 } else {
230 if let Some((last, last_bits)) = args.last_mut()
231 && *last == c
232 {
233 *last_bits += symbol_bits;
234 continue;
235 }
236
237 args.push((c, symbol_bits));
238 }
239 }
240 let arg_bits = args.iter().map(|(_, bits)| bits).sum::<u16>();
241
242 if opcode_bits == 0 {
243 return Err(Error::custom("Opcode must have a non-empty fixed prefix")
244 .with_span(&instr.code.span()));
245 }
246
247 let total_bits = opcode_bits + arg_bits;
248 if total_bits as usize > MAX_OPCODE_BITS {
249 return Err(Error::custom(format!(
250 "Too much bits for the opcode: {opcode_bits}/{MAX_OPCODE_BITS}"
251 ))
252 .with_span(&instr.code.span()));
253 }
254 let n = (total_bits / 4) as usize;
255
256 let opcode_base_max = (opcode_base_min | ((1 << arg_bits) - 1)) + 1;
257
258 let remaining_bits = MAX_OPCODE_BITS - total_bits as usize;
259
260 let mut range = OpcodeRange {
261 span: instr.code.span(),
262 aligned_opcode_min: opcode_base_min << remaining_bits,
263 aligned_opcode_max: opcode_base_max << remaining_bits,
264 total_bits,
265 };
266
267 let function_name = function.sig.ident.clone();
268 let fmt = match instr.fmt {
269 syn::Expr::Tuple(items) => items.elems.into_token_stream(),
270 syn::Expr::Lit(expr) if matches!(&expr.lit, syn::Lit::Str(..)) => expr.into_token_stream(),
271 fmt => quote! { "{}", #fmt },
272 };
273
274 let ty = match (!args.is_empty(), parsed.range_from, parsed.range_to) {
275 (false, range_from, range_to) => {
276 let mut errors = Vec::new();
277 if range_from.is_some() {
278 errors.push(
279 Error::custom("Unexpected `range_from` for a simple opcode")
280 .with_span(&instr.code.span()),
281 );
282 }
283 if range_to.is_some() {
284 errors.push(
285 Error::custom("Unexpected `range_to` for a simple opcode")
286 .with_span(&instr.code.span()),
287 );
288 }
289
290 if errors.is_empty() {
291 opcodes.add_opcode(range)?;
292 OpcodeTy::Simple {
293 opcode: opcode_base_min,
294 bits: opcode_bits,
295 }
296 } else {
297 return Err(Error::multiple(errors));
298 }
299 }
300 (true, None, None) => {
301 opcodes.add_opcode(range)?;
302 OpcodeTy::Fixed {
303 opcode: opcode_base_min >> arg_bits,
304 opcode_bits,
305 arg_bits,
306 }
307 }
308 (true, range_from, range_to) => {
309 let opcode_min = if let Some(range_from) = range_from {
310 let range_from_span = &instr.code.span();
311
312 let range_from_bits = range_from.len() * 4;
313 let range_from = u32::from_str_radix(range_from, 16).map_err(|e| {
314 Error::custom(format!("Invalid `range_from` value: {e}"))
315 .with_span(range_from_span)
316 })?;
317
318 if range_from_bits != total_bits as usize {
319 return Err(Error::custom(format!(
320 "Invalid `range_from` size in bits. Expected {total_bits}, got {range_from_bits}",
321 )).with_span(range_from_span));
322 }
323 if range_from <= opcode_base_min {
324 return Err(Error::custom(format!(
325 "`range_from` must be greater than opcode base. Opcode base: {:0n$x}",
326 opcode_base_min >> arg_bits
327 ))
328 .with_span(range_from_span));
329 }
330 if range_from >= opcode_base_max {
331 return Err(Error::custom(format!(
332 "`range_from` must be less than opcode max value. Opcode max value: {:0n$x}",
333 opcode_base_max >> arg_bits
334 ))
335 .with_span(range_from_span));
336 }
337
338 range.aligned_opcode_min = range_from << remaining_bits;
339 range_from
340 } else {
341 opcode_base_min
342 };
343
344 let opcode_max = if let Some(range_to) = range_to {
345 let range_to_span = &instr.code.span();
346
347 let range_to_bits = range_to.len() * 4;
348 let range_to = u32::from_str_radix(range_to, 16).map_err(|e| {
349 Error::custom(format!("Invalid `range_to` value: {e}")).with_span(range_to_span)
350 })?;
351
352 if range_to_bits != total_bits as usize {
353 return Err(Error::custom(format!(
354 "Invalid `range_to` size in bits. Expected {total_bits}, got {range_to_bits}",
355 ))
356 .with_span(range_to_span));
357 }
358 if range_to <= opcode_min {
359 return Err(Error::custom(format!(
360 "`range_to` must be greater than opcode base. Opcode base: {:0n$x}",
361 opcode_min >> arg_bits
362 ))
363 .with_span(range_to_span));
364 }
365 if range_to >= opcode_base_max {
366 return Err(Error::custom(format!(
367 "`range_to` must be less than opcode max value. Opcode max value: {:0n$x}",
368 opcode_base_max >> arg_bits
369 ))
370 .with_span(range_to_span));
371 }
372
373 range.aligned_opcode_max = range_to << remaining_bits;
374 range_to
375 } else {
376 opcode_base_max
377 };
378
379 opcodes.add_opcode(range)?;
380
381 OpcodeTy::FixedRange {
382 opcode_min,
383 opcode_max,
384 total_bits,
385 arg_bits,
386 }
387 }
388 };
389
390 let (arg_definitions, arg_idents) = {
391 let mut shift = arg_bits as u32;
392
393 let function_arg_count = function.sig.inputs.len().saturating_sub(1);
394
395 let mut errors = Vec::new();
396 let mut opcode_args = args.iter().peekable();
397 let mut arg_definitions = Vec::with_capacity(function_arg_count);
398 let mut arg_idents = Vec::with_capacity(function_arg_count);
399
400 #[allow(clippy::never_loop)] for function_arg in function.sig.inputs.iter().skip(1) {
402 let ty;
403 let name = if let syn::FnArg::Typed(input) = function_arg
404 && let syn::Pat::Ident(pat) = &*input.pat
405 {
406 ty = &input.ty;
407 pat.ident.to_string()
408 } else {
409 return Err(Error::custom("Unsupported argument binding").with_span(&function_arg));
410 };
411
412 let explicit_arg = instr.args.remove(&name);
413
414 match opcode_args.peek() {
415 Some((opcode_arg, bits)) => {
416 if opcode_arg.to_string() != name {
417 if let Some(expr) = explicit_arg {
418 let ident = quote::format_ident!("{name}");
419 arg_definitions.push(quote! { let #ident: #ty = #expr; });
420 arg_idents.push(ident);
421 continue;
422 }
423
424 return Err(Error::custom(format!("Expected argument `{opcode_arg}`"))
425 .with_span(&function_arg));
426 }
427
428 let ident = quote::format_ident!("{name}");
429
430 shift -= *bits as u32;
431 arg_definitions.push(match explicit_arg {
432 None if *bits == 1 => {
433 quote! { let #ident: #ty = (args >> #shift) & 0b1 != 0; }
434 }
435 None => {
436 let mask = (1u32 << *bits) - 1;
437 quote! { let #ident: #ty = (args >> #shift) & #mask; }
438 }
439 Some(expr) => {
440 quote! { let #ident: #ty = #expr; }
441 }
442 });
443 arg_idents.push(ident);
444
445 opcode_args.next();
446 }
447 None => match explicit_arg {
448 Some(expr) => {
449 let ident = quote::format_ident!("{name}");
450 arg_definitions.push(quote! { let #ident: #ty = #expr; });
451 arg_idents.push(ident);
452 }
453 None => {
454 errors.push(Error::custom("Unexpected argument").with_span(&function_arg));
455 }
456 },
457 }
458 }
459
460 for (unused_arg, _) in opcode_args {
461 errors.push(
462 Error::custom(format_args!("Unused opcode arg `{unused_arg}`"))
463 .with_span(&instr.code.span()),
464 )
465 }
466 for (unused_arg, expr) in instr.args {
467 errors.push(
468 Error::custom(format_args!("Unused arg override for {unused_arg}"))
469 .with_span(&expr),
470 )
471 }
472 if !errors.is_empty() {
473 return Err(Error::multiple(errors));
474 }
475
476 (arg_definitions, arg_idents)
477 };
478
479 let wrapper_func_name = quote::format_ident!("{function_name}_wrapper");
480
481 #[cfg(feature = "dump")]
482 let dump_func_name = quote::format_ident!("dump_{function_name}");
483 #[cfg(feature = "dump")]
484 let dump_func;
485
486 let wrapper_func = match &ty {
487 OpcodeTy::Simple { .. } => {
488 if let Some(cond) = instr.cond {
489 return Err(
490 Error::custom("Unexpected condition for simple opcode").with_span(&cond)
491 );
492 }
493
494 #[cfg(feature = "dump")]
495 {
496 dump_func = quote! {
497 fn #dump_func_name(__f: &mut dyn ::tycho_vm::DumpOutput) -> ::tycho_vm::error::DumpResult {
498 #(#arg_definitions)*
499 __f.record_opcode(&format_args!(#fmt))
500 }
501 };
502 }
503
504 quote! {
505 fn #wrapper_func_name(st: &mut ::tycho_vm::state::VmState) -> ::tycho_vm::error::VmResult<i32> {
506 #(#arg_definitions)*
507 vm_log_op!(#fmt);
508 #function_name(st, #(#arg_idents),*)
509 }
510 }
511 }
512 OpcodeTy::Fixed { .. } | OpcodeTy::FixedRange { .. } => {
513 let cond = instr.cond.as_ref().map(|cond| {
514 quote! { vm_ensure!(#cond, InvalidOpcode); }
515 });
516
517 #[cfg(feature = "dump")]
518 {
519 let dump_cond = instr.cond.map(|cond| {
520 quote! {
521 if crate::__private::not(#cond) {
522 return Err(::tycho_vm::error::DumpError::InvalidOpcode);
523 }
524 }
525 });
526
527 dump_func = quote! {
528 fn #dump_func_name(args: u32, __f: &mut dyn ::tycho_vm::DumpOutput) -> ::tycho_vm::error::DumpResult {
529 #(#arg_definitions)*
530 #dump_cond
531 __f.record_opcode(&format_args!(#fmt))
532 }
533 };
534 }
535
536 quote! {
537 fn #wrapper_func_name(st: &mut ::tycho_vm::state::VmState, args: u32) -> ::tycho_vm::error::VmResult<i32> {
538 #(#arg_definitions)*
539 #cond
540 vm_log_op!(#fmt);
541 #function_name(st, #(#arg_idents),*)
542 }
543 }
544 }
545 };
546
547 let expr_add = match ty {
548 #[cfg(feature = "dump")]
549 OpcodeTy::Simple { opcode, bits } => quote! {
550 #opcodes_arg.add_simple(#opcode, #bits, #wrapper_func_name, #dump_func_name)
551 },
552 #[cfg(not(feature = "dump"))]
553 OpcodeTy::Simple { opcode, bits } => quote! {
554 #opcodes_arg.add_simple(#opcode, #bits, #wrapper_func_name)
555 },
556 #[cfg(feature = "dump")]
557 OpcodeTy::Fixed {
558 opcode,
559 opcode_bits,
560 arg_bits,
561 } => quote! {
562 #opcodes_arg.add_fixed(
563 #opcode,
564 #opcode_bits,
565 #arg_bits,
566 #wrapper_func_name,
567 #dump_func_name,
568 )
569 },
570 #[cfg(not(feature = "dump"))]
571 OpcodeTy::Fixed {
572 opcode,
573 opcode_bits,
574 arg_bits,
575 } => quote! {
576 #opcodes_arg.add_fixed(#opcode, #opcode_bits, #arg_bits, #wrapper_func_name)
577 },
578 #[cfg(feature = "dump")]
579 OpcodeTy::FixedRange {
580 opcode_min,
581 opcode_max,
582 total_bits,
583 arg_bits,
584 } => quote! {
585 #opcodes_arg.add_fixed_range(
586 #opcode_min,
587 #opcode_max,
588 #total_bits,
589 #arg_bits,
590 #wrapper_func_name,
591 #dump_func_name,
592 )
593 },
594 #[cfg(not(feature = "dump"))]
595 OpcodeTy::FixedRange {
596 opcode_min,
597 opcode_max,
598 total_bits,
599 arg_bits,
600 } => quote! {
601 #opcodes_arg.add_fixed_range(
602 #opcode_min,
603 #opcode_max,
604 #total_bits,
605 #arg_bits,
606 #wrapper_func_name
607 )
608 },
609 };
610
611 #[cfg(feature = "dump")]
612 {
613 Ok(syn::parse_quote! {{
614 #dump_func
615 #wrapper_func
616 #expr_add?;
617 }})
618 }
619
620 #[cfg(not(feature = "dump"))]
621 {
622 Ok(syn::parse_quote! {{
623 #wrapper_func
624 #expr_add?;
625 }})
626 }
627}
628
629fn process_ext_instr_definition(
630 function: &syn::ImplItemFn,
631 opcodes_arg: &syn::Ident,
632 attr: &syn::Attribute,
633) -> Result<syn::Expr, Error> {
634 let VmExtInstrArgs {
635 code,
636 code_bits,
637 arg_bits,
638 dump_with,
639 } = <_>::from_meta(&attr.meta)?;
640
641 let function_name = &function.sig.ident;
642
643 #[cfg(feature = "dump")]
644 {
645 Ok(syn::parse_quote!({
646 #opcodes_arg.add_ext(#code, #code_bits, #arg_bits, #function_name, #dump_with)?;
647 }))
648 }
649
650 #[cfg(not(feature = "dump"))]
651 {
652 _ = dump_with;
653
654 Ok(syn::parse_quote!({
655 #opcodes_arg.add_ext(#code, #code_bits, #arg_bits, #function_name)?;
656 }))
657 }
658}
659
660fn process_ext_range_instr_definition(
661 function: &syn::ImplItemFn,
662 opcodes_arg: &syn::Ident,
663 attr: &syn::Attribute,
664) -> Result<syn::Expr, Error> {
665 let VmExtRangeInstrArgs {
666 code_min,
667 code_max,
668 total_bits,
669 dump_with,
670 } = <_>::from_meta(&attr.meta)?;
671
672 let function_name = &function.sig.ident;
673
674 #[cfg(feature = "dump")]
675 {
676 Ok(syn::parse_quote!({
677 #opcodes_arg.add_ext_range(#code_min, #code_max, #total_bits, #function_name, #dump_with)?;
678 }))
679 }
680
681 #[cfg(not(feature = "dump"))]
682 {
683 _ = dump_with;
684
685 Ok(syn::parse_quote!({
686 #opcodes_arg.add_ext_range(#code_min, #code_max, #total_bits, #function_name)?;
687 }))
688 }
689}
690
691enum OpcodeTy {
692 Simple {
693 opcode: u32,
694 bits: u16,
695 },
696 Fixed {
697 opcode: u32,
698 opcode_bits: u16,
699 arg_bits: u16,
700 },
701 FixedRange {
702 opcode_min: u32,
703 opcode_max: u32,
704 total_bits: u16,
705 arg_bits: u16,
706 },
707}
708
709struct OpcodeRange {
710 span: proc_macro2::Span,
711 aligned_opcode_min: u32,
712 aligned_opcode_max: u32,
713 total_bits: u16,
714}
715
716#[derive(Default)]
717struct Opcodes {
718 opcodes: BTreeMap<u32, OpcodeRange>,
719}
720
721impl Opcodes {
722 fn add_opcode(&mut self, range: OpcodeRange) -> Result<(), Error> {
723 assert!(range.aligned_opcode_min < range.aligned_opcode_max);
724 assert!(range.aligned_opcode_max <= MAX_OPCODE);
725
726 if let Some((other_min, other)) = self.opcodes.range(range.aligned_opcode_min..).next()
727 && range.aligned_opcode_max > *other_min
728 {
729 let shift = MAX_OPCODE_BITS - other.total_bits as usize;
730 let other_min = other.aligned_opcode_min >> shift;
731 let other_max = other.aligned_opcode_max >> shift;
732 let n = other.total_bits as usize / 4;
733
734 return Err(Error::custom(format!(
735 "Opcode overlaps with the start of the range of another opcode: \
736 {other_min:0n$x}..{other_max:0n$x}"
737 ))
738 .with_span(&range.span));
739 }
740
741 if let Some((k, prev)) = self.opcodes.range(..=range.aligned_opcode_min).next_back() {
742 debug_assert!(prev.aligned_opcode_min < prev.aligned_opcode_max);
743 debug_assert!(prev.aligned_opcode_min == *k);
744 if range.aligned_opcode_min < prev.aligned_opcode_max {
745 let shift = MAX_OPCODE_BITS - prev.total_bits as usize;
746 let prev_min = prev.aligned_opcode_min >> shift;
747 let prev_max = prev.aligned_opcode_max >> shift;
748 let n = prev.total_bits as usize / 4;
749
750 return Err(Error::custom(format!(
751 "Opcode overlaps with the end of the range of another opcode: \
752 {prev_min:0n$x}..{prev_max:0n$x}"
753 ))
754 .with_span(&range.span));
755 }
756 }
757
758 self.opcodes.insert(range.aligned_opcode_min, range);
759 Ok(())
760 }
761}
762
763const MAX_OPCODE_BITS: usize = 24;
764const MAX_OPCODE: u32 = 1 << MAX_OPCODE_BITS;