1use proc_macro2::Span;
2use quote::format_ident;
3use syn::{
4 ext::IdentExt, parse_quote, Attribute, FnArg::Typed, Pat::Ident, PatType, Signature, Stmt,
5};
6
7use crate::utils::extract_docs;
8
9#[derive(Clone, PartialEq)]
10enum SavvyInputTypeCategory {
11 Sexp,
12 SexpWrapper,
13 PrimitiveType,
14 UserDefinedTypeRef, UserDefinedType, DllInfo,
17}
18
19#[derive(Clone)]
20struct SavvyInputType {
21 category: SavvyInputTypeCategory,
22 ty_orig: syn::Type,
23 ty_str: String,
24 optional: bool,
25}
26
27#[allow(dead_code)]
28impl SavvyInputType {
29 fn from_type(ty: &syn::Type, in_option: bool) -> syn::Result<Self> {
30 match &ty {
31 syn::Type::Reference(syn::TypeReference { elem, .. }) => {
34 if let syn::Type::Path(type_path) = elem.as_ref() {
35 let ty_str = type_path.path.segments.last().unwrap().ident.to_string();
36 if &ty_str == "str" {
37 Ok(Self {
38 category: SavvyInputTypeCategory::PrimitiveType,
39 ty_orig: ty.clone(),
40 ty_str: "&str".to_string(),
41 optional: in_option,
42 })
43 } else {
44 Ok(Self {
45 category: SavvyInputTypeCategory::UserDefinedTypeRef,
46 ty_orig: ty.clone(),
47 ty_str,
48 optional: in_option,
49 })
50 }
51 } else {
52 Err(syn::Error::new_spanned(
53 ty.clone(),
54 "Unexpected type specification: {:?}",
55 ))
56 }
57 }
58
59 syn::Type::Path(type_path) => {
60 let type_path_last = type_path.path.segments.last().unwrap();
61 let type_ident = &type_path_last.ident;
62 let ty_str = type_ident.to_string();
63 match ty_str.as_str() {
64 "Option" => {
65 if in_option {
66 return Err(syn::Error::new_spanned(
67 type_path,
68 "`Option` cannot be nested",
69 ));
70 }
71
72 if let syn::PathArguments::AngleBracketed(
73 syn::AngleBracketedGenericArguments { args, .. },
74 ) = &type_path_last.arguments
75 {
76 if args.len() == 1 {
77 if let syn::GenericArgument::Type(ty) = &args.first().unwrap() {
78 return Self::from_type(ty, true);
79 }
80 }
81 }
82
83 Err(syn::Error::new_spanned(
84 type_path,
85 "Option<T> can accept only a type",
86 ))
87 }
88
89 "OwnedIntegerSexp" | "OwnedRealSexp" | "OwnedComplexSexp"
91 | "OwnedLogicalSexp" | "OwnedRawSexp" | "OwnedStringSexp" | "OwnedListSexp" => {
92 let msg = format!(
93 "`Owned-` types are not allowed here. Did you mean `{}`?",
94 ty_str.strip_prefix("Owned").unwrap()
95 );
96 Err(syn::Error::new_spanned(type_path, msg))
97 }
98
99 "Sexp" => Ok(Self {
102 category: SavvyInputTypeCategory::Sexp,
103 ty_orig: ty.clone(),
104 ty_str,
105 optional: in_option,
106 }),
107
108 "IntegerSexp" | "RealSexp" | "NumericSexp" | "ComplexSexp" | "LogicalSexp"
110 | "RawSexp" | "StringSexp" | "ListSexp" | "FunctionSexp"
111 | "EnvironmentSexp" => Ok(Self {
112 category: SavvyInputTypeCategory::SexpWrapper,
113 ty_orig: ty.clone(),
114 ty_str,
115 optional: in_option,
116 }),
117
118 "i32" | "usize" | "f64" | "bool" | "u8" | "NumericScalar" => Ok(Self {
120 category: SavvyInputTypeCategory::PrimitiveType,
121 ty_orig: ty.clone(),
122 ty_str,
123 optional: in_option,
124 }),
125
126 "DllInfo" => Err(syn::Error::new_spanned(
127 type_path,
128 "DllInfo must be `*mut DllInfo`",
129 )),
130
131 _ => Ok(Self {
132 category: SavvyInputTypeCategory::UserDefinedType,
133 ty_orig: ty.clone(),
134 ty_str,
135 optional: in_option,
136 }),
137 }
138 }
139
140 syn::Type::Ptr(syn::TypePtr {
142 mutability, elem, ..
143 }) => {
144 let type_ident = if let syn::Type::Path(p) = elem.as_ref() {
145 p.path.segments.last().unwrap().ident.to_string()
146 } else {
147 "".to_string()
148 };
149
150 if &type_ident != "DllInfo" {
151 return Err(syn::Error::new_spanned(
152 ty.clone(),
153 "Unexpected type specification: {:?}",
154 ));
155 }
156
157 if mutability.is_none() {
158 return Err(syn::Error::new_spanned(
159 ty.clone(),
160 "DllInfo must be `*mut DllInfo`",
161 ));
162 }
163
164 Ok(Self {
165 category: SavvyInputTypeCategory::DllInfo,
166 ty_orig: ty.clone(),
167 ty_str: type_ident.to_string(),
168 optional: in_option,
169 })
170 }
171
172 _ => Err(syn::Error::new_spanned(
173 ty.clone(),
174 "Unexpected type specification: {:?}",
175 )),
176 }
177 }
178
179 fn to_rust_type_outer(&self) -> syn::Type {
181 self.ty_orig.clone()
182 }
183
184 fn to_rust_type_inner(&self) -> syn::Type {
186 if matches!(self.category, SavvyInputTypeCategory::DllInfo) {
187 self.ty_orig.clone()
188 } else {
189 parse_quote!(savvy::ffi::SEXP)
190 }
191 }
192
193 fn to_c_type(&self) -> String {
195 if matches!(self.category, SavvyInputTypeCategory::DllInfo) {
196 "DllInfo*".to_string()
197 } else {
198 "SEXP".to_string()
199 }
200 }
201}
202
203#[derive(Clone)]
204pub struct SavvyFnArg {
205 pub(crate) pat: syn::Ident,
206 ty: SavvyInputType,
207}
208
209impl SavvyFnArg {
210 pub fn pat(&self) -> syn::Ident {
211 self.pat.clone()
212 }
213
214 pub fn is_user_defined_type(&self) -> bool {
215 matches!(
216 &self.ty.category,
217 SavvyInputTypeCategory::UserDefinedTypeRef | SavvyInputTypeCategory::UserDefinedType
218 )
219 }
220
221 pub fn ty_string(&self) -> String {
222 self.ty.ty_str.clone()
223 }
224
225 pub fn is_optional(&self) -> bool {
226 self.ty.optional
227 }
228
229 pub fn to_c_type_string(&self) -> String {
230 self.ty.to_c_type()
231 }
232
233 pub fn to_rust_type_outer(&self) -> syn::Type {
234 self.ty.to_rust_type_outer()
235 }
236
237 pub fn to_rust_type_inner(&self) -> syn::Type {
238 self.ty.to_rust_type_inner()
239 }
240}
241
242impl PartialEq for SavvyFnArg {
243 fn eq(&self, other: &Self) -> bool {
244 self.pat == other.pat
245 && self.ty.category == other.ty.category
246 && self.ty.ty_str == other.ty.ty_str
247 && self.ty.optional == other.ty.optional
248 }
249}
250
251#[derive(Clone)]
257pub struct UserDefinedStructReturnType {
258 pub(crate) ty: syn::Ident,
259 pub(crate) return_type: syn::ReturnType,
260 pub(crate) wrapped_with_result: bool,
261}
262
263#[derive(Clone)]
272pub enum SavvyFnReturnType {
273 Sexp(syn::ReturnType),
274 Unit(syn::ReturnType),
275 UserDefinedStruct(UserDefinedStructReturnType),
276}
277
278impl SavvyFnReturnType {
279 pub fn inner(&self) -> &syn::ReturnType {
280 match self {
281 SavvyFnReturnType::Sexp(ret_ty) => ret_ty,
282 SavvyFnReturnType::Unit(ret_ty) => ret_ty,
283 SavvyFnReturnType::UserDefinedStruct(ret_ty) => &ret_ty.return_type,
284 }
285 }
286}
287
288#[derive(Clone)]
289pub enum SavvyFnType {
290 BareFunction,
292 Method {
295 ty: syn::Type,
296 reference: bool,
297 mutability: bool,
298 },
299 AssociatedFunction(syn::Type),
302 InitFunction,
304}
305
306#[derive(Clone)]
307pub struct SavvyFn {
308 pub docs: Vec<String>,
310 pub attrs: Vec<syn::Attribute>,
312 pub fn_name: syn::Ident,
314 pub fn_type: SavvyFnType,
316 pub args: Vec<SavvyFnArg>,
318 pub return_type: SavvyFnReturnType,
320 pub stmts_additional: Vec<syn::Stmt>,
322}
323
324#[allow(dead_code)]
325impl SavvyFn {
326 pub(crate) fn get_self_ty_ident(&self) -> Option<syn::Ident> {
327 let self_ty = match &self.fn_type {
328 SavvyFnType::Method { ty, .. } => ty,
329 SavvyFnType::AssociatedFunction(ty) => ty,
330 _ => return None,
331 };
332 if let syn::Type::Path(type_path) = self_ty {
333 let ty = type_path
334 .path
335 .segments
336 .last()
337 .expect("Unexpected type path")
338 .ident
339 .clone();
340 Some(ty)
341 } else {
342 panic!("Unexpected self type!")
343 }
344 }
345
346 pub fn fn_name_inner(&self) -> syn::Ident {
347 match self.get_self_ty_ident() {
348 Some(ty) => format_ident!("savvy_{}_{}_inner", ty, self.fn_name),
349 None => format_ident!("savvy_{}_inner", self.fn_name),
350 }
351 }
352
353 pub fn fn_name_c_header(&self) -> syn::Ident {
355 match self.get_self_ty_ident() {
356 Some(ty) => format_ident!("savvy_{}_{}__ffi", ty, self.fn_name),
357 None => format_ident!("savvy_{}__ffi", self.fn_name),
358 }
359 }
360
361 pub fn fn_name_c_impl(&self) -> syn::Ident {
363 match self.get_self_ty_ident() {
364 Some(ty) => format_ident!("savvy_{}_{}__impl", ty, self.fn_name),
365 None => format_ident!("savvy_{}__impl", self.fn_name),
366 }
367 }
368
369 pub fn from_fn(orig: &syn::ItemFn, as_init_fn: bool) -> syn::Result<Self> {
370 let fn_type = if as_init_fn {
371 SavvyFnType::InitFunction
372 } else {
373 SavvyFnType::BareFunction
374 };
375 Self::new(&orig.attrs, &orig.sig, fn_type, None)
376 }
377
378 pub fn from_impl_fn(
379 orig: &syn::ImplItemFn,
380 fn_type: SavvyFnType,
381 self_ty: &syn::Type,
382 ) -> syn::Result<Self> {
383 Self::new(&orig.attrs, &orig.sig, fn_type, Some(self_ty))
384 }
385
386 pub fn new(
387 attrs: &[Attribute],
388 sig: &Signature,
389 fn_type: SavvyFnType,
390 self_ty: Option<&syn::Type>,
391 ) -> syn::Result<Self> {
392 let mut attrs = attrs.to_vec();
395 attrs.retain(|attr| {
397 !(attr == &parse_quote!(#[savvy]) || attr == &parse_quote!(#[savvy_init]))
398 });
399
400 let docs = extract_docs(attrs.as_slice());
402
403 let fn_name = sig.ident.clone();
404
405 let mut stmts_additional: Vec<Stmt> = Vec::new();
406
407 let args_new = sig
408 .inputs
409 .iter()
410 .filter_map(|arg| match arg {
411 Typed(PatType { pat, ty, .. }) => {
412 let pat = match pat.as_ref() {
413 Ident(arg) => arg.ident.clone(),
414 _ => {
415 return Some(Err(syn::Error::new_spanned(
416 pat,
417 "non-ident is not supported",
418 )));
419 }
420 };
421
422 let ty = match SavvyInputType::from_type(ty.as_ref(), false) {
423 Ok(ty) => ty,
424 Err(e) => return Some(Err(e)),
425 };
426 let ty_ident = ty.to_rust_type_outer();
427
428 match (&fn_type, &ty.category) {
429 (&SavvyFnType::InitFunction, &SavvyInputTypeCategory::DllInfo) => {}
431
432 (&SavvyFnType::InitFunction, _) => {
433 return Some(Err(syn::Error::new_spanned(
434 ty.ty_orig,
435 "#[savvy_init] can be used only on a function that takes `*mut DllInfo`",
436 )));
437 }
438
439 (_, &SavvyInputTypeCategory::DllInfo) => {
440 return Some(Err(syn::Error::new_spanned(
441 ty.ty_orig,
442 "#[savvy] doesn't accept `*mut DllInfo`. Did you mean #[savvy_init]?",
443 )));
444 }
445
446 (_, &SavvyInputTypeCategory::Sexp) => {
447 if ty.optional {
448 stmts_additional.push(parse_quote! { let #pat = savvy::Sexp(#pat); });
449 stmts_additional.push(parse_quote! {
450 let #pat = if #pat.is_null() {
451 None
452 } else {
453 Some(#pat)
454 };
455 })
456 } else {
457 stmts_additional.push(parse_quote! {
458 let #pat = savvy::Sexp(#pat);
459 });
460 }
461 }
462
463 (_, _) => {
464 let arg_lit = syn::LitStr::new(&pat.unraw().to_string(), Span::call_site());
465 if ty.optional {
466 stmts_additional.push(parse_quote! { let #pat = savvy::Sexp(#pat); });
467 stmts_additional.push(parse_quote! {
468 let #pat = if #pat.is_null() {
469 None
470 } else {
471 Some(<#ty_ident>::try_from(#pat).map_err(|e| e.with_arg_name(#arg_lit))?)
472 };
473 })
474 } else {
475 stmts_additional.push(parse_quote! {
476 let #pat = <#ty_ident>::try_from(savvy::Sexp(#pat)).map_err(|e| e.with_arg_name(#arg_lit))?;
477 });
478 }
479 }
480 }
481
482 Some(Ok(SavvyFnArg { pat, ty }))
483 }
484 syn::FnArg::Receiver(syn::Receiver { .. }) => None,
486 })
487 .collect::<syn::Result<Vec<SavvyFnArg>>>()?;
488
489 let mut args_after_optional = args_new.iter().skip_while(|x| !x.is_optional());
491 if args_after_optional.any(|x| !x.is_optional()) {
492 return Err(syn::Error::new_spanned(
493 sig.inputs.clone(),
494 "optional args can be placed only after mandatory args",
495 ));
496 }
497
498 let is_init_fn = args_new
500 .iter()
501 .any(|x| matches!(x.ty.category, SavvyInputTypeCategory::DllInfo));
502 if is_init_fn && args_new.len() > 1 {
503 return Err(syn::Error::new_spanned(
504 sig,
505 "Initialization function can accept `*mut DllInfo` only",
506 ));
507 }
508
509 let fn_type = if is_init_fn {
510 SavvyFnType::InitFunction
511 } else {
512 fn_type
513 };
514
515 Ok(Self {
516 docs,
517 attrs,
518 fn_name,
519 fn_type,
520 args: args_new,
521 return_type: get_savvy_return_type(&sig.output, self_ty)?,
522 stmts_additional,
523 })
524 }
525}
526
527fn self_ty_to_actual_ty(self_ty: Option<&syn::Type>) -> Option<syn::Ident> {
528 if let Some(syn::Type::Path(type_path)) = self_ty {
529 Some(type_path.path.segments.last().unwrap().ident.clone())
530 } else {
531 None
532 }
533}
534
535fn get_savvy_return_type(
542 return_type: &syn::ReturnType,
543 self_ty: Option<&syn::Type>,
544) -> syn::Result<SavvyFnReturnType> {
545 match return_type {
546 syn::ReturnType::Default => Err(syn::Error::new_spanned(
547 return_type.clone(),
548 "function must have return type",
549 )),
550
551 syn::ReturnType::Type(_, ty) => {
552 let e = Err(syn::Error::new_spanned(
553 return_type.clone(),
554 "the return type must be savvy::Result<T> or savvy::Result<()>",
555 ));
556
557 let path_args = match ty.as_ref() {
560 syn::Type::Path(type_path) => {
561 if !is_type_path_savvy_or_no_qualifier(type_path) {
563 return e;
564 }
565
566 let last_path_seg = type_path.path.segments.last().unwrap();
567 match (
568 last_path_seg.ident.to_string().as_str(),
569 self_ty_to_actual_ty(self_ty),
570 ) {
571 ("Result", _) => {}
573 (ret_ty_str, Some(ty_actual)) => {
575 if ret_ty_str != "Self" && ty_actual != ret_ty_str {
576 return e;
577 } else {
578 return Ok(SavvyFnReturnType::UserDefinedStruct(
579 UserDefinedStructReturnType {
580 ty: ty_actual,
581 return_type: parse_quote!(-> savvy::Result<#self_ty>),
582 wrapped_with_result: false,
583 },
584 ));
585 }
586 }
587 _ => {
588 return e;
589 }
590 }
591 &last_path_seg.arguments
592 }
593 _ => return e,
594 };
595
596 if let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
599 args,
600 ..
601 }) = path_args
602 {
603 if args.len() != 1 {
604 return e;
605 }
606
607 if let syn::GenericArgument::Type(ty) = &args.first().unwrap() {
608 match ty {
609 syn::Type::Tuple(type_tuple) => {
610 if type_tuple.elems.is_empty() {
611 return Ok(SavvyFnReturnType::Unit(return_type.clone()));
612 }
613 }
614
615 syn::Type::Path(type_path) => {
616 let last_ident = &type_path.path.segments.last().unwrap().ident;
617 match last_ident.to_string().as_str() {
618 "Sexp" => return Ok(SavvyFnReturnType::Sexp(return_type.clone())),
619
620 "Self" => {
622 if let Some(ty_actual) = self_ty_to_actual_ty(self_ty) {
623 return Ok(SavvyFnReturnType::UserDefinedStruct(
624 UserDefinedStructReturnType {
625 ty: ty_actual,
626 return_type: parse_quote!(-> savvy::Result<#self_ty>),
627 wrapped_with_result: true,
628 },
629 ));
630 }
631 }
632
633 wrong_ty @ ("String" | "i32" | "usize" | "f64" | "bool") => {
635 let msg = format!(
636"Return type must be either (), savvy::Sexp, or a user-defined type.
637You can use .try_into() to convert {wrong_ty} to savvy::Sexp."
638 );
639 return Err(syn::Error::new_spanned(type_path, msg));
640 }
641
642 _ => {
644 return Ok(SavvyFnReturnType::UserDefinedStruct(
645 UserDefinedStructReturnType {
646 ty: last_ident.clone(),
647 return_type: return_type.clone(),
648 wrapped_with_result: true,
649 },
650 ))
651 }
652 }
653 }
654
655 _ => {}
656 }
657 }
658 }
659
660 e
661 }
662 }
663}
664
665fn is_type_path_savvy_or_no_qualifier(type_path: &syn::TypePath) -> bool {
667 if type_path.qself.is_some() || type_path.path.leading_colon.is_some() {
668 return false;
669 }
670
671 match type_path.path.segments.len() {
672 1 => true,
673 2 => {
674 let first_path_seg = type_path.path.segments.first().unwrap();
675 first_path_seg.arguments.is_none() && &first_path_seg.ident.to_string() == "savvy"
676 }
677 _ => false,
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use syn::parse_quote;
685
686 #[test]
687 fn test_detect_return_type_sexp() {
688 let ok_cases1: &[syn::ReturnType] = &[
689 parse_quote!(-> Result<Sexp>),
690 parse_quote!(-> savvy::Result<Sexp>),
691 parse_quote!(-> savvy::Result<savvy::Sexp>),
692 ];
693
694 for rt in ok_cases1 {
695 let srt = get_savvy_return_type(rt, None);
696 assert!(srt.is_ok());
697 assert!(matches!(srt.unwrap(), SavvyFnReturnType::Sexp(_)));
698 }
699 }
700
701 #[test]
702 fn test_detect_return_type_unit() {
703 let ok_cases2: &[syn::ReturnType] = &[
704 parse_quote!(-> Result<()>),
705 parse_quote!(-> savvy::Result<()>),
706 ];
707
708 for rt in ok_cases2 {
709 let srt = get_savvy_return_type(rt, None);
710 assert!(srt.is_ok());
711 assert!(matches!(srt.unwrap(), SavvyFnReturnType::Unit(_)));
712 }
713 }
714
715 #[test]
716 fn test_detect_return_type_sturct() {
717 let ok_cases3: &[syn::ReturnType] = &[
718 parse_quote!(-> Result<Foo>),
719 parse_quote!(-> savvy::Result<Foo>),
720 ];
721
722 for rt in ok_cases3 {
723 let srt = get_savvy_return_type(rt, None);
724 assert!(srt.is_ok());
725 assert!(matches!(
726 srt.unwrap(),
727 SavvyFnReturnType::UserDefinedStruct(_)
728 ));
729 }
730 }
731
732 #[test]
733 fn test_detect_return_type_self() {
734 let ok_cases4: &[syn::ReturnType] = &[
735 parse_quote!(-> Self),
736 parse_quote!(-> Result<Self>),
737 parse_quote!(-> savvy::Result<Self>),
738 ];
739 let self_ty: syn::Type = parse_quote!(Foo);
740
741 for (i, rt) in ok_cases4.iter().enumerate() {
742 let srt = get_savvy_return_type(rt, Some(&self_ty));
743 assert!(srt.is_ok());
744 if let SavvyFnReturnType::UserDefinedStruct(uds) = srt.unwrap() {
745 assert_eq!(uds.ty.to_string().as_str(), "Foo");
746 assert_eq!(uds.return_type, parse_quote!(-> savvy::Result<Foo>));
747 assert_eq!(uds.wrapped_with_result, i != 0); } else {
749 panic!("Unpexpected SavvyFnReturnType");
750 }
751 }
752 }
753
754 #[test]
755 fn test_detect_return_type_fail() {
756 let err_cases: &[syn::ReturnType] = &[
757 parse_quote!(-> Foo),
758 parse_quote!(-> savvy::Result<(T, T)>),
759 parse_quote!(-> foo::Result<Sexp>),
760 parse_quote!(),
761 ];
762
763 for rt in err_cases {
764 assert!(get_savvy_return_type(rt, None).is_err())
765 }
766 }
767}