use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{Expr, ExprLit, Lit};
use syn::{Meta, Path, punctuated::Punctuated, token::Comma};
use syn::Data::Struct;
use syn::DataStruct;
use syn::Fields::Named;
use syn::FieldsNamed;
use crate::utils::{
parse_table_name,
extract_type_from_option,
};
mod utils;
#[proc_macro_derive(ORM, attributes(ssql))]
pub fn ssql(tokens: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(tokens).unwrap();
let table_name = parse_table_name(&ast.attrs);
let struct_name = ast.ident;
let fields = match ast.data {
Struct(DataStruct { fields: Named(FieldsNamed { ref named, .. }), .. }) => named,
_ => unimplemented!()
};
let builder_types = fields.iter().map(|f| {
let mn = f.clone().ident.unwrap().to_string();
let ty = &f.ty.to_token_stream().to_string();
quote! {
#mn => #ty
}
});
let builder_fields_mapping = fields.iter().map(|f| f.clone().ident.unwrap().to_string());
let builder_row_func = fields.iter().map(|f| {
let mn = f.clone().ident.unwrap().to_string();
let field_name = match table_name.as_str() {
"" => format!("{}", &mn),
_ => format!("{}.{}", &table_name, &mn)
};
let ty = &f.ty;
let ty = match extract_type_from_option(ty) {
Some(value) => value,
None => ty
};
let type_name = ty.to_token_stream().to_string();
return match type_name.as_str() {
"String" => {
quote! {
map.insert(#mn.to_string(), row.get::<&str, &str>(#field_name).into())
}
}
"NaiveDateTime" => {
quote! {
map.insert(#mn.to_string(), row.get::<#ty, &str>(#field_name).unwrap().to_string().into())
}
}
_ => {
quote! {
map.insert(#mn.to_string(), row.get::<#ty, &str>(#field_name).into())
}
}
};
});
let builder_insert_rows = fields.iter().map(|f| {
let field = f.clone().ident.unwrap();
return quote! {
row.push(item.#field.into_sql())
};
});
let builder_insert_fields = fields.iter()
.map(|f| { f.clone().ident.unwrap().to_string() })
.reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
let mut fields_count = 0;
let builder_insert_params = fields.iter()
.map(|_| {
fields_count += 1;
return format!("@p{}", fields_count);
})
.reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
let builder_insert_data = fields.iter().map(|f|
f.clone().ident.unwrap()
)
.map(|f| return quote! {&self.#f});
fields_count = 0;
let builder_update_fields = fields.iter()
.map(|f| {
fields_count += 1;
return format!(" {} = @p{}", f.clone().ident.unwrap().to_string(), fields_count);
})
.reduce(|cur: String, next: String| format!("{},{}", cur, &next)).unwrap();
let builder_update_data = builder_insert_data.clone();
#[cfg(feature = "polars")]
let builder_new_vecs = fields.iter().map(|f| {
let field = f.clone().ident.unwrap();
let ty = &f.ty;
quote! {
let mut #field : Vec<#ty> = vec![]
}
});
#[cfg(feature = "polars")]
let builder_insert_to_df = fields.iter().map(|f| {
let field = f.clone().ident.unwrap();
quote! {
#field.push(Phant_Name1.#field)
}
});
#[cfg(feature = "polars")]
let builder_df = fields.iter().map(|f| {
let field = f.clone().ident.unwrap();
let mn = field.to_string();
quote! {
#mn => #field
}
});
let builder_row_to_self_func = fields.iter().map(|f| {
let mn = f.clone().ident.unwrap();
let field_name = match table_name.as_str() {
"" => format!("{}", &mn),
_ => format!("{}.{}", &table_name, &mn)
};
let ty = &f.ty;
return match extract_type_from_option(ty) {
Some(value) => {
let type_name = value.to_token_stream().to_string();
match type_name.as_str() {
"String" => {
quote! {
#mn: row.get::<&str, &str>(#field_name).map(|i| i.to_string())
}
}
_ => {
quote! {
#mn: row.get::<#value, &str>(#field_name)
}
}
}
}
None => {
let type_name = ty.to_token_stream().to_string();
match type_name.as_str() {
"String" => {
quote! {
#mn: row.get::<&str, &str>(#field_name).unwrap().to_string()
}
}
_ => {
quote! {
#mn: row.get::<#ty, &str>(#field_name).unwrap()
}
}
}
}
};
});
let mut result = quote! {
};
let mut relations: Vec<String> = vec![];
let mut tables: Vec<String> = vec![];
let mut primary_key = None;
for field in fields.iter() {
for attr in field.attrs.iter() {
if let Some(ident) = attr.path().get_ident() {
if ident == "ssql" {
if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
for meta in list.iter() {
if let Meta::Path(path) = meta {
let Path { ref segments, .. } = path;
for ssql_tags in segments.iter() {
if ssql_tags.ident == "primary_key" {
primary_key = Some(field.clone());
}
}
}
if let Meta::NameValue(named_v) = meta {
let Path { ref segments, .. } = &named_v.path;
for ssql_tags in segments.iter() {
if ssql_tags.ident == "foreign_key" {
if let Expr::Lit(ExprLit { lit, .. }) = &named_v.value {
if let Lit::Str(v) = lit {
let field_name = field.ident.as_ref().unwrap().to_string();
relations.push(format!("{}.{} = {}", &table_name, field_name, v.value()));
tables.push(v.value()[..v.value().rfind('.').unwrap()].to_string());
}
}
}
}
}
}
}
}
}
}
}
let builder_fields = relations.iter().zip(tables.iter()).map(|(rel, tb)| {
quote! { #tb => {
concat!(" ", #tb, " ON ", #rel)
}}
});
let pk = if let Some(f) = primary_key {
let field_name = f.ident.as_ref().unwrap().to_string();
let mn = f.ident.unwrap();
quote! {
impl #struct_name {
fn primary_key(&self) -> (&'static str, ColumnData) {
(#field_name, self.#mn.to_sql())
}
}
}
} else {
quote! {
impl #struct_name {
fn primary_key(&self) -> (&'static str, ColumnData) {
unimplemented!("Primary key not set");
}
}
}
};
result.extend(pk);
result.extend(quote! {
#[async_trait(?Send)]
impl SsqlMarker for #struct_name {
fn table_name() -> &'static str {
#table_name
}
fn fields() -> Vec<&'static str> {
vec![#(#builder_fields_mapping,)*]
}
fn row_to_json(row:&Row) -> Map<String, Value> {
let mut map = Map::new();
#(#builder_row_func;)*
map
}
fn row_to_struct(row:&Row) -> Self {
Self{
#(#builder_row_to_self_func,)*
}
}
fn query<'a>() -> QueryBuilder<'a, #struct_name> {
QueryBuilder::<#struct_name>::new(
(#table_name, #struct_name::fields()),
#struct_name::relationship)
}
async fn insert_many(iter: impl IntoIterator<Item = #struct_name> , conn: &mut Client<Compat<TcpStream>>) -> SsqlResult<u64>
{
let mut req = conn.bulk_insert(#table_name).await?;
for item in iter{
let mut row = TokenRow::new();
#(#builder_insert_rows;)*
req.send(row).await?;
}
let res = req.finalize().await?;
Ok(res.total())
}
async fn insert(self, conn: &mut Client<Compat<TcpStream>>) -> SsqlResult<()> {
let sql = format!("INSERT INTO {} ({}) values({})", #table_name, #builder_insert_fields, #builder_insert_params);
conn.execute(sql, &[#(#builder_insert_data,)*]).await?;
Ok(())
}
async fn delete(self, conn: &mut Client<Compat<TcpStream>>) -> SsqlResult<()> {
let (pk, dt) = self.primary_key();
QueryBuilder::<#struct_name>::delete(&dt, #table_name, pk, conn).await?;
Ok(())
}
async fn update(&self, conn: &mut Client<Compat<TcpStream>>) -> SsqlResult<()> {
let (pk, dt) = self.primary_key();
let sql = format!("UPDATE {} SET {} WHERE {} {}", #table_name, #builder_update_fields, pk, QueryBuilder::<#struct_name>::process_pk_condition(&dt));
conn.execute(sql, &[#(#builder_update_data,)*]).await?;
Ok(())
}
}
impl #struct_name {
fn relationship(input: &str) -> &'static str {
match input {
#(#builder_fields,)*
_ => unimplemented!("relationship not found"),
}
}
fn column_type(input: &str) -> &'static str{
match input {
#(#builder_types,)*
_ => unimplemented!("column_type not found"),
}
}
}
});
#[cfg(feature = "polars")]
result.extend(quote! {
impl PolarsHelper for #struct_name {
fn dataframe(vec: Vec<Self>) -> PolarsResult<DataFrame> {
#(#builder_new_vecs;)*
#[allow(non_snake_case)]
for Phant_Name1 in vec {
#(#builder_insert_to_df;)*
}
df!(
#(#builder_df,)*
)
}
}
});
result.into()
}