1use std::collections::HashMap;
13
14use vibesql_ast::{
15 CloseCursorStmt, DeclareCursorStmt, FetchOrientation, FetchStmt, OpenCursorStmt, SelectStmt,
16};
17use vibesql_storage::{Database, Row};
18
19use crate::errors::ExecutorError;
20use crate::SelectExecutor;
21
22#[derive(Debug, Clone)]
24pub struct Cursor {
25 pub name: String,
27 pub query: Box<SelectStmt>,
29 pub result: Option<CursorResult>,
31 pub position: usize,
33 pub scroll: bool,
35 pub holdable: bool,
37 pub insensitive: bool,
39}
40
41#[derive(Debug, Clone)]
43pub struct CursorResult {
44 pub columns: Vec<String>,
46 pub rows: Vec<Row>,
48}
49
50#[derive(Debug, Clone)]
52pub struct FetchResult {
53 pub columns: Vec<String>,
55 pub rows: Vec<Row>,
57}
58
59impl FetchResult {
60 pub fn empty(columns: Vec<String>) -> Self {
62 Self { columns, rows: vec![] }
63 }
64
65 pub fn single(columns: Vec<String>, row: Row) -> Self {
67 Self { columns, rows: vec![row] }
68 }
69}
70
71#[derive(Debug, Default)]
73pub struct CursorStore {
74 cursors: HashMap<String, Cursor>,
75}
76
77impl CursorStore {
78 pub fn new() -> Self {
80 Self { cursors: HashMap::new() }
81 }
82
83 pub fn declare(&mut self, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
85 let name = stmt.cursor_name.to_uppercase();
86
87 if self.cursors.contains_key(&name) {
88 return Err(ExecutorError::CursorAlreadyExists(name));
89 }
90
91 let cursor = Cursor {
92 name: name.clone(),
93 query: stmt.query.clone(),
94 result: None,
95 position: 0,
96 scroll: stmt.scroll,
97 holdable: stmt.hold.unwrap_or(false),
98 insensitive: stmt.insensitive,
99 };
100
101 self.cursors.insert(name, cursor);
102 Ok(())
103 }
104
105 pub fn open(&mut self, stmt: &OpenCursorStmt, db: &Database) -> Result<(), ExecutorError> {
107 let name = stmt.cursor_name.to_uppercase();
108
109 let cursor = self.cursors.get_mut(&name).ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
110
111 if cursor.result.is_some() {
112 return Err(ExecutorError::CursorAlreadyOpen(name));
113 }
114
115 let executor = SelectExecutor::new(db);
117 let select_result = executor.execute_with_columns(&cursor.query)?;
118
119 cursor.result = Some(CursorResult {
120 columns: select_result.columns,
121 rows: select_result.rows,
122 });
123 cursor.position = 0; Ok(())
126 }
127
128 pub fn fetch(&mut self, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
130 let name = stmt.cursor_name.to_uppercase();
131
132 let cursor = self.cursors.get_mut(&name).ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
133
134 let result = cursor.result.as_ref().ok_or_else(|| ExecutorError::CursorNotOpen(name.clone()))?;
135
136 let row_count = result.rows.len();
137
138 let new_position = match &stmt.orientation {
140 FetchOrientation::Next => cursor.position.saturating_add(1),
141 FetchOrientation::Prior => {
142 if !cursor.scroll {
143 return Err(ExecutorError::CursorNotScrollable(name));
144 }
145 cursor.position.saturating_sub(1)
146 }
147 FetchOrientation::First => {
148 if !cursor.scroll && cursor.position > 0 {
149 return Err(ExecutorError::CursorNotScrollable(name));
150 }
151 1
152 }
153 FetchOrientation::Last => {
154 if !cursor.scroll {
155 return Err(ExecutorError::CursorNotScrollable(name));
156 }
157 row_count
158 }
159 FetchOrientation::Absolute(n) => {
160 if !cursor.scroll {
161 return Err(ExecutorError::CursorNotScrollable(name));
162 }
163 if *n >= 0 {
164 *n as usize
165 } else {
166 row_count.saturating_sub((-*n - 1) as usize)
168 }
169 }
170 FetchOrientation::Relative(n) => {
171 if !cursor.scroll && *n < 0 {
172 return Err(ExecutorError::CursorNotScrollable(name));
173 }
174 if *n >= 0 {
175 cursor.position.saturating_add(*n as usize)
176 } else {
177 cursor.position.saturating_sub((-*n) as usize)
178 }
179 }
180 };
181
182 cursor.position = new_position;
183
184 if new_position > 0 && new_position <= row_count {
186 Ok(FetchResult::single(
187 result.columns.clone(),
188 result.rows[new_position - 1].clone(),
189 ))
190 } else {
191 Ok(FetchResult::empty(result.columns.clone()))
193 }
194 }
195
196 pub fn close(&mut self, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
198 let name = stmt.cursor_name.to_uppercase();
199
200 if self.cursors.remove(&name).is_none() {
201 return Err(ExecutorError::CursorNotFound(name));
202 }
203
204 Ok(())
205 }
206
207 pub fn exists(&self, name: &str) -> bool {
209 self.cursors.contains_key(&name.to_uppercase())
210 }
211
212 pub fn is_open(&self, name: &str) -> bool {
214 self.cursors.get(&name.to_uppercase()).map(|c| c.result.is_some()).unwrap_or(false)
215 }
216
217 pub fn count(&self) -> usize {
219 self.cursors.len()
220 }
221
222 pub fn clear_non_holdable(&mut self) {
224 self.cursors.retain(|_, cursor| cursor.holdable);
225 }
226
227 pub fn clear(&mut self) {
229 self.cursors.clear();
230 }
231}
232
233pub struct CursorExecutor;
235
236impl CursorExecutor {
237 pub fn declare(store: &mut CursorStore, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
239 store.declare(stmt)
240 }
241
242 pub fn open(store: &mut CursorStore, stmt: &OpenCursorStmt, db: &Database) -> Result<(), ExecutorError> {
244 store.open(stmt, db)
245 }
246
247 pub fn fetch(store: &mut CursorStore, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
249 store.fetch(stmt)
250 }
251
252 pub fn close(store: &mut CursorStore, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
254 store.close(stmt)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use vibesql_ast::{
262 DeclareCursorStmt, OpenCursorStmt, FetchStmt, FetchOrientation, CloseCursorStmt,
263 SelectStmt, SelectItem, FromClause, CursorUpdatability,
264 };
265 use vibesql_catalog::{ColumnSchema, TableSchema};
266 use vibesql_types::{DataType, SqlValue};
267
268 fn create_test_db() -> Database {
269 let mut db = Database::new();
270 db.catalog.set_case_sensitive_identifiers(false);
271
272 let columns = vec![
274 ColumnSchema::new("id".to_string(), DataType::Integer, false),
275 ColumnSchema::new("name".to_string(), DataType::Varchar { max_length: Some(100) }, true),
276 ColumnSchema::new("salary".to_string(), DataType::Integer, true),
277 ];
278 let schema = TableSchema::with_primary_key(
279 "employees".to_string(),
280 columns,
281 vec!["id".to_string()],
282 );
283 db.create_table(schema).unwrap();
284
285 db.insert_row("employees", Row::new(vec![
287 SqlValue::Integer(1),
288 SqlValue::Varchar("Alice".into()),
289 SqlValue::Integer(50000),
290 ])).unwrap();
291 db.insert_row("employees", Row::new(vec![
292 SqlValue::Integer(2),
293 SqlValue::Varchar("Bob".into()),
294 SqlValue::Integer(60000),
295 ])).unwrap();
296 db.insert_row("employees", Row::new(vec![
297 SqlValue::Integer(3),
298 SqlValue::Varchar("Carol".into()),
299 SqlValue::Integer(55000),
300 ])).unwrap();
301
302 db
303 }
304
305 fn create_select_stmt() -> SelectStmt {
306 SelectStmt {
307 with_clause: None,
308 distinct: false,
309 select_list: vec![SelectItem::Wildcard { alias: None }],
310 into_table: None,
311 into_variables: None,
312 from: Some(FromClause::Table {
313 name: "employees".to_string(),
314 alias: None,
315 column_aliases: None,
316 }),
317 where_clause: None,
318 group_by: None,
319 having: None,
320 order_by: None,
321 limit: None,
322 offset: None,
323 set_operation: None,
324 }
325 }
326
327 #[test]
328 fn test_declare_cursor() {
329 let mut store = CursorStore::new();
330
331 let stmt = DeclareCursorStmt {
332 cursor_name: "emp_cursor".to_string(),
333 insensitive: false,
334 scroll: false,
335 hold: None,
336 query: Box::new(create_select_stmt()),
337 updatability: CursorUpdatability::Unspecified,
338 };
339
340 assert!(store.declare(&stmt).is_ok());
341 assert!(store.exists("emp_cursor"));
342 assert!(!store.is_open("emp_cursor"));
343 }
344
345 #[test]
346 fn test_declare_cursor_already_exists() {
347 let mut store = CursorStore::new();
348
349 let stmt = DeclareCursorStmt {
350 cursor_name: "emp_cursor".to_string(),
351 insensitive: false,
352 scroll: false,
353 hold: None,
354 query: Box::new(create_select_stmt()),
355 updatability: CursorUpdatability::Unspecified,
356 };
357
358 store.declare(&stmt).unwrap();
359 let result = store.declare(&stmt);
360 assert!(matches!(result, Err(ExecutorError::CursorAlreadyExists(_))));
361 }
362
363 #[test]
364 fn test_open_cursor() {
365 let db = create_test_db();
366 let mut store = CursorStore::new();
367
368 let declare_stmt = DeclareCursorStmt {
369 cursor_name: "emp_cursor".to_string(),
370 insensitive: false,
371 scroll: false,
372 hold: None,
373 query: Box::new(create_select_stmt()),
374 updatability: CursorUpdatability::Unspecified,
375 };
376 store.declare(&declare_stmt).unwrap();
377
378 let open_stmt = OpenCursorStmt {
379 cursor_name: "emp_cursor".to_string(),
380 };
381 assert!(store.open(&open_stmt, &db).is_ok());
382 assert!(store.is_open("emp_cursor"));
383 }
384
385 #[test]
386 fn test_open_cursor_not_found() {
387 let db = create_test_db();
388 let mut store = CursorStore::new();
389
390 let open_stmt = OpenCursorStmt {
391 cursor_name: "nonexistent".to_string(),
392 };
393 let result = store.open(&open_stmt, &db);
394 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
395 }
396
397 #[test]
398 fn test_open_cursor_already_open() {
399 let db = create_test_db();
400 let mut store = CursorStore::new();
401
402 let declare_stmt = DeclareCursorStmt {
403 cursor_name: "emp_cursor".to_string(),
404 insensitive: false,
405 scroll: false,
406 hold: None,
407 query: Box::new(create_select_stmt()),
408 updatability: CursorUpdatability::Unspecified,
409 };
410 store.declare(&declare_stmt).unwrap();
411
412 let open_stmt = OpenCursorStmt {
413 cursor_name: "emp_cursor".to_string(),
414 };
415 store.open(&open_stmt, &db).unwrap();
416
417 let result = store.open(&open_stmt, &db);
418 assert!(matches!(result, Err(ExecutorError::CursorAlreadyOpen(_))));
419 }
420
421 #[test]
422 fn test_fetch_next() {
423 let db = create_test_db();
424 let mut store = CursorStore::new();
425
426 let declare_stmt = DeclareCursorStmt {
427 cursor_name: "emp_cursor".to_string(),
428 insensitive: false,
429 scroll: false,
430 hold: None,
431 query: Box::new(create_select_stmt()),
432 updatability: CursorUpdatability::Unspecified,
433 };
434 store.declare(&declare_stmt).unwrap();
435
436 let open_stmt = OpenCursorStmt {
437 cursor_name: "emp_cursor".to_string(),
438 };
439 store.open(&open_stmt, &db).unwrap();
440
441 let fetch_stmt = FetchStmt {
443 cursor_name: "emp_cursor".to_string(),
444 orientation: FetchOrientation::Next,
445 into_variables: None,
446 };
447 let result = store.fetch(&fetch_stmt).unwrap();
448 assert_eq!(result.rows.len(), 1);
449 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
450
451 let result = store.fetch(&fetch_stmt).unwrap();
453 assert_eq!(result.rows.len(), 1);
454 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
455
456 let result = store.fetch(&fetch_stmt).unwrap();
458 assert_eq!(result.rows.len(), 1);
459 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
460
461 let result = store.fetch(&fetch_stmt).unwrap();
463 assert_eq!(result.rows.len(), 0);
464 }
465
466 #[test]
467 fn test_fetch_from_unopened_cursor() {
468 let mut store = CursorStore::new();
469
470 let declare_stmt = DeclareCursorStmt {
471 cursor_name: "emp_cursor".to_string(),
472 insensitive: false,
473 scroll: false,
474 hold: None,
475 query: Box::new(create_select_stmt()),
476 updatability: CursorUpdatability::Unspecified,
477 };
478 store.declare(&declare_stmt).unwrap();
479
480 let fetch_stmt = FetchStmt {
481 cursor_name: "emp_cursor".to_string(),
482 orientation: FetchOrientation::Next,
483 into_variables: None,
484 };
485 let result = store.fetch(&fetch_stmt);
486 assert!(matches!(result, Err(ExecutorError::CursorNotOpen(_))));
487 }
488
489 #[test]
490 fn test_fetch_prior_non_scrollable() {
491 let db = create_test_db();
492 let mut store = CursorStore::new();
493
494 let declare_stmt = DeclareCursorStmt {
495 cursor_name: "emp_cursor".to_string(),
496 insensitive: false,
497 scroll: false, hold: None,
499 query: Box::new(create_select_stmt()),
500 updatability: CursorUpdatability::Unspecified,
501 };
502 store.declare(&declare_stmt).unwrap();
503
504 let open_stmt = OpenCursorStmt {
505 cursor_name: "emp_cursor".to_string(),
506 };
507 store.open(&open_stmt, &db).unwrap();
508
509 let fetch_next = FetchStmt {
511 cursor_name: "emp_cursor".to_string(),
512 orientation: FetchOrientation::Next,
513 into_variables: None,
514 };
515 store.fetch(&fetch_next).unwrap();
516
517 let fetch_prior = FetchStmt {
519 cursor_name: "emp_cursor".to_string(),
520 orientation: FetchOrientation::Prior,
521 into_variables: None,
522 };
523 let result = store.fetch(&fetch_prior);
524 assert!(matches!(result, Err(ExecutorError::CursorNotScrollable(_))));
525 }
526
527 #[test]
528 fn test_scroll_cursor() {
529 let db = create_test_db();
530 let mut store = CursorStore::new();
531
532 let declare_stmt = DeclareCursorStmt {
533 cursor_name: "scroll_cursor".to_string(),
534 insensitive: false,
535 scroll: true, hold: None,
537 query: Box::new(create_select_stmt()),
538 updatability: CursorUpdatability::Unspecified,
539 };
540 store.declare(&declare_stmt).unwrap();
541
542 let open_stmt = OpenCursorStmt {
543 cursor_name: "scroll_cursor".to_string(),
544 };
545 store.open(&open_stmt, &db).unwrap();
546
547 let fetch_last = FetchStmt {
549 cursor_name: "scroll_cursor".to_string(),
550 orientation: FetchOrientation::Last,
551 into_variables: None,
552 };
553 let result = store.fetch(&fetch_last).unwrap();
554 assert_eq!(result.rows.len(), 1);
555 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
556
557 let fetch_first = FetchStmt {
559 cursor_name: "scroll_cursor".to_string(),
560 orientation: FetchOrientation::First,
561 into_variables: None,
562 };
563 let result = store.fetch(&fetch_first).unwrap();
564 assert_eq!(result.rows.len(), 1);
565 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
566
567 let fetch_abs = FetchStmt {
569 cursor_name: "scroll_cursor".to_string(),
570 orientation: FetchOrientation::Absolute(2),
571 into_variables: None,
572 };
573 let result = store.fetch(&fetch_abs).unwrap();
574 assert_eq!(result.rows.len(), 1);
575 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
576 }
577
578 #[test]
579 fn test_close_cursor() {
580 let db = create_test_db();
581 let mut store = CursorStore::new();
582
583 let declare_stmt = DeclareCursorStmt {
584 cursor_name: "emp_cursor".to_string(),
585 insensitive: false,
586 scroll: false,
587 hold: None,
588 query: Box::new(create_select_stmt()),
589 updatability: CursorUpdatability::Unspecified,
590 };
591 store.declare(&declare_stmt).unwrap();
592
593 let open_stmt = OpenCursorStmt {
594 cursor_name: "emp_cursor".to_string(),
595 };
596 store.open(&open_stmt, &db).unwrap();
597
598 let close_stmt = CloseCursorStmt {
599 cursor_name: "emp_cursor".to_string(),
600 };
601 assert!(store.close(&close_stmt).is_ok());
602 assert!(!store.exists("emp_cursor"));
603 }
604
605 #[test]
606 fn test_close_nonexistent_cursor() {
607 let mut store = CursorStore::new();
608
609 let close_stmt = CloseCursorStmt {
610 cursor_name: "nonexistent".to_string(),
611 };
612 let result = store.close(&close_stmt);
613 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
614 }
615
616 #[test]
617 fn test_case_insensitive_cursor_names() {
618 let db = create_test_db();
619 let mut store = CursorStore::new();
620
621 let declare_stmt = DeclareCursorStmt {
622 cursor_name: "My_Cursor".to_string(),
623 insensitive: false,
624 scroll: false,
625 hold: None,
626 query: Box::new(create_select_stmt()),
627 updatability: CursorUpdatability::Unspecified,
628 };
629 store.declare(&declare_stmt).unwrap();
630
631 let open_stmt = OpenCursorStmt {
633 cursor_name: "MY_CURSOR".to_string(),
634 };
635 assert!(store.open(&open_stmt, &db).is_ok());
636
637 let close_stmt = CloseCursorStmt {
639 cursor_name: "my_cursor".to_string(),
640 };
641 assert!(store.close(&close_stmt).is_ok());
642 }
643
644 #[test]
645 fn test_holdable_cursor() {
646 let mut store = CursorStore::new();
647
648 let holdable_stmt = DeclareCursorStmt {
650 cursor_name: "holdable".to_string(),
651 insensitive: false,
652 scroll: false,
653 hold: Some(true),
654 query: Box::new(create_select_stmt()),
655 updatability: CursorUpdatability::Unspecified,
656 };
657 store.declare(&holdable_stmt).unwrap();
658
659 let non_holdable_stmt = DeclareCursorStmt {
661 cursor_name: "non_holdable".to_string(),
662 insensitive: false,
663 scroll: false,
664 hold: Some(false),
665 query: Box::new(create_select_stmt()),
666 updatability: CursorUpdatability::Unspecified,
667 };
668 store.declare(&non_holdable_stmt).unwrap();
669
670 assert_eq!(store.count(), 2);
671
672 store.clear_non_holdable();
674
675 assert_eq!(store.count(), 1);
676 assert!(store.exists("holdable"));
677 assert!(!store.exists("non_holdable"));
678 }
679}