Skip to main content

vortex_tui/browse/ui/
query.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::RecordBatch;
5use ratatui::buffer::Buffer;
6use ratatui::layout::Constraint;
7use ratatui::layout::Layout;
8use ratatui::layout::Rect;
9use ratatui::style::Color;
10use ratatui::style::Style;
11use ratatui::text::Line;
12use ratatui::text::Span;
13use ratatui::widgets::Block;
14use ratatui::widgets::BorderType;
15use ratatui::widgets::Borders;
16use ratatui::widgets::Cell;
17use ratatui::widgets::Paragraph;
18use ratatui::widgets::Row;
19use ratatui::widgets::Scrollbar;
20use ratatui::widgets::ScrollbarOrientation;
21use ratatui::widgets::ScrollbarState;
22use ratatui::widgets::StatefulWidget;
23use ratatui::widgets::Table;
24use ratatui::widgets::TableState;
25use ratatui::widgets::Widget;
26use tokio::runtime::Handle;
27use tokio::task::block_in_place;
28
29use crate::browse::app::AppState;
30use crate::datafusion_helper::arrow_value_to_json;
31use crate::datafusion_helper::execute_vortex_query;
32use crate::datafusion_helper::json_value_to_display;
33
34/// Sort direction for table columns.
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub enum SortDirection {
37    /// No sorting applied.
38    #[default]
39    None,
40    /// Sort in ascending order.
41    Ascending,
42    /// Sort in descending order.
43    Descending,
44}
45
46impl SortDirection {
47    /// Cycle to the next sort direction: None -> Ascending -> Descending -> None.
48    pub fn cycle(self) -> Self {
49        match self {
50            SortDirection::None => SortDirection::Ascending,
51            SortDirection::Ascending => SortDirection::Descending,
52            SortDirection::Descending => SortDirection::None,
53        }
54    }
55
56    /// Get the sort direction indicator character for display.
57    pub fn indicator(self) -> &'static str {
58        match self {
59            SortDirection::None => "",
60            SortDirection::Ascending => " ▲",
61            SortDirection::Descending => " ▼",
62        }
63    }
64}
65
66/// Focus state within the Query tab.
67#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
68pub enum QueryFocus {
69    /// Focus is on the SQL input field.
70    #[default]
71    SqlInput,
72    /// Focus is on the results table.
73    ResultsTable,
74}
75
76/// State for the SQL query interface.
77pub struct QueryState {
78    /// The SQL query input text.
79    pub sql_input: String,
80    /// Cursor position in the SQL input.
81    pub cursor_position: usize,
82    /// Current focus within the Query tab.
83    pub focus: QueryFocus,
84    /// Query results as RecordBatches.
85    pub results: Option<QueryResults>,
86    /// Error message if query failed.
87    pub error: Option<String>,
88    /// Whether a query is currently running.
89    pub running: bool,
90    /// Table state for the results view.
91    pub table_state: TableState,
92    /// Horizontal scroll offset for the results table.
93    pub horizontal_scroll: usize,
94    /// Column being sorted (if any).
95    pub sort_column: Option<usize>,
96    /// Sort direction.
97    pub sort_direction: SortDirection,
98    /// Current page (0-indexed).
99    pub current_page: usize,
100    /// Rows per page (parsed from LIMIT clause).
101    pub page_size: usize,
102    /// Total row count from COUNT(*) query.
103    pub total_row_count: Option<usize>,
104    /// Base SQL query (without LIMIT/OFFSET) for pagination.
105    pub base_query: String,
106    /// ORDER BY clause if any.
107    pub order_clause: Option<String>,
108}
109
110impl Default for QueryState {
111    fn default() -> Self {
112        let default_sql = "SELECT * FROM data LIMIT 20";
113        Self {
114            sql_input: default_sql.to_string(),
115            cursor_position: default_sql.len(),
116            focus: QueryFocus::default(),
117            results: None,
118            error: None,
119            running: false,
120            table_state: TableState::default(),
121            horizontal_scroll: 0,
122            sort_column: None,
123            sort_direction: SortDirection::default(),
124            current_page: 0,
125            page_size: 20,
126            total_row_count: None,
127            base_query: "SELECT * FROM data".to_string(),
128            order_clause: None,
129        }
130    }
131}
132
133impl QueryState {
134    /// Insert a character at the cursor position.
135    pub fn insert_char(&mut self, c: char) {
136        self.sql_input.insert(self.cursor_position, c);
137        self.cursor_position += 1;
138    }
139
140    /// Delete the character before the cursor.
141    pub fn delete_char(&mut self) {
142        if self.cursor_position > 0 {
143            self.cursor_position -= 1;
144            self.sql_input.remove(self.cursor_position);
145        }
146    }
147
148    /// Delete the character at the cursor.
149    pub fn delete_char_forward(&mut self) {
150        if self.cursor_position < self.sql_input.len() {
151            self.sql_input.remove(self.cursor_position);
152        }
153    }
154
155    /// Move cursor left.
156    pub fn move_cursor_left(&mut self) {
157        self.cursor_position = self.cursor_position.saturating_sub(1);
158    }
159
160    /// Move cursor right.
161    pub fn move_cursor_right(&mut self) {
162        if self.cursor_position < self.sql_input.len() {
163            self.cursor_position += 1;
164        }
165    }
166
167    /// Move cursor to start.
168    pub fn move_cursor_start(&mut self) {
169        self.cursor_position = 0;
170    }
171
172    /// Move cursor to end.
173    pub fn move_cursor_end(&mut self) {
174        self.cursor_position = self.sql_input.len();
175    }
176
177    /// Clear the SQL input.
178    pub fn clear_input(&mut self) {
179        self.sql_input.clear();
180        self.cursor_position = 0;
181    }
182
183    /// Toggle focus between SQL input and results table.
184    pub fn toggle_focus(&mut self) {
185        self.focus = match self.focus {
186            QueryFocus::SqlInput => QueryFocus::ResultsTable,
187            QueryFocus::ResultsTable => QueryFocus::SqlInput,
188        };
189    }
190
191    /// Execute initial query - parses SQL, gets total count, fetches first page.
192    pub fn execute_initial_query(
193        &mut self,
194        session: &vortex::session::VortexSession,
195        file_path: &str,
196    ) {
197        self.running = true;
198        self.error = None;
199
200        // Parse the SQL to extract base query, order clause, and page size
201        let (base_sql, order_clause, limit) = self.parse_sql_parts();
202        self.base_query = base_sql;
203        self.order_clause = order_clause;
204        self.page_size = limit.unwrap_or(20);
205        self.current_page = 0;
206
207        // Get total row count
208        self.total_row_count = get_row_count(session, file_path, &self.base_query).ok();
209
210        // Build and execute the query
211        self.rebuild_and_execute(session, file_path);
212    }
213
214    /// Navigate to next page.
215    pub fn next_page(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
216        let total_pages = self.total_pages();
217        if self.current_page + 1 < total_pages {
218            self.current_page += 1;
219            self.rebuild_and_execute(session, file_path);
220        }
221    }
222
223    /// Navigate to previous page.
224    pub fn prev_page(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
225        if self.current_page > 0 {
226            self.current_page -= 1;
227            self.rebuild_and_execute(session, file_path);
228        }
229    }
230
231    /// Get total number of pages.
232    pub fn total_pages(&self) -> usize {
233        match self.total_row_count {
234            Some(total) if total > 0 => total.div_ceil(self.page_size),
235            _ => 1,
236        }
237    }
238
239    /// Build SQL query from current state and execute it.
240    fn rebuild_and_execute(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
241        let offset = self.current_page * self.page_size;
242
243        let new_sql = match &self.order_clause {
244            Some(order) => {
245                format!(
246                    "{} {} LIMIT {} OFFSET {}",
247                    self.base_query, order, self.page_size, offset
248                )
249            }
250            None => {
251                format!(
252                    "{} LIMIT {} OFFSET {}",
253                    self.base_query, self.page_size, offset
254                )
255            }
256        };
257
258        self.sql_input = new_sql;
259        self.cursor_position = self.sql_input.len();
260
261        self.running = true;
262        self.error = None;
263
264        match execute_query(session, file_path, &self.sql_input) {
265            Ok(results) => {
266                self.results = Some(results);
267                self.table_state.select(Some(0));
268            }
269            Err(e) => {
270                self.error = Some(e);
271            }
272        }
273        self.running = false;
274    }
275
276    /// Parse SQL to extract base query, ORDER BY clause, and LIMIT value.
277    fn parse_sql_parts(&self) -> (String, Option<String>, Option<usize>) {
278        let sql = &self.sql_input;
279        let sql_upper = sql.to_uppercase();
280
281        // Find positions of clauses
282        let order_idx = sql_upper.find(" ORDER BY ");
283        let limit_idx = sql_upper.find(" LIMIT ");
284        let offset_idx = sql_upper.find(" OFFSET ");
285
286        // Extract limit value if present
287        let limit_value = if let Some(li) = limit_idx {
288            let after_limit = &sql[li + 7..]; // Skip " LIMIT "
289            let end_idx = after_limit
290                .find(|c: char| !c.is_ascii_digit() && c != ' ')
291                .unwrap_or(after_limit.len());
292            after_limit[..end_idx].trim().parse::<usize>().ok()
293        } else {
294            None
295        };
296
297        // Find the earliest of LIMIT or OFFSET to know where to cut
298        let cut_idx = match (limit_idx, offset_idx) {
299            (Some(li), Some(oi)) => Some(li.min(oi)),
300            (Some(li), None) => Some(li),
301            (None, Some(oi)) => Some(oi),
302            (None, None) => None,
303        };
304
305        match (order_idx, cut_idx) {
306            (Some(oi), Some(ci)) if oi < ci => {
307                // ORDER BY comes before LIMIT/OFFSET
308                let base = sql[..oi].trim().to_string();
309                let order = sql[oi..ci].trim().to_string();
310                (base, Some(order), limit_value)
311            }
312            (Some(oi), None) => {
313                // Only ORDER BY, no LIMIT/OFFSET
314                let base = sql[..oi].trim().to_string();
315                let order = sql[oi..].trim().to_string();
316                (base, Some(order), limit_value)
317            }
318            (None, Some(ci)) => {
319                // No ORDER BY, just LIMIT/OFFSET
320                let base = sql[..ci].trim().to_string();
321                (base, None, limit_value)
322            }
323            (Some(_oi), Some(ci)) => {
324                // ORDER BY comes after LIMIT (unusual) - just cut at LIMIT
325                let base = sql[..ci].trim().to_string();
326                (base, None, limit_value)
327            }
328            (None, None) => {
329                // No ORDER BY or LIMIT/OFFSET
330                (sql.clone(), None, limit_value)
331            }
332        }
333    }
334
335    /// Get the currently selected column index.
336    pub fn selected_column(&self) -> usize {
337        self.horizontal_scroll
338    }
339
340    /// Total number of columns in results.
341    pub fn column_count(&self) -> usize {
342        self.results
343            .as_ref()
344            .and_then(|r| r.batches.first())
345            .map(|b| b.num_columns())
346            .unwrap_or(0)
347    }
348
349    /// Apply sort on a column by modifying the ORDER BY clause and re-executing.
350    pub fn apply_sort(
351        &mut self,
352        session: &vortex::session::VortexSession,
353        column: usize,
354        file_path: &str,
355    ) {
356        // Get the column name from results
357        let column_name = match &self.results {
358            Some(results) if column < results.column_names.len() => {
359                results.column_names[column].clone()
360            }
361            _ => return,
362        };
363
364        // Cycle sort direction
365        if self.sort_column == Some(column) {
366            self.sort_direction = self.sort_direction.cycle();
367            if self.sort_direction == SortDirection::None {
368                self.sort_column = None;
369            }
370        } else {
371            self.sort_column = Some(column);
372            self.sort_direction = SortDirection::Ascending;
373        }
374
375        // Update the ORDER BY clause
376        self.order_clause = if self.sort_direction == SortDirection::None {
377            None
378        } else {
379            let direction = match self.sort_direction {
380                SortDirection::Ascending => "ASC",
381                SortDirection::Descending => "DESC",
382                SortDirection::None => unreachable!(),
383            };
384            Some(format!("ORDER BY \"{column_name}\" {direction}"))
385        };
386
387        // Reset to first page and re-execute
388        self.current_page = 0;
389        self.rebuild_and_execute(session, file_path);
390    }
391}
392
393/// Holds query results for display.
394pub struct QueryResults {
395    pub batches: Vec<RecordBatch>,
396    pub total_rows: usize,
397    pub column_names: Vec<String>,
398}
399
400/// Execute a SQL query against the Vortex file.
401pub fn execute_query(
402    session: &vortex::session::VortexSession,
403    file_path: &str,
404    sql: &str,
405) -> Result<QueryResults, String> {
406    block_in_place(|| {
407        Handle::current().block_on(async {
408            let batches = execute_vortex_query(session, file_path, sql).await?;
409
410            let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
411
412            let column_names = if let Some(batch) = batches.first() {
413                let schema = batch.schema();
414                schema.fields().iter().map(|f| f.name().clone()).collect()
415            } else {
416                vec![]
417            };
418
419            Ok(QueryResults {
420                batches,
421                total_rows,
422                column_names,
423            })
424        })
425    })
426}
427
428/// Get total row count for a base query using COUNT(*).
429pub fn get_row_count(
430    session: &vortex::session::VortexSession,
431    file_path: &str,
432    base_query: &str,
433) -> Result<usize, String> {
434    block_in_place(|| {
435        Handle::current().block_on(async {
436            let count_sql = format!("SELECT COUNT(*) as count FROM ({base_query}) AS subquery");
437
438            let batches = execute_vortex_query(session, file_path, &count_sql).await?;
439
440            // Extract count from result
441            if let Some(batch) = batches.first()
442                && batch.num_rows() > 0
443                && batch.num_columns() > 0
444            {
445                use arrow_array::Int64Array;
446                if let Some(arr) = batch.column(0).as_any().downcast_ref::<Int64Array>() {
447                    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
448                    return Ok(arr.value(0) as usize);
449                }
450            }
451
452            Ok(0)
453        })
454    })
455}
456
457/// Render the Query tab UI.
458pub fn render_query(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
459    let [input_area, results_area] =
460        Layout::vertical([Constraint::Length(5), Constraint::Min(10)]).areas(area);
461
462    render_sql_input(app, input_area, buf);
463    render_results_table(app, results_area, buf);
464}
465
466fn render_sql_input(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
467    let is_focused = app.query_state.focus == QueryFocus::SqlInput;
468
469    let border_color = if is_focused {
470        Color::Cyan
471    } else {
472        Color::DarkGray
473    };
474
475    let block = Block::default()
476        .title("SQL Query (Enter to execute, Esc to switch focus)")
477        .borders(Borders::ALL)
478        .border_type(BorderType::Rounded)
479        .border_style(Style::default().fg(border_color));
480
481    let inner = block.inner(area);
482    block.render(area, buf);
483
484    // Create the input text with cursor
485    let sql = &app.query_state.sql_input;
486    let cursor_pos = app.query_state.cursor_position;
487
488    let (before_cursor, after_cursor) = sql.split_at(cursor_pos.min(sql.len()));
489
490    let first_char = after_cursor.chars().next();
491    let cursor_char = if is_focused {
492        match first_char {
493            None => Span::styled(" ", Style::default().bg(Color::White).fg(Color::Black)),
494            Some(c) => Span::styled(
495                c.to_string(),
496                Style::default().bg(Color::White).fg(Color::Black),
497            ),
498        }
499    } else {
500        match first_char {
501            None => Span::raw(""),
502            Some(c) => Span::raw(c.to_string()),
503        }
504    };
505
506    let rest = match first_char {
507        Some(c) if after_cursor.len() > c.len_utf8() => &after_cursor[c.len_utf8()..],
508        _ => "",
509    };
510
511    let line = Line::from(vec![Span::raw(before_cursor), cursor_char, Span::raw(rest)]);
512
513    let paragraph = Paragraph::new(line).style(Style::default().fg(Color::White));
514
515    paragraph.render(inner, buf);
516}
517
518fn render_results_table(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
519    let is_focused = app.query_state.focus == QueryFocus::ResultsTable;
520
521    let border_color = if is_focused {
522        Color::Cyan
523    } else {
524        Color::DarkGray
525    };
526
527    // Show status in title
528    let title = if app.query_state.running {
529        "Results (running...)".to_string()
530    } else if let Some(ref error) = app.query_state.error {
531        format!("Results (error: {})", truncate_str(error, 50))
532    } else if let Some(ref _results) = app.query_state.results {
533        let total_rows = app.query_state.total_row_count.unwrap_or(0);
534        let total_pages = app.query_state.total_pages();
535        format!(
536            "Results ({} rows, page {}/{}) [hjkl navigate, [/] pages, s sort]",
537            total_rows,
538            app.query_state.current_page + 1,
539            total_pages,
540        )
541    } else {
542        "Results (press Enter to execute query)".to_string()
543    };
544
545    let block = Block::default()
546        .title(title)
547        .borders(Borders::ALL)
548        .border_type(BorderType::Rounded)
549        .border_style(Style::default().fg(border_color));
550
551    let inner = block.inner(area);
552    block.render(area, buf);
553
554    if let Some(ref error) = app.query_state.error {
555        let error_text = Paragraph::new(error.as_str())
556            .style(Style::default().fg(Color::Red))
557            .wrap(ratatui::widgets::Wrap { trim: true });
558        error_text.render(inner, buf);
559        return;
560    }
561
562    let Some(ref results) = app.query_state.results else {
563        let help = Paragraph::new("Enter a SQL query above and press Enter to execute.\nThe table is available as 'data'.\n\nExample: SELECT * FROM data WHERE column > 10 LIMIT 100")
564            .style(Style::default().fg(Color::Gray));
565        help.render(inner, buf);
566        return;
567    };
568
569    if results.batches.is_empty() || results.total_rows == 0 {
570        let empty =
571            Paragraph::new("Query returned no results.").style(Style::default().fg(Color::Yellow));
572        empty.render(inner, buf);
573        return;
574    }
575
576    // Build header row with sort indicators
577    let header_cells: Vec<Cell> = results
578        .column_names
579        .iter()
580        .enumerate()
581        .map(|(i, name)| {
582            let indicator = if app.query_state.sort_column == Some(i) {
583                app.query_state.sort_direction.indicator()
584            } else {
585                ""
586            };
587
588            let style = if is_focused && i == app.query_state.horizontal_scroll {
589                Style::default().fg(Color::Black).bg(Color::Cyan).bold()
590            } else {
591                Style::default().fg(Color::Green).bold()
592            };
593
594            Cell::from(format!("{name}{indicator}")).style(style)
595        })
596        .collect();
597
598    let header = Row::new(header_cells).height(1);
599
600    // Since we use LIMIT/OFFSET in SQL, batches contain only the current page's data
601    // Display all rows from the batches
602    let rows = get_all_rows(results, &app.query_state);
603
604    // Calculate column widths
605    #[allow(clippy::cast_possible_truncation)]
606    let widths: Vec<Constraint> = results
607        .column_names
608        .iter()
609        .map(|name| Constraint::Min((name.len() + 3).max(10) as u16))
610        .collect();
611
612    let table = Table::new(rows, widths)
613        .header(header)
614        .row_highlight_style(Style::default().bg(Color::DarkGray));
615
616    // Split area for table and scrollbar
617    let [table_area, scrollbar_area] =
618        Layout::horizontal([Constraint::Min(0), Constraint::Length(1)]).areas(inner);
619
620    StatefulWidget::render(table, table_area, buf, &mut app.query_state.table_state);
621
622    // Render vertical scrollbar
623    let total_pages = app.query_state.total_pages();
624    if total_pages > 1 {
625        let mut scrollbar_state = ScrollbarState::new(total_pages)
626            .position(app.query_state.current_page)
627            .viewport_content_length(1);
628
629        Scrollbar::new(ScrollbarOrientation::VerticalRight)
630            .begin_symbol(Some("▲"))
631            .end_symbol(Some("▼"))
632            .render(scrollbar_area, buf, &mut scrollbar_state);
633    }
634}
635
636/// Get all rows from batches (pagination is handled via SQL LIMIT/OFFSET).
637fn get_all_rows<'a>(results: &'a QueryResults, query_state: &QueryState) -> Vec<Row<'a>> {
638    let mut rows = Vec::new();
639
640    for batch in &results.batches {
641        for row_idx in 0..batch.num_rows() {
642            let cells: Vec<Cell> = (0..batch.num_columns())
643                .map(|col_idx| {
644                    let json_value = arrow_value_to_json(batch.column(col_idx).as_ref(), row_idx);
645                    let value = json_value_to_display(json_value);
646                    let style = if query_state.sort_column == Some(col_idx) {
647                        Style::default().fg(Color::Cyan)
648                    } else {
649                        Style::default()
650                    };
651                    Cell::from(truncate_str(&value, 30).to_string()).style(style)
652                })
653                .collect();
654            rows.push(Row::new(cells));
655        }
656    }
657
658    rows
659}
660
661fn truncate_str(s: &str, max_len: usize) -> &str {
662    if s.len() <= max_len {
663        s
664    } else {
665        &s[..max_len.saturating_sub(3)]
666    }
667}