1use proc_macro2::Span;
2use quote::format_ident;
3use syn::{ext::IdentExt, parse_quote, Attribute, FnArg::Typed, Pat::Ident, PatType, Signature, Stmt};
4
5use crate::utils::extract_docs;
6
7#[derive(Clone, PartialEq)]
8enum SavvyInputTypeCategory {
9 Sexp,
10 SexpWrapper,
11 PrimitiveType,
12 UserDefinedTypeRef, UserDefinedType, DllInfo,
15}
16
17#[derive(Clone)]
18struct SavvyInputType {
19 category: SavvyInputTypeCategory,
20 ty_orig: syn::Type,
21 ty_str: String,
22 optional: bool,
23}
24
25#[allow(dead_code)]
26impl SavvyInputType {
27 fn from_type(ty: &syn::Type, in_option: bool) -> syn::Result<Self> {
28 match &ty {
29 syn::Type::Reference(syn::TypeReference { elem, .. }) => {
32 if let syn::Type::Path(type_path) = elem.as_ref() {
33 let ty_str = type_path.path.segments.last().unwrap().ident.to_string();
34 if &ty_str == "str" {
35 Ok(Self {
36 category: SavvyInputTypeCategory::PrimitiveType,
37 ty_orig: ty.clone(),
38 ty_str: "&str".to_string(),
39 optional: in_option,
40 })
41 } else {
42 Ok(Self {
43 category: SavvyInputTypeCategory::UserDefinedTypeRef,
44 ty_orig: ty.clone(),
45 ty_str,
46 optional: in_option,
47 })
48 }
49 } else {
50 Err(syn::Error::new_spanned(
51 ty.clone(),
52 "Unexpected type specification: {:?}",
53 ))
54 }
55 }
56
57 syn::Type::Path(type_path) => {
58 let type_path_last = type_path.path.segments.last().unwrap();
59 let type_ident = &type_path_last.ident;
60 let ty_str = type_ident.to_string();
61 match ty_str.as_str() {
62 "Option" => {
63 if in_option {
64 return Err(syn::Error::new_spanned(
65 type_path,
66 "`Option` cannot be nested",
67 ));
68 }
69
70 if let syn::PathArguments::AngleBracketed(
71 syn::AngleBracketedGenericArguments { args, .. },
72 ) = &type_path_last.arguments
73 {
74 if args.len() == 1 {
75 if let syn::GenericArgument::Type(ty) = &args.first().unwrap() {
76 return Self::from_type(ty, true);
77 }
78 }
79 }
80
81 Err(syn::Error::new_spanned(
82 type_path,
83 "Option<T> can accept only a type",
84 ))
85 }
86
87 "OwnedIntegerSexp" | "OwnedRealSexp" | "OwnedComplexSexp"
89 | "OwnedLogicalSexp" | "OwnedRawSexp" | "OwnedStringSexp" | "OwnedListSexp" => {
90 let msg = format!(
91 "`Owned-` types are not allowed here. Did you mean `{}`?",
92 ty_str.strip_prefix("Owned").unwrap()
93 );
94 Err(syn::Error::new_spanned(type_path, msg))
95 }
96
97 "Sexp" => Ok(Self {
100 category: SavvyInputTypeCategory::Sexp,
101 ty_orig: ty.clone(),
102 ty_str,
103 optional: in_option,
104 }),
105
106 "IntegerSexp" | "RealSexp" | "NumericSexp" | "ComplexSexp"
108 | "LogicalSexp" | "RawSexp" | "StringSexp" | "ListSexp" | "FunctionSexp"
109 | "EnvironmentSexp" => Ok(Self {
110 category: SavvyInputTypeCategory::SexpWrapper,
111 ty_orig: ty.clone(),
112 ty_str,
113 optional: in_option,
114 }),
115
116 "i32" | "usize" | "f64" | "bool" | "u8" | "NumericScalar" => Ok(Self {
118 category: SavvyInputTypeCategory::PrimitiveType,
119 ty_orig: ty.clone(),
120 ty_str,
121 optional: in_option,
122 }),
123
124 "DllInfo" => Err(syn::Error::new_spanned(
125 type_path,
126 "DllInfo must be `*mut DllInfo`",
127 )),
128
129 _ => Ok(Self {
130 category: SavvyInputTypeCategory::UserDefinedType,
131 ty_orig: ty.clone(),
132 ty_str,
133 optional: in_option,
134 }),
135 }
136 }
137
138 syn::Type::Ptr(syn::TypePtr {
140 mutability, elem, ..
141 }) => {
142 let type_ident = if let syn::Type::Path(p) = elem.as_ref() {
143 p.path.segments.last().unwrap().ident.to_string()
144 } else {
145 "".to_string()
146 };
147
148 if &type_ident != "DllInfo" {
149 return Err(syn::Error::new_spanned(
150 ty.clone(),
151 "Unexpected type specification: {:?}",
152 ));
153 }
154
155 if mutability.is_none() {
156 return Err(syn::Error::new_spanned(
157 ty.clone(),
158 "DllInfo must be `*mut DllInfo`",
159 ));
160 }
161
162 Ok(Self {
163 category: SavvyInputTypeCategory::DllInfo,
164 ty_orig: ty.clone(),
165 ty_str: type_ident.to_string(),
166 optional: in_option,
167 })
168 }
169
170 _ => Err(syn::Error::new_spanned(
171 ty.clone(),
172 "Unexpected type specification: {:?}",
173 )),
174 }
175 }
176
177 fn to_rust_type_outer(&self) -> syn::Type {
179 self.ty_orig.clone()
180 }
181
182 fn to_rust_type_inner(&self) -> syn::Type {
184 if matches!(self.category, SavvyInputTypeCategory::DllInfo) {
185 self.ty_orig.clone()
186 } else {
187 parse_quote!(savvy::ffi::SEXP)
188 }
189 }
190
191 fn to_c_type(&self) -> String {
193 if matches!(self.category, SavvyInputTypeCategory::DllInfo) {
194 "DllInfo*".to_string()
195 } else {
196 "SEXP".to_string()
197 }
198 }
199}
200
201#[derive(Clone)]
202pub struct SavvyFnArg {
203 pub(crate) pat: syn::Ident,
204 ty: SavvyInputType,
205}
206
207impl SavvyFnArg {
208 pub fn pat(&self) -> syn::Ident {
209 self.pat.clone()
210 }
211
212 pub fn is_user_defined_type(&self) -> bool {
213 matches!(
214 &self.ty.category,
215 SavvyInputTypeCategory::UserDefinedTypeRef | SavvyInputTypeCategory::UserDefinedType
216 )
217 }
218
219 pub fn ty_string(&self) -> String {
220 self.ty.ty_str.clone()
221 }
222
223 pub fn is_optional(&self) -> bool {
224 self.ty.optional
225 }
226
227 pub fn to_c_type_string(&self) -> String {
228 self.ty.to_c_type()
229 }
230
231 pub fn to_rust_type_outer(&self) -> syn::Type {
232 self.ty.to_rust_type_outer()
233 }
234
235 pub fn to_rust_type_inner(&self) -> syn::Type {
236 self.ty.to_rust_type_inner()
237 }
238}
239
240impl PartialEq for SavvyFnArg {
241 fn eq(&self, other: &Self) -> bool {
242 self.pat == other.pat && self.ty.category == other.ty.category && self.ty.ty_str == other.ty.ty_str && self.ty.optional == other.ty.optional
243 }
244}
245
246#[derive(Clone)]
252pub struct UserDefinedStructReturnType {
253 pub(crate) ty: syn::Ident,
254 pub(crate) return_type: syn::ReturnType,
255 pub(crate) wrapped_with_result: bool,
256}
257
258#[derive(Clone)]
267pub enum SavvyFnReturnType {
268 Sexp(syn::ReturnType),
269 Unit(syn::ReturnType),
270 UserDefinedStruct(UserDefinedStructReturnType),
271}
272
273impl SavvyFnReturnType {
274 pub fn inner(&self) -> &syn::ReturnType {
275 match self {
276 SavvyFnReturnType::Sexp(ret_ty) => ret_ty,
277 SavvyFnReturnType::Unit(ret_ty) => ret_ty,
278 SavvyFnReturnType::UserDefinedStruct(ret_ty) => &ret_ty.return_type,
279 }
280 }
281}
282
283#[derive(Clone)]
284pub enum SavvyFnType {
285 BareFunction,
287 Method {
290 ty: syn::Type,
291 reference: bool,
292 mutability: bool,
293 },
294 AssociatedFunction(syn::Type),
297 InitFunction,
299}
300
301#[derive(Clone)]
302pub struct SavvyFn {
303 pub docs: Vec<String>,
305 pub attrs: Vec<syn::Attribute>,
307 pub fn_name: syn::Ident,
309 pub fn_type: SavvyFnType,
311 pub args: Vec<SavvyFnArg>,
313 pub return_type: SavvyFnReturnType,
315 pub stmts_additional: Vec<syn::Stmt>,
317}
318
319#[allow(dead_code)]
320impl SavvyFn {
321 pub(crate) fn get_self_ty_ident(&self) -> Option<syn::Ident> {
322 let self_ty = match &self.fn_type {
323 SavvyFnType::Method { ty, .. } => ty,
324 SavvyFnType::AssociatedFunction(ty) => ty,
325 _ => return None,
326 };
327 if let syn::Type::Path(type_path) = self_ty {
328 let ty = type_path
329 .path
330 .segments
331 .last()
332 .expect("Unexpected type path")
333 .ident
334 .clone();
335 Some(ty)
336 } else {
337 panic!("Unexpected self type!")
338 }
339 }
340
341 pub fn fn_name_inner(&self) -> syn::Ident {
342 match self.get_self_ty_ident() {
343 Some(ty) => format_ident!("savvy_{}_{}_inner", ty, self.fn_name),
344 None => format_ident!("savvy_{}_inner", self.fn_name),
345 }
346 }
347
348 pub fn fn_name_c_header(&self) -> syn::Ident {
350 match self.get_self_ty_ident() {
351 Some(ty) => format_ident!("savvy_{}_{}__ffi", ty, self.fn_name),
352 None => format_ident!("savvy_{}__ffi", self.fn_name),
353 }
354 }
355
356 pub fn fn_name_c_impl(&self) -> syn::Ident {
358 match self.get_self_ty_ident() {
359 Some(ty) => format_ident!("savvy_{}_{}__impl", ty, self.fn_name),
360 None => format_ident!("savvy_{}__impl", self.fn_name),
361 }
362 }
363
364 pub fn from_fn(orig: &syn::ItemFn, as_init_fn: bool) -> syn::Result<Self> {
365 let fn_type = if as_init_fn {
366 SavvyFnType::InitFunction
367 } else {
368 SavvyFnType::BareFunction
369 };
370 Self::new(&orig.attrs, &orig.sig, fn_type, None)
371 }
372
373 pub fn from_impl_fn(
374 orig: &syn::ImplItemFn,
375 fn_type: SavvyFnType,
376 self_ty: &syn::Type,
377 ) -> syn::Result<Self> {
378 Self::new(&orig.attrs, &orig.sig, fn_type, Some(self_ty))
379 }
380
381 pub fn new(
382 attrs: &[Attribute],
383 sig: &Signature,
384 fn_type: SavvyFnType,
385 self_ty: Option<&syn::Type>,
386 ) -> syn::Result<Self> {
387 let mut attrs = attrs.to_vec();
390 attrs.retain(|attr| {
392 !(attr == &parse_quote!(#[savvy]) || attr == &parse_quote!(#[savvy_init]))
393 });
394
395 let docs = extract_docs(attrs.as_slice());
397
398 let fn_name = sig.ident.clone();
399
400 let mut stmts_additional: Vec<Stmt> = Vec::new();
401
402 let args_new = sig
403 .inputs
404 .iter()
405 .filter_map(|arg| match arg {
406 Typed(PatType { pat, ty, .. }) => {
407 let pat = match pat.as_ref() {
408 Ident(arg) => arg.ident.clone(),
409 _ => {
410 return Some(Err(syn::Error::new_spanned(
411 pat,
412 "non-ident is not supported",
413 )));
414 }
415 };
416
417 let ty = match SavvyInputType::from_type(ty.as_ref(), false) {
418 Ok(ty) => ty,
419 Err(e) => return Some(Err(e)),
420 };
421 let ty_ident = ty.to_rust_type_outer();
422
423 match (&fn_type, &ty.category) {
424 (&SavvyFnType::InitFunction, &SavvyInputTypeCategory::DllInfo) => {}
426
427 (&SavvyFnType::InitFunction, _) => {
428 return Some(Err(syn::Error::new_spanned(
429 ty.ty_orig,
430 "#[savvy_init] can be used only on a function that takes `*mut DllInfo`",
431 )));
432 }
433
434 (_, &SavvyInputTypeCategory::DllInfo) => {
435 return Some(Err(syn::Error::new_spanned(
436 ty.ty_orig,
437 "#[savvy] doesn't accept `*mut DllInfo`. Did you mean #[savvy_init]?",
438 )));
439 }
440
441 (_, &SavvyInputTypeCategory::Sexp) => {
442 if ty.optional {
443 stmts_additional.push(parse_quote! { let #pat = savvy::Sexp(#pat); });
444 stmts_additional.push(parse_quote! {
445 let #pat = if #pat.is_null() {
446 None
447 } else {
448 Some(#pat)
449 };
450 })
451 } else {
452 stmts_additional.push(parse_quote! {
453 let #pat = savvy::Sexp(#pat);
454 });
455 }
456 }
457
458 (_, _) => {
459 let arg_lit = syn::LitStr::new(&pat.unraw().to_string(), Span::call_site());
460 if ty.optional {
461 stmts_additional.push(parse_quote! { let #pat = savvy::Sexp(#pat); });
462 stmts_additional.push(parse_quote! {
463 let #pat = if #pat.is_null() {
464 None
465 } else {
466 Some(<#ty_ident>::try_from(#pat).map_err(|e| e.with_arg_name(#arg_lit))?)
467 };
468 })
469 } else {
470 stmts_additional.push(parse_quote! {
471 let #pat = <#ty_ident>::try_from(savvy::Sexp(#pat)).map_err(|e| e.with_arg_name(#arg_lit))?;
472 });
473 }
474 }
475 }
476
477 Some(Ok(SavvyFnArg { pat, ty }))
478 }
479 syn::FnArg::Receiver(syn::Receiver { .. }) => None,
481 })
482 .collect::<syn::Result<Vec<SavvyFnArg>>>()?;
483
484 let mut args_after_optional = args_new.iter().skip_while(|x| !x.is_optional());
486 if args_after_optional.any(|x| !x.is_optional()) {
487 return Err(syn::Error::new_spanned(
488 sig.inputs.clone(),
489 "optional args can be placed only after mandatory args",
490 ));
491 }
492
493 let is_init_fn = args_new
495 .iter()
496 .any(|x| matches!(x.ty.category, SavvyInputTypeCategory::DllInfo));
497 if is_init_fn && args_new.len() > 1 {
498 return Err(syn::Error::new_spanned(
499 sig,
500 "Initialization function can accept `*mut DllInfo` only",
501 ));
502 }
503
504 let fn_type = if is_init_fn {
505 SavvyFnType::InitFunction
506 } else {
507 fn_type
508 };
509
510 Ok(Self {
511 docs,
512 attrs,
513 fn_name,
514 fn_type,
515 args: args_new,
516 return_type: get_savvy_return_type(&sig.output, self_ty)?,
517 stmts_additional,
518 })
519 }
520}
521
522fn self_ty_to_actual_ty(self_ty: Option<&syn::Type>) -> Option<syn::Ident> {
523 if let Some(syn::Type::Path(type_path)) = self_ty {
524 Some(type_path.path.segments.last().unwrap().ident.clone())
525 } else {
526 None
527 }
528}
529
530fn get_savvy_return_type(
537 return_type: &syn::ReturnType,
538 self_ty: Option<&syn::Type>,
539) -> syn::Result<SavvyFnReturnType> {
540 match return_type {
541 syn::ReturnType::Default => Err(syn::Error::new_spanned(
542 return_type.clone(),
543 "function must have return type",
544 )),
545
546 syn::ReturnType::Type(_, ty) => {
547 let e = Err(syn::Error::new_spanned(
548 return_type.clone(),
549 "the return type must be savvy::Result<T> or savvy::Result<()>",
550 ));
551
552 let path_args = match ty.as_ref() {
555 syn::Type::Path(type_path) => {
556 if !is_type_path_savvy_or_no_qualifier(type_path) {
558 return e;
559 }
560
561 let last_path_seg = type_path.path.segments.last().unwrap();
562 match (
563 last_path_seg.ident.to_string().as_str(),
564 self_ty_to_actual_ty(self_ty),
565 ) {
566 ("Result", _) => {}
568 (ret_ty_str, Some(ty_actual)) => {
570 if ret_ty_str != "Self" && ty_actual != ret_ty_str {
571 return e;
572 } else {
573 return Ok(SavvyFnReturnType::UserDefinedStruct(
574 UserDefinedStructReturnType {
575 ty: ty_actual,
576 return_type: parse_quote!(-> savvy::Result<#self_ty>),
577 wrapped_with_result: false,
578 },
579 ));
580 }
581 }
582 _ => {
583 return e;
584 }
585 }
586 &last_path_seg.arguments
587 }
588 _ => return e,
589 };
590
591 if let syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
594 args,
595 ..
596 }) = path_args
597 {
598 if args.len() != 1 {
599 return e;
600 }
601
602 if let syn::GenericArgument::Type(ty) = &args.first().unwrap() {
603 match ty {
604 syn::Type::Tuple(type_tuple) => {
605 if type_tuple.elems.is_empty() {
606 return Ok(SavvyFnReturnType::Unit(return_type.clone()));
607 }
608 }
609
610 syn::Type::Path(type_path) => {
611 let last_ident = &type_path.path.segments.last().unwrap().ident;
612 match last_ident.to_string().as_str() {
613 "Sexp" => return Ok(SavvyFnReturnType::Sexp(return_type.clone())),
614
615 "Self" => {
617 if let Some(ty_actual) = self_ty_to_actual_ty(self_ty) {
618 return Ok(SavvyFnReturnType::UserDefinedStruct(
619 UserDefinedStructReturnType {
620 ty: ty_actual,
621 return_type: parse_quote!(-> savvy::Result<#self_ty>),
622 wrapped_with_result: true,
623 },
624 ));
625 }
626 }
627
628 wrong_ty @ ("String" | "i32" | "usize" | "f64" | "bool") => {
630 let msg = format!(
631"Return type must be either (), savvy::Sexp, or a user-defined type.
632You can use .try_into() to convert {wrong_ty} to savvy::Sexp."
633 );
634 return Err(syn::Error::new_spanned(type_path, msg));
635 }
636
637 _ => {
639 return Ok(SavvyFnReturnType::UserDefinedStruct(
640 UserDefinedStructReturnType {
641 ty: last_ident.clone(),
642 return_type: return_type.clone(),
643 wrapped_with_result: true,
644 },
645 ))
646 }
647 }
648 }
649
650 _ => {}
651 }
652 }
653 }
654
655 e
656 }
657 }
658}
659
660fn is_type_path_savvy_or_no_qualifier(type_path: &syn::TypePath) -> bool {
662 if type_path.qself.is_some() || type_path.path.leading_colon.is_some() {
663 return false;
664 }
665
666 match type_path.path.segments.len() {
667 1 => true,
668 2 => {
669 let first_path_seg = type_path.path.segments.first().unwrap();
670 first_path_seg.arguments.is_none() && &first_path_seg.ident.to_string() == "savvy"
671 }
672 _ => false,
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use syn::parse_quote;
680
681 #[test]
682 fn test_detect_return_type_sexp() {
683 let ok_cases1: &[syn::ReturnType] = &[
684 parse_quote!(-> Result<Sexp>),
685 parse_quote!(-> savvy::Result<Sexp>),
686 parse_quote!(-> savvy::Result<savvy::Sexp>),
687 ];
688
689 for rt in ok_cases1 {
690 let srt = get_savvy_return_type(rt, None);
691 assert!(srt.is_ok());
692 assert!(matches!(srt.unwrap(), SavvyFnReturnType::Sexp(_)));
693 }
694 }
695
696 #[test]
697 fn test_detect_return_type_unit() {
698 let ok_cases2: &[syn::ReturnType] = &[
699 parse_quote!(-> Result<()>),
700 parse_quote!(-> savvy::Result<()>),
701 ];
702
703 for rt in ok_cases2 {
704 let srt = get_savvy_return_type(rt, None);
705 assert!(srt.is_ok());
706 assert!(matches!(srt.unwrap(), SavvyFnReturnType::Unit(_)));
707 }
708 }
709
710 #[test]
711 fn test_detect_return_type_sturct() {
712 let ok_cases3: &[syn::ReturnType] = &[
713 parse_quote!(-> Result<Foo>),
714 parse_quote!(-> savvy::Result<Foo>),
715 ];
716
717 for rt in ok_cases3 {
718 let srt = get_savvy_return_type(rt, None);
719 assert!(srt.is_ok());
720 assert!(matches!(
721 srt.unwrap(),
722 SavvyFnReturnType::UserDefinedStruct(_)
723 ));
724 }
725 }
726
727 #[test]
728 fn test_detect_return_type_self() {
729 let ok_cases4: &[syn::ReturnType] = &[
730 parse_quote!(-> Self),
731 parse_quote!(-> Result<Self>),
732 parse_quote!(-> savvy::Result<Self>),
733 ];
734 let self_ty: syn::Type = parse_quote!(Foo);
735
736 for (i, rt) in ok_cases4.iter().enumerate() {
737 let srt = get_savvy_return_type(rt, Some(&self_ty));
738 assert!(srt.is_ok());
739 if let SavvyFnReturnType::UserDefinedStruct(uds) = srt.unwrap() {
740 assert_eq!(uds.ty.to_string().as_str(), "Foo");
741 assert_eq!(uds.return_type, parse_quote!(-> savvy::Result<Foo>));
742 assert_eq!(uds.wrapped_with_result, i != 0); } else {
744 panic!("Unpexpected SavvyFnReturnType");
745 }
746 }
747 }
748
749 #[test]
750 fn test_detect_return_type_fail() {
751 let err_cases: &[syn::ReturnType] = &[
752 parse_quote!(-> Foo),
753 parse_quote!(-> savvy::Result<(T, T)>),
754 parse_quote!(-> foo::Result<Sexp>),
755 parse_quote!(),
756 ];
757
758 for rt in err_cases {
759 assert!(get_savvy_return_type(rt, None).is_err())
760 }
761 }
762}