sync_lsp_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::{TokenStream, Span};
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::token::{Paren, Brace, Colon, Bracket, Semi, Eq};
7use quote::{quote, ToTokens};
8use syn::{
9    parse_macro_input, DeriveInput, Data, Fields, Ident, Meta, Expr, Lit, ExprLit,
10    ExprAssign, ExprPath, Generics, Variant, Result, Error, PathSegment,
11    FieldsUnnamed, Visibility, Pat, PatStruct, Path, PatTupleStruct, PatIdent,
12    FieldPat, Member, Field, Type, TypeArray, LitInt, FieldMutability, PatSlice,
13    ItemImpl, ImplItemType, TypeNever, TypePath, ImplItem
14};
15
16/// This macro provides default implementations for all required types in `TypeProvider`.
17/// 
18/// # Example
19/// ```
20/// use sync_lsp::{TypeProvider, type_provider};
21/// 
22/// struct MyServerState;
23/// 
24/// #[type_provider]
25/// impl TypeProvider for MyServerState {
26///     type ShowMessageRequestData = u32;
27///     // All other types will be set to `Option<()>`
28/// }
29/// ```
30#[proc_macro_attribute]
31pub fn type_provider(args: TokenStream, input: TokenStream) -> TokenStream {
32    if !args.is_empty() {
33        return Error::new(
34            Span::call_site().into(),
35            "The type_provider attribute does not take any arguments"
36        ).to_compile_error().into()
37    }
38
39    let mut input = parse_macro_input!(input as ItemImpl);
40    let mut types = Vec::new();
41
42    let unit = Type::Verbatim(quote! {
43        std::option::Option::<()>
44    });
45
46    let default = ImplItemType {
47        attrs: Vec::new(),
48        vis: Visibility::Inherited,
49        defaultness: None,
50        type_token: Default::default(),
51        ident: Ident::new("Default", Span::call_site().into()),
52        eq_token: Eq::default(),
53        ty: Type::Never(TypeNever {
54            bang_token: Default::default()
55        }),
56        semi_token: Semi::default(),
57        generics: Generics::default()
58    };
59
60    types.push(ImplItem::Type(ImplItemType {
61        ident: Ident::new("Command", Span::call_site().into()),
62        ty: Type::Path(TypePath {
63            qself: None,
64            path: {
65                let mut unit = Path {
66                    leading_colon: Some(Default::default()),
67                    segments: Punctuated::new()
68                };
69
70                unit.segments.push(PathSegment::from(Ident::new("sync_lsp", Span::call_site().into())));
71                unit.segments.push(PathSegment::from(Ident::new("workspace", Span::call_site().into())));
72                unit.segments.push(PathSegment::from(Ident::new("execute_command", Span::call_site().into())));
73                unit.segments.push(PathSegment::from(Ident::new("UnitCommand", Span::call_site().into())));
74                unit
75            }
76        }),
77        ..default.clone()
78    }));
79
80    types.push(ImplItem::Type(ImplItemType {
81        ident: Ident::new("CodeLensData", Span::call_site().into()),
82        ty: unit.clone(),
83        ..default.clone()
84    }));
85
86    types.push(ImplItem::Type(ImplItemType {
87        ident: Ident::new("CompletionData", Span::call_site().into()),
88        ty: unit.clone(),
89        ..default.clone()
90    }));
91
92    types.push(ImplItem::Type(ImplItemType {
93        ident: Ident::new("Configuration", Span::call_site().into()),
94        ty: unit.clone(),
95        ..default.clone()
96    }));
97
98    types.push(ImplItem::Type(ImplItemType {
99        ident: Ident::new("InitializeOptions", Span::call_site().into()),
100        ty: unit.clone(),
101        ..default.clone()
102    }));
103
104    types.push(ImplItem::Type(ImplItemType {
105        ident: Ident::new("ShowMessageRequestData", Span::call_site().into()),
106        ty: unit.clone(),
107        ..default.clone()
108    }));
109
110    types.push(ImplItem::Type(ImplItemType {
111        ident: Ident::new("ApplyEditData", Span::call_site().into()),
112        ty: unit.clone(),
113        ..default.clone()
114    }));
115
116    for item in input.items.iter() {
117        let ImplItem::Type(item) = item else { continue };
118        types.retain(|r#type| {
119            let ImplItem::Type(r#type) = r#type else { return false };
120            r#type.ident != item.ident
121        });
122    }
123
124    input.items.extend_from_slice(&types);
125
126    input.into_token_stream().into()
127}
128
129/// This macro implements the [`Command`] trait for a given type.
130/// the `#[command(title = "...")]` attribute can be used to define the title of the command
131/// on enum variants or structs.
132/// 
133/// # Example
134/// ```
135/// use sync_lsp::workspace::execute_command::Command;
136/// 
137/// #[derive(Clone, Command)]
138/// #[command(title = "My command without variants or arguments")]
139/// struct MyCommand;
140/// ```
141/// ```
142/// use sync_lsp::workspace::execute_command::Command;
143/// 
144/// #[derive(Clone, Command)]
145/// enum MyCommand {
146///     #[command(title = "My first command")]
147///     MyCommand,
148///     #[command(title = "My command with arguments")]
149///     MyCommandWithArguments(u32),
150/// }
151/// ```
152#[proc_macro_derive(Command, attributes(command))]
153pub fn command(input: TokenStream) -> TokenStream {
154    let input = parse_macro_input!(input as DeriveInput);
155    match input.data {
156        Data::Struct(data) => {
157            let variant = Variant {
158                attrs: input.attrs,
159                ident: Ident::new("Target", Span::call_site().into()),
160                fields: data.fields,
161                discriminant: None
162            };
163            impl_enum(input.ident, input.generics, vec![variant], None)
164        },
165        Data::Enum(data) => {
166            for attr in input.attrs.iter() {
167                let Meta::List(list) = &attr.meta else { continue };
168                if list.path.is_ident("command") {
169                    return Error::new(
170                        attr.span(), 
171                        "The command attribute is not supported in this position"
172                    ).to_compile_error().into()
173                }
174            }
175
176            let segment = PathSegment::from(Ident::new("Target", Span::call_site().into()));
177            impl_enum(input.ident, input.generics, data.variants.into_iter().collect(), Some(segment))
178        },
179        Data::Union(..) => panic!("Command macro cannot be implemented on unions"),
180    }
181        .unwrap_or_else(|err| err.to_compile_error().into())
182        .into()
183}
184
185fn impl_enum(ident: Ident, generics: Generics, variants: Vec<Variant>, segment: Option<PathSegment>) -> Result<TokenStream> {
186    let arguments = Ident::new("Arguments", Span::call_site().into());
187    let module = Ident::new(&format!("command_{}", ident.to_string().to_lowercase()), ident.span());
188    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
189    let where_clause_extension = where_clause.map(|where_clause| &where_clause.predicates);
190
191    let commands = get_commands(variants.iter());
192    let argument_variants = get_argument_variants(variants.iter());
193    let variant_patterns = get_variant_patterns(segment, variants.iter(), false);
194    let argument_patterns = get_variant_patterns(Some(PathSegment::from(arguments.clone())), argument_variants.iter(), true);
195    let titles = variant_titles(variants.iter())?;
196
197    Ok(quote!(
198        #[allow(non_camel_case_types)]
199        mod #module {
200            use super::#ident as Target;
201            use sync_lsp::workspace::execute_command::Command;
202            use serde::{Serialize, Deserialize};
203            use serde::de::{Deserializer, DeserializeOwned};
204            use serde::ser::Serializer;
205
206            impl #impl_generics Command for Target #ty_generics where Self: Clone, #arguments #ty_generics: Serialize + DeserializeOwned + Clone, #where_clause_extension {
207                fn commands() -> Vec<String> {
208                    let mut commands = Vec::new();
209                    #(commands.push(#commands.to_string());)*
210                    commands
211                }
212
213                fn serialize<__S__: Serializer>(&self, serializer: __S__) -> Result<__S__::Ok, __S__::Error> {
214                    Title {
215                        title: match &self {
216                            #(#variant_patterns => #titles.to_string(),)*
217                        },
218                        arguments: <#arguments #ty_generics as From<_>>::from(self.clone())
219                    }.serialize(serializer)
220                }
221
222                fn deserialize<'__de__, __D__: Deserializer<'__de__>>(deserializer: __D__) -> Result<Self, __D__::Error> {
223                    <#arguments #ty_generics as Deserialize>::deserialize(deserializer).map(Into::into)
224                }
225            }
226
227            impl #impl_generics From<Target #ty_generics> for #arguments #ty_generics #where_clause {
228                fn from(command: Target #ty_generics) -> Self {
229                    match command {
230                        #(#variant_patterns => #argument_patterns,)*
231                    }
232                }
233            }
234
235            impl #impl_generics Into<Target #ty_generics> for #arguments #ty_generics #where_clause {
236                fn into(self) -> Target #ty_generics {
237                    match self {
238                        #(#argument_patterns => #variant_patterns,)*
239                    }
240                }
241            }
242
243            #[derive(Serialize)]
244            struct Title<T> {
245                title: String,
246                #[serde(flatten)]
247                arguments: T
248            }
249
250            #[derive(Serialize, Deserialize, Clone)]
251            #[serde(tag = "command", content = "arguments")]
252            enum #arguments #generics {
253                #(#argument_variants,)*
254            }
255        }
256    ).into())
257}
258
259fn get_variant_patterns<'a>(segment: Option<PathSegment>, variants: impl Iterator<Item = &'a Variant>, cast_unit: bool) -> Vec<Pat> {
260    let mut patterns = Vec::new();
261    
262    for variant in variants {
263
264        let path = if let Some(segment) = segment.clone() {
265            let mut path = Path::from(segment);
266            path.segments.push(PathSegment::from(variant.ident.clone()));
267            path
268        } else {
269            Path::from(PathSegment::from(variant.ident.clone()))
270        };
271
272
273        match variant.fields {
274            Fields::Named(ref named) => {
275                let mut fields = Punctuated::new();
276
277                for (index, field) in named.named.iter().enumerate() {
278                    fields.push(FieldPat {
279                        attrs: Vec::new(),
280                        member: Member::Named(field.ident.clone().unwrap()),
281                        colon_token: Some(Colon::default()),
282                        pat: Box::new(Pat::Ident(PatIdent {
283                            attrs: Vec::new(),
284                            by_ref: None,
285                            mutability: None,
286                            ident: Ident::new(&format!("m{index}"), Span::call_site().into()),
287                            subpat: None
288                        }))
289                    })
290                }
291
292                patterns.push(Pat::Struct(PatStruct {
293                    attrs: Vec::new(),
294                    qself: None,
295                    path,
296                    brace_token: Brace::default(),
297                    fields,
298                    rest: None
299                }))
300            },
301            Fields::Unnamed(ref unnamed) => {
302                let mut elems = Punctuated::new();
303
304                for index in 0..unnamed.unnamed.len() {
305                    elems.push(Pat::Ident(PatIdent {
306                        attrs: Vec::new(),
307                        by_ref: None,
308                        mutability: None,
309                        ident: Ident::new(&format!("m{index}"), Span::call_site().into()),
310                        subpat: None
311                    }));
312                }
313                
314                if elems.len() == 1 && cast_unit {
315                    let mut puncutated = Punctuated::new();
316                    puncutated.push(elems.pop().unwrap().into_value());
317                    elems.push(Pat::Slice(PatSlice {
318                        attrs: Vec::new(),
319                        bracket_token: Bracket::default(),
320                        elems: puncutated
321                    }))
322                }
323
324                patterns.push(Pat::TupleStruct(PatTupleStruct {
325                    attrs: Vec::new(),
326                    qself: None,
327                    path,
328                    paren_token: Paren::default(),
329                    elems
330                }))
331            },
332            Fields::Unit => {
333                patterns.push(Pat::Struct(PatStruct {
334                    attrs: Vec::new(),
335                    qself: None,
336                    path,
337                    brace_token: Brace::default(),
338                    fields: Punctuated::new(),
339                    rest: None
340                }))
341            }
342        }
343    }
344
345    patterns
346}
347
348fn get_argument_variants<'a>(variants: impl Iterator<Item = &'a Variant>) -> Vec<Variant> {
349    variants.map(|variant| {
350        let mut fields = variant.fields.clone();
351
352        if let Fields::Named(named) = fields {
353            fields = Fields::Unnamed(FieldsUnnamed {
354                paren_token: Paren::default(),
355                unnamed: named.named
356            });
357        }
358
359        //Deserialisation will fail if the client omits arguments for a command with no arguments
360        if let Fields::Unit = fields {
361            fields = Fields::Unnamed(FieldsUnnamed {
362                paren_token: Paren::default(),
363                unnamed: Punctuated::new()
364            });
365        }
366
367        //Cast commands with a single argument to an array
368        if fields.len() == 1 {
369            if let Fields::Unnamed(unnamed) = &mut fields {
370                let first = unnamed.unnamed.pop().unwrap().into_value();
371                unnamed.unnamed.push(Field {
372                    attrs: Vec::new(),
373                    vis: Visibility::Inherited,
374                    mutability: FieldMutability::None,
375                    ident: None,
376                    colon_token: None,
377                    ty: Type::Array(TypeArray {
378                        bracket_token: Bracket::default(),
379                        elem: Box::new(first.ty),
380                        semi_token: Semi::default(),
381                        len: Expr::Lit(ExprLit {
382                            attrs: Vec::new(),
383                            lit: Lit::Int(LitInt::new("1", Span::call_site().into()))
384                        })
385                    })
386                })
387            }
388        }
389
390        for field in fields.iter_mut() {
391            field.attrs = Vec::new();
392            field.ident = None;
393            field.vis = Visibility::Inherited;
394        }
395
396        Variant {
397            attrs: Vec::new(),
398            ident: variant.ident.clone(),
399            fields,
400            discriminant: None
401        }
402    }).collect()
403}
404
405fn get_commands<'a>(variants: impl Iterator<Item = &'a Variant>) -> Vec<String> {
406    variants.map(|variant| variant.ident.to_string()).collect()
407}
408
409fn variant_titles<'a>(variants: impl Iterator<Item = &'a Variant>) -> Result<Vec<String>> {
410    let mut titles = Vec::new();
411    
412    for variant in variants {
413        let mut title = variant.ident.to_string();
414
415        for field in variant.fields.iter() {
416            for attr in field.attrs.iter() {
417                let Meta::List(list) = &attr.meta else { continue };
418                if list.path.is_ident("command") {
419                    return Err(Error::new(
420                        attr.span(), 
421                        "The command attribute is not supported on fields"
422                    ))
423                }
424            }
425        }
426
427        for attribute in variant.attrs.iter() {
428            let span = attribute.span();
429            let Meta::List(list) = &attribute.meta else { continue };
430            if !list.path.is_ident("command") { continue };
431            let pair: ExprAssign = list.parse_args()?;
432
433            let Expr::Path(ExprPath { path, .. }) = *pair.left else {
434                return Err(Error::new(
435                    span, 
436                    "Expected a path for the left side of the command attribute")
437                );
438            };
439
440            if !path.is_ident("title") {
441                return Err(Error::new(
442                    span, 
443                    "The command attribute only supports the title field")
444                );
445            };
446
447            let Expr::Lit(ExprLit { lit: Lit::Str(literal), .. }) = *pair.right else {
448                return Err(Error::new(
449                    span, 
450                    "Expected a string literal for the right side of the command attribute")
451                );
452            };
453
454            title = literal.value();
455        }
456
457        titles.push(title);
458    }
459
460    Ok(titles)
461}