const SQLITE_64BIT_ERROR: &str = r##"Sadly, SQLite cannot natively store unsigned 64-bit integers, so TurboSQL does not support u64 members. Use i64, u32, f64, or a string or binary format instead. (see https://sqlite.org/fileformat.html#record_format )"##;
use once_cell::sync::Lazy;
use proc_macro2::Span;
use proc_macro_error::{abort, abort_call_site, proc_macro_error};
use quote::{format_ident, quote, ToTokens};
use rusqlite::{params, Connection, Statement};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Mutex;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
parse_macro_input, Data, DeriveInput, Expr, Fields, FieldsNamed, Ident, LitStr, Meta, NestedMeta,
Token, Type,
};
#[cfg(not(feature = "test"))]
const MIGRATIONS_FILENAME: &str = "migrations.toml";
#[cfg(feature = "test")]
const MIGRATIONS_FILENAME: &str = "test.migrations.toml";
mod create;
mod insert;
mod select;
#[derive(Debug, Clone)]
struct Table {
ident: Ident,
span: Span,
name: String,
columns: Vec<Column>,
}
#[derive(Debug)]
struct MiniTable {
name: String,
columns: Vec<MiniColumn>,
}
impl ToTokens for Table {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let ident = &self.ident;
tokens.extend(quote!(#ident));
}
}
#[derive(Debug, Clone)]
struct Column {
ident: Ident,
span: Span,
name: String,
rust_type: String,
sql_type: &'static str,
}
#[derive(Debug)]
struct MiniColumn {
name: String,
rust_type: String,
sql_type: &'static str,
}
static LAST_TABLE_NAME: Lazy<Mutex<String>> = Lazy::new(|| Mutex::new("none".to_string()));
static TABLES: Lazy<Mutex<HashMap<String, MiniTable>>> = Lazy::new(|| Mutex::new(HashMap::new()));
#[derive(Debug)]
struct SelectTokens {
tokens: proc_macro2::TokenStream,
}
#[derive(Debug)]
struct ExecuteTokens {
tokens: proc_macro2::TokenStream,
}
#[derive(Debug)]
struct QueryParams {
params: Punctuated<Expr, Token![,]>,
}
impl Parse for QueryParams {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(QueryParams {
params: if input.peek(Token![,]) {
input.parse::<Token![,]>().unwrap();
input.parse_terminated(Expr::parse)?
} else {
Punctuated::new()
},
})
}
}
#[derive(Clone, Debug)]
struct ResultType {
container: Option<Ident>,
contents: Option<Ident>,
}
#[derive(Debug)]
struct MembersAndCasters {
members: Vec<(Ident, Ident, usize)>,
struct_members: Vec<proc_macro2::TokenStream>,
row_casters: Vec<proc_macro2::TokenStream>,
}
impl MembersAndCasters {
fn create(members: Vec<(Ident, Ident, usize)>) -> MembersAndCasters {
let struct_members: Vec<_> = members.iter().map(|(name, ty, _i)| quote!(#name: #ty)).collect();
let row_casters =
members.iter().map(|(name, _ty, i)| quote!(#name: row.get(#i)?)).collect::<Vec<_>>();
Self { members, struct_members, row_casters }
}
}
fn extract_explicit_members(columns: &[String]) -> Option<MembersAndCasters> {
println!("extractexplicitmembers: {:#?}", columns);
None
}
fn extract_stmt_members(stmt: &Statement, span: &Span) -> MembersAndCasters {
let members: Vec<_> = stmt
.column_names()
.iter()
.enumerate()
.map(|(i, col_name)| {
let mut parts: Vec<_> = col_name.split('_').collect();
if parts.len() < 2 {
abort!(
span,
"SQL column name {:#?} must include a type annotation, e.g. {}_String or {}_i64.",
col_name,
col_name,
col_name
)
}
let ty = parts.pop().unwrap();
match ty {
"i64" | "String" => (),
_ => abort!(span, "Invalid type annotation \"_{}\", try e.g. _String or _i64.", ty),
}
let name = parts.join("_");
(format_ident!("{}", name), format_ident!("{}", ty), i)
})
.collect();
MembersAndCasters::create(members)
}
enum ParseStatementType {
Execute,
Select,
}
use ParseStatementType::{Execute, Select};
#[derive(Debug)]
struct StatementInfo {
parameter_count: usize,
column_names: Vec<String>,
}
impl StatementInfo {
fn membersandcasters(&self) -> syn::parse::Result<MembersAndCasters> {
Ok(MembersAndCasters::create(
self
.column_names
.iter()
.enumerate()
.map(|(i, col_name)| Ok((syn::parse_str::<Ident>(col_name)?, format_ident!("None"), i)))
.collect::<syn::parse::Result<Vec<_>>>()?,
))
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
struct MigrationsToml {
migrations_append_only: Option<Vec<String>>,
autogenerated_schema_for_your_information_do_not_edit: Option<String>,
}
fn migrations_to_tempdb(migrations: &[String]) -> Connection {
let tempdb = rusqlite::Connection::open_in_memory().unwrap();
tempdb
.execute_batch(
"CREATE TABLE turbosql_migrations (rowid INTEGER PRIMARY KEY, migration TEXT NOT NULL);",
)
.unwrap();
migrations.iter().for_each(|m| match tempdb.execute(m, params![]) {
Ok(_) => (),
Err(rusqlite::Error::ExecuteReturnedResults) => (),
Err(e) => abort_call_site!("Running migrations on temp db: {:?}", e),
});
tempdb
}
fn migrations_to_schema(migrations: &[String]) -> Result<String, rusqlite::Error> {
Ok(
migrations_to_tempdb(migrations)
.prepare("SELECT sql FROM sqlite_master WHERE type='table' ORDER BY sql")?
.query_map(params![], |row| Ok(row.get(0)?))?
.collect::<Result<Vec<String>, _>>()?
.join("\n"),
)
}
fn read_migrations_toml() -> MigrationsToml {
let lockfile = std::fs::File::create(std::env::temp_dir().join("migrations.toml.lock")).unwrap();
fs2::FileExt::lock_exclusive(&lockfile).unwrap();
let migrations_toml_path = std::env::current_dir().unwrap().join(MIGRATIONS_FILENAME);
let migrations_toml_path_lossy = migrations_toml_path.to_string_lossy();
match migrations_toml_path.exists() {
true => {
let toml_str = std::fs::read_to_string(&migrations_toml_path)
.unwrap_or_else(|e| abort_call_site!("Unable to read {}: {:?}", migrations_toml_path_lossy, e));
let toml_decoded: MigrationsToml = toml::from_str(&toml_str).unwrap_or_else(|e| {
abort_call_site!("Unable to decode toml in {}: {:?}", migrations_toml_path_lossy, e)
});
toml_decoded
}
false => MigrationsToml::default(),
}
}
fn validate_sql<S: AsRef<str>>(sql: S) -> Result<StatementInfo, rusqlite::Error> {
let tempdb = migrations_to_tempdb(&read_migrations_toml().migrations_append_only.unwrap());
let stmt = tempdb.prepare(sql.as_ref());
let stmt = stmt?;
Ok(StatementInfo {
parameter_count: stmt.parameter_count(),
column_names: stmt.column_names().into_iter().map(str::to_string).collect(),
})
}
fn validate_sql_or_abort<S: AsRef<str> + std::fmt::Debug>(sql: S) -> StatementInfo {
validate_sql(sql.as_ref()).unwrap_or_else(|e| {
abort_call_site!(r#"Error validating SQL statement: "{}". SQL: {:?}"#, e, sql)
})
}
fn do_parse_tokens(
input: ParseStream,
statement_type: ParseStatementType,
) -> syn::Result<proc_macro2::TokenStream> {
let span = input.span();
let result_type = input.parse::<Type>().ok();
let sql = input.parse::<LitStr>().ok().map(|s| s.value());
let stmt_info = sql.clone().and_then(|s| validate_sql(s).ok());
let (sql, stmt_info) = match (sql, stmt_info) {
(Some(sql), None) => {
let sql_with_select = format!("SELECT {}", sql);
let stmt_info = validate_sql(&sql_with_select).ok();
(Some(if stmt_info.is_some() { sql_with_select } else { sql }), stmt_info)
}
t => t,
};
let result_type = match result_type {
Some(syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }))
if segments.len() == 1 =>
{
let segment = segments.first().unwrap();
Some(match segment.ident.to_string().as_str() {
"Vec" | "Option" => match &segment.arguments {
syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { args, .. })
if args.len() == 1 =>
{
let arg = args.first().unwrap();
match arg {
syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
path: syn::Path { segments, .. },
..
}))
if segments.len() == 1 =>
{
let contents_segment = segments.first().unwrap();
ResultType {
container: Some(segment.ident.clone()),
contents: Some(contents_segment.ident.clone()),
}
}
syn::GenericArgument::Type(syn::Type::Infer(_)) => {
ResultType { container: Some(segment.ident.clone()), contents: None }
}
_ => abort_call_site!("No segments found for container type {:#?}", arg),
}
}
_ => abort_call_site!("No arguments found for container type"),
},
_ => ResultType { container: None, contents: Some(segment.ident.clone()) },
})
}
Some(_) => abort_call_site!("Could not parse result_type"),
None => None,
};
let (sql, stmt_info) = match (result_type.clone(), sql, stmt_info) {
(Some(ResultType { contents: Some(contents), .. }), sql, None) => {
let result_type = contents.to_string();
let table_name = result_type.to_lowercase();
let tables = TABLES.lock().unwrap();
let table = tables.get(&table_name).unwrap_or_else(|| {
abort!(
span,
"Table {:?} not found. Does struct {} exist and have #[derive(Turbosql)]?",
table_name,
result_type
)
});
let column_names_str =
table.columns.iter().map(|c| c.name.as_str()).collect::<Vec<_>>().join(", ");
let sql = format!("SELECT {} FROM {} {}", column_names_str, table_name, sql.unwrap_or_default());
(sql.clone(), validate_sql_or_abort(sql))
}
(_, Some(sql), Some(stmt_info)) => (sql, stmt_info),
_ => abort_call_site!("no predicate and no result type found"),
};
let QueryParams { params } = input.parse()?;
if params.len() != stmt_info.parameter_count {
abort!(
span,
"Expected {} bound parameter{}, got {}: {:?}",
stmt_info.parameter_count,
if stmt_info.parameter_count == 1 { "" } else { "s" },
params.len(),
sql
);
}
if !input.is_empty() {
return Err(input.error("Expected parameters"));
}
if stmt_info.column_names.is_empty() {
if !matches!(statement_type, Execute) {
abort_call_site!("No rows returned from SQL, use execute! instead.");
}
return Ok(quote! {
{
(|| -> Result<_, _> {
let db = ::turbosql::__TURBOSQL_DB.lock().unwrap();
let mut stmt = db.prepare_cached(#sql)?;
stmt.execute(::turbosql::params![#params])
})()
}
});
}
if !matches!(statement_type, Select) {
abort_call_site!("Rows returned from SQL, use select! instead.");
}
let tokens = match result_type {
Some(ResultType { container: Some(container), contents: Some(contents) })
if container == "Vec" =>
{
let m = stmt_info
.membersandcasters()
.unwrap_or_else(|_| abort_call_site!("stmt_info.membersandcasters failed"));
let row_casters = m.row_casters;
quote! {
{
(|| -> Result<Vec<#contents>, ::turbosql::Error> {
let db = ::turbosql::__TURBOSQL_DB.lock().unwrap();
let mut stmt = db.prepare_cached(#sql)?;
let result = stmt.query_map(::turbosql::params![#params], |row| {
Ok(#contents {
#(#row_casters),*
})
})?.collect::<Vec<_>>();
let result = result.into_iter().flatten().collect::<Vec<_>>();
Ok(result)
})()
}
}
}
Some(ResultType { container: Some(container), contents: Some(contents) })
if container == "Option" =>
{
let m = stmt_info
.membersandcasters()
.unwrap_or_else(|_| abort_call_site!("stmt_info.membersandcasters failed"));
let row_casters = m.row_casters;
quote! {
{
(|| -> Result<Option<#contents>, ::turbosql::Error> {
use ::turbosql::OptionalExtension;
let db = ::turbosql::__TURBOSQL_DB.lock().unwrap();
let mut stmt = db.prepare_cached(#sql)?;
let result = stmt.query_row(::turbosql::params![#params], |row| -> Result<#contents, _> {
Ok(#contents {
#(#row_casters),*
})
}).optional()?;
Ok(result)
})()
}
}
}
Some(ResultType { container: None, contents: Some(contents) })
if ["i64", "bool"].contains(&&contents.to_string().as_str()) =>
{
quote! {
{
(|| -> Result<#contents, ::turbosql::Error> {
let db = ::turbosql::__TURBOSQL_DB.lock().unwrap();
let mut stmt = db.prepare_cached(#sql)?;
let result = stmt.query_row(::turbosql::params![#params], |row| -> Result<#contents, _> {
Ok(row.get(0)?)
})?;
Ok(result)
})()
}
}
}
Some(ResultType { container: None, contents: Some(contents) }) => {
let m = stmt_info
.membersandcasters()
.unwrap_or_else(|_| abort_call_site!("stmt_info.membersandcasters failed"));
let row_casters = m.row_casters;
quote! {
{
(|| -> Result<#contents, ::turbosql::Error> {
let db = ::turbosql::__TURBOSQL_DB.lock().unwrap();
let mut stmt = db.prepare_cached(#sql)?;
let result = stmt.query_row(::turbosql::params![#params], |row| -> Result<#contents, _> {
Ok(#contents {
#(#row_casters),*
})
})?;
Ok(result)
})()
}
}
}
Some(ResultType { container: Some(container), contents: None }) => abort_call_site!("INFERRED"),
_ => abort_call_site!("unknown result_type"),
};
Ok(tokens)
}
impl Parse for SelectTokens {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(SelectTokens { tokens: do_parse_tokens(input, Select)? })
}
}
impl Parse for ExecuteTokens {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(ExecuteTokens { tokens: do_parse_tokens(input, Execute)? })
}
}
#[proc_macro]
#[proc_macro_error]
pub fn execute(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ExecuteTokens { tokens } = parse_macro_input!(input);
proc_macro::TokenStream::from(tokens)
}
#[proc_macro]
#[proc_macro_error]
pub fn select(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let SelectTokens { tokens } = parse_macro_input!(input);
proc_macro::TokenStream::from(tokens)
}
#[proc_macro_derive(Turbosql, attributes(turbosql))]
#[proc_macro_error]
pub fn turbosql_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let table_span = input.span();
let table_ident = input.ident;
let table_name = table_ident.to_string().to_lowercase();
let ltn = LAST_TABLE_NAME.lock().unwrap().clone();
let mut last_table_name_ref = LAST_TABLE_NAME.lock().unwrap();
*last_table_name_ref = format!("{}, {}", ltn, table_name);
let fields = match input.data {
Data::Struct(ref data) => match data.fields {
Fields::Named(ref fields) => fields,
Fields::Unnamed(_) | Fields::Unit => unimplemented!(),
},
Data::Enum(_) | Data::Union(_) => unimplemented!(),
};
let table = Table {
ident: table_ident,
span: table_span,
name: table_name.clone(),
columns: extract_columns(fields),
};
let minitable = MiniTable {
name: table_name.clone(),
columns: table
.columns
.iter()
.map(|c| MiniColumn {
name: c.name.clone(),
sql_type: c.sql_type,
rust_type: c.rust_type.clone(),
})
.collect(),
};
TABLES.lock().unwrap().insert(table_name, minitable);
let fn_create = create::create(&table);
let fn_insert = insert::insert(&table);
let fn_select = select::select(&table);
proc_macro::TokenStream::from(quote! {
impl #table {
#fn_create
#fn_insert
#fn_select
}
})
}
fn extract_columns(fields: &FieldsNamed) -> Vec<Column> {
let columns = fields
.named
.iter()
.filter_map(|f| {
for attr in &f.attrs {
let meta = attr.parse_meta().unwrap();
match meta {
Meta::List(list) if list.path.is_ident("turbosql") => {
for value in list.nested.iter() {
if let NestedMeta::Meta(meta) = value {
match meta {
Meta::Path(p) if p.is_ident("skip") => {
return None;
}
_ => (),
}
}
}
}
_ => (),
}
}
let ident = &f.ident;
let name = ident.as_ref().unwrap().to_string();
let ty = &f.ty;
let ty_str = quote!(#ty).to_string();
let sql_type = match (name.as_str(), ty_str.as_str()) {
("rowid", "Option < i64 >") => "INTEGER PRIMARY KEY",
("rowid", "Option < i54 >") => "INTEGER PRIMARY KEY",
(_, "Option < i8 >") => "INTEGER",
(_, "Option < u8 >") => "INTEGER",
(_, "Option < i16 >") => "INTEGER",
(_, "Option < u16 >") => "INTEGER",
(_, "Option < i32 >") => "INTEGER",
(_, "Option < u32 >") => "INTEGER",
(_, "Option < i54 >") => "INTEGER",
(_, "Option < i64 >") => "INTEGER",
(_, "u64") => abort!(ty, SQLITE_64BIT_ERROR),
(_, "Option < u64 >") => abort!(ty, SQLITE_64BIT_ERROR),
(_, "Option < f64 >") => "REAL",
(_, "Option < bool >") => "BOOLEAN",
(_, "Option < String >") => "TEXT",
(_, "Option < Blob >") => "BLOB",
_ => abort!(ty, "turbosql doesn't support rust type: {}", ty_str),
};
Some(Column {
ident: ident.clone().unwrap(),
span: ty.span(),
rust_type: ty_str,
name,
sql_type,
})
})
.collect::<Vec<_>>();
if !matches!(
columns.iter().find(|c| c.name == "rowid"),
Some(Column { sql_type: "INTEGER PRIMARY KEY", .. })
) {
abort_call_site!("derive(Turbosql) structs must include a 'rowid: Option<i64>' field")
};
columns
}