1use std::collections::HashMap;
13
14use vibesql_ast::{
15 CloseCursorStmt, DeclareCursorStmt, FetchOrientation, FetchStmt, OpenCursorStmt, SelectStmt,
16};
17use vibesql_storage::{Database, Row};
18
19use crate::errors::ExecutorError;
20use crate::SelectExecutor;
21
22#[derive(Debug, Clone)]
24pub struct Cursor {
25 pub name: String,
27 pub query: Box<SelectStmt>,
29 pub result: Option<CursorResult>,
31 pub position: usize,
33 pub scroll: bool,
35 pub holdable: bool,
37 pub insensitive: bool,
39}
40
41#[derive(Debug, Clone)]
43pub struct CursorResult {
44 pub columns: Vec<String>,
46 pub rows: Vec<Row>,
48}
49
50#[derive(Debug, Clone)]
52pub struct FetchResult {
53 pub columns: Vec<String>,
55 pub rows: Vec<Row>,
57}
58
59impl FetchResult {
60 pub fn empty(columns: Vec<String>) -> Self {
62 Self { columns, rows: vec![] }
63 }
64
65 pub fn single(columns: Vec<String>, row: Row) -> Self {
67 Self { columns, rows: vec![row] }
68 }
69}
70
71#[derive(Debug, Default)]
73pub struct CursorStore {
74 cursors: HashMap<String, Cursor>,
75}
76
77impl CursorStore {
78 pub fn new() -> Self {
80 Self { cursors: HashMap::new() }
81 }
82
83 pub fn declare(&mut self, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
85 let name = stmt.cursor_name.to_uppercase();
86
87 if self.cursors.contains_key(&name) {
88 return Err(ExecutorError::CursorAlreadyExists(name));
89 }
90
91 let cursor = Cursor {
92 name: name.clone(),
93 query: stmt.query.clone(),
94 result: None,
95 position: 0,
96 scroll: stmt.scroll,
97 holdable: stmt.hold.unwrap_or(false),
98 insensitive: stmt.insensitive,
99 };
100
101 self.cursors.insert(name, cursor);
102 Ok(())
103 }
104
105 pub fn open(&mut self, stmt: &OpenCursorStmt, db: &Database) -> Result<(), ExecutorError> {
107 let name = stmt.cursor_name.to_uppercase();
108
109 let cursor = self
110 .cursors
111 .get_mut(&name)
112 .ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
113
114 if cursor.result.is_some() {
115 return Err(ExecutorError::CursorAlreadyOpen(name));
116 }
117
118 let executor = SelectExecutor::new(db);
120 let select_result = executor.execute_with_columns(&cursor.query)?;
121
122 cursor.result =
123 Some(CursorResult { columns: select_result.columns, rows: select_result.rows });
124 cursor.position = 0; Ok(())
127 }
128
129 pub fn fetch(&mut self, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
131 let name = stmt.cursor_name.to_uppercase();
132
133 let cursor = self
134 .cursors
135 .get_mut(&name)
136 .ok_or_else(|| ExecutorError::CursorNotFound(name.clone()))?;
137
138 let result =
139 cursor.result.as_ref().ok_or_else(|| ExecutorError::CursorNotOpen(name.clone()))?;
140
141 let row_count = result.rows.len();
142
143 let new_position = match &stmt.orientation {
145 FetchOrientation::Next => cursor.position.saturating_add(1),
146 FetchOrientation::Prior => {
147 if !cursor.scroll {
148 return Err(ExecutorError::CursorNotScrollable(name));
149 }
150 cursor.position.saturating_sub(1)
151 }
152 FetchOrientation::First => {
153 if !cursor.scroll && cursor.position > 0 {
154 return Err(ExecutorError::CursorNotScrollable(name));
155 }
156 1
157 }
158 FetchOrientation::Last => {
159 if !cursor.scroll {
160 return Err(ExecutorError::CursorNotScrollable(name));
161 }
162 row_count
163 }
164 FetchOrientation::Absolute(n) => {
165 if !cursor.scroll {
166 return Err(ExecutorError::CursorNotScrollable(name));
167 }
168 if *n >= 0 {
169 *n as usize
170 } else {
171 row_count.saturating_sub((-*n - 1) as usize)
173 }
174 }
175 FetchOrientation::Relative(n) => {
176 if !cursor.scroll && *n < 0 {
177 return Err(ExecutorError::CursorNotScrollable(name));
178 }
179 if *n >= 0 {
180 cursor.position.saturating_add(*n as usize)
181 } else {
182 cursor.position.saturating_sub((-*n) as usize)
183 }
184 }
185 };
186
187 cursor.position = new_position;
188
189 if new_position > 0 && new_position <= row_count {
191 Ok(FetchResult::single(result.columns.clone(), result.rows[new_position - 1].clone()))
192 } else {
193 Ok(FetchResult::empty(result.columns.clone()))
195 }
196 }
197
198 pub fn close(&mut self, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
200 let name = stmt.cursor_name.to_uppercase();
201
202 if self.cursors.remove(&name).is_none() {
203 return Err(ExecutorError::CursorNotFound(name));
204 }
205
206 Ok(())
207 }
208
209 pub fn exists(&self, name: &str) -> bool {
211 self.cursors.contains_key(&name.to_uppercase())
212 }
213
214 pub fn is_open(&self, name: &str) -> bool {
216 self.cursors.get(&name.to_uppercase()).map(|c| c.result.is_some()).unwrap_or(false)
217 }
218
219 pub fn count(&self) -> usize {
221 self.cursors.len()
222 }
223
224 pub fn clear_non_holdable(&mut self) {
226 self.cursors.retain(|_, cursor| cursor.holdable);
227 }
228
229 pub fn clear(&mut self) {
231 self.cursors.clear();
232 }
233}
234
235pub struct CursorExecutor;
237
238impl CursorExecutor {
239 pub fn declare(store: &mut CursorStore, stmt: &DeclareCursorStmt) -> Result<(), ExecutorError> {
241 store.declare(stmt)
242 }
243
244 pub fn open(
246 store: &mut CursorStore,
247 stmt: &OpenCursorStmt,
248 db: &Database,
249 ) -> Result<(), ExecutorError> {
250 store.open(stmt, db)
251 }
252
253 pub fn fetch(store: &mut CursorStore, stmt: &FetchStmt) -> Result<FetchResult, ExecutorError> {
255 store.fetch(stmt)
256 }
257
258 pub fn close(store: &mut CursorStore, stmt: &CloseCursorStmt) -> Result<(), ExecutorError> {
260 store.close(stmt)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use vibesql_ast::{
268 CloseCursorStmt, CursorUpdatability, DeclareCursorStmt, FetchOrientation, FetchStmt,
269 FromClause, OpenCursorStmt, SelectItem, SelectStmt,
270 };
271 use vibesql_catalog::{ColumnSchema, TableSchema};
272 use vibesql_types::{DataType, SqlValue};
273
274 fn create_test_db() -> Database {
275 let mut db = Database::new();
276 db.catalog.set_case_sensitive_identifiers(false);
277
278 let columns = vec![
280 ColumnSchema::new("id".to_string(), DataType::Integer, false),
281 ColumnSchema::new(
282 "name".to_string(),
283 DataType::Varchar { max_length: Some(100) },
284 true,
285 ),
286 ColumnSchema::new("salary".to_string(), DataType::Integer, true),
287 ];
288 let schema =
289 TableSchema::with_primary_key("employees".to_string(), columns, vec!["id".to_string()]);
290 db.create_table(schema).unwrap();
291
292 db.insert_row(
294 "employees",
295 Row::new(vec![
296 SqlValue::Integer(1),
297 SqlValue::Varchar("Alice".into()),
298 SqlValue::Integer(50000),
299 ]),
300 )
301 .unwrap();
302 db.insert_row(
303 "employees",
304 Row::new(vec![
305 SqlValue::Integer(2),
306 SqlValue::Varchar("Bob".into()),
307 SqlValue::Integer(60000),
308 ]),
309 )
310 .unwrap();
311 db.insert_row(
312 "employees",
313 Row::new(vec![
314 SqlValue::Integer(3),
315 SqlValue::Varchar("Carol".into()),
316 SqlValue::Integer(55000),
317 ]),
318 )
319 .unwrap();
320
321 db
322 }
323
324 fn create_select_stmt() -> SelectStmt {
325 SelectStmt {
326 with_clause: None,
327 distinct: false,
328 select_list: vec![SelectItem::Wildcard { alias: None }],
329 into_table: None,
330 into_variables: None,
331 from: Some(FromClause::Table {
332 name: "employees".to_string(),
333 alias: None,
334 column_aliases: None,
335 }),
336 where_clause: None,
337 group_by: None,
338 having: None,
339 order_by: None,
340 limit: None,
341 offset: None,
342 set_operation: None,
343 }
344 }
345
346 #[test]
347 fn test_declare_cursor() {
348 let mut store = CursorStore::new();
349
350 let stmt = DeclareCursorStmt {
351 cursor_name: "emp_cursor".to_string(),
352 insensitive: false,
353 scroll: false,
354 hold: None,
355 query: Box::new(create_select_stmt()),
356 updatability: CursorUpdatability::Unspecified,
357 };
358
359 assert!(store.declare(&stmt).is_ok());
360 assert!(store.exists("emp_cursor"));
361 assert!(!store.is_open("emp_cursor"));
362 }
363
364 #[test]
365 fn test_declare_cursor_already_exists() {
366 let mut store = CursorStore::new();
367
368 let stmt = DeclareCursorStmt {
369 cursor_name: "emp_cursor".to_string(),
370 insensitive: false,
371 scroll: false,
372 hold: None,
373 query: Box::new(create_select_stmt()),
374 updatability: CursorUpdatability::Unspecified,
375 };
376
377 store.declare(&stmt).unwrap();
378 let result = store.declare(&stmt);
379 assert!(matches!(result, Err(ExecutorError::CursorAlreadyExists(_))));
380 }
381
382 #[test]
383 fn test_open_cursor() {
384 let db = create_test_db();
385 let mut store = CursorStore::new();
386
387 let declare_stmt = DeclareCursorStmt {
388 cursor_name: "emp_cursor".to_string(),
389 insensitive: false,
390 scroll: false,
391 hold: None,
392 query: Box::new(create_select_stmt()),
393 updatability: CursorUpdatability::Unspecified,
394 };
395 store.declare(&declare_stmt).unwrap();
396
397 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
398 assert!(store.open(&open_stmt, &db).is_ok());
399 assert!(store.is_open("emp_cursor"));
400 }
401
402 #[test]
403 fn test_open_cursor_not_found() {
404 let db = create_test_db();
405 let mut store = CursorStore::new();
406
407 let open_stmt = OpenCursorStmt { cursor_name: "nonexistent".to_string() };
408 let result = store.open(&open_stmt, &db);
409 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
410 }
411
412 #[test]
413 fn test_open_cursor_already_open() {
414 let db = create_test_db();
415 let mut store = CursorStore::new();
416
417 let declare_stmt = DeclareCursorStmt {
418 cursor_name: "emp_cursor".to_string(),
419 insensitive: false,
420 scroll: false,
421 hold: None,
422 query: Box::new(create_select_stmt()),
423 updatability: CursorUpdatability::Unspecified,
424 };
425 store.declare(&declare_stmt).unwrap();
426
427 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
428 store.open(&open_stmt, &db).unwrap();
429
430 let result = store.open(&open_stmt, &db);
431 assert!(matches!(result, Err(ExecutorError::CursorAlreadyOpen(_))));
432 }
433
434 #[test]
435 fn test_fetch_next() {
436 let db = create_test_db();
437 let mut store = CursorStore::new();
438
439 let declare_stmt = DeclareCursorStmt {
440 cursor_name: "emp_cursor".to_string(),
441 insensitive: false,
442 scroll: false,
443 hold: None,
444 query: Box::new(create_select_stmt()),
445 updatability: CursorUpdatability::Unspecified,
446 };
447 store.declare(&declare_stmt).unwrap();
448
449 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
450 store.open(&open_stmt, &db).unwrap();
451
452 let fetch_stmt = FetchStmt {
454 cursor_name: "emp_cursor".to_string(),
455 orientation: FetchOrientation::Next,
456 into_variables: None,
457 };
458 let result = store.fetch(&fetch_stmt).unwrap();
459 assert_eq!(result.rows.len(), 1);
460 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
461
462 let result = store.fetch(&fetch_stmt).unwrap();
464 assert_eq!(result.rows.len(), 1);
465 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
466
467 let result = store.fetch(&fetch_stmt).unwrap();
469 assert_eq!(result.rows.len(), 1);
470 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
471
472 let result = store.fetch(&fetch_stmt).unwrap();
474 assert_eq!(result.rows.len(), 0);
475 }
476
477 #[test]
478 fn test_fetch_from_unopened_cursor() {
479 let mut store = CursorStore::new();
480
481 let declare_stmt = DeclareCursorStmt {
482 cursor_name: "emp_cursor".to_string(),
483 insensitive: false,
484 scroll: false,
485 hold: None,
486 query: Box::new(create_select_stmt()),
487 updatability: CursorUpdatability::Unspecified,
488 };
489 store.declare(&declare_stmt).unwrap();
490
491 let fetch_stmt = FetchStmt {
492 cursor_name: "emp_cursor".to_string(),
493 orientation: FetchOrientation::Next,
494 into_variables: None,
495 };
496 let result = store.fetch(&fetch_stmt);
497 assert!(matches!(result, Err(ExecutorError::CursorNotOpen(_))));
498 }
499
500 #[test]
501 fn test_fetch_prior_non_scrollable() {
502 let db = create_test_db();
503 let mut store = CursorStore::new();
504
505 let declare_stmt = DeclareCursorStmt {
506 cursor_name: "emp_cursor".to_string(),
507 insensitive: false,
508 scroll: false, hold: None,
510 query: Box::new(create_select_stmt()),
511 updatability: CursorUpdatability::Unspecified,
512 };
513 store.declare(&declare_stmt).unwrap();
514
515 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
516 store.open(&open_stmt, &db).unwrap();
517
518 let fetch_next = FetchStmt {
520 cursor_name: "emp_cursor".to_string(),
521 orientation: FetchOrientation::Next,
522 into_variables: None,
523 };
524 store.fetch(&fetch_next).unwrap();
525
526 let fetch_prior = FetchStmt {
528 cursor_name: "emp_cursor".to_string(),
529 orientation: FetchOrientation::Prior,
530 into_variables: None,
531 };
532 let result = store.fetch(&fetch_prior);
533 assert!(matches!(result, Err(ExecutorError::CursorNotScrollable(_))));
534 }
535
536 #[test]
537 fn test_scroll_cursor() {
538 let db = create_test_db();
539 let mut store = CursorStore::new();
540
541 let declare_stmt = DeclareCursorStmt {
542 cursor_name: "scroll_cursor".to_string(),
543 insensitive: false,
544 scroll: true, hold: None,
546 query: Box::new(create_select_stmt()),
547 updatability: CursorUpdatability::Unspecified,
548 };
549 store.declare(&declare_stmt).unwrap();
550
551 let open_stmt = OpenCursorStmt { cursor_name: "scroll_cursor".to_string() };
552 store.open(&open_stmt, &db).unwrap();
553
554 let fetch_last = FetchStmt {
556 cursor_name: "scroll_cursor".to_string(),
557 orientation: FetchOrientation::Last,
558 into_variables: None,
559 };
560 let result = store.fetch(&fetch_last).unwrap();
561 assert_eq!(result.rows.len(), 1);
562 assert_eq!(result.rows[0].values[0], SqlValue::Integer(3));
563
564 let fetch_first = FetchStmt {
566 cursor_name: "scroll_cursor".to_string(),
567 orientation: FetchOrientation::First,
568 into_variables: None,
569 };
570 let result = store.fetch(&fetch_first).unwrap();
571 assert_eq!(result.rows.len(), 1);
572 assert_eq!(result.rows[0].values[0], SqlValue::Integer(1));
573
574 let fetch_abs = FetchStmt {
576 cursor_name: "scroll_cursor".to_string(),
577 orientation: FetchOrientation::Absolute(2),
578 into_variables: None,
579 };
580 let result = store.fetch(&fetch_abs).unwrap();
581 assert_eq!(result.rows.len(), 1);
582 assert_eq!(result.rows[0].values[0], SqlValue::Integer(2));
583 }
584
585 #[test]
586 fn test_close_cursor() {
587 let db = create_test_db();
588 let mut store = CursorStore::new();
589
590 let declare_stmt = DeclareCursorStmt {
591 cursor_name: "emp_cursor".to_string(),
592 insensitive: false,
593 scroll: false,
594 hold: None,
595 query: Box::new(create_select_stmt()),
596 updatability: CursorUpdatability::Unspecified,
597 };
598 store.declare(&declare_stmt).unwrap();
599
600 let open_stmt = OpenCursorStmt { cursor_name: "emp_cursor".to_string() };
601 store.open(&open_stmt, &db).unwrap();
602
603 let close_stmt = CloseCursorStmt { cursor_name: "emp_cursor".to_string() };
604 assert!(store.close(&close_stmt).is_ok());
605 assert!(!store.exists("emp_cursor"));
606 }
607
608 #[test]
609 fn test_close_nonexistent_cursor() {
610 let mut store = CursorStore::new();
611
612 let close_stmt = CloseCursorStmt { cursor_name: "nonexistent".to_string() };
613 let result = store.close(&close_stmt);
614 assert!(matches!(result, Err(ExecutorError::CursorNotFound(_))));
615 }
616
617 #[test]
618 fn test_case_insensitive_cursor_names() {
619 let db = create_test_db();
620 let mut store = CursorStore::new();
621
622 let declare_stmt = DeclareCursorStmt {
623 cursor_name: "My_Cursor".to_string(),
624 insensitive: false,
625 scroll: false,
626 hold: None,
627 query: Box::new(create_select_stmt()),
628 updatability: CursorUpdatability::Unspecified,
629 };
630 store.declare(&declare_stmt).unwrap();
631
632 let open_stmt = OpenCursorStmt { cursor_name: "MY_CURSOR".to_string() };
634 assert!(store.open(&open_stmt, &db).is_ok());
635
636 let close_stmt = CloseCursorStmt { cursor_name: "my_cursor".to_string() };
638 assert!(store.close(&close_stmt).is_ok());
639 }
640
641 #[test]
642 fn test_holdable_cursor() {
643 let mut store = CursorStore::new();
644
645 let holdable_stmt = DeclareCursorStmt {
647 cursor_name: "holdable".to_string(),
648 insensitive: false,
649 scroll: false,
650 hold: Some(true),
651 query: Box::new(create_select_stmt()),
652 updatability: CursorUpdatability::Unspecified,
653 };
654 store.declare(&holdable_stmt).unwrap();
655
656 let non_holdable_stmt = DeclareCursorStmt {
658 cursor_name: "non_holdable".to_string(),
659 insensitive: false,
660 scroll: false,
661 hold: Some(false),
662 query: Box::new(create_select_stmt()),
663 updatability: CursorUpdatability::Unspecified,
664 };
665 store.declare(&non_holdable_stmt).unwrap();
666
667 assert_eq!(store.count(), 2);
668
669 store.clear_non_holdable();
671
672 assert_eq!(store.count(), 1);
673 assert!(store.exists("holdable"));
674 assert!(!store.exists("non_holdable"));
675 }
676}