1use std::collections::hash_map::DefaultHasher;
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16use crate::compile::DependencyParser;
17use crate::graph::{CellId, CellInfo, CellParser, GraphEngine};
18
19use super::conversions::{CellData, CompilationStatus};
20use super::inputs::{CompilerSettings, SourceFile};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
55pub enum QueryResult<T> {
56 Ok(T),
58 Err(String),
60}
61
62impl<T> QueryResult<T> {
63 pub fn is_ok(&self) -> bool {
65 matches!(self, Self::Ok(_))
66 }
67
68 pub fn is_err(&self) -> bool {
70 matches!(self, Self::Err(_))
71 }
72
73 pub fn ok(&self) -> Option<&T> {
75 match self {
76 Self::Ok(v) => Some(v),
77 Self::Err(_) => None,
78 }
79 }
80
81 pub fn err(&self) -> Option<&str> {
83 match self {
84 Self::Ok(_) => None,
85 Self::Err(e) => Some(e),
86 }
87 }
88
89 pub fn unwrap(self) -> T {
97 match self {
98 Self::Ok(v) => v,
99 Self::Err(e) => panic!("Query failed: {}", e),
100 }
101 }
102
103 pub fn unwrap_or(self, default: T) -> T {
108 match self {
109 Self::Ok(v) => v,
110 Self::Err(_) => default,
111 }
112 }
113}
114
115#[salsa::tracked]
123pub fn parse_cells_result(
124 db: &dyn salsa::Database,
125 source: SourceFile,
126) -> QueryResult<Vec<CellData>> {
127 let path = source.path(db);
128 let text = source.text(db);
129
130 let mut parser = CellParser::new();
131 match parser.parse_str(&text, &path) {
132 Ok(parse_result) => {
133 QueryResult::Ok(parse_result.code_cells.into_iter().map(CellData::from).collect())
134 }
135 Err(e) => {
136 let error_msg = format!("Failed to parse '{}': {}", path.display(), e);
137 tracing::error!("{}", error_msg);
138 QueryResult::Err(error_msg)
139 }
140 }
141}
142
143#[salsa::tracked]
151pub fn parse_cells(db: &dyn salsa::Database, source: SourceFile) -> Vec<CellData> {
152 parse_cells_result(db, source).unwrap_or(Vec::new())
153}
154
155#[salsa::tracked]
157pub fn cell_names(db: &dyn salsa::Database, source: SourceFile) -> Vec<String> {
158 parse_cells(db, source)
159 .iter()
160 .map(|c| c.name.clone())
161 .collect()
162}
163
164fn build_graph_engine(cells: Vec<CellData>) -> Result<GraphEngine, String> {
169 let mut engine = GraphEngine::new();
170
171 for cell_data in cells {
172 engine.add_cell(cell_data.into());
173 }
174
175 engine.resolve_dependencies().map_err(|e| {
176 format!("Failed to resolve dependencies: {}", e)
177 })?;
178
179 Ok(engine)
180}
181
182#[salsa::tracked]
192pub fn graph_analysis_result(
193 db: &dyn salsa::Database,
194 source: SourceFile,
195) -> QueryResult<super::conversions::GraphAnalysis> {
196 let cells_result = parse_cells_result(db, source);
198 let cells = match cells_result {
199 QueryResult::Ok(c) => c,
200 QueryResult::Err(e) => return QueryResult::Err(e),
201 };
202
203 if cells.is_empty() {
204 return QueryResult::Ok(super::conversions::GraphAnalysis::empty());
205 }
206
207 let engine = match build_graph_engine(cells) {
209 Ok(e) => e,
210 Err(e) => {
211 tracing::error!("{}", e);
212 return QueryResult::Err(e);
213 }
214 };
215
216 let order = match engine.topological_order() {
218 Ok(order) => order,
219 Err(e) => {
220 let error_msg = format!("Failed to compute execution order: {}", e);
221 tracing::error!("{}", error_msg);
222 return QueryResult::Err(error_msg);
223 }
224 };
225
226 let parallel_levels = engine
228 .topological_levels(&order)
229 .into_iter()
230 .map(|level| level.into_iter().map(|id| id.as_usize()).collect())
231 .collect();
232
233 let execution_order = order.into_iter().map(|id| id.as_usize()).collect();
234
235 QueryResult::Ok(super::conversions::GraphAnalysis {
236 execution_order,
237 parallel_levels,
238 })
239}
240
241#[salsa::tracked]
246pub fn graph_analysis(
247 db: &dyn salsa::Database,
248 source: SourceFile,
249) -> super::conversions::GraphAnalysis {
250 graph_analysis_result(db, source).unwrap_or(super::conversions::GraphAnalysis::empty())
251}
252
253#[salsa::tracked]
260pub fn execution_order_result(
261 db: &dyn salsa::Database,
262 source: SourceFile,
263) -> QueryResult<Vec<usize>> {
264 match graph_analysis_result(db, source) {
265 QueryResult::Ok(analysis) => QueryResult::Ok(analysis.execution_order),
266 QueryResult::Err(e) => QueryResult::Err(e),
267 }
268}
269
270#[salsa::tracked]
278pub fn execution_order(db: &dyn salsa::Database, source: SourceFile) -> Vec<usize> {
279 graph_analysis(db, source).execution_order
280}
281
282#[salsa::tracked]
287pub fn invalidated_by(
288 db: &dyn salsa::Database,
289 source: SourceFile,
290 changed_idx: usize,
291) -> Vec<usize> {
292 let cells = parse_cells(db, source);
293
294 let engine = match build_graph_engine(cells) {
295 Ok(e) => e,
296 Err(_) => return Vec::new(),
297 };
298
299 engine
300 .invalidated_cells(CellId::new(changed_idx))
301 .into_iter()
302 .map(|id| id.as_usize())
303 .collect()
304}
305
306#[salsa::tracked]
311pub fn parallel_levels(db: &dyn salsa::Database, source: SourceFile) -> Vec<Vec<usize>> {
312 graph_analysis(db, source).parallel_levels
313}
314
315#[salsa::tracked]
320pub fn dependency_hash(db: &dyn salsa::Database, source: SourceFile) -> u64 {
321 let text = source.text(db);
322
323 let mut parser = DependencyParser::new();
324 parser.parse(&text);
325
326 let mut hasher = DefaultHasher::new();
327
328 for dep in parser.dependencies() {
330 dep.name.hash(&mut hasher);
331 dep.version.hash(&mut hasher);
332 dep.features.hash(&mut hasher);
333 if let Some(path) = &dep.path {
334 path.hash(&mut hasher);
335 }
336 }
337
338 hasher.finish()
339}
340
341#[salsa::tracked]
351pub fn compiled_cell(
352 db: &dyn salsa::Database,
353 source: SourceFile,
354 cell_idx: usize,
355 settings: CompilerSettings,
356) -> CompilationStatus {
357 let cells = parse_cells(db, source);
358
359 let Some(cell_data) = cells.get(cell_idx) else {
361 return CompilationStatus::Failed(format!("Cell index {} not found", cell_idx));
362 };
363
364 let deps_hash = dependency_hash(db, source);
366
367 let cell_info: CellInfo = cell_data.clone().into();
369
370 let config = crate::compile::CompilerConfig {
372 build_dir: settings.build_dir(db),
373 cache_dir: settings.cache_dir(db),
374 use_cranelift: settings.use_cranelift(db),
375 debug_info: true,
376 opt_level: settings.opt_level(db),
377 extra_rustc_flags: Vec::new(),
378 venus_crate_path: crate::compile::CompilerConfig::default().venus_crate_path,
379 };
380
381 let toolchain = match crate::compile::ToolchainManager::new() {
383 Ok(tc) => tc,
384 Err(e) => {
385 return CompilationStatus::Failed(format!("Toolchain error: {}", e));
386 }
387 };
388
389 let mut compiler = crate::compile::CellCompiler::new(config, toolchain);
390
391 if let Some(universe_path) = settings.universe_path(db) {
393 compiler = compiler.with_universe(universe_path);
394 }
395
396 match compiler.compile(&cell_info, deps_hash) {
398 crate::compile::CompilationResult::Success(compiled) => {
399 CompilationStatus::Success(compiled.into())
400 }
401 crate::compile::CompilationResult::Cached(compiled) => {
402 CompilationStatus::Cached(compiled.into())
403 }
404 crate::compile::CompilationResult::Failed { errors, .. } => {
405 let error_msg = errors
406 .iter()
407 .map(|e| e.message.clone())
408 .collect::<Vec<_>>()
409 .join("\n");
410 CompilationStatus::Failed(error_msg)
411 }
412 }
413}
414
415#[salsa::tracked]
421pub fn compile_all_cells(
422 db: &dyn salsa::Database,
423 source: SourceFile,
424 settings: CompilerSettings,
425) -> Arc<Vec<CompilationStatus>> {
426 let order = execution_order(db, source);
427
428 let results: Vec<CompilationStatus> = order
429 .iter()
430 .map(|&idx| compiled_cell(db, source, idx, settings))
431 .collect();
432
433 Arc::new(results)
434}
435
436#[salsa::tracked]
442pub fn cell_output(
443 db: &dyn salsa::Database,
444 outputs: super::inputs::CellOutputs,
445 cell_idx: usize,
446) -> super::conversions::ExecutionStatus {
447 let statuses = outputs.statuses(db);
448 statuses
449 .get(cell_idx)
450 .cloned()
451 .unwrap_or(super::conversions::ExecutionStatus::Pending)
452}
453
454#[salsa::tracked]
458pub fn all_cells_executed(
459 db: &dyn salsa::Database,
460 outputs: super::inputs::CellOutputs,
461) -> bool {
462 let statuses = outputs.statuses(db);
463 statuses.iter().all(|s| {
464 matches!(
465 s,
466 super::conversions::ExecutionStatus::Success(_)
467 | super::conversions::ExecutionStatus::Failed(_)
468 )
469 })
470}
471
472#[salsa::tracked]
477pub fn cell_output_data(
478 db: &dyn salsa::Database,
479 outputs: super::inputs::CellOutputs,
480 cell_idx: usize,
481) -> Option<super::conversions::CellOutputData> {
482 let status = cell_output(db, outputs, cell_idx);
483 status.output().cloned()
484}
485
486#[cfg(test)]
487mod tests {
488 use crate::salsa_db::VenusDatabase;
489 use std::path::PathBuf;
490
491 #[test]
492 fn test_parse_cells_query() {
493 let db = VenusDatabase::new();
494
495 let source = db.set_source(
496 PathBuf::from("test.rs"),
497 r#"
498 #[venus::cell]
499 pub fn a() -> i32 { 1 }
500
501 #[venus::cell]
502 pub fn b(a: &i32) -> i32 { *a + 1 }
503 "#
504 .to_string(),
505 );
506
507 let cells = db.get_cells(source);
508
509 assert_eq!(cells.len(), 2);
510 assert_eq!(cells[0].name, "a");
511 assert_eq!(cells[1].name, "b");
512 }
513
514 #[test]
515 fn test_execution_order_query() {
516 let db = VenusDatabase::new();
517
518 let source = db.set_source(
519 PathBuf::from("test.rs"),
520 r#"
521 #[venus::cell]
522 pub fn a() -> i32 { 1 }
523
524 #[venus::cell]
525 pub fn b(a: &i32) -> i32 { *a + 1 }
526 "#
527 .to_string(),
528 );
529
530 let order = db.get_execution_order(source);
531 assert_eq!(order.len(), 2);
532 assert_eq!(order[0], 0); assert_eq!(order[1], 1); }
536
537 #[test]
538 fn test_invalidated_cells_query() {
539 let db = VenusDatabase::new();
540
541 let source = db.set_source(
542 PathBuf::from("test.rs"),
543 r#"
544 #[venus::cell]
545 pub fn a() -> i32 { 1 }
546
547 #[venus::cell]
548 pub fn b(a: &i32) -> i32 { *a + 1 }
549
550 #[venus::cell]
551 pub fn c(b: &i32) -> i32 { *b + 1 }
552 "#
553 .to_string(),
554 );
555
556 let invalidated = db.get_invalidated(source, 0);
559 assert_eq!(invalidated.len(), 3);
560 assert_eq!(invalidated, vec![0, 1, 2]); }
562
563 #[test]
564 fn test_parallel_levels() {
565 let db = VenusDatabase::new();
566
567 let source = db.set_source(
568 PathBuf::from("test.rs"),
569 r#"
570 #[venus::cell]
571 pub fn a() -> i32 { 1 }
572
573 #[venus::cell]
574 pub fn b() -> i32 { 2 }
575
576 #[venus::cell]
577 pub fn c(a: &i32, b: &i32) -> i32 { *a + *b }
578 "#
579 .to_string(),
580 );
581
582 let levels = db.get_parallel_levels(source);
583 assert_eq!(levels.len(), 2);
584 assert_eq!(levels[0].len(), 2);
586 assert_eq!(levels[1].len(), 1);
588 }
589
590 #[test]
591 fn test_dependency_hash() {
592 let db = VenusDatabase::new();
593
594 let source1 = db.set_source(
596 PathBuf::from("test1.rs"),
597 r#"
598//! ```cargo
599//! [dependencies]
600//! tokio = "1"
601//! serde = { version = "1.0", features = ["derive"] }
602//! ```
603
604#[venus::cell]
605pub fn a() -> i32 { 1 }
606 "#
607 .to_string(),
608 );
609
610 let source2 = db.set_source(
612 PathBuf::from("test2.rs"),
613 r#"
614//! ```cargo
615//! [dependencies]
616//! tokio = "1"
617//! serde = { version = "1.0", features = ["derive"] }
618//! ```
619
620#[venus::cell]
621pub fn b() -> i32 { 2 }
622 "#
623 .to_string(),
624 );
625
626 let source3 = db.set_source(
628 PathBuf::from("test3.rs"),
629 r#"
630//! ```cargo
631//! [dependencies]
632//! anyhow = "1.0"
633//! ```
634
635#[venus::cell]
636pub fn c() -> i32 { 3 }
637 "#
638 .to_string(),
639 );
640
641 let hash1 = db.get_dependency_hash(source1);
642 let hash2 = db.get_dependency_hash(source2);
643 let hash3 = db.get_dependency_hash(source3);
644
645 assert_eq!(hash1, hash2);
647 assert_ne!(hash1, hash3);
649 }
650
651 #[test]
652 fn test_query_result_methods() {
653 use super::QueryResult;
654
655 let ok: QueryResult<i32> = QueryResult::Ok(42);
656 assert!(ok.is_ok());
657 assert!(!ok.is_err());
658 assert_eq!(ok.ok(), Some(&42));
659 assert_eq!(ok.err(), None);
660 assert_eq!(ok.unwrap(), 42);
661
662 let err: QueryResult<i32> = QueryResult::Err("error".to_string());
663 assert!(!err.is_ok());
664 assert!(err.is_err());
665 assert_eq!(err.ok(), None);
666 assert_eq!(err.err(), Some("error"));
667 assert_eq!(err.unwrap_or(0), 0);
668 }
669
670 #[test]
671 fn test_parse_cells_result_success() {
672 let db = VenusDatabase::new();
673
674 let source = db.set_source(
675 PathBuf::from("test.rs"),
676 r#"
677 #[venus::cell]
678 pub fn a() -> i32 { 1 }
679 "#
680 .to_string(),
681 );
682
683 let result = db.get_cells_result(source);
684 assert!(result.is_ok());
685 assert_eq!(result.ok().unwrap().len(), 1);
686 }
687
688 #[test]
689 fn test_execution_order_result_success() {
690 let db = VenusDatabase::new();
691
692 let source = db.set_source(
693 PathBuf::from("test.rs"),
694 r#"
695 #[venus::cell]
696 pub fn a() -> i32 { 1 }
697
698 #[venus::cell]
699 pub fn b(a: &i32) -> i32 { *a + 1 }
700 "#
701 .to_string(),
702 );
703
704 let result = db.get_execution_order_result(source);
705 assert!(result.is_ok());
706 assert_eq!(result.ok().unwrap().len(), 2);
707 }
708}