Skip to main content

sqlite3_ext_macro/
lib.rs

1use convert_case::{Case, Casing};
2use ext_attr::*;
3use fn_attr::*;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote, ToTokens};
7use regex::Regex;
8use std::mem::replace;
9use syn::{punctuated::Punctuated, *};
10use vtab_attr::*;
11
12mod ext_attr;
13mod fn_attr;
14mod vtab_attr;
15
16mod kw {
17    syn::custom_keyword!(DirectOnly);
18    syn::custom_keyword!(EponymousModule);
19    syn::custom_keyword!(EponymousOnlyModule);
20    syn::custom_keyword!(FindFunctionVTab);
21    syn::custom_keyword!(Innocuous);
22    syn::custom_keyword!(RenameVTab);
23    syn::custom_keyword!(StandardModule);
24    syn::custom_keyword!(TransactionVTab);
25    syn::custom_keyword!(UpdateVTab);
26    syn::custom_keyword!(deterministic);
27    syn::custom_keyword!(export);
28    syn::custom_keyword!(n_args);
29    syn::custom_keyword!(persistent);
30    syn::custom_keyword!(risk_level);
31}
32
33/// Declare the primary extension entry point for the crate.
34///
35/// This is equivalent to [macro@sqlite3_ext_init], but it will automatically name the export
36/// according to the name of the crate (e.g. `sqlite3_myextension_init`).
37///
38/// # Examples
39///
40/// Specify a persistent extension:
41///
42/// ```no_run
43/// # use sqlite3_ext_macro::*;
44/// use sqlite3_ext::*;
45///
46/// #[sqlite3_ext_main(persistent)]
47/// fn init(db: &Connection) -> Result<()> {
48///     Ok(())
49/// }
50/// ```
51#[proc_macro_attribute]
52pub fn sqlite3_ext_main(attr: TokenStream, item: TokenStream) -> TokenStream {
53    let attr = proc_macro2::TokenStream::from(attr);
54    let item = parse_macro_input!(item as ItemFn);
55    let crate_name = std::env::var("CARGO_CRATE_NAME").unwrap();
56    let export_base = crate_name.to_lowercase();
57    let export_base = Regex::new("[^a-z]").unwrap().replace_all(&export_base, "");
58    let init_ident = format_ident!("sqlite3_{}_init", export_base);
59    let expanded = quote! {
60        #[::sqlite3_ext::sqlite3_ext_init(export = #init_ident, #attr)]
61        #item
62    };
63    TokenStream::from(expanded)
64}
65
66/// Declare the entry point to an extension.
67///
68/// This method generates an `extern "C"` function suitable for use by SQLite's loadable
69/// extension feature. An export name can optionally be provided. Consult [the SQLite
70/// documentation](https://www.sqlite.org/loadext.html#loading_an_extension) for information
71/// about naming the exported method, but generally you can use [macro@sqlite3_ext_main] to
72/// automatically name the export correctly.
73///
74/// If the persistent keyword is included in the attribute, the extension will be loaded
75/// permanently. See [the SQLite
76/// documentation](https://www.sqlite.org/loadext.html#persistent_loadable_extensions) for more
77/// information.
78///
79/// # Example
80///
81/// Specifying a nonstandard entry point name:
82///
83/// ```no_run
84/// # use sqlite3_ext_macro::*;
85/// use sqlite3_ext::*;
86///
87/// #[sqlite3_ext_init(export = nonstandard_entry_point, persistent)]
88/// fn init(db: &Connection) -> Result<()> {
89///     Ok(())
90/// }
91/// ```
92///
93/// This extension could be loaded from SQLite:
94///
95/// ```sql
96/// SELECT load_extension('path/to/extension', 'nonstandard_entry_point');
97/// ```
98///
99/// # Implementation
100///
101/// This macro renames the original Rust function and instead creates an
102/// `sqlite3_ext::Extension` object in its place. Because `Extension` dereferences to the
103/// original function, you generally won't notice this change. This behavior allows you to use
104/// the original identifier to pass the auto extension methods.
105#[proc_macro_attribute]
106pub fn sqlite3_ext_init(attr: TokenStream, item: TokenStream) -> TokenStream {
107    let directives =
108        parse_macro_input!(attr with Punctuated::<ExtAttr, Token![,]>::parse_terminated);
109    let mut export: Option<Ident> = None;
110    let mut persistent: Option<kw::persistent> = None;
111    for d in directives {
112        match d {
113            ExtAttr::Export(ExtAttrExport { value }) => {
114                if let Some(_) = export {
115                    return Error::new(value.span(), "export specified multiple times")
116                        .into_compile_error()
117                        .into();
118                } else {
119                    export = Some(value)
120                }
121            }
122            ExtAttr::Persistent(tok) => {
123                persistent = Some(tok);
124            }
125        }
126    }
127    let mut item = parse_macro_input!(item as ItemFn);
128    let extension_vis = replace(&mut item.vis, Visibility::Inherited);
129    let name = item.sig.ident.clone();
130    let load_result = match persistent {
131        None => quote!(::sqlite3_ext::ffi::SQLITE_OK),
132        Some(tok) => {
133            if let Some(_) = export {
134                // Persistent loadable extensions were added in SQLite 3.14.0. If
135                // we were to return SQLITE_OK_LOAD_PERSISTENT, then the load
136                // would fail. We want the load to complete: any API which
137                // requires persistent extensions would return an error, but
138                // ignored errors imply that the persistent loading requirement
139                // is optional.
140                quote!(::sqlite3_ext::sqlite3_match_version!(
141                    3_014_000 => ::sqlite3_ext::ffi::SQLITE_OK_LOAD_PERMANENTLY,
142                    _ => ::sqlite3_ext::ffi::SQLITE_OK,
143                ))
144            } else {
145                return Error::new(tok.span, "unexported extension cannot be persistent")
146                    .into_compile_error()
147                    .into();
148            }
149        }
150    };
151
152    let c_export = export.as_ref().map(|_| quote!(#[no_mangle] pub));
153    let c_name = match export {
154        None => format_ident!("{}_entry", item.sig.ident),
155        Some(x) => x,
156    };
157
158    let expanded = quote! {
159        #[allow(non_upper_case_globals)]
160        #extension_vis static #name: ::sqlite3_ext::Extension = {
161            #c_export
162            unsafe extern "C" fn #c_name(
163                db: *mut ::sqlite3_ext::ffi::sqlite3,
164                err_msg: *mut *mut ::std::os::raw::c_char,
165                api: *mut ::sqlite3_ext::ffi::sqlite3_api_routines,
166            ) -> ::std::os::raw::c_int {
167                if let Err(e) = ::sqlite3_ext::ffi::init_api_routines(api) {
168                    return ::sqlite3_ext::ffi::handle_error(e, err_msg);
169                }
170                match #name(::sqlite3_ext::Connection::from_ptr(db)) {
171                    Ok(_) => #load_result,
172                    Err(e) => ::sqlite3_ext::ffi::handle_error(e, err_msg),
173                }
174            }
175
176            #item
177
178            ::sqlite3_ext::Extension::new(#c_name, #name)
179        };
180    };
181    TokenStream::from(expanded)
182}
183
184/// Declare a virtual table module.
185///
186/// This attribute is intended to be applied to the struct which implements VTab and related
187/// traits. The first parameter to the attribute is the type of module to create, which is one
188/// of StandardModule, EponymousModule, EponymousOnlyModule. The subsequent parameters refer to
189/// traits in sqlite3_ext::vtab, and describe the functionality which the virtual table
190/// supports. See the corresponding structs and traits in sqlite3_ext::vtab for more details.
191///
192/// The resulting struct will have an associated method `module` which returns the concrete
193/// type of module specified in the first parameter, or a Result containing it.
194///
195/// # Examples
196///
197/// Declare a table-valued function:
198///
199/// ```no_run
200/// # use sqlite3_ext_macro::*;
201/// use sqlite3_ext::*;
202///
203/// #[sqlite3_ext_vtab(EponymousModule)]
204/// struct MyTableFunction {}
205/// # sqlite3_ext_doctest_impl!(MyTableFunction);
206///
207/// #[sqlite3_ext_main]
208/// fn init(db: &Connection) -> Result<()> {
209///     db.create_module("my_table_function", MyTableFunction::module(), ())?;
210///     Ok(())
211/// }
212/// ```
213///
214/// Declare a standard virtual table that supports updates:
215///
216/// ```no_run
217/// # use sqlite3_ext_macro::*;
218/// use sqlite3_ext::*;
219///
220/// #[sqlite3_ext_vtab(StandardModule, UpdateVTab)]
221/// struct MyTable {}
222/// # sqlite3_ext_doctest_impl!(MyTable);
223///
224/// #[sqlite3_ext_main]
225/// fn init(db: &Connection) -> Result<()> {
226///     db.create_module("my_table", MyTable::module(), ())?;
227///     Ok(())
228/// }
229/// ```
230///
231/// Declare an eponymous-only table that supports updates:
232///
233/// ```no_run
234/// # use sqlite3_ext_macro::*;
235/// use sqlite3_ext::*;
236///
237/// #[sqlite3_ext_vtab(EponymousOnlyModule, UpdateVTab)]
238/// struct MyTable {}
239/// # sqlite3_ext_doctest_impl!(MyTable);
240///
241/// #[sqlite3_ext_main]
242/// fn init(db: &Connection) -> Result<()> {
243///     db.create_module("my_table", MyTable::module()?, ())?;
244///     Ok(())
245/// }
246/// ```
247#[proc_macro_attribute]
248pub fn sqlite3_ext_vtab(attr: TokenStream, item: TokenStream) -> TokenStream {
249    let attr = match parse::<VTabAttr>(attr) {
250        Ok(syntax_tree) => syntax_tree,
251        Err(err) => {
252            let mut ret = TokenStream::from(err.to_compile_error());
253            ret.extend(item);
254            return ret;
255        }
256    };
257    let item = parse_macro_input!(item as ItemStruct);
258    let struct_generics = &item.generics;
259    let impl_arguments = if struct_generics.params.is_empty() {
260        None
261    } else {
262        Some(AngleBracketedGenericArguments {
263            colon2_token: None,
264            lt_token: token::Lt::default(),
265            args: struct_generics
266                .params
267                .iter()
268                .map(|gp| match gp {
269                    GenericParam::Type(t) => {
270                        GenericArgument::Type(Type::Verbatim(t.ident.to_token_stream()))
271                    }
272                    GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
273                    GenericParam::Const(c) => {
274                        GenericArgument::Const(Expr::Verbatim(c.ident.to_token_stream()))
275                    }
276                })
277                .collect(),
278            gt_token: token::Gt::default(),
279        })
280    };
281    let struct_generic_def = {
282        let mut segments = Punctuated::default();
283        segments.push_value(PathSegment {
284            ident: item.ident.clone(),
285            arguments: impl_arguments
286                .map(PathArguments::AngleBracketed)
287                .unwrap_or(PathArguments::None),
288        });
289        Type::Path(TypePath {
290            qself: None,
291            path: Path {
292                leading_colon: None,
293                segments,
294            },
295        })
296    };
297    let lifetime = quote!('sqlite3_ext_vtab);
298    let lifetime_bounds: Punctuated<_, Token![+]> = struct_generics
299        .params
300        .iter()
301        .filter_map(|gp| {
302            if let GenericParam::Lifetime(LifetimeDef { lifetime, .. }) = gp {
303                Some(lifetime)
304            } else {
305                None
306            }
307        })
308        .collect();
309    let lifetime_bounds = if lifetime_bounds.is_empty() {
310        quote!()
311    } else {
312        quote!(: #lifetime_bounds)
313    };
314    let base = match attr.base {
315        VTabBase::Standard(_) => quote!(::sqlite3_ext::vtab::StandardModule),
316        VTabBase::Eponymous(_) => quote!(::sqlite3_ext::vtab::EponymousModule),
317        VTabBase::EponymousOnly(_) => quote!(::sqlite3_ext::vtab::EponymousOnlyModule),
318    };
319    let mut expr = quote!(#base::<Self>::new());
320    let ret = if let VTabBase::EponymousOnly(_) = attr.base {
321        expr.extend(quote!(?));
322        quote!(::sqlite3_ext::Result<#base<#lifetime, Self>>)
323    } else {
324        quote!(#base<#lifetime, Self>)
325    };
326    for t in attr.additional {
327        match t {
328            VTabTrait::UpdateVTab(_) => expr.extend(quote!(.with_update())),
329            VTabTrait::TransactionVTab(_) => expr.extend(quote!(.with_transactions())),
330            VTabTrait::FindFunctionVTab(_) => expr.extend(quote!(.with_find_function())),
331            VTabTrait::RenameVTab(_) => expr.extend(quote!(.with_rename())),
332        }
333    }
334    if let VTabBase::EponymousOnly(_) = attr.base {
335        expr = quote!(Ok(#expr));
336    };
337    let expanded = quote! {
338        #item
339
340        #[automatically_derived]
341        impl #struct_generics #struct_generic_def {
342            /// Return the [Module](::sqlite3_ext::vtab::Module) associated with
343            /// this virtual table.
344            pub fn module<#lifetime #lifetime_bounds> () -> #ret {
345                use ::sqlite3_ext::vtab::*;
346                #expr
347            }
348        }
349    };
350    TokenStream::from(expanded)
351}
352
353/// Create a FunctionOptions for an application-defined function.
354///
355/// This macro declares a FunctionOptions constant with the provided values. The constant will
356/// take on a name based on the function, so for example applying this attribute to a function
357/// named "count_horses" or a trait named "CountHorses" will create a constant named
358/// "COUNT_HORSES_OPTS".
359///
360/// # Syntax
361///
362/// Arguments passed to the macro are comma-separated. The following are supported:
363///
364/// - `n_args=N` corresponds to set_n_args.
365/// - `risk_level=X` corresponds to set_risk_level.
366/// - `deterministic` corresponds to set_desterministic with true.
367///
368/// # Example
369///
370/// ```no_run
371/// use sqlite3_ext::{function::*, *};
372///
373/// #[sqlite3_ext_fn(n_args=0, risk_level=Innocuous)]
374/// pub fn random_number(ctx: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
375///     ctx.set_result(4) // chosen by fair dice roll.
376/// }
377///
378/// pub fn init(db: &Connection) -> Result<()> {
379///     db.create_scalar_function("random_number", &RANDOM_NUMBER_OPTS, random_number)
380/// }
381/// ```
382#[proc_macro_attribute]
383pub fn sqlite3_ext_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
384    let directives =
385        parse_macro_input!(attr with Punctuated::<FnAttr, Token![,]>::parse_terminated);
386    let item = parse_macro_input!(item as Item);
387    let (ident, vis) = match &item {
388        Item::Fn(item) => (&item.sig.ident, &item.vis),
389        Item::Struct(item) => (&item.ident, &item.vis),
390        _ => {
391            return TokenStream::from(
392                Error::new(Span::call_site(), "only applies to fn or struct").into_compile_error(),
393            )
394        }
395    };
396    let opts_name = Ident::new(
397        &format!("{ident}_opts").to_case(Case::UpperSnake),
398        Span::call_site(),
399    );
400    let mut opts = quote! {
401        #[automatically_derived]
402        #vis const #opts_name: ::sqlite3_ext::function::FunctionOptions = ::sqlite3_ext::function::FunctionOptions::default()
403    };
404    for d in directives {
405        match d {
406            FnAttr::NumArgs(x) => opts.extend(quote!(.set_n_args(#x))),
407            FnAttr::RiskLevel(FnAttrRiskLevel::Innocuous) => {
408                opts.extend(quote!(.set_risk_level(::sqlite3_ext::RiskLevel::Innocuous)))
409            }
410            FnAttr::RiskLevel(FnAttrRiskLevel::DirectOnly) => {
411                opts.extend(quote!(.set_risk_level(::sqlite3_ext::RiskLevel::DirectOnly)))
412            }
413            FnAttr::Deterministic => opts.extend(quote!(.set_deterministic(true))),
414        }
415    }
416    let expanded = quote! {
417        #opts;
418        #item
419    };
420    TokenStream::from(expanded)
421}
422
423#[doc(hidden)]
424#[proc_macro]
425pub fn sqlite3_ext_doctest_impl(item: TokenStream) -> TokenStream {
426    let item = parse_macro_input!(item as Type);
427    let expanded = quote! {
428        impl<'vtab> ::sqlite3_ext::vtab::VTab<'vtab> for #item {
429            type Aux = ();
430            type Cursor = Cursor;
431
432            fn connect(_: &::sqlite3_ext::vtab::VTabConnection, _: &Self::Aux, _: &[&str]) -> std::result::Result<(String, Self), ::sqlite3_ext::Error> { todo!() }
433            fn best_index(&self, _: &mut ::sqlite3_ext::vtab::IndexInfo) -> std::result::Result<(), ::sqlite3_ext::Error> { todo!() }
434            fn open(&self) -> std::result::Result<Self::Cursor, ::sqlite3_ext::Error> { todo!() }
435        }
436
437        impl<'vtab> ::sqlite3_ext::vtab::CreateVTab<'vtab> for #item {
438            fn create(_: &::sqlite3_ext::vtab::VTabConnection, _: &Self::Aux, _: &[&str]) -> std::result::Result<(String, Self), ::sqlite3_ext::Error> { todo!() }
439            fn destroy(self) -> ::sqlite3_ext::vtab::DisconnectResult<Self> { todo!() }
440        }
441
442        impl<'vtab> ::sqlite3_ext::vtab::UpdateVTab<'vtab> for #item {
443            fn update(&self, _: &mut ::sqlite3_ext::vtab::ChangeInfo) -> ::sqlite3_ext::Result<i64> { todo!() }
444        }
445
446        struct Cursor {}
447        impl ::sqlite3_ext::vtab::VTabCursor for Cursor {
448            fn filter(&mut self, _: i32, _: Option<&str>, _: &mut [&mut ValueRef]) -> std::result::Result<(), ::sqlite3_ext::Error> { todo!() }
449            fn next(&mut self) -> std::result::Result<(), ::sqlite3_ext::Error> { todo!() }
450            fn eof(&mut self) -> bool { todo!() }
451            fn column(&mut self, _: usize, _: &::sqlite3_ext::vtab::ColumnContext) -> Result<()> { todo!() }
452            fn rowid(&mut self) -> std::result::Result<i64, ::sqlite3_ext::Error> { todo!() }
453        }
454    };
455    TokenStream::from(expanded)
456}