symbols/
lib.rs

1#![deny(warnings)]
2#![deny(missing_docs)]
3
4//! # Symbols
5//!
6//! This is an utility to build a proc-macro that connects to a database, retrieves data from given table and populates an enum variants with primary keys values  
7//! It also generates a method for every non-primary-key field, and, when there are multiple primary keys, a costructor for every possible subset of primary keys
8
9use std::{collections::HashMap, env, fs, future::Future, io};
10
11use heck::{ToSnakeCase, ToUpperCamelCase};
12
13use itertools::Itertools;
14
15use proc_macro2::{Ident, Literal, Span, TokenStream};
16
17use quote::quote;
18
19use sea_orm::{
20    DatabaseConnection, EntityName, EntityTrait, Iterable, ModelTrait, PrimaryKeyToColumn, QueryFilter, Value,
21};
22
23use serde::{de::DeserializeOwned, Serialize};
24
25pub use symbols_models::EntityFilter;
26
27use syn::{punctuated::Punctuated, token::Comma, Fields, ItemEnum, Lit, LitBool, Meta, NestedMeta, Variant};
28
29use tracing::{error, info};
30
31/// Main function  
32/// Given a database model (via generics), an enum item, a list of arguments and an async function to retrieve a database connection
33/// it populates the enum using primary key(s) values.  
34/// Only string-typed primary keys are supported.
35///
36/// When a single primary key is present, it simply generate an as_str method and a TryFrom<&str> implementation.  
37/// When multiple primary keys are present, it generates a costructor for every possible subset of primary keys.
38///
39/// For every non-primary key field of a supported type, it generates a const method to retrieve it.
40///
41/// Replacements can be done on every string-typed field, even primary keys, and are done using annotated parameters.  
42/// Two type of replacements are supported:
43/// * basic: written in the form #[macro(field = "enum")] or #[macro(field(type = "enum"))], where we are telling to replace string values from `field` with variants from enum `enum`, variant names will be the CamelCase version of field value.
44/// * advanced: written in the form #[macro(field(type = "bar", fn = "foo"))], where we are telling to replace string values from `field` with a call to method `foo` from struct/enum `bar`, method output is expected to be of type `bar`.
45pub async fn symbols<M, F, Fut>(item: &mut ItemEnum, args: &[NestedMeta], get_conn: F) -> syn::Result<TokenStream>
46where
47    M: EntityTrait + EntityFilter + Default,
48    <M as EntityTrait>::Model: Serialize + DeserializeOwned,
49    <M as EntityTrait>::Column: PartialEq,
50    F: Fn() -> Fut,
51    Fut: Future<Output = syn::Result<DatabaseConnection>>,
52{
53    let name = &item.ident;
54    let primary_keys = <M as EntityTrait>::PrimaryKey::iter().map(|k| k.into_column()).collect::<Vec<_>>();
55
56    let mut constructors = HashMap::new();
57    let mut methods = HashMap::new();
58
59    let data = get_data::<M, _, _>(get_conn).await?;
60
61    data.iter().try_for_each(|v| {
62        let mut key_s = vec![];
63
64        // scan primary keys
65        for k in &primary_keys {
66            let val = v.get(*k);
67            // only string values are accepted
68            if let Value::String(Some(s)) = val {
69                key_s.push(s.to_upper_camel_case());
70
71                // if we have a single primary key, create a method as_str and a counter-trait-impl TryFrom<&str>
72                if primary_keys.len() == 1 {
73                    let key = Ident::new(&s.to_upper_camel_case(), Span::call_site());
74                    let v = Literal::string(s.as_str());
75
76                    let (_, method, _) = methods
77                        .entry(String::from("as_str"))
78                        .or_insert_with(|| (quote! { &'static str }, Punctuated::<_, Comma>::new(), false));
79                    method.push(quote! {
80                        #name::#key => #v
81                    });
82
83                    let (_, method, _) = methods
84                        .entry(String::from("try_from"))
85                        .or_insert_with(|| (quote! { () }, Punctuated::<_, Comma>::new(), false));
86                    method.push(quote! {
87                        #v => Ok(#name::#key)
88                    });
89                }
90            } else {
91                return Err(syn::Error::new(Span::call_site(), format!("Unrecognized value type {val:?}")));
92            }
93        }
94        // push primary keys into enum variants
95        let key_ident = Ident::new(&key_s.join("_"), Span::call_site());
96        item.variants.push(Variant {
97            attrs: vec![],
98            ident: key_ident.clone(),
99            fields: Fields::Unit,
100            discriminant: None,
101        });
102        // generate constructors for every combination of primary keys
103        if primary_keys.len() > 1 {
104            for n in 1..=primary_keys.len() {
105                for combo in primary_keys.iter().enumerate().combinations(n) {
106                    let cols = combo.iter().map(|(_, col)| **col).collect::<Vec<_>>();
107                    let method = combo
108                        .iter()
109                        .map(|(_, col)| format!("{col:?}").to_snake_case())
110                        .collect::<Vec<_>>()
111                        .join("_and_");
112                    let key = combo.iter().map(|(index, _)| key_s[*index].clone()).collect::<Vec<_>>();
113                    let (_, method) = constructors.entry(method).or_insert_with(|| (cols, HashMap::new()));
114                    let (_, idents) =
115                        method.entry(key.join("_")).or_insert_with(|| (key, Punctuated::<_, Comma>::new()));
116                    idents.push(quote! { #name::#key_ident });
117                }
118            }
119        }
120
121        // create a method for every non-primary_key column
122        for col in <M as EntityTrait>::Column::iter() {
123            let replace = get_replacement::<M>(col, args);
124
125            // skip self-describing methods (would be an as_str clone)
126            if primary_keys.len() == 1 && primary_keys.contains(&col) && replace.is_none() {
127                continue;
128            }
129
130            // keep only managed data types
131            let (t, value) = match v.get(col) {
132                Value::Bool(b) => (
133                    quote! { bool },
134                    b.map(|b| {
135                        let v = LitBool::new(b, Span::call_site());
136                        quote! { #v }
137                    }),
138                ),
139                Value::TinyInt(n) => (
140                    quote! { i8 },
141                    n.map(|n| {
142                        let v = Literal::i8_unsuffixed(n);
143                        quote! { #v }
144                    }),
145                ),
146                Value::SmallInt(n) => (
147                    quote! { i16 },
148                    n.map(|n| {
149                        let v = Literal::i16_unsuffixed(n);
150                        quote! { #v }
151                    }),
152                ),
153                Value::Int(n) => (
154                    quote! { i32 },
155                    n.map(|n| {
156                        let v = Literal::i32_unsuffixed(n);
157                        quote! { #v }
158                    }),
159                ),
160                Value::BigInt(n) => (
161                    quote! { i64 },
162                    n.map(|n| {
163                        let v = Literal::i64_unsuffixed(n);
164                        quote! { #v }
165                    }),
166                ),
167                Value::TinyUnsigned(n) => (
168                    quote! { u8 },
169                    n.map(|n| {
170                        let v = Literal::u8_unsuffixed(n);
171                        quote! { #v }
172                    }),
173                ),
174                Value::SmallUnsigned(n) => (
175                    quote! { u16 },
176                    n.map(|n| {
177                        let v = Literal::u16_unsuffixed(n);
178                        quote! { #v }
179                    }),
180                ),
181                Value::Unsigned(n) => (
182                    quote! { u32 },
183                    n.map(|n| {
184                        let v = Literal::u32_unsuffixed(n);
185                        quote! { #v }
186                    }),
187                ),
188                Value::BigUnsigned(n) => (
189                    quote! { u64 },
190                    n.map(|n| {
191                        let v = Literal::u64_unsuffixed(n);
192                        quote! { #v }
193                    }),
194                ),
195                Value::Float(n) => (
196                    quote! { f32 },
197                    n.map(|n| {
198                        let v = Literal::f32_unsuffixed(n);
199                        quote! { #v }
200                    }),
201                ),
202                Value::Double(n) => (
203                    quote! { f64 },
204                    n.map(|n| {
205                        let v = Literal::f64_unsuffixed(n);
206                        quote! { #v }
207                    }),
208                ),
209                Value::String(s) => match replace {
210                    Some(Replacement::Type(r)) => (
211                        r.clone(),
212                        s.map(|s| {
213                            let ident = Ident::new(&s.to_upper_camel_case(), Span::call_site());
214                            quote! { #r::#ident }
215                        }),
216                    ),
217                    Some(Replacement::Fn(f, Some(r))) => (
218                        r.clone(),
219                        s.map(|s| {
220                            let v = Literal::string(s.as_str());
221                            quote! { #r::#f(#v) }
222                        }),
223                    ),
224                    Some(Replacement::Fn(_, None)) => {
225                        // teoretically we could accept only a function, but we won't know the return type
226                        return Err(syn::Error::new(
227                            Span::call_site(),
228                            format!("Missing parameter type for field {col:?}"),
229                        ));
230                    }
231                    _ => (
232                        quote! { &'static str },
233                        s.map(|s| {
234                            let v = Literal::string(s.as_str());
235                            quote! { #v }
236                        }),
237                    ),
238                },
239                // disable ChronoDateTime for now, it would only produce methods for created_at and updated_at fields
240                // Value::ChronoDateTime(dt) => (quote! { chrono::NaiveDateTime }, Lit::Verbatim(Literal)),
241                _ => continue,
242            };
243            let (_, method, option) =
244                methods.entry(format!("{col:?}")).or_insert_with(|| (t, Punctuated::<_, Comma>::new(), false));
245            if let Some(v) = value {
246                method.push(quote! {
247                    #name::#key_ident => #v
248                });
249            } else {
250                *option = true;
251            }
252        }
253
254        Ok(())
255    })?;
256
257    // decorate constructors
258    let constructors = constructors.into_iter().map(|(name, (cols, body))| {
259        let is_full = cols.len() == primary_keys.len();
260        let fn_name = Ident::new(&format!("get_by_{name}"), Span::call_site());
261        let signature = cols
262            .iter()
263            .map(|col| {
264                let field_name = Ident::new(&format!("{col:?}").to_snake_case(), Span::call_site());
265                match get_replacement::<M>(*col, args) {
266                    Some(Replacement::Type(r)) => quote! { #field_name: #r },
267                    _ => quote! { #field_name: &str },
268                }
269            })
270            .collect::<Punctuated<_, Comma>>();
271        let m = cols
272            .iter()
273            .map(|col| {
274                let field_name = Ident::new(&format!("{col:?}").to_snake_case(), Span::call_site());
275                quote! { #field_name }
276            })
277            .collect::<Punctuated<_, Comma>>();
278        let body = body
279            .iter()
280            .map(|(_, (values, array_body))| {
281                let args = cols
282                    .iter()
283                    .enumerate()
284                    .map(|(index, col)| match get_replacement::<M>(*col, args) {
285                        Some(Replacement::Type(r)) => {
286                            let ident = Ident::new(&values[index].to_upper_camel_case(), Span::call_site());
287                            quote! { #r::#ident }
288                        }
289                        _ => {
290                            let v = Literal::string(values[index].as_str());
291                            quote! { #v }
292                        }
293                    })
294                    .collect::<Punctuated<_, Comma>>();
295                if is_full {
296                    quote! {
297                        (#args,) => Some(#array_body)
298                    }
299                } else {
300                    quote! {
301                        (#args,) => &[#array_body]
302                    }
303                }
304            })
305            .collect::<Punctuated<_, Comma>>();
306        if is_full {
307            quote! {
308                pub const fn #fn_name(#signature) -> Option<Self> {
309                    match (#m,) {
310                        #body,
311                        _ => None,
312                    }
313                }
314            }
315        } else {
316            quote! {
317                pub const fn #fn_name(#signature) -> &'static [Self] {
318                    match (#m,) {
319                        #body,
320                        _ => &[],
321                    }
322                }
323            }
324        }
325    });
326
327    // separate try_from from other methods
328    let try_from = methods
329        .remove("try_from")
330        .map(|(_, matches, _)| {
331            quote! {
332                impl<'a> TryFrom<&'a str> for #name {
333                    type Error = String;
334                    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
335                        match s {
336                            #matches,
337                            _ => Err(format!("Unknown {} {}", stringify!(#name), s)),
338                        }
339                    }
340                }
341            }
342        })
343        .unwrap_or_default();
344
345    // decorate methods
346    let methods: TokenStream = methods
347        .into_iter()
348        .map(|(name, (t, matches, option))| {
349            let n = Ident::new(&name.to_snake_case(), Span::call_site());
350            if option {
351                if matches.is_empty() {
352                    quote! {
353                        pub const fn #n(&self) -> Option<#t> {
354                            None
355                        }
356                    }
357                } else {
358                    quote! {
359                        pub const fn #n(&self) -> Option<#t> {
360                            Some(match self {
361                                #matches,
362                                _ => return None,
363                            })
364                        }
365                    }
366                }
367            } else {
368                quote! {
369                    pub const fn #n(&self) -> #t {
370                        match self {
371                            #matches,
372                        }
373                    }
374                }
375            }
376        })
377        .chain(constructors)
378        .collect();
379
380    // output result
381    Ok(quote! {
382        #item
383
384        impl #name {
385            #methods
386        }
387
388        #try_from
389    })
390}
391
392/// Replacement types
393enum Replacement {
394    Type(TokenStream),
395    Fn(TokenStream, Option<TokenStream>),
396}
397
398/// Field replacement facility
399/// Searches between macro arguments
400fn get_replacement<M>(col: M::Column, args: &[NestedMeta]) -> Option<Replacement>
401where
402    M: EntityTrait,
403    M::Column: PartialEq,
404{
405    let col_name = format!("{col:?}");
406    let field_name = col_name.to_snake_case();
407    // search for replacements
408    args.iter().find_map(|arg| {
409        // simple #[macro(field = "enum")]
410        if let NestedMeta::Meta(Meta::NameValue(mv)) = arg {
411            if mv.path.is_ident(&col_name) || mv.path.is_ident(&field_name) {
412                if let Lit::Str(s) = &mv.lit {
413                    let ident = Ident::new(&s.value(), Span::call_site());
414                    return Some(Replacement::Type(quote! { #ident }));
415                }
416            }
417        }
418        // quite complex #[macro(field(type = "enum", fn = "foo"))]
419        if let NestedMeta::Meta(Meta::List(ml)) = arg {
420            if ml.path.is_ident(&col_name) || ml.path.is_ident(&field_name) {
421                return ml.nested.iter().fold(None, |mut acc, nested| {
422                    if let NestedMeta::Meta(Meta::NameValue(mv)) = nested {
423                        if let Lit::Str(s) = &mv.lit {
424                            let ident = Ident::new(&s.value(), Span::call_site());
425                            if mv.path.is_ident("type") {
426                                if let Some(Replacement::Fn(f, None)) = acc {
427                                    acc = Some(Replacement::Fn(f, Some(quote! { #ident })));
428                                } else {
429                                    acc = Some(Replacement::Type(quote! { #ident }));
430                                }
431                            } else if mv.path.is_ident("fn") {
432                                if let Some(Replacement::Type(t)) = acc {
433                                    acc = Some(Replacement::Fn(quote! { #ident }, Some(t)));
434                                } else {
435                                    acc = Some(Replacement::Fn(quote! { #ident }, None));
436                                }
437                            }
438                        }
439                    }
440                    acc
441                });
442            }
443        }
444        None
445    })
446}
447
448/// Data retrieve function with cache capabilities
449/// File access is sync to not have to depend on an async runtime
450async fn get_data<M, F, Fut>(get_conn: F) -> syn::Result<Vec<<M as EntityTrait>::Model>>
451where
452    M: EntityTrait + EntityFilter + Default,
453    <M as EntityTrait>::Model: Serialize + DeserializeOwned,
454    F: Fn() -> Fut,
455    Fut: Future<Output = syn::Result<DatabaseConnection>>,
456{
457    let instance = M::default();
458    let mut cache = env::temp_dir();
459    cache.push(EntityName::table_name(&instance));
460    cache.set_extension("cache");
461    if cache.exists() {
462        info!("Cache file {} exists, loading data from there", cache.display());
463
464        let file = fs::File::open(&cache)
465            .map_err(|e| syn::Error::new(Span::call_site(), format!("Error reading {}: {}", cache.display(), e)))?;
466
467        match bincode::deserialize_from(io::BufReader::new(file)) {
468            Ok(data) => return Ok(data),
469            Err(e) => error!("Error deserializing {}: {}", cache.display(), e),
470        }
471    } else {
472        info!("Cache file {} doesn't exists, creating", cache.display());
473    }
474
475    let conn = get_conn().await?;
476    let data = <M as EntityTrait>::find()
477        .filter(M::filter())
478        .all(&conn)
479        .await
480        .map_err(|e| syn::Error::new(Span::call_site(), e))?;
481    let buf = bincode::serialize(&data)
482        .map_err(|e| syn::Error::new(Span::call_site(), format!("Error serializing {}: {}", cache.display(), e)))?;
483    fs::write(&cache, buf)
484        .map_err(|e| syn::Error::new(Span::call_site(), format!("Error writing {}: {}", cache.display(), e)))?;
485    Ok(data)
486}