1use arrow_array::RecordBatch;
5use ratatui::buffer::Buffer;
6use ratatui::layout::Constraint;
7use ratatui::layout::Layout;
8use ratatui::layout::Rect;
9use ratatui::style::Color;
10use ratatui::style::Style;
11use ratatui::text::Line;
12use ratatui::text::Span;
13use ratatui::widgets::Block;
14use ratatui::widgets::BorderType;
15use ratatui::widgets::Borders;
16use ratatui::widgets::Cell;
17use ratatui::widgets::Paragraph;
18use ratatui::widgets::Row;
19use ratatui::widgets::Scrollbar;
20use ratatui::widgets::ScrollbarOrientation;
21use ratatui::widgets::ScrollbarState;
22use ratatui::widgets::StatefulWidget;
23use ratatui::widgets::Table;
24use ratatui::widgets::TableState;
25use ratatui::widgets::Widget;
26use tokio::runtime::Handle;
27use tokio::task::block_in_place;
28
29use crate::browse::app::AppState;
30use crate::datafusion_helper::arrow_value_to_json;
31use crate::datafusion_helper::execute_vortex_query;
32use crate::datafusion_helper::json_value_to_display;
33
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
36pub enum SortDirection {
37 #[default]
39 None,
40 Ascending,
42 Descending,
44}
45
46impl SortDirection {
47 pub fn cycle(self) -> Self {
49 match self {
50 SortDirection::None => SortDirection::Ascending,
51 SortDirection::Ascending => SortDirection::Descending,
52 SortDirection::Descending => SortDirection::None,
53 }
54 }
55
56 pub fn indicator(self) -> &'static str {
58 match self {
59 SortDirection::None => "",
60 SortDirection::Ascending => " ▲",
61 SortDirection::Descending => " ▼",
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
68pub enum QueryFocus {
69 #[default]
71 SqlInput,
72 ResultsTable,
74}
75
76pub struct QueryState {
78 pub sql_input: String,
80 pub cursor_position: usize,
82 pub focus: QueryFocus,
84 pub results: Option<QueryResults>,
86 pub error: Option<String>,
88 pub running: bool,
90 pub table_state: TableState,
92 pub horizontal_scroll: usize,
94 pub sort_column: Option<usize>,
96 pub sort_direction: SortDirection,
98 pub current_page: usize,
100 pub page_size: usize,
102 pub total_row_count: Option<usize>,
104 pub base_query: String,
106 pub order_clause: Option<String>,
108}
109
110impl Default for QueryState {
111 fn default() -> Self {
112 let default_sql = "SELECT * FROM data LIMIT 20";
113 Self {
114 sql_input: default_sql.to_string(),
115 cursor_position: default_sql.len(),
116 focus: QueryFocus::default(),
117 results: None,
118 error: None,
119 running: false,
120 table_state: TableState::default(),
121 horizontal_scroll: 0,
122 sort_column: None,
123 sort_direction: SortDirection::default(),
124 current_page: 0,
125 page_size: 20,
126 total_row_count: None,
127 base_query: "SELECT * FROM data".to_string(),
128 order_clause: None,
129 }
130 }
131}
132
133impl QueryState {
134 pub fn insert_char(&mut self, c: char) {
136 self.sql_input.insert(self.cursor_position, c);
137 self.cursor_position += 1;
138 }
139
140 pub fn delete_char(&mut self) {
142 if self.cursor_position > 0 {
143 self.cursor_position -= 1;
144 self.sql_input.remove(self.cursor_position);
145 }
146 }
147
148 pub fn delete_char_forward(&mut self) {
150 if self.cursor_position < self.sql_input.len() {
151 self.sql_input.remove(self.cursor_position);
152 }
153 }
154
155 pub fn move_cursor_left(&mut self) {
157 self.cursor_position = self.cursor_position.saturating_sub(1);
158 }
159
160 pub fn move_cursor_right(&mut self) {
162 if self.cursor_position < self.sql_input.len() {
163 self.cursor_position += 1;
164 }
165 }
166
167 pub fn move_cursor_start(&mut self) {
169 self.cursor_position = 0;
170 }
171
172 pub fn move_cursor_end(&mut self) {
174 self.cursor_position = self.sql_input.len();
175 }
176
177 pub fn clear_input(&mut self) {
179 self.sql_input.clear();
180 self.cursor_position = 0;
181 }
182
183 pub fn toggle_focus(&mut self) {
185 self.focus = match self.focus {
186 QueryFocus::SqlInput => QueryFocus::ResultsTable,
187 QueryFocus::ResultsTable => QueryFocus::SqlInput,
188 };
189 }
190
191 pub fn execute_initial_query(
193 &mut self,
194 session: &vortex::session::VortexSession,
195 file_path: &str,
196 ) {
197 self.running = true;
198 self.error = None;
199
200 let (base_sql, order_clause, limit) = self.parse_sql_parts();
202 self.base_query = base_sql;
203 self.order_clause = order_clause;
204 self.page_size = limit.unwrap_or(20);
205 self.current_page = 0;
206
207 self.total_row_count = get_row_count(session, file_path, &self.base_query).ok();
209
210 self.rebuild_and_execute(session, file_path);
212 }
213
214 pub fn next_page(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
216 let total_pages = self.total_pages();
217 if self.current_page + 1 < total_pages {
218 self.current_page += 1;
219 self.rebuild_and_execute(session, file_path);
220 }
221 }
222
223 pub fn prev_page(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
225 if self.current_page > 0 {
226 self.current_page -= 1;
227 self.rebuild_and_execute(session, file_path);
228 }
229 }
230
231 pub fn total_pages(&self) -> usize {
233 match self.total_row_count {
234 Some(total) if total > 0 => total.div_ceil(self.page_size),
235 _ => 1,
236 }
237 }
238
239 fn rebuild_and_execute(&mut self, session: &vortex::session::VortexSession, file_path: &str) {
241 let offset = self.current_page * self.page_size;
242
243 let new_sql = match &self.order_clause {
244 Some(order) => {
245 format!(
246 "{} {} LIMIT {} OFFSET {}",
247 self.base_query, order, self.page_size, offset
248 )
249 }
250 None => {
251 format!(
252 "{} LIMIT {} OFFSET {}",
253 self.base_query, self.page_size, offset
254 )
255 }
256 };
257
258 self.sql_input = new_sql;
259 self.cursor_position = self.sql_input.len();
260
261 self.running = true;
262 self.error = None;
263
264 match execute_query(session, file_path, &self.sql_input) {
265 Ok(results) => {
266 self.results = Some(results);
267 self.table_state.select(Some(0));
268 }
269 Err(e) => {
270 self.error = Some(e);
271 }
272 }
273 self.running = false;
274 }
275
276 fn parse_sql_parts(&self) -> (String, Option<String>, Option<usize>) {
278 let sql = &self.sql_input;
279 let sql_upper = sql.to_uppercase();
280
281 let order_idx = sql_upper.find(" ORDER BY ");
283 let limit_idx = sql_upper.find(" LIMIT ");
284 let offset_idx = sql_upper.find(" OFFSET ");
285
286 let limit_value = if let Some(li) = limit_idx {
288 let after_limit = &sql[li + 7..]; let end_idx = after_limit
290 .find(|c: char| !c.is_ascii_digit() && c != ' ')
291 .unwrap_or(after_limit.len());
292 after_limit[..end_idx].trim().parse::<usize>().ok()
293 } else {
294 None
295 };
296
297 let cut_idx = match (limit_idx, offset_idx) {
299 (Some(li), Some(oi)) => Some(li.min(oi)),
300 (Some(li), None) => Some(li),
301 (None, Some(oi)) => Some(oi),
302 (None, None) => None,
303 };
304
305 match (order_idx, cut_idx) {
306 (Some(oi), Some(ci)) if oi < ci => {
307 let base = sql[..oi].trim().to_string();
309 let order = sql[oi..ci].trim().to_string();
310 (base, Some(order), limit_value)
311 }
312 (Some(oi), None) => {
313 let base = sql[..oi].trim().to_string();
315 let order = sql[oi..].trim().to_string();
316 (base, Some(order), limit_value)
317 }
318 (None, Some(ci)) => {
319 let base = sql[..ci].trim().to_string();
321 (base, None, limit_value)
322 }
323 (Some(_oi), Some(ci)) => {
324 let base = sql[..ci].trim().to_string();
326 (base, None, limit_value)
327 }
328 (None, None) => {
329 (sql.clone(), None, limit_value)
331 }
332 }
333 }
334
335 pub fn selected_column(&self) -> usize {
337 self.horizontal_scroll
338 }
339
340 pub fn column_count(&self) -> usize {
342 self.results
343 .as_ref()
344 .and_then(|r| r.batches.first())
345 .map(|b| b.num_columns())
346 .unwrap_or(0)
347 }
348
349 pub fn apply_sort(
351 &mut self,
352 session: &vortex::session::VortexSession,
353 column: usize,
354 file_path: &str,
355 ) {
356 let column_name = match &self.results {
358 Some(results) if column < results.column_names.len() => {
359 results.column_names[column].clone()
360 }
361 _ => return,
362 };
363
364 if self.sort_column == Some(column) {
366 self.sort_direction = self.sort_direction.cycle();
367 if self.sort_direction == SortDirection::None {
368 self.sort_column = None;
369 }
370 } else {
371 self.sort_column = Some(column);
372 self.sort_direction = SortDirection::Ascending;
373 }
374
375 self.order_clause = if self.sort_direction == SortDirection::None {
377 None
378 } else {
379 let direction = match self.sort_direction {
380 SortDirection::Ascending => "ASC",
381 SortDirection::Descending => "DESC",
382 SortDirection::None => unreachable!(),
383 };
384 Some(format!("ORDER BY \"{column_name}\" {direction}"))
385 };
386
387 self.current_page = 0;
389 self.rebuild_and_execute(session, file_path);
390 }
391}
392
393pub struct QueryResults {
395 pub batches: Vec<RecordBatch>,
396 pub total_rows: usize,
397 pub column_names: Vec<String>,
398}
399
400pub fn execute_query(
402 session: &vortex::session::VortexSession,
403 file_path: &str,
404 sql: &str,
405) -> Result<QueryResults, String> {
406 block_in_place(|| {
407 Handle::current().block_on(async {
408 let batches = execute_vortex_query(session, file_path, sql).await?;
409
410 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
411
412 let column_names = if let Some(batch) = batches.first() {
413 let schema = batch.schema();
414 schema.fields().iter().map(|f| f.name().clone()).collect()
415 } else {
416 vec![]
417 };
418
419 Ok(QueryResults {
420 batches,
421 total_rows,
422 column_names,
423 })
424 })
425 })
426}
427
428pub fn get_row_count(
430 session: &vortex::session::VortexSession,
431 file_path: &str,
432 base_query: &str,
433) -> Result<usize, String> {
434 block_in_place(|| {
435 Handle::current().block_on(async {
436 let count_sql = format!("SELECT COUNT(*) as count FROM ({base_query}) AS subquery");
437
438 let batches = execute_vortex_query(session, file_path, &count_sql).await?;
439
440 if let Some(batch) = batches.first()
442 && batch.num_rows() > 0
443 && batch.num_columns() > 0
444 {
445 use arrow_array::Int64Array;
446 if let Some(arr) = batch.column(0).as_any().downcast_ref::<Int64Array>() {
447 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
448 return Ok(arr.value(0) as usize);
449 }
450 }
451
452 Ok(0)
453 })
454 })
455}
456
457pub fn render_query(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
459 let [input_area, results_area] =
460 Layout::vertical([Constraint::Length(5), Constraint::Min(10)]).areas(area);
461
462 render_sql_input(app, input_area, buf);
463 render_results_table(app, results_area, buf);
464}
465
466fn render_sql_input(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
467 let is_focused = app.query_state.focus == QueryFocus::SqlInput;
468
469 let border_color = if is_focused {
470 Color::Cyan
471 } else {
472 Color::DarkGray
473 };
474
475 let block = Block::default()
476 .title("SQL Query (Enter to execute, Esc to switch focus)")
477 .borders(Borders::ALL)
478 .border_type(BorderType::Rounded)
479 .border_style(Style::default().fg(border_color));
480
481 let inner = block.inner(area);
482 block.render(area, buf);
483
484 let sql = &app.query_state.sql_input;
486 let cursor_pos = app.query_state.cursor_position;
487
488 let (before_cursor, after_cursor) = sql.split_at(cursor_pos.min(sql.len()));
489
490 let first_char = after_cursor.chars().next();
491 let cursor_char = if is_focused {
492 match first_char {
493 None => Span::styled(" ", Style::default().bg(Color::White).fg(Color::Black)),
494 Some(c) => Span::styled(
495 c.to_string(),
496 Style::default().bg(Color::White).fg(Color::Black),
497 ),
498 }
499 } else {
500 match first_char {
501 None => Span::raw(""),
502 Some(c) => Span::raw(c.to_string()),
503 }
504 };
505
506 let rest = match first_char {
507 Some(c) if after_cursor.len() > c.len_utf8() => &after_cursor[c.len_utf8()..],
508 _ => "",
509 };
510
511 let line = Line::from(vec![Span::raw(before_cursor), cursor_char, Span::raw(rest)]);
512
513 let paragraph = Paragraph::new(line).style(Style::default().fg(Color::White));
514
515 paragraph.render(inner, buf);
516}
517
518fn render_results_table(app: &mut AppState<'_>, area: Rect, buf: &mut Buffer) {
519 let is_focused = app.query_state.focus == QueryFocus::ResultsTable;
520
521 let border_color = if is_focused {
522 Color::Cyan
523 } else {
524 Color::DarkGray
525 };
526
527 let title = if app.query_state.running {
529 "Results (running...)".to_string()
530 } else if let Some(ref error) = app.query_state.error {
531 format!("Results (error: {})", truncate_str(error, 50))
532 } else if let Some(ref _results) = app.query_state.results {
533 let total_rows = app.query_state.total_row_count.unwrap_or(0);
534 let total_pages = app.query_state.total_pages();
535 format!(
536 "Results ({} rows, page {}/{}) [hjkl navigate, [/] pages, s sort]",
537 total_rows,
538 app.query_state.current_page + 1,
539 total_pages,
540 )
541 } else {
542 "Results (press Enter to execute query)".to_string()
543 };
544
545 let block = Block::default()
546 .title(title)
547 .borders(Borders::ALL)
548 .border_type(BorderType::Rounded)
549 .border_style(Style::default().fg(border_color));
550
551 let inner = block.inner(area);
552 block.render(area, buf);
553
554 if let Some(ref error) = app.query_state.error {
555 let error_text = Paragraph::new(error.as_str())
556 .style(Style::default().fg(Color::Red))
557 .wrap(ratatui::widgets::Wrap { trim: true });
558 error_text.render(inner, buf);
559 return;
560 }
561
562 let Some(ref results) = app.query_state.results else {
563 let help = Paragraph::new("Enter a SQL query above and press Enter to execute.\nThe table is available as 'data'.\n\nExample: SELECT * FROM data WHERE column > 10 LIMIT 100")
564 .style(Style::default().fg(Color::Gray));
565 help.render(inner, buf);
566 return;
567 };
568
569 if results.batches.is_empty() || results.total_rows == 0 {
570 let empty =
571 Paragraph::new("Query returned no results.").style(Style::default().fg(Color::Yellow));
572 empty.render(inner, buf);
573 return;
574 }
575
576 let header_cells: Vec<Cell> = results
578 .column_names
579 .iter()
580 .enumerate()
581 .map(|(i, name)| {
582 let indicator = if app.query_state.sort_column == Some(i) {
583 app.query_state.sort_direction.indicator()
584 } else {
585 ""
586 };
587
588 let style = if is_focused && i == app.query_state.horizontal_scroll {
589 Style::default().fg(Color::Black).bg(Color::Cyan).bold()
590 } else {
591 Style::default().fg(Color::Green).bold()
592 };
593
594 Cell::from(format!("{name}{indicator}")).style(style)
595 })
596 .collect();
597
598 let header = Row::new(header_cells).height(1);
599
600 let rows = get_all_rows(results, &app.query_state);
603
604 #[allow(clippy::cast_possible_truncation)]
606 let widths: Vec<Constraint> = results
607 .column_names
608 .iter()
609 .map(|name| Constraint::Min((name.len() + 3).max(10) as u16))
610 .collect();
611
612 let table = Table::new(rows, widths)
613 .header(header)
614 .row_highlight_style(Style::default().bg(Color::DarkGray));
615
616 let [table_area, scrollbar_area] =
618 Layout::horizontal([Constraint::Min(0), Constraint::Length(1)]).areas(inner);
619
620 StatefulWidget::render(table, table_area, buf, &mut app.query_state.table_state);
621
622 let total_pages = app.query_state.total_pages();
624 if total_pages > 1 {
625 let mut scrollbar_state = ScrollbarState::new(total_pages)
626 .position(app.query_state.current_page)
627 .viewport_content_length(1);
628
629 Scrollbar::new(ScrollbarOrientation::VerticalRight)
630 .begin_symbol(Some("▲"))
631 .end_symbol(Some("▼"))
632 .render(scrollbar_area, buf, &mut scrollbar_state);
633 }
634}
635
636fn get_all_rows<'a>(results: &'a QueryResults, query_state: &QueryState) -> Vec<Row<'a>> {
638 let mut rows = Vec::new();
639
640 for batch in &results.batches {
641 for row_idx in 0..batch.num_rows() {
642 let cells: Vec<Cell> = (0..batch.num_columns())
643 .map(|col_idx| {
644 let json_value = arrow_value_to_json(batch.column(col_idx).as_ref(), row_idx);
645 let value = json_value_to_display(json_value);
646 let style = if query_state.sort_column == Some(col_idx) {
647 Style::default().fg(Color::Cyan)
648 } else {
649 Style::default()
650 };
651 Cell::from(truncate_str(&value, 30).to_string()).style(style)
652 })
653 .collect();
654 rows.push(Row::new(cells));
655 }
656 }
657
658 rows
659}
660
661fn truncate_str(s: &str, max_len: usize) -> &str {
662 if s.len() <= max_len {
663 s
664 } else {
665 &s[..max_len.saturating_sub(3)]
666 }
667}