venus_core/salsa_db/
queries.rs

1//! Salsa tracked query functions.
2//!
3//! These functions are memoized by Salsa. Results are cached and only
4//! recomputed when their inputs change.
5//!
6//! # Error Handling
7//!
8//! Query functions that can fail return a [`QueryResult`] enum that captures
9//! both the successful result and any errors. This allows callers to distinguish
10//! between "no results" and "error occurred" cases.
11
12use std::collections::hash_map::DefaultHasher;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16use crate::compile::DependencyParser;
17use crate::graph::{CellId, CellInfo, CellParser, GraphEngine};
18
19use super::conversions::{CellData, CompilationStatus};
20use super::inputs::{CompilerSettings, SourceFile};
21
22/// Result wrapper for Salsa queries that can fail.
23///
24/// # Why not `std::result::Result`?
25///
26/// Salsa's memoization requires return types to implement `Clone`, `PartialEq`,
27/// `Eq`, and `Hash`. While `std::result::Result<T, E>` implements these traits
28/// when `T` and `E` do, using a dedicated type provides:
29///
30/// 1. **Simpler error type**: Always `String`, avoiding generic error handling
31/// 2. **Explicit intent**: Clearly marks Salsa-compatible error boundaries
32/// 3. **Consistent API**: All Venus queries use the same error pattern
33///
34/// # Usage Pattern
35///
36/// Most queries come in pairs: a `*_result` variant that returns errors, and
37/// a convenience variant that returns a default on failure:
38///
39/// ```ignore
40/// // Option 1: Handle errors explicitly
41/// match db.get_execution_order_result(source) {
42///     QueryResult::Ok(order) => process(order),
43///     QueryResult::Err(e) => log_error(e),
44/// }
45///
46/// // Option 2: Use default on error (empty vec)
47/// let order = db.get_execution_order(source); // Returns empty vec on error
48/// ```
49///
50/// Use the `*_result` variant when you need to:
51/// - Display error messages to users
52/// - Distinguish "no cells found" from "parse error"
53/// - Propagate errors to callers
54#[derive(Debug, Clone, PartialEq, Eq, Hash)]
55pub enum QueryResult<T> {
56    /// Query succeeded with a result
57    Ok(T),
58    /// Query failed with an error message
59    Err(String),
60}
61
62impl<T> QueryResult<T> {
63    /// Returns true if the query succeeded.
64    pub fn is_ok(&self) -> bool {
65        matches!(self, Self::Ok(_))
66    }
67
68    /// Returns true if the query failed.
69    pub fn is_err(&self) -> bool {
70        matches!(self, Self::Err(_))
71    }
72
73    /// Get the result if successful.
74    pub fn ok(&self) -> Option<&T> {
75        match self {
76            Self::Ok(v) => Some(v),
77            Self::Err(_) => None,
78        }
79    }
80
81    /// Get the error message if failed.
82    pub fn err(&self) -> Option<&str> {
83        match self {
84            Self::Ok(_) => None,
85            Self::Err(e) => Some(e),
86        }
87    }
88
89    /// Unwrap the result, panicking if it's an error.
90    ///
91    /// # Panics
92    ///
93    /// Panics with the error message if the query failed. This is intentional
94    /// for cases where failure indicates a programming error (e.g., in tests).
95    /// For production code, prefer [`ok()`], [`unwrap_or()`], or pattern matching.
96    pub fn unwrap(self) -> T {
97        match self {
98            Self::Ok(v) => v,
99            Self::Err(e) => panic!("Query failed: {}", e),
100        }
101    }
102
103    /// Unwrap the result or return a default value.
104    ///
105    /// This is the recommended way to handle errors when you have a sensible
106    /// default (e.g., empty vector for missing data).
107    pub fn unwrap_or(self, default: T) -> T {
108        match self {
109            Self::Ok(v) => v,
110            Self::Err(_) => default,
111        }
112    }
113}
114
115/// Tracked function: Parse cells from a source file with error reporting.
116///
117/// This query extracts all `#[venus::cell]` functions from the source.
118/// Results are memoized and only recomputed when the source changes.
119///
120/// Unlike [`parse_cells`], this version reports parsing errors instead of
121/// silently returning an empty vector.
122#[salsa::tracked]
123pub fn parse_cells_result(
124    db: &dyn salsa::Database,
125    source: SourceFile,
126) -> QueryResult<Vec<CellData>> {
127    let path = source.path(db);
128    let text = source.text(db);
129
130    let mut parser = CellParser::new();
131    match parser.parse_str(&text, &path) {
132        Ok(parse_result) => {
133            QueryResult::Ok(parse_result.code_cells.into_iter().map(CellData::from).collect())
134        }
135        Err(e) => {
136            let error_msg = format!("Failed to parse '{}': {}", path.display(), e);
137            tracing::error!("{}", error_msg);
138            QueryResult::Err(error_msg)
139        }
140    }
141}
142
143/// Tracked function: Parse cells from a source file.
144///
145/// This query extracts all `#[venus::cell]` functions from the source.
146/// Results are memoized and only recomputed when the source changes.
147///
148/// Returns an empty vector on parse errors. Use [`parse_cells_result`] if you
149/// need to distinguish between "no cells" and "parse error".
150#[salsa::tracked]
151pub fn parse_cells(db: &dyn salsa::Database, source: SourceFile) -> Vec<CellData> {
152    parse_cells_result(db, source).unwrap_or(Vec::new())
153}
154
155/// Tracked function: Get cell names for quick lookup.
156#[salsa::tracked]
157pub fn cell_names(db: &dyn salsa::Database, source: SourceFile) -> Vec<String> {
158    parse_cells(db, source)
159        .iter()
160        .map(|c| c.name.clone())
161        .collect()
162}
163
164/// Build a GraphEngine from parsed cells.
165///
166/// Takes ownership of cells to avoid cloning during conversion.
167/// Returns an error message if dependency resolution fails.
168fn build_graph_engine(cells: Vec<CellData>) -> Result<GraphEngine, String> {
169    let mut engine = GraphEngine::new();
170
171    for cell_data in cells {
172        engine.add_cell(cell_data.into());
173    }
174
175    engine.resolve_dependencies().map_err(|e| {
176        format!("Failed to resolve dependencies: {}", e)
177    })?;
178
179    Ok(engine)
180}
181
182/// Tracked function: Analyze dependency graph and compute execution metadata.
183///
184/// This is the central query for graph analysis. It builds the GraphEngine once
185/// and computes both execution order and parallel levels together, eliminating
186/// redundant graph construction. Other queries like `execution_order` and
187/// `parallel_levels` extract their results from this cached analysis.
188///
189/// Returns a `QueryResult` with the analysis on success, or an error describing
190/// what went wrong (parse errors, missing dependencies, cycles, etc.).
191#[salsa::tracked]
192pub fn graph_analysis_result(
193    db: &dyn salsa::Database,
194    source: SourceFile,
195) -> QueryResult<super::conversions::GraphAnalysis> {
196    // First check for parse errors
197    let cells_result = parse_cells_result(db, source);
198    let cells = match cells_result {
199        QueryResult::Ok(c) => c,
200        QueryResult::Err(e) => return QueryResult::Err(e),
201    };
202
203    if cells.is_empty() {
204        return QueryResult::Ok(super::conversions::GraphAnalysis::empty());
205    }
206
207    // Build the graph once (takes ownership of cells)
208    let engine = match build_graph_engine(cells) {
209        Ok(e) => e,
210        Err(e) => {
211            tracing::error!("{}", e);
212            return QueryResult::Err(e);
213        }
214    };
215
216    // Get topological order
217    let order = match engine.topological_order() {
218        Ok(order) => order,
219        Err(e) => {
220            let error_msg = format!("Failed to compute execution order: {}", e);
221            tracing::error!("{}", error_msg);
222            return QueryResult::Err(error_msg);
223        }
224    };
225
226    // Compute parallel levels from the same graph (no rebuild!)
227    let parallel_levels = engine
228        .topological_levels(&order)
229        .into_iter()
230        .map(|level| level.into_iter().map(|id| id.as_usize()).collect())
231        .collect();
232
233    let execution_order = order.into_iter().map(|id| id.as_usize()).collect();
234
235    QueryResult::Ok(super::conversions::GraphAnalysis {
236        execution_order,
237        parallel_levels,
238    })
239}
240
241/// Tracked function: Get cached graph analysis.
242///
243/// Returns the combined execution order and parallel levels.
244/// Use [`graph_analysis_result`] if you need error details.
245#[salsa::tracked]
246pub fn graph_analysis(
247    db: &dyn salsa::Database,
248    source: SourceFile,
249) -> super::conversions::GraphAnalysis {
250    graph_analysis_result(db, source).unwrap_or(super::conversions::GraphAnalysis::empty())
251}
252
253/// Tracked function: Build and validate dependency graph with error reporting.
254///
255/// Returns the topological execution order if the graph is valid,
256/// or an error describing what went wrong.
257///
258/// This query extracts execution order from the cached [`graph_analysis_result`].
259#[salsa::tracked]
260pub fn execution_order_result(
261    db: &dyn salsa::Database,
262    source: SourceFile,
263) -> QueryResult<Vec<usize>> {
264    match graph_analysis_result(db, source) {
265        QueryResult::Ok(analysis) => QueryResult::Ok(analysis.execution_order),
266        QueryResult::Err(e) => QueryResult::Err(e),
267    }
268}
269
270/// Tracked function: Build and validate dependency graph.
271///
272/// Returns the topological execution order if the graph is valid,
273/// or an empty vec if there are cycles or missing dependencies.
274///
275/// Use [`execution_order_result`] if you need to distinguish between
276/// "no cells" and "graph error".
277#[salsa::tracked]
278pub fn execution_order(db: &dyn salsa::Database, source: SourceFile) -> Vec<usize> {
279    graph_analysis(db, source).execution_order
280}
281
282/// Tracked function: Get cells invalidated by a change.
283///
284/// Note: This query still builds its own graph because it needs to call
285/// `invalidated_cells()` which isn't part of the standard graph analysis.
286#[salsa::tracked]
287pub fn invalidated_by(
288    db: &dyn salsa::Database,
289    source: SourceFile,
290    changed_idx: usize,
291) -> Vec<usize> {
292    let cells = parse_cells(db, source);
293
294    let engine = match build_graph_engine(cells) {
295        Ok(e) => e,
296        Err(_) => return Vec::new(),
297    };
298
299    engine
300        .invalidated_cells(CellId::new(changed_idx))
301        .into_iter()
302        .map(|id| id.as_usize())
303        .collect()
304}
305
306/// Tracked function: Get parallel execution levels.
307///
308/// Returns groups of cell indices that can be executed in parallel.
309/// This query extracts parallel levels from the cached [`graph_analysis`].
310#[salsa::tracked]
311pub fn parallel_levels(db: &dyn salsa::Database, source: SourceFile) -> Vec<Vec<usize>> {
312    graph_analysis(db, source).parallel_levels
313}
314
315/// Tracked function: Compute dependency hash from source.
316///
317/// This hash represents all external dependencies declared in the notebook.
318/// Changes to dependencies will invalidate compiled cells.
319#[salsa::tracked]
320pub fn dependency_hash(db: &dyn salsa::Database, source: SourceFile) -> u64 {
321    let text = source.text(db);
322
323    let mut parser = DependencyParser::new();
324    parser.parse(&text);
325
326    let mut hasher = DefaultHasher::new();
327
328    // Hash each dependency's name, version, features, and path
329    for dep in parser.dependencies() {
330        dep.name.hash(&mut hasher);
331        dep.version.hash(&mut hasher);
332        dep.features.hash(&mut hasher);
333        if let Some(path) = &dep.path {
334            path.hash(&mut hasher);
335        }
336    }
337
338    hasher.finish()
339}
340
341/// Tracked function: Compile a cell.
342///
343/// This query compiles a cell to a dynamic library. Results are memoized
344/// by Salsa, so repeated calls with the same inputs return cached results.
345///
346/// The compilation depends on:
347/// - The cell's source code (via CellData from parse_cells)
348/// - The dependency hash (via dependency_hash)
349/// - The compiler settings (via CompilerSettings input)
350#[salsa::tracked]
351pub fn compiled_cell(
352    db: &dyn salsa::Database,
353    source: SourceFile,
354    cell_idx: usize,
355    settings: CompilerSettings,
356) -> CompilationStatus {
357    let cells = parse_cells(db, source);
358
359    // Find the cell
360    let Some(cell_data) = cells.get(cell_idx) else {
361        return CompilationStatus::Failed(format!("Cell index {} not found", cell_idx));
362    };
363
364    // Get dependency hash
365    let deps_hash = dependency_hash(db, source);
366
367    // Convert to CellInfo for the compiler
368    let cell_info: CellInfo = cell_data.clone().into();
369
370    // Create compiler configuration
371    let config = crate::compile::CompilerConfig {
372        build_dir: settings.build_dir(db),
373        cache_dir: settings.cache_dir(db),
374        use_cranelift: settings.use_cranelift(db),
375        debug_info: true,
376        opt_level: settings.opt_level(db),
377        extra_rustc_flags: Vec::new(),
378        venus_crate_path: crate::compile::CompilerConfig::default().venus_crate_path,
379    };
380
381    // Create the compiler
382    let toolchain = match crate::compile::ToolchainManager::new() {
383        Ok(tc) => tc,
384        Err(e) => {
385            return CompilationStatus::Failed(format!("Toolchain error: {}", e));
386        }
387    };
388
389    let mut compiler = crate::compile::CellCompiler::new(config, toolchain);
390
391    // Set universe path if available
392    if let Some(universe_path) = settings.universe_path(db) {
393        compiler = compiler.with_universe(universe_path);
394    }
395
396    // Compile the cell
397    match compiler.compile(&cell_info, deps_hash) {
398        crate::compile::CompilationResult::Success(compiled) => {
399            CompilationStatus::Success(compiled.into())
400        }
401        crate::compile::CompilationResult::Cached(compiled) => {
402            CompilationStatus::Cached(compiled.into())
403        }
404        crate::compile::CompilationResult::Failed { errors, .. } => {
405            let error_msg = errors
406                .iter()
407                .map(|e| e.message.clone())
408                .collect::<Vec<_>>()
409                .join("\n");
410            CompilationStatus::Failed(error_msg)
411        }
412    }
413}
414
415/// Tracked function: Compile all cells in execution order.
416///
417/// Returns a list of compilation results for all cells, wrapped in `Arc` for
418/// efficient sharing across clones of the query result without deep copying
419/// the potentially large compilation results vector.
420#[salsa::tracked]
421pub fn compile_all_cells(
422    db: &dyn salsa::Database,
423    source: SourceFile,
424    settings: CompilerSettings,
425) -> Arc<Vec<CompilationStatus>> {
426    let order = execution_order(db, source);
427
428    let results: Vec<CompilationStatus> = order
429        .iter()
430        .map(|&idx| compiled_cell(db, source, idx, settings))
431        .collect();
432
433    Arc::new(results)
434}
435
436/// Tracked function: Get the execution status for a specific cell.
437///
438/// Returns the current execution status (pending, running, success, or failed)
439/// for the specified cell. This query depends on the CellOutputs input, so
440/// it will be recomputed when outputs are updated.
441#[salsa::tracked]
442pub fn cell_output(
443    db: &dyn salsa::Database,
444    outputs: super::inputs::CellOutputs,
445    cell_idx: usize,
446) -> super::conversions::ExecutionStatus {
447    let statuses = outputs.statuses(db);
448    statuses
449        .get(cell_idx)
450        .cloned()
451        .unwrap_or(super::conversions::ExecutionStatus::Pending)
452}
453
454/// Tracked function: Check if all cells have completed execution.
455///
456/// Returns true if all cells have either succeeded or failed.
457#[salsa::tracked]
458pub fn all_cells_executed(
459    db: &dyn salsa::Database,
460    outputs: super::inputs::CellOutputs,
461) -> bool {
462    let statuses = outputs.statuses(db);
463    statuses.iter().all(|s| {
464        matches!(
465            s,
466            super::conversions::ExecutionStatus::Success(_)
467                | super::conversions::ExecutionStatus::Failed(_)
468        )
469    })
470}
471
472/// Tracked function: Get successful output for a cell.
473///
474/// Returns the output data if the cell executed successfully,
475/// or None if pending, running, or failed.
476#[salsa::tracked]
477pub fn cell_output_data(
478    db: &dyn salsa::Database,
479    outputs: super::inputs::CellOutputs,
480    cell_idx: usize,
481) -> Option<super::conversions::CellOutputData> {
482    let status = cell_output(db, outputs, cell_idx);
483    status.output().cloned()
484}
485
486#[cfg(test)]
487mod tests {
488    use crate::salsa_db::VenusDatabase;
489    use std::path::PathBuf;
490
491    #[test]
492    fn test_parse_cells_query() {
493        let db = VenusDatabase::new();
494
495        let source = db.set_source(
496            PathBuf::from("test.rs"),
497            r#"
498                #[venus::cell]
499                pub fn a() -> i32 { 1 }
500
501                #[venus::cell]
502                pub fn b(a: &i32) -> i32 { *a + 1 }
503            "#
504            .to_string(),
505        );
506
507        let cells = db.get_cells(source);
508
509        assert_eq!(cells.len(), 2);
510        assert_eq!(cells[0].name, "a");
511        assert_eq!(cells[1].name, "b");
512    }
513
514    #[test]
515    fn test_execution_order_query() {
516        let db = VenusDatabase::new();
517
518        let source = db.set_source(
519            PathBuf::from("test.rs"),
520            r#"
521                #[venus::cell]
522                pub fn a() -> i32 { 1 }
523
524                #[venus::cell]
525                pub fn b(a: &i32) -> i32 { *a + 1 }
526            "#
527            .to_string(),
528        );
529
530        let order = db.get_execution_order(source);
531        assert_eq!(order.len(), 2);
532        // 'a' should come before 'b'
533        assert_eq!(order[0], 0); // a
534        assert_eq!(order[1], 1); // b
535    }
536
537    #[test]
538    fn test_invalidated_cells_query() {
539        let db = VenusDatabase::new();
540
541        let source = db.set_source(
542            PathBuf::from("test.rs"),
543            r#"
544                #[venus::cell]
545                pub fn a() -> i32 { 1 }
546
547                #[venus::cell]
548                pub fn b(a: &i32) -> i32 { *a + 1 }
549
550                #[venus::cell]
551                pub fn c(b: &i32) -> i32 { *b + 1 }
552            "#
553            .to_string(),
554        );
555
556        // If 'a' changes, a, b, and c all need to re-execute
557        // (the changed cell plus all its transitive dependents)
558        let invalidated = db.get_invalidated(source, 0);
559        assert_eq!(invalidated.len(), 3);
560        assert_eq!(invalidated, vec![0, 1, 2]); // a -> b -> c in topological order
561    }
562
563    #[test]
564    fn test_parallel_levels() {
565        let db = VenusDatabase::new();
566
567        let source = db.set_source(
568            PathBuf::from("test.rs"),
569            r#"
570                #[venus::cell]
571                pub fn a() -> i32 { 1 }
572
573                #[venus::cell]
574                pub fn b() -> i32 { 2 }
575
576                #[venus::cell]
577                pub fn c(a: &i32, b: &i32) -> i32 { *a + *b }
578            "#
579            .to_string(),
580        );
581
582        let levels = db.get_parallel_levels(source);
583        assert_eq!(levels.len(), 2);
584        // First level: a and b (can run in parallel)
585        assert_eq!(levels[0].len(), 2);
586        // Second level: c (depends on both)
587        assert_eq!(levels[1].len(), 1);
588    }
589
590    #[test]
591    fn test_dependency_hash() {
592        let db = VenusDatabase::new();
593
594        // Source with dependencies (using correct ```cargo block format)
595        let source1 = db.set_source(
596            PathBuf::from("test1.rs"),
597            r#"
598//! ```cargo
599//! [dependencies]
600//! tokio = "1"
601//! serde = { version = "1.0", features = ["derive"] }
602//! ```
603
604#[venus::cell]
605pub fn a() -> i32 { 1 }
606            "#
607            .to_string(),
608        );
609
610        // Same dependencies - should produce same hash
611        let source2 = db.set_source(
612            PathBuf::from("test2.rs"),
613            r#"
614//! ```cargo
615//! [dependencies]
616//! tokio = "1"
617//! serde = { version = "1.0", features = ["derive"] }
618//! ```
619
620#[venus::cell]
621pub fn b() -> i32 { 2 }
622            "#
623            .to_string(),
624        );
625
626        // Different dependencies - should produce different hash
627        let source3 = db.set_source(
628            PathBuf::from("test3.rs"),
629            r#"
630//! ```cargo
631//! [dependencies]
632//! anyhow = "1.0"
633//! ```
634
635#[venus::cell]
636pub fn c() -> i32 { 3 }
637            "#
638            .to_string(),
639        );
640
641        let hash1 = db.get_dependency_hash(source1);
642        let hash2 = db.get_dependency_hash(source2);
643        let hash3 = db.get_dependency_hash(source3);
644
645        // Same dependencies should have same hash
646        assert_eq!(hash1, hash2);
647        // Different dependencies should have different hash
648        assert_ne!(hash1, hash3);
649    }
650
651    #[test]
652    fn test_query_result_methods() {
653        use super::QueryResult;
654
655        let ok: QueryResult<i32> = QueryResult::Ok(42);
656        assert!(ok.is_ok());
657        assert!(!ok.is_err());
658        assert_eq!(ok.ok(), Some(&42));
659        assert_eq!(ok.err(), None);
660        assert_eq!(ok.unwrap(), 42);
661
662        let err: QueryResult<i32> = QueryResult::Err("error".to_string());
663        assert!(!err.is_ok());
664        assert!(err.is_err());
665        assert_eq!(err.ok(), None);
666        assert_eq!(err.err(), Some("error"));
667        assert_eq!(err.unwrap_or(0), 0);
668    }
669
670    #[test]
671    fn test_parse_cells_result_success() {
672        let db = VenusDatabase::new();
673
674        let source = db.set_source(
675            PathBuf::from("test.rs"),
676            r#"
677                #[venus::cell]
678                pub fn a() -> i32 { 1 }
679            "#
680            .to_string(),
681        );
682
683        let result = db.get_cells_result(source);
684        assert!(result.is_ok());
685        assert_eq!(result.ok().unwrap().len(), 1);
686    }
687
688    #[test]
689    fn test_execution_order_result_success() {
690        let db = VenusDatabase::new();
691
692        let source = db.set_source(
693            PathBuf::from("test.rs"),
694            r#"
695                #[venus::cell]
696                pub fn a() -> i32 { 1 }
697
698                #[venus::cell]
699                pub fn b(a: &i32) -> i32 { *a + 1 }
700            "#
701            .to_string(),
702        );
703
704        let result = db.get_execution_order_result(source);
705        assert!(result.is_ok());
706        assert_eq!(result.ok().unwrap().len(), 2);
707    }
708}