valkey_module_macros_internals/
lib.rs

1mod api_versions;
2
3use api_versions::{get_feature_flags, API_OLDEST_VERSION, API_VERSION_MAPPING};
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::token::{self, Paren, RArrow};
9use syn::Ident;
10use syn::ItemFn;
11use syn::{self, bracketed, parse_macro_input, ReturnType, Token, Type, TypeTuple};
12
13#[derive(Debug)]
14struct Args {
15    requested_apis: Vec<Ident>,
16    function: ItemFn,
17}
18
19impl Parse for Args {
20    fn parse(input: ParseStream) -> Result<Self> {
21        let content;
22        let _paren_token: token::Bracket = bracketed!(content in input);
23        let vars: Punctuated<Ident, Token![,]> = content.parse_terminated(Ident::parse)?;
24        input.parse::<Token![,]>()?;
25        let function: ItemFn = input.parse()?;
26        Ok(Args {
27            requested_apis: vars.into_iter().collect(),
28            function,
29        })
30    }
31}
32
33/// This proc macro allows specifying which RedisModuleAPI is required by some valekymodue-rs
34/// function. The macro finds, for a given set of RedisModuleAPI, what the minimal Valkey version is
35/// that contains all those APIs and decides whether or not the function might raise an [APIError].
36///
37/// In addition, for each RedisModuleAPI, the proc macro injects a code that extracts the actual
38/// API function pointer and raises an error or panics if the API is invalid.
39///
40/// # Panics
41///
42/// Panics when an API is not available and if the function doesn't return [`Result`]. If it does
43/// return a [`Result`], the panics are replaced with returning a [`Result::Err`].
44///
45/// # Examples
46///
47/// Creating a wrapper for the [`RedisModule_AddPostNotificationJob`]
48/// ```rust,no_run,ignore
49///    redismodule_api!(
50///         [RedisModule_AddPostNotificationJob],
51///         pub fn add_post_notification_job<F: Fn(&Context)>(&self, callback: F) -> Status {
52///             let callback = Box::into_raw(Box::new(callback));
53///             unsafe {
54///                 RedisModule_AddPostNotificationJob(
55///                     self.ctx,
56///                     Some(post_notification_job::<F>),
57///                     callback as *mut c_void,
58///                     Some(post_notification_job_free_callback::<F>),
59///                 )
60///             }
61///             .into()
62///         }
63///     );
64/// ```
65#[proc_macro]
66pub fn api(item: TokenStream) -> TokenStream {
67    let args = parse_macro_input!(item as Args);
68    let minimum_require_version =
69        args.requested_apis
70            .iter()
71            .fold(*API_OLDEST_VERSION, |min_api_version, item| {
72                // if we do not have a version mapping, we assume the API exists and return the minimum version.
73                let api_version = API_VERSION_MAPPING
74                    .get(&item.to_string())
75                    .map(|v| *v)
76                    .unwrap_or(*API_OLDEST_VERSION);
77                api_version.max(min_api_version)
78            });
79
80    let requested_apis = args.requested_apis;
81    let requested_apis_str: Vec<String> = requested_apis.iter().map(|e| e.to_string()).collect();
82
83    let original_func = args.function;
84    let original_func_attr = original_func.attrs;
85    let original_func_code = original_func.block;
86    let original_func_sig = original_func.sig;
87    let original_func_vis = original_func.vis;
88
89    let inner_return_return_type = match original_func_sig.output.clone() {
90        ReturnType::Default => Box::new(Type::Tuple(TypeTuple {
91            paren_token: Paren::default(),
92            elems: Punctuated::new(),
93        })),
94        ReturnType::Type(_, t) => t,
95    };
96    let new_return_return_type = Type::Path(
97        syn::parse(
98            quote!(
99                crate::apierror::APIResult<#inner_return_return_type>
100            )
101            .into(),
102        )
103        .unwrap(),
104    );
105
106    let mut new_func_sig = original_func_sig.clone();
107    new_func_sig.output = ReturnType::Type(RArrow::default(), Box::new(new_return_return_type));
108
109    let old_ver_func = quote!(
110        #(#original_func_attr)*
111        #original_func_vis #new_func_sig {
112            #(
113                #[allow(non_snake_case)]
114                let #requested_apis = unsafe{crate::raw::#requested_apis.ok_or(concat!(#requested_apis_str, " does not exists"))?};
115            )*
116            let __callback__ = move || -> #inner_return_return_type {
117                #original_func_code
118            };
119            Ok(__callback__())
120        }
121    );
122
123    let new_ver_func = quote!(
124        #(#original_func_attr)*
125        #original_func_vis #original_func_sig {
126            #(
127                #[allow(non_snake_case)]
128                let #requested_apis = unsafe{crate::raw::#requested_apis.unwrap()};
129            )*
130            let __callback__ = move || -> #inner_return_return_type {
131                #original_func_code
132            };
133            __callback__()
134        }
135    );
136
137    let (all_lower_features, all_upper_features) = get_feature_flags(minimum_require_version);
138
139    let gen = quote! {
140        cfg_if::cfg_if! {
141            if #[cfg(any(#(#all_lower_features, )*))] {
142                #old_ver_func
143            } else if #[cfg(any(#(#all_upper_features, )*))] {
144                #new_ver_func
145            } else {
146                compile_error!("min-redis-compatibility-version is not set correctly")
147            }
148        }
149    };
150    gen.into()
151}