yazi_codegen/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Attribute, Data, DeriveInput, Fields, FnArg, ItemFn, ext::IdentExt, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn command(_: TokenStream, item: TokenStream) -> TokenStream {
7	let mut f: ItemFn = syn::parse(item).unwrap();
8	let mut ins = f.sig.inputs.clone();
9
10	// Turn `opt: Opt` into `opt: impl Into<Opt>`
11	ins[1] = {
12		let FnArg::Typed(opt) = &f.sig.inputs[1] else {
13			panic!("Cannot find the `opt` argument in the function signature.");
14		};
15
16		let opt_ty = &opt.ty;
17		syn::parse2(quote! { opt: impl Into<#opt_ty> }).unwrap()
18	};
19
20	// Make the original function private and add a public wrapper
21	assert!(matches!(f.vis, syn::Visibility::Public(_)));
22	f.vis = syn::Visibility::Inherited;
23
24	// Add `__` prefix to the original function name
25	let name_ori = f.sig.ident;
26	f.sig.ident = format_ident!("__{}", name_ori.unraw());
27	let name_new = &f.sig.ident;
28
29	// Collect the rest of the arguments
30	let rest_args = ins.iter().skip(2).map(|arg| match arg {
31		FnArg::Receiver(_) => unreachable!(),
32		FnArg::Typed(t) => &t.pat,
33	});
34
35	quote! {
36		#[inline]
37		pub fn #name_ori(#ins) { self.#name_new(opt.into(), #(#rest_args),*); }
38		#f
39	}
40	.into()
41}
42
43#[proc_macro_derive(DeserializeOver1)]
44pub fn deserialize_over1(input: TokenStream) -> TokenStream {
45	// Parse the input tokens into a syntax tree
46	let input = parse_macro_input!(input as DeriveInput);
47
48	// Get the name of the struct
49	let name = &input.ident;
50	let shadow_name = format_ident!("__{name}Shadow");
51
52	// Process the struct fields
53	let (shadow_fields, field_calls) = match &input.data {
54		Data::Struct(struct_) => match &struct_.fields {
55			Fields::Named(fields) => {
56				let mut shadow_fields = Vec::with_capacity(fields.named.len());
57				let mut field_calls = Vec::with_capacity(fields.named.len());
58
59				for field in &fields.named {
60					let name = &field.ident;
61					let attrs: Vec<&Attribute> =
62						field.attrs.iter().filter(|&a| a.path().is_ident("serde")).collect();
63
64					shadow_fields.push(quote! {
65							#(#attrs)*
66							pub(crate) #name: Option<toml::Value>
67					});
68					field_calls.push(quote! {
69						if let Some(value) = shadow.#name {
70							self.#name = self.#name.deserialize_over(value).map_err(serde::de::Error::custom)?;
71						}
72					});
73				}
74
75				(shadow_fields, field_calls)
76			}
77			_ => panic!("DeserializeOver1 only supports structs with named fields"),
78		},
79		_ => panic!("DeserializeOver1 only supports structs"),
80	};
81
82	quote! {
83		#[derive(serde::Deserialize)]
84		pub(crate) struct #shadow_name {
85			#(#shadow_fields),*
86		}
87
88		impl #name {
89			#[inline]
90			pub(crate) fn deserialize_over<'de, D>(self, deserializer: D) -> Result<Self, D::Error>
91			where
92				D: serde::Deserializer<'de>,
93			{
94				self.deserialize_over_with::<D>(Self::deserialize_shadow(deserializer)?)
95			}
96
97			#[inline]
98			pub(crate) fn deserialize_shadow<'de, D>(deserializer: D) -> Result<#shadow_name, D::Error>
99			where
100				D: serde::Deserializer<'de>,
101			{
102				#shadow_name::deserialize(deserializer)
103			}
104
105			#[inline]
106			pub(crate) fn deserialize_over_with<'de, D>(mut self, shadow: #shadow_name) -> Result<Self, D::Error>
107			where
108				D: serde::Deserializer<'de>,
109			{
110				#(#field_calls)*
111				Ok(self)
112			}
113		}
114	}
115	.into()
116}
117
118#[proc_macro_derive(DeserializeOver2)]
119pub fn deserialize_over2(input: TokenStream) -> TokenStream {
120	// Parse the input tokens into a syntax tree
121	let input = parse_macro_input!(input as DeriveInput);
122
123	// Get the name of the struct
124	let name = &input.ident;
125	let shadow_name = format_ident!("__{name}Shadow");
126
127	// Process the struct fields
128	let (shadow_fields, field_assignments) = match &input.data {
129		Data::Struct(struct_) => match &struct_.fields {
130			Fields::Named(fields) => {
131				let mut shadow_fields = Vec::with_capacity(fields.named.len());
132				let mut field_assignments = Vec::with_capacity(fields.named.len());
133
134				for field in &fields.named {
135					let (ty, name) = (&field.ty, &field.ident);
136					shadow_fields.push(quote! {
137						pub(crate) #name: Option<#ty>
138					});
139					field_assignments.push(quote! {
140						if let Some(value) = shadow.#name {
141							self.#name = value;
142						}
143					});
144				}
145
146				(shadow_fields, field_assignments)
147			}
148			_ => panic!("DeserializeOver2 only supports structs with named fields"),
149		},
150		_ => panic!("DeserializeOver2 only supports structs"),
151	};
152
153	quote! {
154		#[derive(serde::Deserialize)]
155		pub(crate) struct #shadow_name {
156			#(#shadow_fields),*
157		}
158
159		impl #name {
160			#[inline]
161			pub(crate) fn deserialize_over<'de, D>(mut self, deserializer: D) -> Result<Self, D::Error>
162			where
163				D: serde::Deserializer<'de>
164			{
165				Ok(self.deserialize_over_with(Self::deserialize_shadow(deserializer)?))
166			}
167
168			#[inline]
169			pub(crate) fn deserialize_shadow<'de, D>(deserializer: D) -> Result<#shadow_name, D::Error>
170			where
171				D: serde::Deserializer<'de>
172			{
173				#shadow_name::deserialize(deserializer)
174			}
175
176			#[inline]
177			pub(crate) fn deserialize_over_with(mut self, shadow: #shadow_name) -> Self {
178				#(#field_assignments)*
179				self
180			}
181		}
182	}
183	.into()
184}