Skip to main content

tempest_repl/
lib.rs

1use std::{
2    io::{self, BufRead, Write},
3    marker::PhantomData,
4    path::PathBuf,
5};
6
7use derive_more::{Display, Error, From};
8use itertools::Itertools;
9use owo_colors::{OwoColorize, colors::css::Gray};
10use tabled::{builder::Builder as TabledBuilder, settings::Style as TabledStyle};
11use tempest_engine::{
12    Engine, EngineError,
13    catalog::{CatalogState, schema::{TypeId, VariantId}},
14    config::EngineConfig,
15    query::QueryResult,
16    types::TempestValue,
17};
18use tempest_io::Io;
19use tempest_tql::ParseError;
20
21use crate::stdio::Stdio;
22
23#[macro_use]
24extern crate tracing;
25
26pub mod stdio;
27
28#[derive(Debug, Display, Error, From)]
29pub enum ReplError {
30    Io(io::Error),
31    Engine(EngineError),
32}
33
34const REPL_INTRODUCTION: &str = r"This is the Tempest REPL: Read, Evaluate, Print, Loop.
35You can type in some special commands, starting with a dot,
36or you can enter in valid TQL statements!";
37
38fn format_value(val: &TempestValue, catalog: &CatalogState) -> String {
39    format_value_inner(val, catalog, false)
40}
41
42fn format_value_inner(val: &TempestValue, catalog: &CatalogState, quoted: bool) -> String {
43    match val {
44        TempestValue::String(s) if quoted => format!("{:?}", s.as_ref()),
45        TempestValue::Enum { type_id, variant_id, fields } => {
46            let type_schema = catalog.get_type(TypeId(*type_id));
47            let type_prefix = type_schema.map(|ts| match ts.database_id() {
48                Some(db_id) => format!("{}.{}", catalog.databases[&db_id].name, ts.name()),
49                None => ts.name().to_string(),
50            }).unwrap_or_else(|| type_id.to_string());
51            let variant_name = type_schema
52                .and_then(|ts| ts.as_enum())
53                .and_then(|e| e.variants.get(&VariantId(*variant_id)))
54                .map(|v| v.name.to_string())
55                .unwrap_or_else(|| variant_id.to_string());
56            if fields.is_empty() {
57                format!("{}.{}", type_prefix, variant_name)
58            } else {
59                let args: Vec<_> = fields.iter().map(|f| format_value_inner(f, catalog, true)).collect();
60                format!("{}.{}({})", type_prefix, variant_name, args.join(", "))
61            }
62        }
63        other => format!("{}", other),
64    }
65}
66
67pub struct Repl<I: Io, S: Stdio> {
68    data_dir: PathBuf,
69    config: EngineConfig,
70    stdio: S,
71    _marker: PhantomData<I>,
72}
73
74impl<I: Io, S: Stdio> Repl<I, S> {
75    pub fn new(data_dir: PathBuf, config: EngineConfig, stdio: S) -> Self {
76        Self {
77            data_dir,
78            config,
79            stdio,
80            _marker: PhantomData,
81        }
82    }
83
84    fn explain_command(&mut self, name: &str, description: &str) -> io::Result<()> {
85        writeln!(
86            self.stdio.stdout(),
87            "{} {}",
88            name.bright_green(),
89            description.fg::<Gray>()
90        )
91    }
92
93    fn show_help(&mut self) -> io::Result<()> {
94        writeln!(
95            self.stdio.stdout(),
96            "{}",
97            "List of available commands:".bright_green()
98        )?;
99        self.explain_command(".help | .h", "show this menu")?;
100        self.explain_command(".clear | .c", "clear the screen")?;
101        self.explain_command(".quit | .q", "terminate the REPL session")?;
102        self.explain_command(".databases | .dbs", "list all databases")?;
103        self.explain_command(
104            ".tables <database>",
105            "list all tables, optionally scoped to a database",
106        )?;
107        self.explain_command(
108            ".types <database>",
109            "list all types, optionally scoped to a database",
110        )?;
111        Ok(())
112    }
113
114    fn clear_screen(&mut self) -> io::Result<()> {
115        write!(self.stdio.stdout(), "\x1b[2J\x1b[H")
116    }
117
118    fn list_databases(&mut self, catalog: &CatalogState) -> io::Result<()> {
119        let mut builder = TabledBuilder::new();
120        builder.push_record(["database", "tables"].map(|s| format!("{}", s.bold())));
121        for db in catalog.databases.values() {
122            builder.push_record([db.name.to_string(), db.tables.len().to_string()]);
123        }
124        if catalog.databases.is_empty() {
125            builder.push_record([""]);
126        }
127        let mut table = builder.build();
128        table.with(TabledStyle::rounded());
129        writeln!(self.stdio.stdout(), "{table}")
130    }
131
132    fn list_types(&mut self, catalog: &CatalogState, database: Option<&str>) -> io::Result<()> {
133        let mut builder = TabledBuilder::new();
134        builder.push_record(["database", "type", "fields"].map(|s| format!("{}", s.bold())));
135        let types: Vec<_> = if let Some(database) = database {
136            catalog.types_in_database(database).collect()
137        } else {
138            catalog
139                .types
140                .iter()
141                .chain(catalog.global_types.iter())
142                .map(|(tid, schema)| (*tid, schema))
143                .collect()
144        };
145
146        if types.is_empty() {
147            builder.push_record([""]);
148        }
149        for (_, type_schema) in types {
150            use tempest_engine::catalog::schema::TypeSchema;
151            let db_name = match type_schema.database_id() {
152                Some(db_id) => catalog.databases[&db_id].name.to_string(),
153                None => "(global)".to_string(),
154            };
155            let members = match type_schema {
156                TypeSchema::Struct(s) => s.fields.values().map(|f| f.name.to_string()).join(", "),
157                TypeSchema::Enum(e) => {
158                    format!("enum {{ {} }}", e.variants.values().map(|v| v.name.to_string()).join(", "))
159                }
160            };
161            builder.push_record([
162                db_name,
163                type_schema.name().to_string(),
164                members,
165            ]);
166        }
167
168        let mut table = builder.build();
169        table.with(TabledStyle::rounded());
170
171        writeln!(self.stdio.stdout(), "{table}")
172    }
173
174    fn list_tables(&mut self, catalog: &CatalogState, database: Option<&str>) -> io::Result<()> {
175        let mut builder = TabledBuilder::new();
176        builder.push_record(
177            ["database", "table", "type", "columns", "primary key"].map(|s| format!("{}", s.bold())),
178        );
179        let tables: Vec<_> = if let Some(database) = database {
180            catalog.tables_in_database(database).collect()
181        } else {
182            catalog
183                .tables
184                .iter()
185                .map(|(tid, schema)| (*tid, schema))
186                .collect()
187        };
188
189        if tables.is_empty() {
190            builder.push_record([""]);
191        }
192        for (_, table_schema) in tables {
193            let database = &catalog.databases[&table_schema.database_id].name;
194            let type_schema = catalog.get_type(table_schema.type_id);
195            let type_name = type_schema.map(|ts| match ts.database_id() {
196                Some(db_id) => format!("{}.{}", catalog.databases[&db_id].name, ts.name()),
197                None => ts.name().to_string(),
198            }).unwrap_or_else(|| "?".to_string());
199            let columns = type_schema
200                .and_then(|ts| ts.as_struct())
201                .map(|s| s.fields.len())
202                .unwrap_or(0);
203            builder.push_record(vec![
204                database.to_string(),
205                table_schema.name.to_string(),
206                type_name,
207                columns.to_string(),
208                table_schema
209                    .primary_key
210                    .iter()
211                    .map(|path| catalog.pk_path_name(path, table_schema))
212                    .join(", "),
213            ]);
214        }
215
216        let mut table = builder.build();
217        table.with(TabledStyle::rounded());
218        writeln!(self.stdio.stdout(), "{table}")
219    }
220
221    fn print_query_results(&mut self, results: Vec<QueryResult>, catalog: &CatalogState) -> io::Result<()> {
222        for res in results {
223            match res {
224                QueryResult::Rows { columns, rows } => {
225                    let mut builder = TabledBuilder::new();
226                    builder.push_record(columns.iter().map(|col| format!("{}", col.bold())));
227                    for row in rows.iter().map(|row| row.iter().map(|v| format_value(v, catalog))) {
228                        builder.push_record(row);
229                    }
230                    if rows.is_empty() {
231                        builder.push_record([""]);
232                    }
233                    let mut table = builder.build();
234                    table.with(TabledStyle::rounded());
235                    writeln!(self.stdio.stdout(), "{table}")?;
236                }
237                QueryResult::Empty => {}
238                QueryResult::RowsChanged(n) => {
239                    writeln!(
240                        self.stdio.stdout(),
241                        "{}",
242                        format!("Rows changed: {}", n).bold().green()
243                    )?;
244                }
245            }
246        }
247        Ok(())
248    }
249
250    fn print_parse_errors(&mut self, source: &str, errors: &[ParseError]) -> io::Result<()> {
251        use ariadne::{Color, Label, Report, ReportKind, Source};
252
253        for error in errors {
254            Report::build(ReportKind::Error, ("<repl>", error.span.clone()))
255                .with_message("parse error")
256                .with_label(
257                    Label::new(("<repl>", error.span.clone()))
258                        .with_message(format!("{}", error.kind))
259                        .with_color(Color::Red),
260                )
261                .finish()
262                .write(("<repl>", Source::from(source)), self.stdio.stdout())?;
263        }
264
265        Ok(())
266    }
267
268    pub async fn run(&mut self) -> Result<(), ReplError> {
269        let mut engine = Engine::<I>::open(self.data_dir.clone(), self.config.clone()).await?;
270        debug!("starting repl shell");
271
272        writeln!(
273            self.stdio.stdout(),
274            "{}",
275            "-- TempestDB REPL --".bright_cyan().bold()
276        )?;
277        writeln!(self.stdio.stdout(), "{}", REPL_INTRODUCTION.bright_cyan())?;
278        self.show_help()?;
279
280        let mut buf = String::new();
281        let mut interrupts = 0;
282        loop {
283            buf.clear();
284            self.stdio.stdout().flush()?;
285            match self.stdio.stdin().read_line(&mut buf) {
286                Ok(0) => break, // EOF by Ctrl-D
287                Ok(_) => interrupts = 0,
288                Err(e) if e.kind() == io::ErrorKind::Interrupted => {
289                    interrupts += 1;
290                    if interrupts >= 2 {
291                        break;
292                    }
293                    writeln!(self.stdio.stdout(), "(press Ctrl-C again to exit)")?;
294                    continue;
295                }
296                Err(e) => return Err(e.into()),
297            }
298
299            let cmd = buf.trim();
300            if cmd.len() == 0 {
301                continue;
302            }
303
304            if cmd.starts_with(".") {
305                let args: Vec<_> = cmd.split_whitespace().collect();
306                match args[0] {
307                    ".help" | ".h" => self.show_help()?,
308                    ".clear" | ".c" => self.clear_screen()?,
309                    ".quit" | ".q" => break,
310                    ".databases" | ".dbs" => self.list_databases(engine.catalog())?,
311                    ".types" => self.list_types(engine.catalog(), args.get(1).copied())?,
312                    ".tables" => self.list_tables(engine.catalog(), args.get(1).copied())?,
313                    _ => {
314                        writeln!(
315                            self.stdio.stdout(),
316                            "{} `{}`",
317                            "unknown command:".bright_red(),
318                            cmd
319                        )?;
320                        writeln!(
321                            self.stdio.stdout(),
322                            "{}",
323                            "type .help to show available commands"
324                                .bright_green()
325                                .italic()
326                        )?;
327                    }
328                }
329            } else {
330                match engine.execute(cmd).await {
331                    Ok(results) => self.print_query_results(results, engine.catalog())?,
332                    Err(err) => {
333                        writeln!(
334                            self.stdio.stdout(),
335                            "{}",
336                            "Failed to execute query:".bright_red().bold()
337                        )?;
338                        match err {
339                            EngineError::Parse(parse_errors) => {
340                                self.print_parse_errors(cmd, &parse_errors)?;
341                            }
342                            _ => writeln!(
343                                self.stdio.stdout(),
344                                "{}",
345                                format!("{}", err).bright_red()
346                            )?,
347                        }
348                    }
349                }
350            }
351
352            self.stdio.push_history(cmd);
353        }
354
355        self.stdio.stdout().flush()?;
356        Ok(())
357    }
358}
359
360pub async fn repl<I: Io, S: Stdio>(
361    data_dir: PathBuf,
362    config: EngineConfig,
363    stdio: S,
364) -> Result<(), ReplError> {
365    Repl::<I, S>::new(data_dir, config, stdio).run().await
366}