sql_db_mapper/
lib.rs

1#![forbid(unsafe_code)]
2//! Connects to a PostgreSQL database and creates a rust module representing all the schemas complete with mappings for stored functions/procedures
3
4pub mod ast_convert;
5pub mod connection;
6mod pg_select_types;
7mod sql_tree;
8
9pub const VERSION: &str = std::env!("CARGO_PKG_VERSION");
10
11use postgres::{Client, NoTls};
12use std::path::PathBuf;
13use structopt::StructOpt;
14
15/// The program options for the code generation
16#[derive(Debug, StructOpt)]
17#[structopt(
18	name = "sql_db_mapper",
19	about = "Generate a rust wrapper for a PostgreSQL database",
20	version = VERSION
21)]
22pub struct Opt {
23	/// Activate debug mode
24	#[structopt(short, long)]
25	pub debug: bool,
26
27	/// Skip running output through rustfmt
28	#[structopt(short, long)]
29	pub ugly: bool,
30
31	/// Program will treat output as a directory name rather than a file and generate a whole crate. If output is not provided code is printed as usual
32	#[structopt(long)]
33	pub dir: bool,
34
35	/// Convert names from the database to rust standard (i.e. table names in CamelCase, fields and functions in snake_case)
36	#[structopt(long)]
37	pub rust_case: bool,
38
39	/// string passed to rustfmt --config
40	#[structopt(long)]
41	pub rustfmt_config: Option<String>,
42
43	/// string passed to rustfmt --config-path
44	#[structopt(long)]
45	pub rustfmt_config_path: Option<String>,
46
47	/// Only make mappings for tables and views
48	#[structopt(long)]
49	pub no_functions: bool,
50
51	/// How to use tuples (used by default for just overloads). Options:
52	/// overloads (the default, use tuples to represent function overloading).
53	/// all (Have all functions take a tuple for consitency).
54	/// none (skip mapping overloaded procs at all).
55	/// one_overload (avoid tuples by only mapping the oldest sql proc in the database).
56	#[structopt(long, default_value = "overloads")]
57	pub use_tuples: Tuples,
58
59	/// A comma seperated list of third party crates which contain types that will be mapped to and from sql types.
60	/// Valid values are "bit_vec,chrono,eui48,geo_types,rust_decimal,serde_json,time,uuid"
61	#[structopt(long, use_delimiter = true)]
62	pub third_party: Vec<ThirdParty>,
63
64	/// String to connect to database, see tokio_postgres::Config for details.
65	/// If not provided environment variable DATABASE_URL is checked instead
66	#[structopt(long, env = "DATABASE_URL")]
67	pub conn: String,
68
69	/// Output file, stdout if not present
70	#[structopt(parse(from_os_str))]
71	pub output: Option<PathBuf>,
72}
73
74#[derive(Debug, StructOpt, Clone, Copy, PartialEq, Eq)]
75pub enum Tuples {
76	/// use tuples to represent function overloading
77	ForOverloads,
78	/// Have all functions take a tuple for consitency
79	ForAll,
80	/// skip mapping overloaded procs at all
81	NoOverloads,
82	/// avoid tuples by only mapping the oldest sql proc in the database
83	OldestOverload,
84}
85impl std::str::FromStr for Tuples {
86	type Err = &'static str;
87
88	fn from_str(s: &str) -> Result<Tuples, &'static str> {
89		match s {
90			"overloads" => Ok(Tuples::ForOverloads),
91			"all" => Ok(Tuples::ForAll),
92			"none" => Ok(Tuples::NoOverloads),
93			"one_overload" => Ok(Tuples::OldestOverload),
94			_ => Err("Invalid tuple handling option, use one of (overloads, all, none, one_overload)"),
95		}
96	}
97}
98impl Tuples {
99	fn to_str(&self) -> &'static str {
100		match self {
101			Tuples::ForOverloads => "overloads",
102			Tuples::ForAll => "all",
103			Tuples::NoOverloads => "none",
104			Tuples::OldestOverload => "one_overload",
105		}
106	}
107}
108#[derive(Debug, StructOpt, Clone, Copy, PartialEq, Eq)]
109pub enum ThirdParty {
110	Chrono,
111	Time,
112	Eui48,
113	GeoTypes,
114	SerdeJson,
115	Uuid,
116	BitVec,
117	RustDecimal,
118}
119impl std::str::FromStr for ThirdParty {
120	type Err = String;
121
122	fn from_str(s: &str) -> Result<ThirdParty, String> {
123		match s {
124			"bit_vec" => Ok(ThirdParty::BitVec),
125			"chrono" => Ok(ThirdParty::Chrono),
126			"eui48" => Ok(ThirdParty::Eui48),
127			"geo_types" => Ok(ThirdParty::GeoTypes),
128			"rust_decimal" => Ok(ThirdParty::RustDecimal),
129			"serde_json" => Ok(ThirdParty::SerdeJson),
130			"time" => Ok(ThirdParty::Time),
131			"uuid" => Ok(ThirdParty::Uuid),
132			_ => Err(String::from(s)),
133		}
134	}
135}
136impl ThirdParty {
137	fn to_str(&self) -> &'static str {
138		match self {
139			ThirdParty::BitVec => "with-bit-vec-0_6",
140			ThirdParty::Chrono => "with-chrono-0_4",
141			ThirdParty::Eui48 => "with-eui48-0_4",
142			ThirdParty::GeoTypes => "with-geo-types-0_6",
143			ThirdParty::RustDecimal => "with-rust_decimal-1",
144			ThirdParty::SerdeJson => "with-serde_json-1",
145			ThirdParty::Uuid => "with-uuid-0_8",
146			ThirdParty::Time => "with-time-0_2",
147		}
148	}
149}
150
151impl Opt {
152	/// Produce the Cargo.toml file contents (the dependecies of the generated code)
153	pub fn get_cargo_toml(&self) -> String {
154		let package_name = self
155			.output
156			.as_ref()
157			.map(|v| v.file_name())
158			.flatten()
159			.map(|v| v.to_str())
160			.flatten()
161			.unwrap_or("my_db_mapping");
162
163		let dependencies = format!("[package]\nname = \"{}\"", package_name)
164			+ r#"
165version = "0.1.0"
166edition = "2018"
167
168[dependencies]
169sql_db_mapper_core = { version = "0.1.0", features = ["#
170			+ &self.get_dependencies()
171			+ r#"] }
172postgres-types = { version = "0.2", features = ["derive"] }
173async-trait = { version = "0.1", optional = true }
174
175serde = { version = "1.0", features = ["derive"] }
176
177[features]
178sync = []
179async = ["async-trait"]
180"#;
181
182		dependencies
183	}
184
185	fn get_dependencies(&self) -> String {
186		let mut ret = String::new();
187		if self.third_party.contains(&ThirdParty::BitVec) {
188			ret += r#""with-bit-vec-0_6", "#;
189		}
190		if self.third_party.contains(&ThirdParty::Chrono) {
191			ret += r#""with-chrono-0_4", "#;
192		}
193		if self.third_party.contains(&ThirdParty::Eui48) {
194			ret += r#""with-eui48-0_4", "#;
195		}
196		if self.third_party.contains(&ThirdParty::GeoTypes) {
197			ret += r#""with-geo-types-0_6", "#;
198		}
199		if self.third_party.contains(&ThirdParty::RustDecimal) {
200			ret += r#""with-rust_decimal-1", "#;
201		}
202		if self.third_party.contains(&ThirdParty::SerdeJson) {
203			ret += r#""with-serde_json-1", "#;
204		}
205		if self.third_party.contains(&ThirdParty::Time) {
206			ret += r#""with-time-0_2", "#;
207		}
208		if self.third_party.contains(&ThirdParty::Uuid) {
209			ret += r#""with-uuid-0_8", "#;
210		}
211		ret
212	}
213
214	/// Build a call string that could be used to get the same options
215	pub fn get_call_string(&self) -> String {
216		let ugly = if self.ugly { " -u" } else { "" };
217		let dir = if self.dir { " --dir" } else { "" };
218		let rust_case = if self.rust_case { " --rust_case" } else { "" };
219		let no_functions = if self.no_functions { " --no_functions" } else { "" };
220		let use_tuples = if self.use_tuples == Tuples::ForOverloads {
221			String::new()
222		} else {
223			format!(" --use-tuples {}", self.use_tuples.to_str())
224		};
225		let third_party = if self.third_party.is_empty() {
226			String::new()
227		} else {
228			let list = self.third_party.iter().map(|v| v.to_str()).fold(String::new(), |acc, v| acc+v+",");
229			format!(" --third-party \"{}\"", &list[..(list.len()-1)])
230		};
231		format!(
232			"sql_db_mapper{ugly}{dir}{rust_case}{no_functions}{use_tuples}{third_party}",
233			ugly = ugly,
234			dir = dir,
235			rust_case = rust_case,
236			no_functions = no_functions,
237			use_tuples = use_tuples,
238			third_party = third_party,
239		)
240	}
241
242	pub fn get_client(&self) -> connection::MyClient {
243		let client = Client::connect(&self.conn, NoTls)
244			.expect("Failed to connect to database, please check your connection string and try again");
245
246		connection::MyClient::new(client)
247	}
248
249	fn uses_lib(&self, lib_name: ThirdParty) -> bool {
250		self.third_party.contains(&lib_name)
251	}
252}
253
254/// Calls rustfmt (the program) on the input
255///
256/// On any rustfmt error stderr is written to and a copy of the input is returned
257///
258/// Can panic if acquiring/writing to stdin fails or the the text written to stdout or stderr by rustfmt is not valid utf8
259pub fn format_rust(value: &str, rustfmt_config: Option<&str>, rustfmt_config_path: Option<&str>) -> String {
260	use std::{
261		io::Write,
262		process::{Command, Stdio},
263	};
264	let mut args = Vec::new();
265	if let Some(s) = rustfmt_config {
266		args.push("--config");
267		args.push(s);
268	}
269	if let Some(s) = rustfmt_config_path {
270		args.push("--config-path");
271		args.push(s);
272	}
273	if let Ok(mut proc) = Command::new("rustfmt")
274		.arg("--emit=stdout")
275		.arg("--edition=2018")
276		.args(&args)
277		.stdin(Stdio::piped())
278		.stdout(Stdio::piped())
279		.stderr(Stdio::piped())
280		.spawn()
281	{
282		{
283			let stdin = proc.stdin.as_mut().unwrap();
284			stdin.write_all(value.as_bytes()).unwrap();
285		}
286		match proc.wait_with_output() {
287			Ok(output) => {
288				if !output.stderr.is_empty() {
289					eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap());
290				}
291				if output.status.success() {
292					return std::str::from_utf8(&output.stdout).unwrap().to_owned().into();
293				} else {
294					eprintln!("{:?}", output.status.code());
295					eprintln!("{}", std::str::from_utf8(&output.stdout).unwrap());
296				}
297			},
298			Err(e) => {
299				eprintln!("Error running rustfmt: {}", e);
300			},
301		}
302	} else {
303		eprintln!("failed to spawn rustfmt")
304	}
305	value.to_string()
306}