1#![deny(warnings)]
2#![deny(missing_docs)]
3
4use std::{collections::HashMap, env, fs, future::Future, io};
10
11use heck::{ToSnakeCase, ToUpperCamelCase};
12
13use itertools::Itertools;
14
15use proc_macro2::{Ident, Literal, Span, TokenStream};
16
17use quote::quote;
18
19use sea_orm::{
20 DatabaseConnection, EntityName, EntityTrait, Iterable, ModelTrait, PrimaryKeyToColumn, QueryFilter, Value,
21};
22
23use serde::{de::DeserializeOwned, Serialize};
24
25pub use symbols_models::EntityFilter;
26
27use syn::{punctuated::Punctuated, token::Comma, Fields, ItemEnum, Lit, LitBool, Meta, NestedMeta, Variant};
28
29use tracing::{error, info};
30
31pub async fn symbols<M, F, Fut>(item: &mut ItemEnum, args: &[NestedMeta], get_conn: F) -> syn::Result<TokenStream>
46where
47 M: EntityTrait + EntityFilter + Default,
48 <M as EntityTrait>::Model: Serialize + DeserializeOwned,
49 <M as EntityTrait>::Column: PartialEq,
50 F: Fn() -> Fut,
51 Fut: Future<Output = syn::Result<DatabaseConnection>>,
52{
53 let name = &item.ident;
54 let primary_keys = <M as EntityTrait>::PrimaryKey::iter().map(|k| k.into_column()).collect::<Vec<_>>();
55
56 let mut constructors = HashMap::new();
57 let mut methods = HashMap::new();
58
59 let data = get_data::<M, _, _>(get_conn).await?;
60
61 data.iter().try_for_each(|v| {
62 let mut key_s = vec![];
63
64 for k in &primary_keys {
66 let val = v.get(*k);
67 if let Value::String(Some(s)) = val {
69 key_s.push(s.to_upper_camel_case());
70
71 if primary_keys.len() == 1 {
73 let key = Ident::new(&s.to_upper_camel_case(), Span::call_site());
74 let v = Literal::string(s.as_str());
75
76 let (_, method, _) = methods
77 .entry(String::from("as_str"))
78 .or_insert_with(|| (quote! { &'static str }, Punctuated::<_, Comma>::new(), false));
79 method.push(quote! {
80 #name::#key => #v
81 });
82
83 let (_, method, _) = methods
84 .entry(String::from("try_from"))
85 .or_insert_with(|| (quote! { () }, Punctuated::<_, Comma>::new(), false));
86 method.push(quote! {
87 #v => Ok(#name::#key)
88 });
89 }
90 } else {
91 return Err(syn::Error::new(Span::call_site(), format!("Unrecognized value type {val:?}")));
92 }
93 }
94 let key_ident = Ident::new(&key_s.join("_"), Span::call_site());
96 item.variants.push(Variant {
97 attrs: vec![],
98 ident: key_ident.clone(),
99 fields: Fields::Unit,
100 discriminant: None,
101 });
102 if primary_keys.len() > 1 {
104 for n in 1..=primary_keys.len() {
105 for combo in primary_keys.iter().enumerate().combinations(n) {
106 let cols = combo.iter().map(|(_, col)| **col).collect::<Vec<_>>();
107 let method = combo
108 .iter()
109 .map(|(_, col)| format!("{col:?}").to_snake_case())
110 .collect::<Vec<_>>()
111 .join("_and_");
112 let key = combo.iter().map(|(index, _)| key_s[*index].clone()).collect::<Vec<_>>();
113 let (_, method) = constructors.entry(method).or_insert_with(|| (cols, HashMap::new()));
114 let (_, idents) =
115 method.entry(key.join("_")).or_insert_with(|| (key, Punctuated::<_, Comma>::new()));
116 idents.push(quote! { #name::#key_ident });
117 }
118 }
119 }
120
121 for col in <M as EntityTrait>::Column::iter() {
123 let replace = get_replacement::<M>(col, args);
124
125 if primary_keys.len() == 1 && primary_keys.contains(&col) && replace.is_none() {
127 continue;
128 }
129
130 let (t, value) = match v.get(col) {
132 Value::Bool(b) => (
133 quote! { bool },
134 b.map(|b| {
135 let v = LitBool::new(b, Span::call_site());
136 quote! { #v }
137 }),
138 ),
139 Value::TinyInt(n) => (
140 quote! { i8 },
141 n.map(|n| {
142 let v = Literal::i8_unsuffixed(n);
143 quote! { #v }
144 }),
145 ),
146 Value::SmallInt(n) => (
147 quote! { i16 },
148 n.map(|n| {
149 let v = Literal::i16_unsuffixed(n);
150 quote! { #v }
151 }),
152 ),
153 Value::Int(n) => (
154 quote! { i32 },
155 n.map(|n| {
156 let v = Literal::i32_unsuffixed(n);
157 quote! { #v }
158 }),
159 ),
160 Value::BigInt(n) => (
161 quote! { i64 },
162 n.map(|n| {
163 let v = Literal::i64_unsuffixed(n);
164 quote! { #v }
165 }),
166 ),
167 Value::TinyUnsigned(n) => (
168 quote! { u8 },
169 n.map(|n| {
170 let v = Literal::u8_unsuffixed(n);
171 quote! { #v }
172 }),
173 ),
174 Value::SmallUnsigned(n) => (
175 quote! { u16 },
176 n.map(|n| {
177 let v = Literal::u16_unsuffixed(n);
178 quote! { #v }
179 }),
180 ),
181 Value::Unsigned(n) => (
182 quote! { u32 },
183 n.map(|n| {
184 let v = Literal::u32_unsuffixed(n);
185 quote! { #v }
186 }),
187 ),
188 Value::BigUnsigned(n) => (
189 quote! { u64 },
190 n.map(|n| {
191 let v = Literal::u64_unsuffixed(n);
192 quote! { #v }
193 }),
194 ),
195 Value::Float(n) => (
196 quote! { f32 },
197 n.map(|n| {
198 let v = Literal::f32_unsuffixed(n);
199 quote! { #v }
200 }),
201 ),
202 Value::Double(n) => (
203 quote! { f64 },
204 n.map(|n| {
205 let v = Literal::f64_unsuffixed(n);
206 quote! { #v }
207 }),
208 ),
209 Value::String(s) => match replace {
210 Some(Replacement::Type(r)) => (
211 r.clone(),
212 s.map(|s| {
213 let ident = Ident::new(&s.to_upper_camel_case(), Span::call_site());
214 quote! { #r::#ident }
215 }),
216 ),
217 Some(Replacement::Fn(f, Some(r))) => (
218 r.clone(),
219 s.map(|s| {
220 let v = Literal::string(s.as_str());
221 quote! { #r::#f(#v) }
222 }),
223 ),
224 Some(Replacement::Fn(_, None)) => {
225 return Err(syn::Error::new(
227 Span::call_site(),
228 format!("Missing parameter type for field {col:?}"),
229 ));
230 }
231 _ => (
232 quote! { &'static str },
233 s.map(|s| {
234 let v = Literal::string(s.as_str());
235 quote! { #v }
236 }),
237 ),
238 },
239 _ => continue,
242 };
243 let (_, method, option) =
244 methods.entry(format!("{col:?}")).or_insert_with(|| (t, Punctuated::<_, Comma>::new(), false));
245 if let Some(v) = value {
246 method.push(quote! {
247 #name::#key_ident => #v
248 });
249 } else {
250 *option = true;
251 }
252 }
253
254 Ok(())
255 })?;
256
257 let constructors = constructors.into_iter().map(|(name, (cols, body))| {
259 let is_full = cols.len() == primary_keys.len();
260 let fn_name = Ident::new(&format!("get_by_{name}"), Span::call_site());
261 let signature = cols
262 .iter()
263 .map(|col| {
264 let field_name = Ident::new(&format!("{col:?}").to_snake_case(), Span::call_site());
265 match get_replacement::<M>(*col, args) {
266 Some(Replacement::Type(r)) => quote! { #field_name: #r },
267 _ => quote! { #field_name: &str },
268 }
269 })
270 .collect::<Punctuated<_, Comma>>();
271 let m = cols
272 .iter()
273 .map(|col| {
274 let field_name = Ident::new(&format!("{col:?}").to_snake_case(), Span::call_site());
275 quote! { #field_name }
276 })
277 .collect::<Punctuated<_, Comma>>();
278 let body = body
279 .iter()
280 .map(|(_, (values, array_body))| {
281 let args = cols
282 .iter()
283 .enumerate()
284 .map(|(index, col)| match get_replacement::<M>(*col, args) {
285 Some(Replacement::Type(r)) => {
286 let ident = Ident::new(&values[index].to_upper_camel_case(), Span::call_site());
287 quote! { #r::#ident }
288 }
289 _ => {
290 let v = Literal::string(values[index].as_str());
291 quote! { #v }
292 }
293 })
294 .collect::<Punctuated<_, Comma>>();
295 if is_full {
296 quote! {
297 (#args,) => Some(#array_body)
298 }
299 } else {
300 quote! {
301 (#args,) => &[#array_body]
302 }
303 }
304 })
305 .collect::<Punctuated<_, Comma>>();
306 if is_full {
307 quote! {
308 pub const fn #fn_name(#signature) -> Option<Self> {
309 match (#m,) {
310 #body,
311 _ => None,
312 }
313 }
314 }
315 } else {
316 quote! {
317 pub const fn #fn_name(#signature) -> &'static [Self] {
318 match (#m,) {
319 #body,
320 _ => &[],
321 }
322 }
323 }
324 }
325 });
326
327 let try_from = methods
329 .remove("try_from")
330 .map(|(_, matches, _)| {
331 quote! {
332 impl<'a> TryFrom<&'a str> for #name {
333 type Error = String;
334 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
335 match s {
336 #matches,
337 _ => Err(format!("Unknown {} {}", stringify!(#name), s)),
338 }
339 }
340 }
341 }
342 })
343 .unwrap_or_default();
344
345 let methods: TokenStream = methods
347 .into_iter()
348 .map(|(name, (t, matches, option))| {
349 let n = Ident::new(&name.to_snake_case(), Span::call_site());
350 if option {
351 if matches.is_empty() {
352 quote! {
353 pub const fn #n(&self) -> Option<#t> {
354 None
355 }
356 }
357 } else {
358 quote! {
359 pub const fn #n(&self) -> Option<#t> {
360 Some(match self {
361 #matches,
362 _ => return None,
363 })
364 }
365 }
366 }
367 } else {
368 quote! {
369 pub const fn #n(&self) -> #t {
370 match self {
371 #matches,
372 }
373 }
374 }
375 }
376 })
377 .chain(constructors)
378 .collect();
379
380 Ok(quote! {
382 #item
383
384 impl #name {
385 #methods
386 }
387
388 #try_from
389 })
390}
391
392enum Replacement {
394 Type(TokenStream),
395 Fn(TokenStream, Option<TokenStream>),
396}
397
398fn get_replacement<M>(col: M::Column, args: &[NestedMeta]) -> Option<Replacement>
401where
402 M: EntityTrait,
403 M::Column: PartialEq,
404{
405 let col_name = format!("{col:?}");
406 let field_name = col_name.to_snake_case();
407 args.iter().find_map(|arg| {
409 if let NestedMeta::Meta(Meta::NameValue(mv)) = arg {
411 if mv.path.is_ident(&col_name) || mv.path.is_ident(&field_name) {
412 if let Lit::Str(s) = &mv.lit {
413 let ident = Ident::new(&s.value(), Span::call_site());
414 return Some(Replacement::Type(quote! { #ident }));
415 }
416 }
417 }
418 if let NestedMeta::Meta(Meta::List(ml)) = arg {
420 if ml.path.is_ident(&col_name) || ml.path.is_ident(&field_name) {
421 return ml.nested.iter().fold(None, |mut acc, nested| {
422 if let NestedMeta::Meta(Meta::NameValue(mv)) = nested {
423 if let Lit::Str(s) = &mv.lit {
424 let ident = Ident::new(&s.value(), Span::call_site());
425 if mv.path.is_ident("type") {
426 if let Some(Replacement::Fn(f, None)) = acc {
427 acc = Some(Replacement::Fn(f, Some(quote! { #ident })));
428 } else {
429 acc = Some(Replacement::Type(quote! { #ident }));
430 }
431 } else if mv.path.is_ident("fn") {
432 if let Some(Replacement::Type(t)) = acc {
433 acc = Some(Replacement::Fn(quote! { #ident }, Some(t)));
434 } else {
435 acc = Some(Replacement::Fn(quote! { #ident }, None));
436 }
437 }
438 }
439 }
440 acc
441 });
442 }
443 }
444 None
445 })
446}
447
448async fn get_data<M, F, Fut>(get_conn: F) -> syn::Result<Vec<<M as EntityTrait>::Model>>
451where
452 M: EntityTrait + EntityFilter + Default,
453 <M as EntityTrait>::Model: Serialize + DeserializeOwned,
454 F: Fn() -> Fut,
455 Fut: Future<Output = syn::Result<DatabaseConnection>>,
456{
457 let instance = M::default();
458 let mut cache = env::temp_dir();
459 cache.push(EntityName::table_name(&instance));
460 cache.set_extension("cache");
461 if cache.exists() {
462 info!("Cache file {} exists, loading data from there", cache.display());
463
464 let file = fs::File::open(&cache)
465 .map_err(|e| syn::Error::new(Span::call_site(), format!("Error reading {}: {}", cache.display(), e)))?;
466
467 match bincode::deserialize_from(io::BufReader::new(file)) {
468 Ok(data) => return Ok(data),
469 Err(e) => error!("Error deserializing {}: {}", cache.display(), e),
470 }
471 } else {
472 info!("Cache file {} doesn't exists, creating", cache.display());
473 }
474
475 let conn = get_conn().await?;
476 let data = <M as EntityTrait>::find()
477 .filter(M::filter())
478 .all(&conn)
479 .await
480 .map_err(|e| syn::Error::new(Span::call_site(), e))?;
481 let buf = bincode::serialize(&data)
482 .map_err(|e| syn::Error::new(Span::call_site(), format!("Error serializing {}: {}", cache.display(), e)))?;
483 fs::write(&cache, buf)
484 .map_err(|e| syn::Error::new(Span::call_site(), format!("Error writing {}: {}", cache.display(), e)))?;
485 Ok(data)
486}