1use crate::core::system::System;
2use crate::db::driver::Driver;
3use crate::error::{EngineError, MyError, MyResult};
4use std::collections::BTreeMap;
5use std::io::{BufReader, BufWriter, Read, Write};
6use std::marker::PhantomData;
7use std::mem;
8
9pub struct Config<F: Read + Write, S: System<F>> {
10 system: S,
11 drivers: BTreeMap<String, Driver>,
12 imported: bool,
13 phantom: PhantomData<F>,
14}
15
16impl<F: Read + Write, S: System<F>> Config<F, S> {
17 pub fn new(system: S) -> MyResult<Self> {
18 let drivers = BTreeMap::new();
19 let imported = false;
20 let phantom = PhantomData;
21 let config = Self { system, drivers, imported, phantom };
22 Ok(config)
23 }
24
25 pub fn with_config(mut self) -> MyResult<Self> {
26 if !self.imported {
27 if let Some(reader) = self.system.open_config()? {
28 self.import_from(reader)?;
29 }
30 self.imported = true;
31 }
32 Ok(self)
33 }
34
35 pub fn export_config(&self) -> MyResult<()> {
36 let mut writer = self.system.create_config()?;
37 self.export_to(&mut writer)?;
38 Ok(())
39 }
40
41 fn import_from(&mut self, reader: F) -> MyResult<()> {
42 let reader = BufReader::new(reader);
43 self.drivers = serde_yml::from_reader(reader)?;
44 let mut default = false;
45 for driver in self.drivers.values_mut() {
46 if default {
47 driver.default = false;
48 } else {
49 default = driver.default;
50 }
51 }
52 Ok(())
53 }
54
55 fn export_to(&self, writer: &mut F) -> MyResult<()> {
56 let writer = BufWriter::new(writer);
57 serde_yml::to_writer(writer, &self.drivers)?;
58 Ok(())
59 }
60
61 pub fn get_aliases(&self) -> Vec<String> {
62 self.drivers.keys().cloned().collect()
63 }
64
65 pub fn add_driver(&mut self, alias: String, odbc: Option<String>, default: bool) -> MyResult<()> {
66 if let Some(odbc) = odbc {
67 let replaced = self.drivers.remove(&alias);
68 let default = default || replaced.map(|x| x.default).unwrap_or(false);
69 let driver = Driver::new(&odbc)?.with_default(default);
70 self.set_driver(alias, driver);
71 Ok(())
72 } else if let Some(driver) = self.drivers.remove(&alias) {
73 let driver = driver.with_default(default);
74 self.set_driver(alias, driver);
75 Ok(())
76 } else {
77 Err(MyError::Engine(EngineError::UnknownDriver(alias)))
78 }
79 }
80
81 fn set_driver(&mut self, alias: String, driver: Driver) {
82 if driver.default {
83 for driver in self.drivers.values_mut() {
84 driver.default = false;
85 }
86 }
87 self.drivers.insert(alias, driver);
88 }
89
90 pub fn remove_driver(&mut self, alias: String) -> MyResult<()> {
91 if let Some(previous) = self.drivers.remove(&alias) {
92 if previous.default {
93 if let Some(mut first) = self.drivers.first_entry() {
94 let first = first.get_mut();
95 first.default = true;
96 }
97 }
98 }
99 Ok(())
100 }
101
102 pub fn remove_history(&self) -> MyResult<()> {
103 self.system.remove_history(&self.drivers)
104 }
105
106 pub fn list_drivers<W2: Write>(&self, writer: &mut W2) -> MyResult<()> {
107 for (alias, driver) in &self.drivers {
108 let odbc = driver.format_odbc();
109 let default = if driver.default { " --default" } else { "" };
110 writeln!(writer, "zql config add {} '{}'{}", alias, odbc, default)?;
111 }
112 Ok(())
113 }
114
115 pub fn find_driver(&mut self, alias: Option<String>) -> MyResult<(String, Driver)> {
116 if let Some(alias) = alias {
117 return match self.drivers.remove(&alias) {
118 Some(driver) => Ok((alias, driver)),
119 None => Err(MyError::Engine(EngineError::UnknownDriver(alias))),
120 }
121 }
122 let count = self.drivers.len();
123 match count {
124 1 => {
125 let (alias, driver) = self.drivers.pop_first().unwrap();
126 return Ok((alias, driver));
127 }
128 0 => return Err(MyError::Engine(EngineError::SelectDriver(count))),
129 _ => (),
130 }
131 let drivers = mem::take(&mut self.drivers);
132 let mut drivers = drivers.into_iter().filter(|(_, v)| v.default).collect::<Vec<_>>();
133 match drivers.len() {
134 1 => {
135 let (alias, driver) = drivers.pop().unwrap();
136 Ok((alias, driver))
137 }
138 _ => Err(MyError::Engine(EngineError::SelectDriver(count))),
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use crate::core::system::tests::MockSystem;
146 use crate::db::config::{Config, Driver};
147 use crate::error::MyResult;
148 use indexmap::IndexMap;
149 use pretty_assertions::assert_eq;
150 use std::collections::BTreeMap;
151 use std::io::Cursor;
152 use std::marker::PhantomData;
153
154 #[test]
155 fn test_config_is_read_from_file() -> MyResult<()> {
156 let source = "\
158dsn:
159 default: true
160 odbc:
161 DSN: datasource
162 Uid: hwalters
163 Pwd: password
164mysql:
165 default: true
166 odbc:
167 Driver: '{MySQL ODBC}'
168 Server: 127.0.0.1
169 Database: hugos
170sqlite:
171 default: false
172 odbc:
173 Driver: SQLite3
174 Database: /home/hwalters/data/hugos.db
175";
176 let mut config = create_config(source, Vec::new())?.with_config()?;
177 assert_eq!(config.drivers.len(), 3);
179 let driver = config.drivers.remove("dsn").unwrap_or_default();
180 assert_eq!(driver.default, true);
181 assert_eq!(driver.odbc, create_odbc(vec![
182 ("DSN", "datasource"),
183 ("Uid", "hwalters"),
184 ("Pwd", "password"),
185 ]));
186 let driver = config.drivers.remove("mysql").unwrap_or_default();
187 assert_eq!(driver.default, false);
188 assert_eq!(driver.odbc, create_odbc(vec![
189 ("Driver", "{MySQL ODBC}"),
190 ("Server", "127.0.0.1"),
191 ("Database", "hugos"),
192 ]));
193 let driver = config.drivers.remove("sqlite").unwrap_or_default();
194 assert_eq!(driver.default, false);
195 assert_eq!(driver.odbc, create_odbc(vec![
196 ("Driver", "SQLite3"),
197 ("Database", "/home/hwalters/data/hugos.db"),
198 ]));
199 Ok(())
200 }
201
202 #[test]
203 fn test_config_is_written_to_file() -> MyResult<()> {
204 let config = create_config("", vec![
206 ("dsn", "DSN=datasource;Uid=hwalters;Pwd=password", false),
207 ("mysql", "Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos", false),
208 ("sqlite", "Driver=SQLite3;Database=/home/hwalters/data/hugos.db", true),
209 ])?;
210 let mut buffer = Cursor::new(Vec::new());
212 config.export_to(&mut buffer)?;
213 let expected = "\
215dsn:
216 default: false
217 odbc:
218 DSN: datasource
219 Uid: hwalters
220 Pwd: password
221mysql:
222 default: false
223 odbc:
224 Driver: '{MySQL ODBC}'
225 Server: '127.0.0.1'
226 Database: hugos
227sqlite:
228 default: true
229 odbc:
230 Driver: SQLite3
231 Database: /home/hwalters/data/hugos.db
232";
233 assert_eq!(expected, to_string(buffer));
234 Ok(())
235 }
236
237 #[test]
238 fn test_driver_is_added_to_config() -> MyResult<()> {
239 let mut config = create_config("", Vec::new())?;
241 config.add_driver(String::from("dsn"), Some(String::from("DSN=datasource")), true)?;
243 config.add_driver(String::from("mysql"), Some(String::from("Driver=MySQL")), true)?;
244 config.add_driver(String::from("sqlite"), Some(String::from("Driver=SQLite3")), false)?;
245 assert_eq!(config.drivers.len(), 3);
247 let dsn = config.drivers.remove("dsn").unwrap_or_default();
248 let mysql = config.drivers.remove("mysql").unwrap_or_default();
249 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
250 assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
251 assert_eq!(mysql, Driver { default: true, odbc: partial_odbc("Driver", "MySQL") });
252 assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
253 Ok(())
254 }
255
256 #[test]
257 fn test_driver_is_replaced_in_config() -> MyResult<()> {
258 let mut config = create_config("", vec![
260 ("dsn", "DSN=datasource", true),
261 ("mysql", "Driver=MySQL", false),
262 ("sqlite", "Driver=SQLite3", false),
263 ])?;
264 config.add_driver(String::from("sqlite"), Some(String::from("Driver=SQLite3a")), true)?;
266 assert_eq!(config.drivers.len(), 3);
268 let dsn = config.drivers.remove("dsn").unwrap_or_default();
269 let mysql = config.drivers.remove("mysql").unwrap_or_default();
270 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
271 assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
272 assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
273 assert_eq!(sqlite, Driver { default: true, odbc: partial_odbc("Driver", "SQLite3a") });
274 Ok(())
275 }
276
277 #[test]
278 fn test_default_flag_is_set_for_known_driver() -> MyResult<()> {
279 let mut config = create_config("", vec![
281 ("dsn", "DSN=datasource", true),
282 ("mysql", "Driver=MySQL", false),
283 ("sqlite", "Driver=SQLite3", false),
284 ])?;
285 config.add_driver(String::from("sqlite"), None, true)?;
287 assert_eq!(config.drivers.len(), 3);
289 let dsn = config.drivers.remove("dsn").unwrap_or_default();
290 let mysql = config.drivers.remove("mysql").unwrap_or_default();
291 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
292 assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
293 assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
294 assert_eq!(sqlite, Driver { default: true, odbc: partial_odbc("Driver", "SQLite3") });
295 Ok(())
296 }
297
298 #[test]
299 fn test_default_flag_is_kept_for_known_driver() -> MyResult<()> {
300 let mut config = create_config("", vec![
302 ("dsn", "DSN=datasource", true),
303 ("mysql", "Driver=MySQL", false),
304 ("sqlite", "Driver=SQLite3", false),
305 ])?;
306 config.add_driver(String::from("dsn"), Some(String::from("DSN=datasource2")), false)?;
308 assert_eq!(config.drivers.len(), 3);
310 let dsn = config.drivers.remove("dsn").unwrap_or_default();
311 let mysql = config.drivers.remove("mysql").unwrap_or_default();
312 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
313 assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource2") });
314 assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
315 assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
316 Ok(())
317 }
318
319 #[test]
320 fn test_default_flag_not_set_for_unknown_driver() -> MyResult<()> {
321 let mut config = create_config("", vec![
323 ("dsn", "DSN=datasource", true),
324 ("mysql", "Driver=MySQL", false),
325 ("sqlite", "Driver=SQLite3", false),
326 ])?;
327 let error = config.add_driver(String::from("unknown"), None, true).unwrap_err();
329 assert_eq!(config.drivers.len(), 3);
331 let dsn = config.drivers.remove("dsn").unwrap_or_default();
332 let mysql = config.drivers.remove("mysql").unwrap_or_default();
333 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
334 assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource") });
335 assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
336 assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
337 assert_eq!(error.to_string(), "Unknown driver alias \"unknown\"");
339 Ok(())
340 }
341
342 #[test]
343 fn test_default_driver_is_removed_from_config() -> MyResult<()> {
344 let mut config = create_config("", vec![
346 ("dsn", "DSN=datasource", true),
347 ("mysql", "Driver=MySQL", false),
348 ("sqlite", "Driver=SQLite3", false),
349 ])?;
350 config.remove_driver(String::from("dsn"))?;
352 assert_eq!(config.drivers.len(), 2);
354 let mysql = config.drivers.remove("mysql").unwrap_or_default();
355 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
356 assert_eq!(mysql, Driver { default: true, odbc: partial_odbc("Driver", "MySQL") });
357 assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
358 Ok(())
359 }
360
361 #[test]
362 fn test_non_default_driver_is_removed_from_config() -> MyResult<()> {
363 let mut config = create_config("", vec![
365 ("dsn", "DSN=datasource", true),
366 ("mysql", "Driver=MySQL", false),
367 ("sqlite", "Driver=SQLite3", false),
368 ])?;
369 config.remove_driver(String::from("mysql"))?;
371 assert_eq!(config.drivers.len(), 2);
373 let dsn = config.drivers.remove("dsn").unwrap_or_default();
374 let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
375 assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource") });
376 assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
377 Ok(())
378 }
379
380 #[test]
381 fn test_drivers_are_listed_as_commands() -> MyResult<()> {
382 let config = create_config("", vec![
384 ("dsn", "DSN=datasource;Uid=hwalters;Pwd=password", false),
385 ("mysql", "Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos", false),
386 ("sqlite", "Driver=SQLite3;Database=/home/hwalters/data/hugos.db", true),
387 ])?;
388 let mut buffer = Cursor::new(Vec::new());
390 config.list_drivers(&mut buffer)?;
391 let expected = "\
393zql config add dsn 'DSN=datasource;Uid=hwalters;Pwd=password'
394zql config add mysql 'Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos'
395zql config add sqlite 'Driver=SQLite3;Database=/home/hwalters/data/hugos.db' --default
396";
397 assert_eq!(expected, to_string(buffer));
398 Ok(())
399 }
400
401 #[test]
402 fn test_driver_is_found_from_valid_supplied_name() -> MyResult<()> {
403 let mut config = create_config("", vec![
405 ("alias1", "odbc=1", false),
406 ("alias2", "odbc=2", false),
407 ("alias3", "odbc=3", false),
408 ])?;
409 let result1 = config.find_driver(Some(String::from("alias1")))?;
411 let result2 = config.find_driver(Some(String::from("alias2")))?;
412 let result3 = config.find_driver(Some(String::from("alias3")))?;
413 assert_eq!(result1.0, "alias1");
414 assert_eq!(result2.0, "alias2");
415 assert_eq!(result3.0, "alias3");
416 assert_eq!(result1.1.format_odbc(), "odbc=1");
417 assert_eq!(result2.1.format_odbc(), "odbc=2");
418 assert_eq!(result3.1.format_odbc(), "odbc=3");
419 Ok(())
420 }
421
422 #[test]
423 fn test_driver_not_found_from_invalid_supplied_name() -> MyResult<()> {
424 let mut config = create_config("", vec![
426 ("alias1", "odbc=1", false),
427 ("alias2", "odbc=2", false),
428 ("alias3", "odbc=3", false),
429 ])?;
430 let error = config.find_driver(Some(String::from("missing"))).unwrap_err();
432 assert_eq!(error.to_string(), "Unknown driver alias \"missing\"");
433 Ok(())
434 }
435
436 #[test]
437 fn test_driver_is_found_from_single_default_name() -> MyResult<()> {
438 let mut config = create_config("", vec![
440 ("alias1", "odbc=1", false),
441 ("alias2", "odbc=2", true),
442 ("alias3", "odbc=3", false),
443 ])?;
444 let result = config.find_driver(None)?;
446 assert_eq!(result.0, "alias2");
447 assert_eq!(result.1.format_odbc(), "odbc=2");
448 Ok(())
449 }
450
451 #[test]
452 fn test_driver_not_found_from_missing_default_names() -> MyResult<()> {
453 let mut config = create_config("", vec![
455 ("alias1", "odbc=1", false),
456 ("alias2", "odbc=2", false),
457 ("alias3", "odbc=3", false),
458 ])?;
459 let error = config.find_driver(None).unwrap_err();
461 assert_eq!(error.to_string(), "Cannot select from 3 drivers");
462 Ok(())
463 }
464
465 #[test]
466 fn test_driver_not_found_from_multiple_default_names() -> MyResult<()> {
467 let mut config = create_config("", vec![
469 ("alias1", "odbc=1", false),
470 ("alias2", "odbc=2", true),
471 ("alias3", "odbc=3", true),
472 ])?;
473 let error = config.find_driver(None).unwrap_err();
475 assert_eq!(error.to_string(), "Cannot select from 3 drivers");
476 Ok(())
477 }
478
479 #[test]
480 fn test_driver_is_found_from_single_configured_entry() -> MyResult<()> {
481 let mut config = create_config("", vec![
483 ("alias1", "odbc=1", false),
484 ])?;
485 let result = config.find_driver(None)?;
487 assert_eq!(result.0, "alias1");
488 assert_eq!(result.1.format_odbc(), "odbc=1");
489 Ok(())
490 }
491
492 fn create_config(source: &str, tuples: Vec<(&str, &str, bool)>) -> MyResult<Config<Cursor<Vec<u8>>, MockSystem>> {
493 let convert = |tuple: (&str, &str, bool)| {
494 let (alias, odbc, default) = tuple;
495 let alias = String::from(alias);
496 let driver = Driver::new(odbc)?.with_default(default);
497 Ok((alias, driver))
498 };
499 let system = MockSystem::new(source);
500 let drivers = tuples.into_iter().map(convert).collect::<MyResult<BTreeMap<_, _>>>()?;
501 let imported = false;
502 let phantom = PhantomData;
503 let config = Config { system, drivers, imported, phantom };
504 Ok(config)
505 }
506
507 fn partial_odbc(key: &str, value: &str) -> IndexMap<String, String> {
508 let mut odbc = IndexMap::new();
509 let key = String::from(key);
510 let value = String::from(value);
511 odbc.insert(key, value);
512 odbc
513 }
514
515 fn create_odbc(pairs: Vec<(&str, &str)>) -> IndexMap<String, String> {
516 let convert = |pair: (&str, &str)| {
517 let (key, value) = pair;
518 let key = String::from(key);
519 let value = String::from(value);
520 (key, value)
521 };
522 pairs.into_iter().map(convert).collect()
523 }
524
525 fn to_string(buffer: Cursor<Vec<u8>>) -> String {
526 String::from_utf8(buffer.into_inner()).unwrap_or_default()
527 }
528}