1use std::collections::HashMap;
13
14use vibesql_ast::{
15 CloseCursorStmt, DeclareCursorStmt, FetchOrientation, FetchStmt, OpenCursorStmt, SelectStmt,
16};
17use vibesql_storage::{Database, Row};
18
19use crate::{errors::ExecutorError, SelectExecutor};
20
21#[derive(Debug, Clone)]
23pub struct Cursor {
24 pub name: String,
26 pub query: Box<SelectStmt>,
28 pub result: Option<CursorResult>,
30 pub position: usize,
32 pub scroll: bool,
34 pub holdable: bool,
36 pub insensitive: bool,
38}
39
40#[derive(Debug, Clone)]
42pub struct CursorResult {
43 pub columns: Vec<String>,
45 pub rows: Vec<Row>,
47}
48
49#[derive(Debug, Clone)]
51pub struct FetchResult {
52 pub columns: Vec<String>,
54 pub rows: Vec<Row>,
56}
57
58impl FetchResult {
59 pub fn empty(columns: Vec<String>) -> Self {
61 Self { columns, rows: vec![] }
62 }
63
64 pub fn single(columns: Vec<String>, row: Row) -> Self {
66 Self { columns, rows: vec![row] }
67 }
68}
69
70#[derive(Debug, Default)]
72pub struct CursorStore {
73 cursors: HashMap<String, Cursor>,
74}
75
76impl CursorStore {
77 pub fn new() -> Self {
79 Self { cursors: HashMap::new() }
80 }
81
82 pub fn declare(&mut self, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
84 let name = stmt.cursor_name.to_uppercase();
85
86 if self.cursors.contains_key(&name) {
87 return Err(ExecutorError::CursorAlreadyExists(name));
88 }
89
90 let cursor = Cursor {
91 name: name.clone(),
92 query: stmt.query.clone(),
93 result: None,
94 position: 0,
95 scroll: stmt.scroll,
96 holdable: stmt.hold.unwrap_or(false),
97 insensitive: stmt.insensitive,
98 };
99
100 self.cursors.insert(name, cursor);
101 Ok(())
102 }
103
104 pub fn open(&mut self, stmt: &OpenCursorStmt, db: &Database) -> Result<(), ExecutorError> {
106 let name = stmt.cursor_name.to_uppercase();
107
108 let cursor = self
109 .cursors
110 .get_mut(&name)
111 .ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
112
113 if cursor.result.is_some() {
114 return Err(ExecutorError::CursorAlreadyOpen(name));
115 }
116
117 let executor = SelectExecutor::new(db);
119 let select_result = executor.execute_with_columns(&cursor.query)?;
120
121 cursor.result =
122 Some(CursorResult { columns: select_result.columns, rows: select_result.rows });
123 cursor.position = 0; Ok(())
126 }
127
128 pub fn fetch(&mut self, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
130 let name = stmt.cursor_name.to_uppercase();
131
132 let cursor = self
133 .cursors
134 .get_mut(&name)
135 .ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
136
137 let result =
138 cursor.result.as_ref().ok_or_else(|| ExecutorError::CursorNotOpen(name.clone()))?;
139
140 let row_count = result.rows.len();
141
142 let new_position = match &stmt.orientation {
144 FetchOrientation::Next => cursor.position.saturating_add(1),
145 FetchOrientation::Prior => {
146 if !cursor.scroll {
147 return Err(ExecutorError::CursorNotScrollable(name));
148 }
149 cursor.position.saturating_sub(1)
150 }
151 FetchOrientation::First => {
152 if !cursor.scroll && cursor.position > 0 {
153 return Err(ExecutorError::CursorNotScrollable(name));
154 }
155 1
156 }
157 FetchOrientation::Last => {
158 if !cursor.scroll {
159 return Err(ExecutorError::CursorNotScrollable(name));
160 }
161 row_count
162 }
163 FetchOrientation::Absolute(n) => {
164 if !cursor.scroll {
165 return Err(ExecutorError::CursorNotScrollable(name));
166 }
167 if *n >= 0 {
168 *n as usize
169 } else {
170 row_count.saturating_sub((-*n - 1) as usize)
172 }
173 }
174 FetchOrientation::Relative(n) => {
175 if !cursor.scroll && *n < 0 {
176 return Err(ExecutorError::CursorNotScrollable(name));
177 }
178 if *n >= 0 {
179 cursor.position.saturating_add(*n as usize)
180 } else {
181 cursor.position.saturating_sub((-*n) as usize)
182 }
183 }
184 };
185
186 cursor.position = new_position;
187
188 if new_position > 0 && new_position <= row_count {
190 Ok(FetchResult::single(result.columns.clone(), result.rows[new_position - 1].clone()))
191 } else {
192 Ok(FetchResult::empty(result.columns.clone()))
194 }
195 }
196
197 pub fn close(&mut self, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
199 let name = stmt.cursor_name.to_uppercase();
200
201 if self.cursors.remove(&name).is_none() {
202 return Err(ExecutorError::CursorNotFound(name));
203 }
204
205 Ok(())
206 }
207
208 pub fn exists(&self, name: &str) -> bool {
210 self.cursors.contains_key(&name.to_uppercase())
211 }
212
213 pub fn is_open(&self, name: &str) -> bool {
215 self.cursors.get(&name.to_uppercase()).map(|c| c.result.is_some()).unwrap_or(false)
216 }
217
218 pub fn count(&self) -> usize {
220 self.cursors.len()
221 }
222
223 pub fn clear_non_holdable(&mut self) {
225 self.cursors.retain(|_, cursor| cursor.holdable);
226 }
227
228 pub fn clear(&mut self) {
230 self.cursors.clear();
231 }
232}
233
234pub struct CursorExecutor;
236
237impl CursorExecutor {
238 pub fn declare(store: &mut CursorStore, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
240 store.declare(stmt)
241 }
242
243 pub fn open(
245 store: &mut CursorStore,
246 stmt: &OpenCursorStmt,
247 db: &Database,
248 ) -> Result<(), ExecutorError> {
249 store.open(stmt, db)
250 }
251
252 pub fn fetch(store: &mut CursorStore, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
254 store.fetch(stmt)
255 }
256
257 pub fn close(store: &mut CursorStore, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
259 store.close(stmt)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use vibesql_ast::{
266 CloseCursorStmt, CursorUpdatability, DeclareCursorStmt, FetchOrientation, FetchStmt,
267 FromClause, OpenCursorStmt, SelectItem, SelectStmt,
268 };
269 use vibesql_catalog::{ColumnSchema, TableSchema};
270 use vibesql_types::{DataType, SqlValue};
271
272 use super::*;
273
274 fn create_test_db() -> Database {
275 let mut db = Database::new();
276 db.catalog.set_case_sensitive_identifiers(false);
277
278 let columns = vec![
280 ColumnSchema::new("id".to_string(), DataType::Integer, false),
281 ColumnSchema::new(
282 "name".to_string(),
283 DataType::Varchar { max_length: Some(100) },
284 true,
285 ),
286 ColumnSchema::new("salary".to_string(), DataType::Integer, true),
287 ];
288 let schema =
289 TableSchema::with_primary_key("employees".to_string(), columns, vec!["id".to_string()]);
290 db.create_table(schema).unwrap();
291
292 db.insert_row(
294 "employees",
295 Row::new(vec![
296 SqlValue::Integer(1),
297 SqlValue::Varchar(arcstr::ArcStr::from("Alice")),
298 SqlValue::Integer(50000),
299 ]),
300 )
301 .unwrap();
302 db.insert_row(
303 "employees",
304 Row::new(vec![
305 SqlValue::Integer(2),
306 SqlValue::Varchar(arcstr::ArcStr::from("Bob")),
307 SqlValue::Integer(60000),
308 ]),
309 )
310 .unwrap();
311 db.insert_row(
312 "employees",
313 Row::new(vec![
314 SqlValue::Integer(3),
315 SqlValue::Varchar(arcstr::ArcStr::from("Carol")),
316 SqlValue::Integer(55000),
317 ]),
318 )
319 .unwrap();
320
321 db
322 }
323
324 fn create_select_stmt() -> SelectStmt {
325 SelectStmt {
326 with_clause: None,
327 distinct: false,
328 select_list: vec![SelectItem::Wildcard { alias: None }],
329 into_table: None,
330 into_variables: None,
331 from: Some(FromClause::Table {
332 name: "employees".to_string(),
333 alias: None,
334 column_aliases: None,
335 quoted: false,
336 }),
337 where_clause: None,
338 group_by: None,
339 having: None,
340 order_by: None,
341 limit: None,
342 offset: None,
343 set_operation: None,
344 values: None,
345 }
346 }
347
348 #[test]
349 fn test_declare_cursor() {
350 let mut store = CursorStore::new();
351
352 let stmt = DeclareCursorStmt {
353 cursor_name: "emp_cursor".to_string(),
354 insensitive: false,
355 scroll: false,
356 hold: None,
357 query: Box::new(create_select_stmt()),
358 updatability: CursorUpdatability::Unspecified,
359 };
360
361 assert!(store.declare(&stmt).is_ok());
362 assert!(store.exists("emp_cursor"));
363 assert!(!store.is_open("emp_cursor"));
364 }
365
366 #[test]
367 fn test_declare_cursor_already_exists() {
368 let mut store = CursorStore::new();
369
370 let stmt = DeclareCursorStmt {
371 cursor_name: "emp_cursor".to_string(),
372 insensitive: false,
373 scroll: false,
374 hold: None,
375 query: Box::new(create_select_stmt()),
376 updatability: CursorUpdatability::Unspecified,
377 };
378
379 store.declare(&stmt).unwrap();
380 let result = store.declare(&stmt);
381 assert!(matches!(result, Err(ExecutorError::CursorAlreadyExists(_))));
382 }
383
384 #[test]
385 fn test_open_cursor() {
386 let db = create_test_db();
387 let mut store = CursorStore::new();
388
389 let declare_stmt = DeclareCursorStmt {
390 cursor_name: "emp_cursor".to_string(),
391 insensitive: false,
392 scroll: false,
393 hold: None,
394 query: Box::new(create_select_stmt()),
395 updatability: CursorUpdatability::Unspecified,
396 };
397 store.declare(&declare_stmt).unwrap();
398
399 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
400 assert!(store.open(&open_stmt, &db).is_ok());
401 assert!(store.is_open("emp_cursor"));
402 }
403
404 #[test]
405 fn test_open_cursor_not_found() {
406 let db = create_test_db();
407 let mut store = CursorStore::new();
408
409 let open_stmt = OpenCursorStmt { cursor_name: "nonexistent".to_string() };
410 let result = store.open(&open_stmt, &db);
411 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
412 }
413
414 #[test]
415 fn test_open_cursor_already_open() {
416 let db = create_test_db();
417 let mut store = CursorStore::new();
418
419 let declare_stmt = DeclareCursorStmt {
420 cursor_name: "emp_cursor".to_string(),
421 insensitive: false,
422 scroll: false,
423 hold: None,
424 query: Box::new(create_select_stmt()),
425 updatability: CursorUpdatability::Unspecified,
426 };
427 store.declare(&declare_stmt).unwrap();
428
429 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
430 store.open(&open_stmt, &db).unwrap();
431
432 let result = store.open(&open_stmt, &db);
433 assert!(matches!(result, Err(ExecutorError::CursorAlreadyOpen(_))));
434 }
435
436 #[test]
437 fn test_fetch_next() {
438 let db = create_test_db();
439 let mut store = CursorStore::new();
440
441 let declare_stmt = DeclareCursorStmt {
442 cursor_name: "emp_cursor".to_string(),
443 insensitive: false,
444 scroll: false,
445 hold: None,
446 query: Box::new(create_select_stmt()),
447 updatability: CursorUpdatability::Unspecified,
448 };
449 store.declare(&declare_stmt).unwrap();
450
451 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
452 store.open(&open_stmt, &db).unwrap();
453
454 let fetch_stmt = FetchStmt {
456 cursor_name: "emp_cursor".to_string(),
457 orientation: FetchOrientation::Next,
458 into_variables: None,
459 };
460 let result = store.fetch(&fetch_stmt).unwrap();
461 assert_eq!(result.rows.len(), 1);
462 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
463
464 let result = store.fetch(&fetch_stmt).unwrap();
466 assert_eq!(result.rows.len(), 1);
467 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
468
469 let result = store.fetch(&fetch_stmt).unwrap();
471 assert_eq!(result.rows.len(), 1);
472 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
473
474 let result = store.fetch(&fetch_stmt).unwrap();
476 assert_eq!(result.rows.len(), 0);
477 }
478
479 #[test]
480 fn test_fetch_from_unopened_cursor() {
481 let mut store = CursorStore::new();
482
483 let declare_stmt = DeclareCursorStmt {
484 cursor_name: "emp_cursor".to_string(),
485 insensitive: false,
486 scroll: false,
487 hold: None,
488 query: Box::new(create_select_stmt()),
489 updatability: CursorUpdatability::Unspecified,
490 };
491 store.declare(&declare_stmt).unwrap();
492
493 let fetch_stmt = FetchStmt {
494 cursor_name: "emp_cursor".to_string(),
495 orientation: FetchOrientation::Next,
496 into_variables: None,
497 };
498 let result = store.fetch(&fetch_stmt);
499 assert!(matches!(result, Err(ExecutorError::CursorNotOpen(_))));
500 }
501
502 #[test]
503 fn test_fetch_prior_non_scrollable() {
504 let db = create_test_db();
505 let mut store = CursorStore::new();
506
507 let declare_stmt = DeclareCursorStmt {
508 cursor_name: "emp_cursor".to_string(),
509 insensitive: false,
510 scroll: false, hold: None,
512 query: Box::new(create_select_stmt()),
513 updatability: CursorUpdatability::Unspecified,
514 };
515 store.declare(&declare_stmt).unwrap();
516
517 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
518 store.open(&open_stmt, &db).unwrap();
519
520 let fetch_next = FetchStmt {
522 cursor_name: "emp_cursor".to_string(),
523 orientation: FetchOrientation::Next,
524 into_variables: None,
525 };
526 store.fetch(&fetch_next).unwrap();
527
528 let fetch_prior = FetchStmt {
530 cursor_name: "emp_cursor".to_string(),
531 orientation: FetchOrientation::Prior,
532 into_variables: None,
533 };
534 let result = store.fetch(&fetch_prior);
535 assert!(matches!(result, Err(ExecutorError::CursorNotScrollable(_))));
536 }
537
538 #[test]
539 fn test_scroll_cursor() {
540 let db = create_test_db();
541 let mut store = CursorStore::new();
542
543 let declare_stmt = DeclareCursorStmt {
544 cursor_name: "scroll_cursor".to_string(),
545 insensitive: false,
546 scroll: true, hold: None,
548 query: Box::new(create_select_stmt()),
549 updatability: CursorUpdatability::Unspecified,
550 };
551 store.declare(&declare_stmt).unwrap();
552
553 let open_stmt = OpenCursorStmt { cursor_name: "scroll_cursor".to_string() };
554 store.open(&open_stmt, &db).unwrap();
555
556 let fetch_last = FetchStmt {
558 cursor_name: "scroll_cursor".to_string(),
559 orientation: FetchOrientation::Last,
560 into_variables: None,
561 };
562 let result = store.fetch(&fetch_last).unwrap();
563 assert_eq!(result.rows.len(), 1);
564 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
565
566 let fetch_first = FetchStmt {
568 cursor_name: "scroll_cursor".to_string(),
569 orientation: FetchOrientation::First,
570 into_variables: None,
571 };
572 let result = store.fetch(&fetch_first).unwrap();
573 assert_eq!(result.rows.len(), 1);
574 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
575
576 let fetch_abs = FetchStmt {
578 cursor_name: "scroll_cursor".to_string(),
579 orientation: FetchOrientation::Absolute(2),
580 into_variables: None,
581 };
582 let result = store.fetch(&fetch_abs).unwrap();
583 assert_eq!(result.rows.len(), 1);
584 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
585 }
586
587 #[test]
588 fn test_close_cursor() {
589 let db = create_test_db();
590 let mut store = CursorStore::new();
591
592 let declare_stmt = DeclareCursorStmt {
593 cursor_name: "emp_cursor".to_string(),
594 insensitive: false,
595 scroll: false,
596 hold: None,
597 query: Box::new(create_select_stmt()),
598 updatability: CursorUpdatability::Unspecified,
599 };
600 store.declare(&declare_stmt).unwrap();
601
602 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
603 store.open(&open_stmt, &db).unwrap();
604
605 let close_stmt = CloseCursorStmt { cursor_name: "emp_cursor".to_string() };
606 assert!(store.close(&close_stmt).is_ok());
607 assert!(!store.exists("emp_cursor"));
608 }
609
610 #[test]
611 fn test_close_nonexistent_cursor() {
612 let mut store = CursorStore::new();
613
614 let close_stmt = CloseCursorStmt { cursor_name: "nonexistent".to_string() };
615 let result = store.close(&close_stmt);
616 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
617 }
618
619 #[test]
620 fn test_case_insensitive_cursor_names() {
621 let db = create_test_db();
622 let mut store = CursorStore::new();
623
624 let declare_stmt = DeclareCursorStmt {
625 cursor_name: "My_Cursor".to_string(),
626 insensitive: false,
627 scroll: false,
628 hold: None,
629 query: Box::new(create_select_stmt()),
630 updatability: CursorUpdatability::Unspecified,
631 };
632 store.declare(&declare_stmt).unwrap();
633
634 let open_stmt = OpenCursorStmt { cursor_name: "MY_CURSOR".to_string() };
636 assert!(store.open(&open_stmt, &db).is_ok());
637
638 let close_stmt = CloseCursorStmt { cursor_name: "my_cursor".to_string() };
640 assert!(store.close(&close_stmt).is_ok());
641 }
642
643 #[test]
644 fn test_holdable_cursor() {
645 let mut store = CursorStore::new();
646
647 let holdable_stmt = DeclareCursorStmt {
649 cursor_name: "holdable".to_string(),
650 insensitive: false,
651 scroll: false,
652 hold: Some(true),
653 query: Box::new(create_select_stmt()),
654 updatability: CursorUpdatability::Unspecified,
655 };
656 store.declare(&holdable_stmt).unwrap();
657
658 let non_holdable_stmt = DeclareCursorStmt {
660 cursor_name: "non_holdable".to_string(),
661 insensitive: false,
662 scroll: false,
663 hold: Some(false),
664 query: Box::new(create_select_stmt()),
665 updatability: CursorUpdatability::Unspecified,
666 };
667 store.declare(&non_holdable_stmt).unwrap();
668
669 assert_eq!(store.count(), 2);
670
671 store.clear_non_holdable();
673
674 assert_eq!(store.count(), 1);
675 assert!(store.exists("holdable"));
676 assert!(!store.exists("non_holdable"));
677 }
678}