Skip to main content

torsh_package/
dependency_solver.rs

1//! Advanced dependency resolution using SAT-based constraint solving
2//!
3//! This module implements a sophisticated dependency resolution algorithm
4//! using Boolean Satisfiability (SAT) solving techniques, similar to modern
5//! package managers like cargo, npm, and pip.
6
7use std::collections::{HashMap, HashSet};
8use std::fmt;
9
10use serde::{Deserialize, Serialize};
11use torsh_core::error::{Result, TorshError};
12
13use crate::dependency::DependencySpec;
14
15/// SAT variable representing a package version choice
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct SatVariable(usize);
18
19/// SAT literal (variable or negation)
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct SatLiteral {
22    variable: SatVariable,
23    negated: bool,
24}
25
26/// SAT clause (disjunction of literals)
27#[derive(Debug, Clone)]
28pub struct SatClause {
29    literals: Vec<SatLiteral>,
30}
31
32/// Assignment of variables to boolean values
33#[derive(Debug, Clone)]
34pub struct Assignment {
35    values: HashMap<SatVariable, bool>,
36}
37
38/// Conflict-Driven Clause Learning (CDCL) SAT solver
39#[derive(Debug)]
40pub struct CdclSolver {
41    /// All clauses in the problem
42    clauses: Vec<SatClause>,
43    /// Current partial assignment
44    assignment: Assignment,
45    /// Decision level for each variable
46    decision_levels: HashMap<SatVariable, usize>,
47    /// Current decision level
48    current_level: usize,
49    /// Learned clauses from conflicts
50    learned_clauses: Vec<SatClause>,
51    /// Variable activity scores for decision heuristics
52    activity: HashMap<SatVariable, f64>,
53}
54
55/// Package version constraint in CNF form
56#[derive(Debug, Clone)]
57pub struct VersionConstraint {
58    /// Package name
59    pub package: String,
60    /// Version
61    pub version: String,
62    /// SAT variable representing this version choice
63    pub variable: SatVariable,
64}
65
66/// Dependency constraint solver using SAT
67pub struct DependencySatSolver {
68    /// Mapping of package versions to SAT variables
69    version_vars: HashMap<(String, String), SatVariable>,
70    /// Reverse mapping
71    var_to_version: HashMap<SatVariable, (String, String)>,
72    /// Next available variable ID
73    next_var_id: usize,
74    /// SAT solver instance
75    solver: CdclSolver,
76    /// All available package versions
77    available_versions: HashMap<String, Vec<String>>,
78}
79
80/// Solution to dependency resolution problem
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct DependencySolution {
83    /// Selected package versions
84    pub selected_versions: HashMap<String, String>,
85    /// Installation order (topologically sorted)
86    pub install_order: Vec<String>,
87    /// Conflict information if unsatisfiable
88    pub conflicts: Vec<String>,
89}
90
91impl SatLiteral {
92    /// Create a positive literal
93    pub fn positive(var: SatVariable) -> Self {
94        Self {
95            variable: var,
96            negated: false,
97        }
98    }
99
100    /// Create a negative literal
101    pub fn negative(var: SatVariable) -> Self {
102        Self {
103            variable: var,
104            negated: true,
105        }
106    }
107
108    /// Negate this literal
109    pub fn negate(&self) -> Self {
110        Self {
111            variable: self.variable,
112            negated: !self.negated,
113        }
114    }
115
116    /// Check if literal is satisfied by assignment
117    pub fn is_satisfied(&self, assignment: &Assignment) -> Option<bool> {
118        assignment
119            .get(self.variable)
120            .map(|value| if self.negated { !value } else { value })
121    }
122}
123
124impl SatClause {
125    /// Create a new clause
126    pub fn new(literals: Vec<SatLiteral>) -> Self {
127        Self { literals }
128    }
129
130    /// Check if clause is satisfied by assignment
131    pub fn is_satisfied(&self, assignment: &Assignment) -> bool {
132        self.literals
133            .iter()
134            .any(|lit| lit.is_satisfied(assignment) == Some(true))
135    }
136
137    /// Check if clause is conflicting (all literals false)
138    pub fn is_conflicting(&self, assignment: &Assignment) -> bool {
139        self.literals
140            .iter()
141            .all(|lit| lit.is_satisfied(assignment) == Some(false))
142    }
143
144    /// Get unit literal if clause is unit (all but one literal assigned false)
145    pub fn get_unit_literal(&self, assignment: &Assignment) -> Option<SatLiteral> {
146        let mut unassigned = None;
147        let mut unassigned_count = 0;
148
149        for literal in &self.literals {
150            match literal.is_satisfied(assignment) {
151                Some(true) => return None, // Clause already satisfied
152                Some(false) => continue,   // Literal is false
153                None => {
154                    unassigned = Some(*literal);
155                    unassigned_count += 1;
156                    if unassigned_count > 1 {
157                        return None; // More than one unassigned
158                    }
159                }
160            }
161        }
162
163        if unassigned_count == 1 {
164            unassigned
165        } else {
166            None
167        }
168    }
169}
170
171impl Assignment {
172    /// Create an empty assignment
173    pub fn new() -> Self {
174        Self {
175            values: HashMap::new(),
176        }
177    }
178
179    /// Get the value of a variable
180    pub fn get(&self, var: SatVariable) -> Option<bool> {
181        self.values.get(&var).copied()
182    }
183
184    /// Set the value of a variable
185    pub fn set(&mut self, var: SatVariable, value: bool) {
186        self.values.insert(var, value);
187    }
188
189    /// Unset a variable
190    pub fn unset(&mut self, var: SatVariable) {
191        self.values.remove(&var);
192    }
193
194    /// Check if variable is assigned
195    pub fn is_assigned(&self, var: SatVariable) -> bool {
196        self.values.contains_key(&var)
197    }
198}
199
200impl Default for Assignment {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206impl CdclSolver {
207    /// Create a new CDCL solver
208    pub fn new() -> Self {
209        Self {
210            clauses: Vec::new(),
211            assignment: Assignment::new(),
212            decision_levels: HashMap::new(),
213            current_level: 0,
214            learned_clauses: Vec::new(),
215            activity: HashMap::new(),
216        }
217    }
218
219    /// Add a clause to the solver
220    pub fn add_clause(&mut self, clause: SatClause) {
221        // Update activity scores for variables in the clause
222        for literal in &clause.literals {
223            *self.activity.entry(literal.variable).or_insert(0.0) += 1.0;
224        }
225        self.clauses.push(clause);
226    }
227
228    /// Solve the SAT problem using CDCL algorithm
229    pub fn solve(&mut self) -> Result<bool> {
230        // Initial unit propagation
231        if self.unit_propagate()? {
232            return Ok(false); // Conflict at decision level 0 - UNSAT
233        }
234
235        loop {
236            // Check if all variables are assigned
237            if self.is_complete() {
238                return Ok(true); // SAT
239            }
240
241            // Make a decision
242            let decision_var = self.choose_decision_variable();
243            self.current_level += 1;
244            self.assign(decision_var, true, self.current_level);
245
246            // Propagate and handle conflicts
247            loop {
248                if self.unit_propagate()? {
249                    // Conflict occurred
250                    if self.current_level == 0 {
251                        return Ok(false); // UNSAT
252                    }
253
254                    // Analyze conflict and learn clause
255                    let (learned_clause, backtrack_level) = self.analyze_conflict()?;
256                    self.learned_clauses.push(learned_clause.clone());
257                    self.add_clause(learned_clause);
258
259                    // Backtrack
260                    self.backtrack(backtrack_level);
261                } else {
262                    break; // No conflict
263                }
264            }
265        }
266    }
267
268    /// Unit propagation: propagate all unit clauses
269    fn unit_propagate(&mut self) -> Result<bool> {
270        loop {
271            let mut propagated = false;
272
273            // Collect unit literals first to avoid borrow checker issues
274            let mut unit_literals = Vec::new();
275            let mut conflicts = Vec::new();
276
277            // Check all clauses for unit clauses
278            for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
279                if let Some(unit_literal) = clause.get_unit_literal(&self.assignment) {
280                    unit_literals.push(unit_literal);
281                } else if clause.is_conflicting(&self.assignment) {
282                    conflicts.push(true);
283                }
284            }
285
286            // Apply assignments after iteration
287            for unit_literal in unit_literals {
288                self.assign(
289                    unit_literal.variable,
290                    !unit_literal.negated,
291                    self.current_level,
292                );
293                propagated = true;
294            }
295
296            // Check for conflicts
297            if !conflicts.is_empty() {
298                return Ok(true); // Conflict
299            }
300
301            if !propagated {
302                break;
303            }
304        }
305
306        Ok(false) // No conflict
307    }
308
309    /// Assign a variable at a decision level
310    fn assign(&mut self, var: SatVariable, value: bool, level: usize) {
311        self.assignment.set(var, value);
312        self.decision_levels.insert(var, level);
313    }
314
315    /// Check if assignment is complete
316    fn is_complete(&self) -> bool {
317        // Get all variables from clauses
318        let mut all_vars = HashSet::new();
319        for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
320            for literal in &clause.literals {
321                all_vars.insert(literal.variable);
322            }
323        }
324
325        all_vars.iter().all(|var| self.assignment.is_assigned(*var))
326    }
327
328    /// Choose next decision variable using activity heuristics
329    fn choose_decision_variable(&self) -> SatVariable {
330        // Get all variables
331        let mut unassigned_vars: Vec<_> = self
332            .activity
333            .iter()
334            .filter(|(var, _)| !self.assignment.is_assigned(**var))
335            .collect();
336
337        if unassigned_vars.is_empty() {
338            // Fallback: find any unassigned variable from clauses
339            for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
340                for literal in &clause.literals {
341                    if !self.assignment.is_assigned(literal.variable) {
342                        return literal.variable;
343                    }
344                }
345            }
346            // Should never reach here if is_complete() check is correct
347            panic!("No unassigned variables found");
348        }
349
350        // Sort by activity (highest first)
351        unassigned_vars.sort_by(|a, b| {
352            b.1.partial_cmp(a.1)
353                .expect("Activity values should be valid floats (not NaN)")
354        });
355
356        *unassigned_vars[0].0
357    }
358
359    /// Analyze conflict and learn a new clause
360    fn analyze_conflict(&self) -> Result<(SatClause, usize)> {
361        // Simplified conflict analysis - in production, use 1UIP analysis
362        // For now, just learn a clause that prevents the current decision
363
364        let mut learned_literals = Vec::new();
365        let mut backtrack_level = 0;
366
367        // Find variables assigned at current level
368        for (var, level) in &self.decision_levels {
369            if *level == self.current_level {
370                if let Some(value) = self.assignment.get(*var) {
371                    learned_literals.push(if value {
372                        SatLiteral::negative(*var)
373                    } else {
374                        SatLiteral::positive(*var)
375                    });
376                }
377            } else if *level > backtrack_level {
378                backtrack_level = *level;
379            }
380        }
381
382        if learned_literals.is_empty() {
383            // Add at least one literal to prevent empty clause
384            for (var, _) in &self.decision_levels {
385                if let Some(value) = self.assignment.get(*var) {
386                    learned_literals.push(if value {
387                        SatLiteral::negative(*var)
388                    } else {
389                        SatLiteral::positive(*var)
390                    });
391                    break;
392                }
393            }
394        }
395
396        Ok((
397            SatClause::new(learned_literals),
398            backtrack_level.saturating_sub(1),
399        ))
400    }
401
402    /// Backtrack to a decision level
403    fn backtrack(&mut self, level: usize) {
404        // Remove assignments made after the backtrack level
405        let vars_to_remove: Vec<_> = self
406            .decision_levels
407            .iter()
408            .filter(|(_, &var_level)| var_level > level)
409            .map(|(var, _)| *var)
410            .collect();
411
412        for var in vars_to_remove {
413            self.assignment.unset(var);
414            self.decision_levels.remove(&var);
415        }
416
417        self.current_level = level;
418    }
419
420    /// Get the current assignment
421    pub fn get_assignment(&self) -> &Assignment {
422        &self.assignment
423    }
424}
425
426impl Default for CdclSolver {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432impl DependencySatSolver {
433    /// Create a new dependency SAT solver
434    pub fn new() -> Self {
435        Self {
436            version_vars: HashMap::new(),
437            var_to_version: HashMap::new(),
438            next_var_id: 0,
439            solver: CdclSolver::new(),
440            available_versions: HashMap::new(),
441        }
442    }
443
444    /// Get or create a SAT variable for a package version
445    fn get_or_create_variable(&mut self, package: &str, version: &str) -> SatVariable {
446        let key = (package.to_string(), version.to_string());
447        if let Some(&var) = self.version_vars.get(&key) {
448            return var;
449        }
450
451        let var = SatVariable(self.next_var_id);
452        self.next_var_id += 1;
453        self.version_vars.insert(key.clone(), var);
454        self.var_to_version.insert(var, key);
455        var
456    }
457
458    /// Add available versions for a package
459    pub fn add_available_versions(&mut self, package: &str, versions: Vec<String>) {
460        self.available_versions
461            .insert(package.to_string(), versions.clone());
462
463        // Create variables for all versions
464        for version in &versions {
465            self.get_or_create_variable(package, version);
466        }
467
468        // Add constraint: exactly one version must be selected (if package is needed)
469        // This is encoded as: at most one version (pairwise conflicts)
470        for i in 0..versions.len() {
471            for j in (i + 1)..versions.len() {
472                let var_i = self.get_or_create_variable(package, &versions[i]);
473                let var_j = self.get_or_create_variable(package, &versions[j]);
474
475                // Clause: NOT var_i OR NOT var_j (can't have both versions)
476                self.solver.add_clause(SatClause::new(vec![
477                    SatLiteral::negative(var_i),
478                    SatLiteral::negative(var_j),
479                ]));
480            }
481        }
482    }
483
484    /// Add dependency constraint: if package A version X is selected,
485    /// then one of the compatible versions of package B must be selected
486    pub fn add_dependency_constraint(
487        &mut self,
488        package: &str,
489        version: &str,
490        dep_spec: &DependencySpec,
491    ) -> Result<()> {
492        let package_var = self.get_or_create_variable(package, version);
493
494        // Find all compatible versions of the dependency (clone to avoid borrow issues)
495        let dep_versions = self
496            .available_versions
497            .get(&dep_spec.name)
498            .ok_or_else(|| {
499                TorshError::InvalidArgument(format!(
500                    "No versions available for dependency: {}",
501                    dep_spec.name
502                ))
503            })?
504            .clone();
505
506        let mut compatible_vars = Vec::new();
507        for dep_version in &dep_versions {
508            if dep_spec.is_satisfied_by(dep_version)? {
509                let dep_var = self.get_or_create_variable(&dep_spec.name, dep_version);
510                compatible_vars.push(dep_var);
511            }
512        }
513
514        if compatible_vars.is_empty() {
515            return Err(TorshError::InvalidArgument(format!(
516                "No compatible versions found for dependency: {} with requirement: {}",
517                dep_spec.name, dep_spec.version_req
518            )));
519        }
520
521        // Add clause: NOT package_var OR dep_var1 OR dep_var2 OR ...
522        // Meaning: if this package version is selected, at least one compatible dep version must be selected
523        let mut clause_literals = vec![SatLiteral::negative(package_var)];
524        for dep_var in compatible_vars {
525            clause_literals.push(SatLiteral::positive(dep_var));
526        }
527
528        self.solver.add_clause(SatClause::new(clause_literals));
529        Ok(())
530    }
531
532    /// Add root dependency constraint: at least one version of the root package must be selected
533    pub fn add_root_constraint(&mut self, package: &str) -> Result<()> {
534        let versions = self
535            .available_versions
536            .get(package)
537            .ok_or_else(|| {
538                TorshError::InvalidArgument(format!(
539                    "No versions available for root package: {}",
540                    package
541                ))
542            })?
543            .clone();
544
545        // Add clause: var1 OR var2 OR ... (at least one version must be selected)
546        let clause_literals: Vec<_> = versions
547            .iter()
548            .map(|v| {
549                let var = self.get_or_create_variable(package, v);
550                SatLiteral::positive(var)
551            })
552            .collect();
553
554        self.solver.add_clause(SatClause::new(clause_literals));
555        Ok(())
556    }
557
558    /// Solve the dependency constraints
559    pub fn solve(&mut self) -> Result<DependencySolution> {
560        let is_sat = self.solver.solve()?;
561
562        if !is_sat {
563            return Ok(DependencySolution {
564                selected_versions: HashMap::new(),
565                install_order: Vec::new(),
566                conflicts: vec!["Dependency constraints are unsatisfiable".to_string()],
567            });
568        }
569
570        // Extract solution from assignment
571        let assignment = self.solver.get_assignment();
572        let mut selected_versions = HashMap::new();
573
574        for (var, &value) in &assignment.values {
575            if value {
576                if let Some((package, version)) = self.var_to_version.get(var) {
577                    selected_versions.insert(package.clone(), version.clone());
578                }
579            }
580        }
581
582        // Compute topological order for installation
583        let install_order = self.compute_install_order(&selected_versions)?;
584
585        Ok(DependencySolution {
586            selected_versions,
587            install_order,
588            conflicts: Vec::new(),
589        })
590    }
591
592    /// Compute installation order using topological sort
593    fn compute_install_order(&self, selected: &HashMap<String, String>) -> Result<Vec<String>> {
594        // Simple topological sort - in production, use the dependency graph
595        let mut order: Vec<_> = selected.keys().cloned().collect();
596        order.sort(); // Simplified - should use actual dependency order
597        Ok(order)
598    }
599}
600
601impl Default for DependencySatSolver {
602    fn default() -> Self {
603        Self::new()
604    }
605}
606
607impl fmt::Display for DependencySolution {
608    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609        if !self.conflicts.is_empty() {
610            writeln!(f, "Dependency resolution failed:")?;
611            for conflict in &self.conflicts {
612                writeln!(f, "  - {}", conflict)?;
613            }
614            return Ok(());
615        }
616
617        writeln!(f, "Dependency resolution successful:")?;
618        writeln!(f, "Selected versions:")?;
619        for (package, version) in &self.selected_versions {
620            writeln!(f, "  {} = {}", package, version)?;
621        }
622        writeln!(f, "Installation order:")?;
623        for (i, package) in self.install_order.iter().enumerate() {
624            writeln!(f, "  {}. {}", i + 1, package)?;
625        }
626        Ok(())
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn test_sat_literal() {
636        let var = SatVariable(0);
637        let pos = SatLiteral::positive(var);
638        let neg = SatLiteral::negative(var);
639
640        assert!(!pos.negated);
641        assert!(neg.negated);
642        assert_eq!(pos.negate(), neg);
643    }
644
645    #[test]
646    fn test_sat_clause_satisfaction() {
647        let var1 = SatVariable(0);
648        let var2 = SatVariable(1);
649
650        let clause = SatClause::new(vec![SatLiteral::positive(var1), SatLiteral::negative(var2)]);
651
652        let mut assignment = Assignment::new();
653        assignment.set(var1, true);
654        assignment.set(var2, false);
655
656        assert!(clause.is_satisfied(&assignment));
657    }
658
659    #[test]
660    fn test_simple_sat_solving() {
661        let mut solver = CdclSolver::new();
662
663        let var1 = SatVariable(0);
664        let var2 = SatVariable(1);
665
666        // Add clause: var1 OR var2
667        solver.add_clause(SatClause::new(vec![
668            SatLiteral::positive(var1),
669            SatLiteral::positive(var2),
670        ]));
671
672        // Add clause: NOT var1 OR var2 (if var1 then var2)
673        solver.add_clause(SatClause::new(vec![
674            SatLiteral::negative(var1),
675            SatLiteral::positive(var2),
676        ]));
677
678        let result = solver.solve().unwrap();
679        assert!(result); // Should be SAT
680    }
681
682    #[test]
683    fn test_dependency_sat_solver() {
684        let mut solver = DependencySatSolver::new();
685
686        // Add package A with versions 1.0.0 and 2.0.0
687        solver.add_available_versions("pkg-a", vec!["1.0.0".to_string(), "2.0.0".to_string()]);
688
689        // Add package B with versions 1.0.0
690        solver.add_available_versions("pkg-b", vec!["1.0.0".to_string()]);
691
692        // Package A 1.0.0 depends on B ^1.0.0
693        let dep_spec = DependencySpec::new("pkg-b".to_string(), "^1.0.0".to_string());
694        solver
695            .add_dependency_constraint("pkg-a", "1.0.0", &dep_spec)
696            .unwrap();
697
698        // We want package A
699        solver.add_root_constraint("pkg-a").unwrap();
700
701        let solution = solver.solve().unwrap();
702        assert!(solution.conflicts.is_empty());
703        assert!(solution.selected_versions.contains_key("pkg-a"));
704    }
705
706    #[test]
707    fn test_version_conflict_detection() {
708        let mut solver = DependencySatSolver::new();
709
710        // Add package A with version 1.0.0
711        solver.add_available_versions("pkg-a", vec!["1.0.0".to_string()]);
712
713        // Add package B with version 1.0.0 that requires A 2.0.0 (not available)
714        solver.add_available_versions("pkg-b", vec!["1.0.0".to_string()]);
715
716        let dep_spec = DependencySpec::new("pkg-a".to_string(), "^2.0.0".to_string());
717        let result = solver.add_dependency_constraint("pkg-b", "1.0.0", &dep_spec);
718
719        // Should fail because no compatible version of A exists
720        assert!(result.is_err());
721    }
722}