1#![forbid(unsafe_code)]
2pub mod ast_convert;
5pub mod connection;
6mod pg_select_types;
7mod sql_tree;
8
9pub const VERSION: &str = std::env!("CARGO_PKG_VERSION");
10
11use postgres::{Client, NoTls};
12use std::path::PathBuf;
13use structopt::StructOpt;
14
15#[derive(Debug, StructOpt)]
17#[structopt(
18 name = "sql_db_mapper",
19 about = "Generate a rust wrapper for a PostgreSQL database",
20 version = VERSION
21)]
22pub struct Opt {
23 #[structopt(short, long)]
25 pub debug: bool,
26
27 #[structopt(short, long)]
29 pub ugly: bool,
30
31 #[structopt(long)]
33 pub dir: bool,
34
35 #[structopt(long)]
37 pub rust_case: bool,
38
39 #[structopt(long)]
41 pub rustfmt_config: Option<String>,
42
43 #[structopt(long)]
45 pub rustfmt_config_path: Option<String>,
46
47 #[structopt(long)]
49 pub no_functions: bool,
50
51 #[structopt(long, default_value = "overloads")]
57 pub use_tuples: Tuples,
58
59 #[structopt(long, use_delimiter = true)]
62 pub third_party: Vec<ThirdParty>,
63
64 #[structopt(long, env = "DATABASE_URL")]
67 pub conn: String,
68
69 #[structopt(parse(from_os_str))]
71 pub output: Option<PathBuf>,
72}
73
74#[derive(Debug, StructOpt, Clone, Copy, PartialEq, Eq)]
75pub enum Tuples {
76 ForOverloads,
78 ForAll,
80 NoOverloads,
82 OldestOverload,
84}
85impl std::str::FromStr for Tuples {
86 type Err = &'static str;
87
88 fn from_str(s: &str) -> Result<Tuples, &'static str> {
89 match s {
90 "overloads" => Ok(Tuples::ForOverloads),
91 "all" => Ok(Tuples::ForAll),
92 "none" => Ok(Tuples::NoOverloads),
93 "one_overload" => Ok(Tuples::OldestOverload),
94 _ => Err("Invalid tuple handling option, use one of (overloads, all, none, one_overload)"),
95 }
96 }
97}
98impl Tuples {
99 fn to_str(&self) -> &'static str {
100 match self {
101 Tuples::ForOverloads => "overloads",
102 Tuples::ForAll => "all",
103 Tuples::NoOverloads => "none",
104 Tuples::OldestOverload => "one_overload",
105 }
106 }
107}
108#[derive(Debug, StructOpt, Clone, Copy, PartialEq, Eq)]
109pub enum ThirdParty {
110 Chrono,
111 Time,
112 Eui48,
113 GeoTypes,
114 SerdeJson,
115 Uuid,
116 BitVec,
117 RustDecimal,
118}
119impl std::str::FromStr for ThirdParty {
120 type Err = String;
121
122 fn from_str(s: &str) -> Result<ThirdParty, String> {
123 match s {
124 "bit_vec" => Ok(ThirdParty::BitVec),
125 "chrono" => Ok(ThirdParty::Chrono),
126 "eui48" => Ok(ThirdParty::Eui48),
127 "geo_types" => Ok(ThirdParty::GeoTypes),
128 "rust_decimal" => Ok(ThirdParty::RustDecimal),
129 "serde_json" => Ok(ThirdParty::SerdeJson),
130 "time" => Ok(ThirdParty::Time),
131 "uuid" => Ok(ThirdParty::Uuid),
132 _ => Err(String::from(s)),
133 }
134 }
135}
136impl ThirdParty {
137 fn to_str(&self) -> &'static str {
138 match self {
139 ThirdParty::BitVec => "with-bit-vec-0_6",
140 ThirdParty::Chrono => "with-chrono-0_4",
141 ThirdParty::Eui48 => "with-eui48-0_4",
142 ThirdParty::GeoTypes => "with-geo-types-0_6",
143 ThirdParty::RustDecimal => "with-rust_decimal-1",
144 ThirdParty::SerdeJson => "with-serde_json-1",
145 ThirdParty::Uuid => "with-uuid-0_8",
146 ThirdParty::Time => "with-time-0_2",
147 }
148 }
149}
150
151impl Opt {
152 pub fn get_cargo_toml(&self) -> String {
154 let package_name = self
155 .output
156 .as_ref()
157 .map(|v| v.file_name())
158 .flatten()
159 .map(|v| v.to_str())
160 .flatten()
161 .unwrap_or("my_db_mapping");
162
163 let dependencies = format!("[package]\nname = \"{}\"", package_name)
164 + r#"
165version = "0.1.0"
166edition = "2018"
167
168[dependencies]
169sql_db_mapper_core = { version = "0.1.0", features = ["#
170 + &self.get_dependencies()
171 + r#"] }
172postgres-types = { version = "0.2", features = ["derive"] }
173async-trait = { version = "0.1", optional = true }
174
175serde = { version = "1.0", features = ["derive"] }
176
177[features]
178sync = []
179async = ["async-trait"]
180"#;
181
182 dependencies
183 }
184
185 fn get_dependencies(&self) -> String {
186 let mut ret = String::new();
187 if self.third_party.contains(&ThirdParty::BitVec) {
188 ret += r#""with-bit-vec-0_6", "#;
189 }
190 if self.third_party.contains(&ThirdParty::Chrono) {
191 ret += r#""with-chrono-0_4", "#;
192 }
193 if self.third_party.contains(&ThirdParty::Eui48) {
194 ret += r#""with-eui48-0_4", "#;
195 }
196 if self.third_party.contains(&ThirdParty::GeoTypes) {
197 ret += r#""with-geo-types-0_6", "#;
198 }
199 if self.third_party.contains(&ThirdParty::RustDecimal) {
200 ret += r#""with-rust_decimal-1", "#;
201 }
202 if self.third_party.contains(&ThirdParty::SerdeJson) {
203 ret += r#""with-serde_json-1", "#;
204 }
205 if self.third_party.contains(&ThirdParty::Time) {
206 ret += r#""with-time-0_2", "#;
207 }
208 if self.third_party.contains(&ThirdParty::Uuid) {
209 ret += r#""with-uuid-0_8", "#;
210 }
211 ret
212 }
213
214 pub fn get_call_string(&self) -> String {
216 let ugly = if self.ugly { " -u" } else { "" };
217 let dir = if self.dir { " --dir" } else { "" };
218 let rust_case = if self.rust_case { " --rust_case" } else { "" };
219 let no_functions = if self.no_functions { " --no_functions" } else { "" };
220 let use_tuples = if self.use_tuples == Tuples::ForOverloads {
221 String::new()
222 } else {
223 format!(" --use-tuples {}", self.use_tuples.to_str())
224 };
225 let third_party = if self.third_party.is_empty() {
226 String::new()
227 } else {
228 let list = self.third_party.iter().map(|v| v.to_str()).fold(String::new(), |acc, v| acc+v+",");
229 format!(" --third-party \"{}\"", &list[..(list.len()-1)])
230 };
231 format!(
232 "sql_db_mapper{ugly}{dir}{rust_case}{no_functions}{use_tuples}{third_party}",
233 ugly = ugly,
234 dir = dir,
235 rust_case = rust_case,
236 no_functions = no_functions,
237 use_tuples = use_tuples,
238 third_party = third_party,
239 )
240 }
241
242 pub fn get_client(&self) -> connection::MyClient {
243 let client = Client::connect(&self.conn, NoTls)
244 .expect("Failed to connect to database, please check your connection string and try again");
245
246 connection::MyClient::new(client)
247 }
248
249 fn uses_lib(&self, lib_name: ThirdParty) -> bool {
250 self.third_party.contains(&lib_name)
251 }
252}
253
254pub fn format_rust(value: &str, rustfmt_config: Option<&str>, rustfmt_config_path: Option<&str>) -> String {
260 use std::{
261 io::Write,
262 process::{Command, Stdio},
263 };
264 let mut args = Vec::new();
265 if let Some(s) = rustfmt_config {
266 args.push("--config");
267 args.push(s);
268 }
269 if let Some(s) = rustfmt_config_path {
270 args.push("--config-path");
271 args.push(s);
272 }
273 if let Ok(mut proc) = Command::new("rustfmt")
274 .arg("--emit=stdout")
275 .arg("--edition=2018")
276 .args(&args)
277 .stdin(Stdio::piped())
278 .stdout(Stdio::piped())
279 .stderr(Stdio::piped())
280 .spawn()
281 {
282 {
283 let stdin = proc.stdin.as_mut().unwrap();
284 stdin.write_all(value.as_bytes()).unwrap();
285 }
286 match proc.wait_with_output() {
287 Ok(output) => {
288 if !output.stderr.is_empty() {
289 eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap());
290 }
291 if output.status.success() {
292 return std::str::from_utf8(&output.stdout).unwrap().to_owned().into();
293 } else {
294 eprintln!("{:?}", output.status.code());
295 eprintln!("{}", std::str::from_utf8(&output.stdout).unwrap());
296 }
297 },
298 Err(e) => {
299 eprintln!("Error running rustfmt: {}", e);
300 },
301 }
302 } else {
303 eprintln!("failed to spawn rustfmt")
304 }
305 value.to_string()
306}