Skip to main content

vortex_tui/browse/ui/
query.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::RecordBatch;
5use ratatui::buffer::Buffer;
6use ratatui::layout::Constraint;
7use ratatui::layout::Layout;
8use ratatui::layout::Rect;
9use ratatui::style::Color;
10use ratatui::style::Style;
11use ratatui::text::Line;
12use ratatui::text::Span;
13use ratatui::widgets::Block;
14use ratatui::widgets::BorderType;
15use ratatui::widgets::Borders;
16use ratatui::widgets::Cell;
17use ratatui::widgets::Paragraph;
18use ratatui::widgets::Row;
19use ratatui::widgets::Scrollbar;
20use ratatui::widgets::ScrollbarOrientation;
21use ratatui::widgets::ScrollbarState;
22use ratatui::widgets::StatefulWidget;
23use ratatui::widgets::Table;
24use ratatui::widgets::TableState;
25use ratatui::widgets::Widget;
26use tokio::sync::oneshot;
27use vortex::session::VortexSession;
28
29use crate::browse::app::AppState;
30use crate::datafusion_helper::arrow_value_to_json;
31use crate::datafusion_helper::execute_vortex_query;
32use crate::datafusion_helper::json_value_to_display;
33
34/// Sort direction for table columns.
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub enum SortDirection {
37    /// No sorting applied.
38    #[default]
39    None,
40    /// Sort in ascending order.
41    Ascending,
42    /// Sort in descending order.
43    Descending,
44}
45
46impl SortDirection {
47    /// Cycle to the next sort direction: None -> Ascending -> Descending -> None.
48    pub fn cycle(self) -> Self {
49        match self {
50            SortDirection::None => SortDirection::Ascending,
51            SortDirection::Ascending => SortDirection::Descending,
52            SortDirection::Descending => SortDirection::None,
53        }
54    }
55
56    /// Get the sort direction indicator character for display.
57    pub fn indicator(self) -> &'static str {
58        match self {
59            SortDirection::None => "",
60            SortDirection::Ascending => " ▲",
61            SortDirection::Descending => " ▼",
62        }
63    }
64}
65
66/// Focus state within the Query tab.
67#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
68pub enum QueryFocus {
69    /// Focus is on the SQL input field.
70    #[default]
71    SqlInput,
72    /// Focus is on the results table.
73    ResultsTable,
74}
75
76/// Result from a background query task.
77pub(crate) struct PendingQueryResult {
78    pub row_count: Option<Result<usize, String>>,
79    pub query_result: Result<QueryResults, String>,
80}
81
82/// State for the SQL query interface.
83pub struct QueryState {
84    /// The SQL query input text.
85    pub sql_input: String,
86    /// Cursor position in the SQL input.
87    pub cursor_position: usize,
88    /// Current focus within the Query tab.
89    pub focus: QueryFocus,
90    /// Query results as RecordBatches.
91    pub results: Option<QueryResults>,
92    /// Error message if query failed.
93    pub error: Option<String>,
94    /// Whether a query is currently running.
95    pub running: bool,
96    /// Table state for the results view.
97    pub table_state: TableState,
98    /// Horizontal scroll offset for the results table.
99    pub horizontal_scroll: usize,
100    /// Column being sorted (if any).
101    pub sort_column: Option<usize>,
102    /// Sort direction.
103    pub sort_direction: SortDirection,
104    /// Current page (0-indexed).
105    pub current_page: usize,
106    /// Rows per page (parsed from LIMIT clause).
107    pub page_size: usize,
108    /// Total row count from COUNT(*) query.
109    pub total_row_count: Option<usize>,
110    /// Base SQL query (without LIMIT/OFFSET) for pagination.
111    pub base_query: String,
112    /// ORDER BY clause if any.
113    pub order_clause: Option<String>,
114    /// Whether a query execution is pending (needs to be spawned).
115    pending_execution: bool,
116    /// Whether a row count query is needed on next spawn.
117    needs_row_count: bool,
118    /// Receiver for in-flight background query result.
119    pub(crate) pending_rx: Option<oneshot::Receiver<PendingQueryResult>>,
120}
121
122impl Default for QueryState {
123    fn default() -> Self {
124        let default_sql = "SELECT * FROM data LIMIT 20";
125        Self {
126            sql_input: default_sql.to_string(),
127            cursor_position: default_sql.len(),
128            focus: QueryFocus::default(),
129            results: None,
130            error: None,
131            running: false,
132            table_state: TableState::default(),
133            horizontal_scroll: 0,
134            sort_column: None,
135            sort_direction: SortDirection::default(),
136            current_page: 0,
137            page_size: 20,
138            total_row_count: None,
139            base_query: "SELECT * FROM data".to_string(),
140            order_clause: None,
141            pending_execution: false,
142            needs_row_count: false,
143            pending_rx: None,
144        }
145    }
146}
147
148impl QueryState {
149    /// Insert a character at the cursor position.
150    pub fn insert_char(&mut self, c: char) {
151        self.sql_input.insert(self.cursor_position, c);
152        self.cursor_position += 1;
153    }
154
155    /// Delete the character before the cursor.
156    pub fn delete_char(&mut self) {
157        if self.cursor_position > 0 {
158            self.cursor_position -= 1;
159            self.sql_input.remove(self.cursor_position);
160        }
161    }
162
163    /// Delete the character at the cursor.
164    pub fn delete_char_forward(&mut self) {
165        if self.cursor_position < self.sql_input.len() {
166            self.sql_input.remove(self.cursor_position);
167        }
168    }
169
170    /// Move cursor left.
171    pub fn move_cursor_left(&mut self) {
172        self.cursor_position = self.cursor_position.saturating_sub(1);
173    }
174
175    /// Move cursor right.
176    pub fn move_cursor_right(&mut self) {
177        if self.cursor_position < self.sql_input.len() {
178            self.cursor_position += 1;
179        }
180    }
181
182    /// Move cursor to start.
183    pub fn move_cursor_start(&mut self) {
184        self.cursor_position = 0;
185    }
186
187    /// Move cursor to end.
188    pub fn move_cursor_end(&mut self) {
189        self.cursor_position = self.sql_input.len();
190    }
191
192    /// Clear the SQL input.
193    pub fn clear_input(&mut self) {
194        self.sql_input.clear();
195        self.cursor_position = 0;
196    }
197
198    /// Toggle focus between SQL input and results table.
199    pub fn toggle_focus(&mut self) {
200        self.focus = match self.focus {
201            QueryFocus::SqlInput => QueryFocus::ResultsTable,
202            QueryFocus::ResultsTable => QueryFocus::SqlInput,
203        };
204    }
205
206    /// Prepare initial query - parses SQL, sets flags for async execution.
207    pub fn prepare_initial_query(&mut self) {
208        self.error = None;
209
210        // Parse the SQL to extract base query, order clause, and page size
211        let (base_sql, order_clause, limit) = self.parse_sql_parts();
212        self.base_query = base_sql;
213        self.order_clause = order_clause;
214        self.page_size = limit.unwrap_or(20);
215        self.current_page = 0;
216
217        self.needs_row_count = true;
218        self.rebuild_sql();
219    }
220
221    /// Prepare navigation to next page.
222    pub fn prepare_next_page(&mut self) {
223        let total_pages = self.total_pages();
224        if self.current_page + 1 < total_pages {
225            self.current_page += 1;
226            self.rebuild_sql();
227        }
228    }
229
230    /// Prepare navigation to previous page.
231    pub fn prepare_prev_page(&mut self) {
232        if self.current_page > 0 {
233            self.current_page -= 1;
234            self.rebuild_sql();
235        }
236    }
237
238    /// Get total number of pages.
239    pub fn total_pages(&self) -> usize {
240        match self.total_row_count {
241            Some(total) if total > 0 => total.div_ceil(self.page_size),
242            _ => 1,
243        }
244    }
245
246    /// Build SQL query from current state and set the pending execution flag.
247    fn rebuild_sql(&mut self) {
248        let offset = self.current_page * self.page_size;
249
250        let new_sql = match &self.order_clause {
251            Some(order) => {
252                format!(
253                    "{} {} LIMIT {} OFFSET {}",
254                    self.base_query, order, self.page_size, offset
255                )
256            }
257            None => {
258                format!(
259                    "{} LIMIT {} OFFSET {}",
260                    self.base_query, self.page_size, offset
261                )
262            }
263        };
264
265        self.sql_input = new_sql;
266        self.cursor_position = self.sql_input.len();
267
268        self.running = true;
269        self.error = None;
270        self.pending_execution = true;
271    }
272
273    /// Spawn a background task for the pending query, if any.
274    ///
275    /// After calling `prepare_*` methods, call this to kick off execution.
276    /// The result will arrive on [`pending_rx`] and should be applied with
277    /// [`apply_query_result`].
278    pub(crate) fn spawn_pending(&mut self, session: &VortexSession, file_path: &str) {
279        if !self.pending_execution {
280            return;
281        }
282        self.pending_execution = false;
283
284        let (tx, rx) = oneshot::channel();
285        let session = session.clone();
286        let file_path = file_path.to_string();
287        let sql = self.sql_input.clone();
288        let base_query = self.base_query.clone();
289        let needs_row_count = self.needs_row_count;
290        self.needs_row_count = false;
291
292        tokio::spawn(async move {
293            let row_count = match needs_row_count {
294                true => Some(get_row_count(&session, &file_path, &base_query).await),
295                false => None,
296            };
297            let query_result = execute_query(&session, &file_path, &sql).await;
298            drop(tx.send(PendingQueryResult {
299                row_count,
300                query_result,
301            }));
302        });
303
304        self.pending_rx = Some(rx);
305    }
306
307    /// Apply a completed background query result to the state.
308    pub(crate) fn apply_query_result(&mut self, result: PendingQueryResult) {
309        if let Some(row_count) = result.row_count {
310            self.total_row_count = row_count.ok();
311        }
312        match result.query_result {
313            Ok(results) => {
314                self.results = Some(results);
315                self.table_state.select(Some(0));
316            }
317            Err(e) => {
318                self.error = Some(e);
319            }
320        }
321        self.running = false;
322    }
323
324    /// Parse SQL to extract base query, ORDER BY clause, and LIMIT value.
325    fn parse_sql_parts(&self) -> (String, Option<String>, Option<usize>) {
326        let sql = &self.sql_input;
327        let sql_upper = sql.to_uppercase();
328
329        // Find positions of clauses
330        let order_idx = sql_upper.find(" ORDER BY ");
331        let limit_idx = sql_upper.find(" LIMIT ");
332        let offset_idx = sql_upper.find(" OFFSET ");
333
334        // Extract limit value if present
335        let limit_value = if let Some(li) = limit_idx {
336            let after_limit = &sql[li + 7..]; // Skip " LIMIT "
337            let end_idx = after_limit
338                .find(|c: char| !c.is_ascii_digit() && c != ' ')
339                .unwrap_or(after_limit.len());
340            after_limit[..end_idx].trim().parse::<usize>().ok()
341        } else {
342            None
343        };
344
345        // Find the earliest of LIMIT or OFFSET to know where to cut
346        let cut_idx = match (limit_idx, offset_idx) {
347            (Some(li), Some(oi)) => Some(li.min(oi)),
348            (Some(li), None) => Some(li),
349            (None, Some(oi)) => Some(oi),
350            (None, None) => None,
351        };
352
353        match (order_idx, cut_idx) {
354            (Some(oi), Some(ci)) if oi < ci => {
355                // ORDER BY comes before LIMIT/OFFSET
356                let base = sql[..oi].trim().to_string();
357                let order = sql[oi..ci].trim().to_string();
358                (base, Some(order), limit_value)
359            }
360            (Some(oi), None) => {
361                // Only ORDER BY, no LIMIT/OFFSET
362                let base = sql[..oi].trim().to_string();
363                let order = sql[oi..].trim().to_string();
364                (base, Some(order), limit_value)
365            }
366            (None, Some(ci)) => {
367                // No ORDER BY, just LIMIT/OFFSET
368                let base = sql[..ci].trim().to_string();
369                (base, None, limit_value)
370            }
371            (Some(_oi), Some(ci)) => {
372                // ORDER BY comes after LIMIT (unusual) - just cut at LIMIT
373                let base = sql[..ci].trim().to_string();
374                (base, None, limit_value)
375            }
376            (None, None) => {
377                // No ORDER BY or LIMIT/OFFSET
378                (sql.clone(), None, limit_value)
379            }
380        }
381    }
382
383    /// Get the currently selected column index.
384    pub fn selected_column(&self) -> usize {
385        self.horizontal_scroll
386    }
387
388    /// Total number of columns in results.
389    pub fn column_count(&self) -> usize {
390        self.results
391            .as_ref()
392            .and_then(|r| r.batches.first())
393            .map(|b| b.num_columns())
394            .unwrap_or(0)
395    }
396
397    /// Prepare sort on a column by modifying the ORDER BY clause and setting execution flag.
398    pub fn prepare_sort(&mut self, column: usize) {
399        // Get the column name from results
400        let column_name = match &self.results {
401            Some(results) if column < results.column_names.len() => {
402                results.column_names[column].clone()
403            }
404            _ => return,
405        };
406
407        // Cycle sort direction
408        if self.sort_column == Some(column) {
409            self.sort_direction = self.sort_direction.cycle();
410            if self.sort_direction == SortDirection::None {
411                self.sort_column = None;
412            }
413        } else {
414            self.sort_column = Some(column);
415            self.sort_direction = SortDirection::Ascending;
416        }
417
418        // Update the ORDER BY clause
419        self.order_clause = if self.sort_direction == SortDirection::None {
420            None
421        } else {
422            let direction = match self.sort_direction {
423                SortDirection::Ascending => "ASC",
424                SortDirection::Descending => "DESC",
425                SortDirection::None => unreachable!(),
426            };
427            Some(format!("ORDER BY \"{column_name}\" {direction}"))
428        };
429
430        // Reset to first page and set pending execution
431        self.current_page = 0;
432        self.rebuild_sql();
433    }
434}
435
436/// Holds query results for display.
437pub struct QueryResults {
438    pub batches: Vec<RecordBatch>,
439    pub total_rows: usize,
440    pub column_names: Vec<String>,
441}
442
443/// Execute a SQL query against the Vortex file.
444async fn execute_query(
445    session: &VortexSession,
446    file_path: &str,
447    sql: &str,
448) -> Result<QueryResults, String> {
449    let batches = execute_vortex_query(session, file_path, sql).await?;
450
451    let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
452
453    let column_names = if let Some(batch) = batches.first() {
454        let schema = batch.schema();
455        schema.fields().iter().map(|f| f.name().clone()).collect()
456    } else {
457        vec![]
458    };
459
460    Ok(QueryResults {
461        batches,
462        total_rows,
463        column_names,
464    })
465}
466
467/// Get total row count for a base query using COUNT(*).
468async fn get_row_count(
469    session: &VortexSession,
470    file_path: &str,
471    base_query: &str,
472) -> Result<usize, String> {
473    let count_sql = format!("SELECT COUNT(*) as count FROM ({base_query}) AS subquery");
474
475    let batches = execute_vortex_query(session, file_path, &count_sql).await?;
476
477    // Extract count from result
478    if let Some(batch) = batches.first()
479        && batch.num_rows() > 0
480        && batch.num_columns() > 0
481    {
482        use arrow_array::Int64Array;
483        if let Some(arr) = batch.column(0).as_any().downcast_ref::<Int64Array>() {
484            #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
485            return Ok(arr.value(0) as usize);
486        }
487    }
488
489    Ok(0)
490}
491
492/// Render the Query tab UI.
493pub fn render_query(app: &mut AppState, area: Rect, buf: &mut Buffer) {
494    let [input_area, results_area] =
495        Layout::vertical([Constraint::Length(5), Constraint::Min(10)]).areas(area);
496
497    render_sql_input(app, input_area, buf);
498    render_results_table(app, results_area, buf);
499}
500
501fn render_sql_input(app: &mut AppState, area: Rect, buf: &mut Buffer) {
502    let is_focused = app.query_state.focus == QueryFocus::SqlInput;
503
504    let border_color = if is_focused {
505        Color::Cyan
506    } else {
507        Color::DarkGray
508    };
509
510    let block = Block::default()
511        .title("SQL Query (Enter to execute, Esc to switch focus)")
512        .borders(Borders::ALL)
513        .border_type(BorderType::Rounded)
514        .border_style(Style::default().fg(border_color));
515
516    let inner = block.inner(area);
517    block.render(area, buf);
518
519    // Create the input text with cursor
520    let sql = &app.query_state.sql_input;
521    let cursor_pos = app.query_state.cursor_position;
522
523    let (before_cursor, after_cursor) = sql.split_at(cursor_pos.min(sql.len()));
524
525    let first_char = after_cursor.chars().next();
526    let cursor_char = if is_focused {
527        match first_char {
528            None => Span::styled(" ", Style::default().bg(Color::White).fg(Color::Black)),
529            Some(c) => Span::styled(
530                c.to_string(),
531                Style::default().bg(Color::White).fg(Color::Black),
532            ),
533        }
534    } else {
535        match first_char {
536            None => Span::raw(""),
537            Some(c) => Span::raw(c.to_string()),
538        }
539    };
540
541    let rest = match first_char {
542        Some(c) if after_cursor.len() > c.len_utf8() => &after_cursor[c.len_utf8()..],
543        _ => "",
544    };
545
546    let line = Line::from(vec![Span::raw(before_cursor), cursor_char, Span::raw(rest)]);
547
548    let paragraph = Paragraph::new(line).style(Style::default().fg(Color::White));
549
550    paragraph.render(inner, buf);
551}
552
553fn render_results_table(app: &mut AppState, area: Rect, buf: &mut Buffer) {
554    let is_focused = app.query_state.focus == QueryFocus::ResultsTable;
555
556    let border_color = if is_focused {
557        Color::Cyan
558    } else {
559        Color::DarkGray
560    };
561
562    // Show status in title
563    let title = if app.query_state.running {
564        "Results (running...)".to_string()
565    } else if let Some(ref error) = app.query_state.error {
566        format!("Results (error: {})", truncate_str(error, 50))
567    } else if let Some(ref _results) = app.query_state.results {
568        let total_rows = app.query_state.total_row_count.unwrap_or(0);
569        let total_pages = app.query_state.total_pages();
570        format!(
571            "Results ({} rows, page {}/{}) [hjkl navigate, [/] pages, s sort]",
572            total_rows,
573            app.query_state.current_page + 1,
574            total_pages,
575        )
576    } else {
577        "Results (press Enter to execute query)".to_string()
578    };
579
580    let block = Block::default()
581        .title(title)
582        .borders(Borders::ALL)
583        .border_type(BorderType::Rounded)
584        .border_style(Style::default().fg(border_color));
585
586    let inner = block.inner(area);
587    block.render(area, buf);
588
589    if let Some(ref error) = app.query_state.error {
590        let error_text = Paragraph::new(error.as_str())
591            .style(Style::default().fg(Color::Red))
592            .wrap(ratatui::widgets::Wrap { trim: true });
593        error_text.render(inner, buf);
594        return;
595    }
596
597    let Some(ref results) = app.query_state.results else {
598        let help = Paragraph::new("Enter a SQL query above and press Enter to execute.\nThe table is available as 'data'.\n\nExample: SELECT * FROM data WHERE column > 10 LIMIT 100")
599            .style(Style::default().fg(Color::Gray));
600        help.render(inner, buf);
601        return;
602    };
603
604    if results.batches.is_empty() || results.total_rows == 0 {
605        let empty =
606            Paragraph::new("Query returned no results.").style(Style::default().fg(Color::Yellow));
607        empty.render(inner, buf);
608        return;
609    }
610
611    // Build header row with sort indicators
612    let header_cells: Vec<Cell> = results
613        .column_names
614        .iter()
615        .enumerate()
616        .map(|(i, name)| {
617            let indicator = if app.query_state.sort_column == Some(i) {
618                app.query_state.sort_direction.indicator()
619            } else {
620                ""
621            };
622
623            let style = if is_focused && i == app.query_state.horizontal_scroll {
624                Style::default().fg(Color::Black).bg(Color::Cyan).bold()
625            } else {
626                Style::default().fg(Color::Green).bold()
627            };
628
629            Cell::from(format!("{name}{indicator}")).style(style)
630        })
631        .collect();
632
633    let header = Row::new(header_cells).height(1);
634
635    // Since we use LIMIT/OFFSET in SQL, batches contain only the current page's data
636    // Display all rows from the batches
637    let rows = get_all_rows(results, &app.query_state);
638
639    // Calculate column widths
640    #[allow(clippy::cast_possible_truncation)]
641    let widths: Vec<Constraint> = results
642        .column_names
643        .iter()
644        .map(|name| Constraint::Min((name.len() + 3).max(10) as u16))
645        .collect();
646
647    let table = Table::new(rows, widths)
648        .header(header)
649        .row_highlight_style(Style::default().bg(Color::DarkGray));
650
651    // Split area for table and scrollbar
652    let [table_area, scrollbar_area] =
653        Layout::horizontal([Constraint::Min(0), Constraint::Length(1)]).areas(inner);
654
655    StatefulWidget::render(table, table_area, buf, &mut app.query_state.table_state);
656
657    // Render vertical scrollbar
658    let total_pages = app.query_state.total_pages();
659    if total_pages > 1 {
660        let mut scrollbar_state = ScrollbarState::new(total_pages)
661            .position(app.query_state.current_page)
662            .viewport_content_length(1);
663
664        Scrollbar::new(ScrollbarOrientation::VerticalRight)
665            .begin_symbol(Some("▲"))
666            .end_symbol(Some("▼"))
667            .render(scrollbar_area, buf, &mut scrollbar_state);
668    }
669}
670
671/// Get all rows from batches (pagination is handled via SQL LIMIT/OFFSET).
672fn get_all_rows<'a>(results: &'a QueryResults, query_state: &QueryState) -> Vec<Row<'a>> {
673    let mut rows = Vec::new();
674
675    for batch in &results.batches {
676        for row_idx in 0..batch.num_rows() {
677            let cells: Vec<Cell> = (0..batch.num_columns())
678                .map(|col_idx| {
679                    let json_value = arrow_value_to_json(batch.column(col_idx).as_ref(), row_idx);
680                    let value = json_value_to_display(json_value);
681                    let style = if query_state.sort_column == Some(col_idx) {
682                        Style::default().fg(Color::Cyan)
683                    } else {
684                        Style::default()
685                    };
686                    Cell::from(truncate_str(&value, 30).to_string()).style(style)
687                })
688                .collect();
689            rows.push(Row::new(cells));
690        }
691    }
692
693    rows
694}
695
696fn truncate_str(s: &str, max_len: usize) -> &str {
697    if s.len() <= max_len {
698        s
699    } else {
700        &s[..max_len.saturating_sub(3)]
701    }
702}