Skip to main content

zql_cli/db/
config.rs

1use crate::core::system::System;
2use crate::db::driver::Driver;
3use crate::error::{EngineError, MyError, MyResult};
4use std::collections::BTreeMap;
5use std::io::{BufReader, BufWriter, Read, Write};
6use std::marker::PhantomData;
7use std::mem;
8
9pub struct Config<F: Read + Write, S: System<F>> {
10    system: S,
11    drivers: BTreeMap<String, Driver>,
12    imported: bool,
13    phantom: PhantomData<F>,
14}
15
16impl<F: Read + Write, S: System<F>> Config<F, S> {
17    pub fn new(system: S) -> MyResult<Self> {
18        let drivers = BTreeMap::new();
19        let imported = false;
20        let phantom = PhantomData;
21        let config = Self { system, drivers, imported, phantom };
22        Ok(config)
23    }
24
25    pub fn with_config(mut self) -> MyResult<Self> {
26        if !self.imported {
27            if let Some(reader) = self.system.open_config()? {
28                self.import_from(reader)?;
29            }
30            self.imported = true;
31        }
32        Ok(self)
33    }
34
35    pub fn export_config(&self) -> MyResult<()> {
36        let mut writer = self.system.create_config()?;
37        self.export_to(&mut writer)?;
38        Ok(())
39    }
40
41    fn import_from(&mut self, reader: F) -> MyResult<()> {
42        let reader = BufReader::new(reader);
43        self.drivers = serde_yml::from_reader(reader)?;
44        let mut default = false;
45        for driver in self.drivers.values_mut() {
46            if default {
47                driver.default = false;
48            } else {
49                default = driver.default;
50            }
51        }
52        Ok(())
53    }
54
55    fn export_to(&self, writer: &mut F) -> MyResult<()> {
56        let writer = BufWriter::new(writer);
57        serde_yml::to_writer(writer, &self.drivers)?;
58        Ok(())
59    }
60
61    pub fn get_aliases(&self) -> Vec<String> {
62        self.drivers.keys().cloned().collect()
63    }
64
65    pub fn add_driver(&mut self, alias: String, odbc: Option<String>, default: bool) -> MyResult<()> {
66        if let Some(odbc) = odbc {
67            let replaced = self.drivers.remove(&alias);
68            let default = default || replaced.map(|x| x.default).unwrap_or(false);
69            let driver = Driver::new(&odbc)?.with_default(default);
70            self.set_driver(alias, driver);
71            Ok(())
72        } else if let Some(driver) = self.drivers.remove(&alias) {
73            let driver = driver.with_default(default);
74            self.set_driver(alias, driver);
75            Ok(())
76        } else {
77            Err(MyError::Engine(EngineError::UnknownDriver(alias)))
78        }
79    }
80
81    fn set_driver(&mut self, alias: String, driver: Driver) {
82        if driver.default {
83            for driver in self.drivers.values_mut() {
84                driver.default = false;
85            }
86        }
87        self.drivers.insert(alias, driver);
88    }
89
90    pub fn remove_driver(&mut self, alias: String) -> MyResult<()> {
91        if let Some(previous) = self.drivers.remove(&alias) {
92            if previous.default {
93                if let Some(mut first) = self.drivers.first_entry() {
94                    let first = first.get_mut();
95                    first.default = true;
96                }
97            }
98        }
99        Ok(())
100    }
101
102    pub fn remove_history(&self) -> MyResult<()> {
103        self.system.remove_history(&self.drivers)
104    }
105
106    pub fn list_drivers<W2: Write>(&self, writer: &mut W2) -> MyResult<()> {
107        for (alias, driver) in &self.drivers {
108            let odbc = driver.format_odbc();
109            let default = if driver.default { " --default" } else { "" };
110            writeln!(writer, "zql config add {} '{}'{}", alias, odbc, default)?;
111        }
112        Ok(())
113    }
114
115    pub fn find_driver(&mut self, alias: Option<String>) -> MyResult<(String, Driver)> {
116        if let Some(alias) = alias {
117            return match self.drivers.remove(&alias) {
118                Some(driver) => Ok((alias, driver)),
119                None => Err(MyError::Engine(EngineError::UnknownDriver(alias))),
120            }
121        }
122        let count = self.drivers.len();
123        match count {
124            1 => {
125                let (alias, driver) = self.drivers.pop_first().unwrap();
126                return Ok((alias, driver));
127            }
128            0 => return Err(MyError::Engine(EngineError::SelectDriver(count))),
129            _ => (),
130        }
131        let drivers = mem::take(&mut self.drivers);
132        let mut drivers = drivers.into_iter().filter(|(_, v)| v.default).collect::<Vec<_>>();
133        match drivers.len() {
134            1 => {
135                let (alias, driver) = drivers.pop().unwrap();
136                Ok((alias, driver))
137            }
138            _ => Err(MyError::Engine(EngineError::SelectDriver(count))),
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use crate::core::system::tests::MockSystem;
146    use crate::db::config::{Config, Driver};
147    use crate::error::MyResult;
148    use indexmap::IndexMap;
149    use pretty_assertions::assert_eq;
150    use std::collections::BTreeMap;
151    use std::io::Cursor;
152    use std::marker::PhantomData;
153
154    #[test]
155    fn test_config_is_read_from_file() -> MyResult<()> {
156        // Given a config with no drivers.
157        let source = "\
158dsn:
159  default: true
160  odbc:
161    DSN: datasource
162    Uid: hwalters
163    Pwd: password
164mysql:
165  default: true
166  odbc:
167    Driver: '{MySQL ODBC}'
168    Server: 127.0.0.1
169    Database: hugos
170sqlite:
171  default: false
172  odbc:
173    Driver: SQLite3
174    Database: /home/hwalters/data/hugos.db
175";
176        let mut config = create_config(source, Vec::new())?.with_config()?;
177        // Then the config has three drivers.
178        assert_eq!(config.drivers.len(), 3);
179        let driver = config.drivers.remove("dsn").unwrap_or_default();
180        assert_eq!(driver.default, true);
181        assert_eq!(driver.odbc, create_odbc(vec![
182            ("DSN", "datasource"),
183            ("Uid", "hwalters"),
184            ("Pwd", "password"),
185        ]));
186        let driver = config.drivers.remove("mysql").unwrap_or_default();
187        assert_eq!(driver.default, false);
188        assert_eq!(driver.odbc, create_odbc(vec![
189            ("Driver", "{MySQL ODBC}"),
190            ("Server", "127.0.0.1"),
191            ("Database", "hugos"),
192        ]));
193        let driver = config.drivers.remove("sqlite").unwrap_or_default();
194        assert_eq!(driver.default, false);
195        assert_eq!(driver.odbc, create_odbc(vec![
196            ("Driver", "SQLite3"),
197            ("Database", "/home/hwalters/data/hugos.db"),
198        ]));
199        Ok(())
200    }
201
202    #[test]
203    fn test_config_is_written_to_file() -> MyResult<()> {
204        // Given a config with three drivers.
205        let config = create_config("", vec![
206            ("dsn", "DSN=datasource;Uid=hwalters;Pwd=password", false),
207            ("mysql", "Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos", false),
208            ("sqlite", "Driver=SQLite3;Database=/home/hwalters/data/hugos.db", true),
209        ])?;
210        // When the config is written to the file.
211        let mut buffer = Cursor::new(Vec::new());
212        config.export_to(&mut buffer)?;
213        // Then the file has expected YAML entries.
214        let expected = "\
215dsn:
216  default: false
217  odbc:
218    DSN: datasource
219    Uid: hwalters
220    Pwd: password
221mysql:
222  default: false
223  odbc:
224    Driver: '{MySQL ODBC}'
225    Server: '127.0.0.1'
226    Database: hugos
227sqlite:
228  default: true
229  odbc:
230    Driver: SQLite3
231    Database: /home/hwalters/data/hugos.db
232";
233        assert_eq!(expected, to_string(buffer));
234        Ok(())
235    }
236
237    #[test]
238    fn test_driver_is_added_to_config() -> MyResult<()> {
239        // Given a config with no drivers.
240        let mut config = create_config("", Vec::new())?;
241        // When drivers are added to the config.
242        config.add_driver(String::from("dsn"), Some(String::from("DSN=datasource")), true)?;
243        config.add_driver(String::from("mysql"), Some(String::from("Driver=MySQL")), true)?;
244        config.add_driver(String::from("sqlite"), Some(String::from("Driver=SQLite3")), false)?;
245        // Then the config has three drivers.
246        assert_eq!(config.drivers.len(), 3);
247        let dsn = config.drivers.remove("dsn").unwrap_or_default();
248        let mysql = config.drivers.remove("mysql").unwrap_or_default();
249        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
250        assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
251        assert_eq!(mysql, Driver { default: true, odbc: partial_odbc("Driver", "MySQL") });
252        assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
253        Ok(())
254    }
255
256    #[test]
257    fn test_driver_is_replaced_in_config() -> MyResult<()> {
258        // Given a config with three drivers.
259        let mut config = create_config("", vec![
260            ("dsn", "DSN=datasource", true),
261            ("mysql", "Driver=MySQL", false),
262            ("sqlite", "Driver=SQLite3", false),
263        ])?;
264        // When a known driver is replaced.
265        config.add_driver(String::from("sqlite"), Some(String::from("Driver=SQLite3a")), true)?;
266        // Then the config has three drivers.
267        assert_eq!(config.drivers.len(), 3);
268        let dsn = config.drivers.remove("dsn").unwrap_or_default();
269        let mysql = config.drivers.remove("mysql").unwrap_or_default();
270        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
271        assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
272        assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
273        assert_eq!(sqlite, Driver { default: true, odbc: partial_odbc("Driver", "SQLite3a") });
274        Ok(())
275    }
276
277    #[test]
278    fn test_default_flag_is_set_for_known_driver() -> MyResult<()> {
279        // Given a config with three drivers.
280        let mut config = create_config("", vec![
281            ("dsn", "DSN=datasource", true),
282            ("mysql", "Driver=MySQL", false),
283            ("sqlite", "Driver=SQLite3", false),
284        ])?;
285        // When a known driver is marked as default.
286        config.add_driver(String::from("sqlite"), None, true)?;
287        // Then the default flag is moved.
288        assert_eq!(config.drivers.len(), 3);
289        let dsn = config.drivers.remove("dsn").unwrap_or_default();
290        let mysql = config.drivers.remove("mysql").unwrap_or_default();
291        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
292        assert_eq!(dsn, Driver { default: false, odbc: partial_odbc("DSN", "datasource") });
293        assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
294        assert_eq!(sqlite, Driver { default: true, odbc: partial_odbc("Driver", "SQLite3") });
295        Ok(())
296    }
297
298    #[test]
299    fn test_default_flag_is_kept_for_known_driver() -> MyResult<()> {
300        // Given a config with three drivers.
301        let mut config = create_config("", vec![
302            ("dsn", "DSN=datasource", true),
303            ("mysql", "Driver=MySQL", false),
304            ("sqlite", "Driver=SQLite3", false),
305        ])?;
306        // When a known driver is marked as default.
307        config.add_driver(String::from("dsn"), Some(String::from("DSN=datasource2")), false)?;
308        // Then the default flag is moved.
309        assert_eq!(config.drivers.len(), 3);
310        let dsn = config.drivers.remove("dsn").unwrap_or_default();
311        let mysql = config.drivers.remove("mysql").unwrap_or_default();
312        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
313        assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource2") });
314        assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
315        assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
316        Ok(())
317    }
318
319    #[test]
320    fn test_default_flag_not_set_for_unknown_driver() -> MyResult<()> {
321        // Given a config with three drivers.
322        let mut config = create_config("", vec![
323            ("dsn", "DSN=datasource", true),
324            ("mysql", "Driver=MySQL", false),
325            ("sqlite", "Driver=SQLite3", false),
326        ])?;
327        // When an unknown driver is marked as default.
328        let error = config.add_driver(String::from("unknown"), None, true).unwrap_err();
329        // Then the default flag is not moved.
330        assert_eq!(config.drivers.len(), 3);
331        let dsn = config.drivers.remove("dsn").unwrap_or_default();
332        let mysql = config.drivers.remove("mysql").unwrap_or_default();
333        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
334        assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource") });
335        assert_eq!(mysql, Driver { default: false, odbc: partial_odbc("Driver", "MySQL") });
336        assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
337        // And an error is returned.
338        assert_eq!(error.to_string(), "Unknown driver alias \"unknown\"");
339        Ok(())
340    }
341
342    #[test]
343    fn test_default_driver_is_removed_from_config() -> MyResult<()> {
344        // Given a config with three drivers.
345        let mut config = create_config("", vec![
346            ("dsn", "DSN=datasource", true),
347            ("mysql", "Driver=MySQL", false),
348            ("sqlite", "Driver=SQLite3", false),
349        ])?;
350        // When the default driver is removed from the config.
351        config.remove_driver(String::from("dsn"))?;
352        // Then the config has two drivers, and another default is set.
353        assert_eq!(config.drivers.len(), 2);
354        let mysql = config.drivers.remove("mysql").unwrap_or_default();
355        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
356        assert_eq!(mysql, Driver { default: true, odbc: partial_odbc("Driver", "MySQL") });
357        assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
358        Ok(())
359    }
360
361    #[test]
362    fn test_non_default_driver_is_removed_from_config() -> MyResult<()> {
363        // Given a config with three drivers.
364        let mut config = create_config("", vec![
365            ("dsn", "DSN=datasource", true),
366            ("mysql", "Driver=MySQL", false),
367            ("sqlite", "Driver=SQLite3", false),
368        ])?;
369        // When a non default driver is removed from the config.
370        config.remove_driver(String::from("mysql"))?;
371        // Then the config has two drivers.
372        assert_eq!(config.drivers.len(), 2);
373        let dsn = config.drivers.remove("dsn").unwrap_or_default();
374        let sqlite = config.drivers.remove("sqlite").unwrap_or_default();
375        assert_eq!(dsn, Driver { default: true, odbc: partial_odbc("DSN", "datasource") });
376        assert_eq!(sqlite, Driver { default: false, odbc: partial_odbc("Driver", "SQLite3") });
377        Ok(())
378    }
379
380    #[test]
381    fn test_drivers_are_listed_as_commands() -> MyResult<()> {
382        // Given a config with three drivers.
383        let config = create_config("", vec![
384            ("dsn", "DSN=datasource;Uid=hwalters;Pwd=password", false),
385            ("mysql", "Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos", false),
386            ("sqlite", "Driver=SQLite3;Database=/home/hwalters/data/hugos.db", true),
387        ])?;
388        // When the drivers are listed.
389        let mut buffer = Cursor::new(Vec::new());
390        config.list_drivers(&mut buffer)?;
391        // Then the output has expected table entries.
392        let expected = "\
393zql config add dsn 'DSN=datasource;Uid=hwalters;Pwd=password'
394zql config add mysql 'Driver={MySQL ODBC};Server=127.0.0.1;Database=hugos'
395zql config add sqlite 'Driver=SQLite3;Database=/home/hwalters/data/hugos.db' --default
396";
397        assert_eq!(expected, to_string(buffer));
398        Ok(())
399    }
400
401    #[test]
402    fn test_driver_is_found_from_valid_supplied_name() -> MyResult<()> {
403        // Given a config with multiple drivers and no default.
404        let mut config = create_config("", vec![
405            ("alias1", "odbc=1", false),
406            ("alias2", "odbc=2", false),
407            ("alias3", "odbc=3", false),
408        ])?;
409        // Then the driver is found from a valid alias.
410        let result1 = config.find_driver(Some(String::from("alias1")))?;
411        let result2 = config.find_driver(Some(String::from("alias2")))?;
412        let result3 = config.find_driver(Some(String::from("alias3")))?;
413        assert_eq!(result1.0, "alias1");
414        assert_eq!(result2.0, "alias2");
415        assert_eq!(result3.0, "alias3");
416        assert_eq!(result1.1.format_odbc(), "odbc=1");
417        assert_eq!(result2.1.format_odbc(), "odbc=2");
418        assert_eq!(result3.1.format_odbc(), "odbc=3");
419        Ok(())
420    }
421
422    #[test]
423    fn test_driver_not_found_from_invalid_supplied_name() -> MyResult<()> {
424        // Given a config with multiple drivers and no default.
425        let mut config = create_config("", vec![
426            ("alias1", "odbc=1", false),
427            ("alias2", "odbc=2", false),
428            ("alias3", "odbc=3", false),
429        ])?;
430        // Then the driver is not found from an invalid alias.
431        let error = config.find_driver(Some(String::from("missing"))).unwrap_err();
432        assert_eq!(error.to_string(), "Unknown driver alias \"missing\"");
433        Ok(())
434    }
435
436    #[test]
437    fn test_driver_is_found_from_single_default_name() -> MyResult<()> {
438        // Given a config with multiple drivers and one default.
439        let mut config = create_config("", vec![
440            ("alias1", "odbc=1", false),
441            ("alias2", "odbc=2", true),
442            ("alias3", "odbc=3", false),
443        ])?;
444        // Then the driver is found from the default.
445        let result = config.find_driver(None)?;
446        assert_eq!(result.0, "alias2");
447        assert_eq!(result.1.format_odbc(), "odbc=2");
448        Ok(())
449    }
450
451    #[test]
452    fn test_driver_not_found_from_missing_default_names() -> MyResult<()> {
453        // Given a config with multiple drivers and no defaults.
454        let mut config = create_config("", vec![
455            ("alias1", "odbc=1", false),
456            ("alias2", "odbc=2", false),
457            ("alias3", "odbc=3", false),
458        ])?;
459        // Then the driver is not found from the entries.
460        let error = config.find_driver(None).unwrap_err();
461        assert_eq!(error.to_string(), "Cannot select from 3 drivers");
462        Ok(())
463    }
464
465    #[test]
466    fn test_driver_not_found_from_multiple_default_names() -> MyResult<()> {
467        // Given a config with multiple drivers and multiple defaults.
468        let mut config = create_config("", vec![
469            ("alias1", "odbc=1", false),
470            ("alias2", "odbc=2", true),
471            ("alias3", "odbc=3", true),
472        ])?;
473        // Then the driver is not found from the defaults.
474        let error = config.find_driver(None).unwrap_err();
475        assert_eq!(error.to_string(), "Cannot select from 3 drivers");
476        Ok(())
477    }
478
479    #[test]
480    fn test_driver_is_found_from_single_configured_entry() -> MyResult<()> {
481        // Given a config with a single driver.
482        let mut config = create_config("", vec![
483            ("alias1", "odbc=1", false),
484        ])?;
485        // Then the driver is found from the entry.
486        let result = config.find_driver(None)?;
487        assert_eq!(result.0, "alias1");
488        assert_eq!(result.1.format_odbc(), "odbc=1");
489        Ok(())
490    }
491
492    fn create_config(source: &str, tuples: Vec<(&str, &str, bool)>) -> MyResult<Config<Cursor<Vec<u8>>, MockSystem>> {
493        let convert = |tuple: (&str, &str, bool)| {
494            let (alias, odbc, default) = tuple;
495            let alias = String::from(alias);
496            let driver = Driver::new(odbc)?.with_default(default);
497            Ok((alias, driver))
498        };
499        let system = MockSystem::new(source);
500        let drivers = tuples.into_iter().map(convert).collect::<MyResult<BTreeMap<_, _>>>()?;
501        let imported = false;
502        let phantom = PhantomData;
503        let config = Config { system, drivers, imported, phantom };
504        Ok(config)
505    }
506
507    fn partial_odbc(key: &str, value: &str) -> IndexMap<String, String> {
508        let mut odbc = IndexMap::new();
509        let key = String::from(key);
510        let value = String::from(value);
511        odbc.insert(key, value);
512        odbc
513    }
514
515    fn create_odbc(pairs: Vec<(&str, &str)>) -> IndexMap<String, String> {
516        let convert = |pair: (&str, &str)| {
517            let (key, value) = pair;
518            let key = String::from(key);
519            let value = String::from(value);
520            (key, value)
521        };
522        pairs.into_iter().map(convert).collect()
523    }
524
525    fn to_string(buffer: Cursor<Vec<u8>>) -> String {
526        String::from_utf8(buffer.into_inner()).unwrap_or_default()
527    }
528}