1use std::collections::{HashMap, HashSet};
8use std::fmt;
9
10use serde::{Deserialize, Serialize};
11use torsh_core::error::{Result, TorshError};
12
13use crate::dependency::DependencySpec;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct SatVariable(usize);
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct SatLiteral {
22 variable: SatVariable,
23 negated: bool,
24}
25
26#[derive(Debug, Clone)]
28pub struct SatClause {
29 literals: Vec<SatLiteral>,
30}
31
32#[derive(Debug, Clone)]
34pub struct Assignment {
35 values: HashMap<SatVariable, bool>,
36}
37
38#[derive(Debug)]
40pub struct CdclSolver {
41 clauses: Vec<SatClause>,
43 assignment: Assignment,
45 decision_levels: HashMap<SatVariable, usize>,
47 current_level: usize,
49 learned_clauses: Vec<SatClause>,
51 activity: HashMap<SatVariable, f64>,
53}
54
55#[derive(Debug, Clone)]
57pub struct VersionConstraint {
58 pub package: String,
60 pub version: String,
62 pub variable: SatVariable,
64}
65
66pub struct DependencySatSolver {
68 version_vars: HashMap<(String, String), SatVariable>,
70 var_to_version: HashMap<SatVariable, (String, String)>,
72 next_var_id: usize,
74 solver: CdclSolver,
76 available_versions: HashMap<String, Vec<String>>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct DependencySolution {
83 pub selected_versions: HashMap<String, String>,
85 pub install_order: Vec<String>,
87 pub conflicts: Vec<String>,
89}
90
91impl SatLiteral {
92 pub fn positive(var: SatVariable) -> Self {
94 Self {
95 variable: var,
96 negated: false,
97 }
98 }
99
100 pub fn negative(var: SatVariable) -> Self {
102 Self {
103 variable: var,
104 negated: true,
105 }
106 }
107
108 pub fn negate(&self) -> Self {
110 Self {
111 variable: self.variable,
112 negated: !self.negated,
113 }
114 }
115
116 pub fn is_satisfied(&self, assignment: &Assignment) -> Option<bool> {
118 assignment
119 .get(self.variable)
120 .map(|value| if self.negated { !value } else { value })
121 }
122}
123
124impl SatClause {
125 pub fn new(literals: Vec<SatLiteral>) -> Self {
127 Self { literals }
128 }
129
130 pub fn is_satisfied(&self, assignment: &Assignment) -> bool {
132 self.literals
133 .iter()
134 .any(|lit| lit.is_satisfied(assignment) == Some(true))
135 }
136
137 pub fn is_conflicting(&self, assignment: &Assignment) -> bool {
139 self.literals
140 .iter()
141 .all(|lit| lit.is_satisfied(assignment) == Some(false))
142 }
143
144 pub fn get_unit_literal(&self, assignment: &Assignment) -> Option<SatLiteral> {
146 let mut unassigned = None;
147 let mut unassigned_count = 0;
148
149 for literal in &self.literals {
150 match literal.is_satisfied(assignment) {
151 Some(true) => return None, Some(false) => continue, None => {
154 unassigned = Some(*literal);
155 unassigned_count += 1;
156 if unassigned_count > 1 {
157 return None; }
159 }
160 }
161 }
162
163 if unassigned_count == 1 {
164 unassigned
165 } else {
166 None
167 }
168 }
169}
170
171impl Assignment {
172 pub fn new() -> Self {
174 Self {
175 values: HashMap::new(),
176 }
177 }
178
179 pub fn get(&self, var: SatVariable) -> Option<bool> {
181 self.values.get(&var).copied()
182 }
183
184 pub fn set(&mut self, var: SatVariable, value: bool) {
186 self.values.insert(var, value);
187 }
188
189 pub fn unset(&mut self, var: SatVariable) {
191 self.values.remove(&var);
192 }
193
194 pub fn is_assigned(&self, var: SatVariable) -> bool {
196 self.values.contains_key(&var)
197 }
198}
199
200impl Default for Assignment {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl CdclSolver {
207 pub fn new() -> Self {
209 Self {
210 clauses: Vec::new(),
211 assignment: Assignment::new(),
212 decision_levels: HashMap::new(),
213 current_level: 0,
214 learned_clauses: Vec::new(),
215 activity: HashMap::new(),
216 }
217 }
218
219 pub fn add_clause(&mut self, clause: SatClause) {
221 for literal in &clause.literals {
223 *self.activity.entry(literal.variable).or_insert(0.0) += 1.0;
224 }
225 self.clauses.push(clause);
226 }
227
228 pub fn solve(&mut self) -> Result<bool> {
230 if self.unit_propagate()? {
232 return Ok(false); }
234
235 loop {
236 if self.is_complete() {
238 return Ok(true); }
240
241 let decision_var = self.choose_decision_variable();
243 self.current_level += 1;
244 self.assign(decision_var, true, self.current_level);
245
246 loop {
248 if self.unit_propagate()? {
249 if self.current_level == 0 {
251 return Ok(false); }
253
254 let (learned_clause, backtrack_level) = self.analyze_conflict()?;
256 self.learned_clauses.push(learned_clause.clone());
257 self.add_clause(learned_clause);
258
259 self.backtrack(backtrack_level);
261 } else {
262 break; }
264 }
265 }
266 }
267
268 fn unit_propagate(&mut self) -> Result<bool> {
270 loop {
271 let mut propagated = false;
272
273 let mut unit_literals = Vec::new();
275 let mut conflicts = Vec::new();
276
277 for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
279 if let Some(unit_literal) = clause.get_unit_literal(&self.assignment) {
280 unit_literals.push(unit_literal);
281 } else if clause.is_conflicting(&self.assignment) {
282 conflicts.push(true);
283 }
284 }
285
286 for unit_literal in unit_literals {
288 self.assign(
289 unit_literal.variable,
290 !unit_literal.negated,
291 self.current_level,
292 );
293 propagated = true;
294 }
295
296 if !conflicts.is_empty() {
298 return Ok(true); }
300
301 if !propagated {
302 break;
303 }
304 }
305
306 Ok(false) }
308
309 fn assign(&mut self, var: SatVariable, value: bool, level: usize) {
311 self.assignment.set(var, value);
312 self.decision_levels.insert(var, level);
313 }
314
315 fn is_complete(&self) -> bool {
317 let mut all_vars = HashSet::new();
319 for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
320 for literal in &clause.literals {
321 all_vars.insert(literal.variable);
322 }
323 }
324
325 all_vars.iter().all(|var| self.assignment.is_assigned(*var))
326 }
327
328 fn choose_decision_variable(&self) -> SatVariable {
330 let mut unassigned_vars: Vec<_> = self
332 .activity
333 .iter()
334 .filter(|(var, _)| !self.assignment.is_assigned(**var))
335 .collect();
336
337 if unassigned_vars.is_empty() {
338 for clause in self.clauses.iter().chain(self.learned_clauses.iter()) {
340 for literal in &clause.literals {
341 if !self.assignment.is_assigned(literal.variable) {
342 return literal.variable;
343 }
344 }
345 }
346 panic!("No unassigned variables found");
348 }
349
350 unassigned_vars.sort_by(|a, b| {
352 b.1.partial_cmp(a.1)
353 .expect("Activity values should be valid floats (not NaN)")
354 });
355
356 *unassigned_vars[0].0
357 }
358
359 fn analyze_conflict(&self) -> Result<(SatClause, usize)> {
361 let mut learned_literals = Vec::new();
365 let mut backtrack_level = 0;
366
367 for (var, level) in &self.decision_levels {
369 if *level == self.current_level {
370 if let Some(value) = self.assignment.get(*var) {
371 learned_literals.push(if value {
372 SatLiteral::negative(*var)
373 } else {
374 SatLiteral::positive(*var)
375 });
376 }
377 } else if *level > backtrack_level {
378 backtrack_level = *level;
379 }
380 }
381
382 if learned_literals.is_empty() {
383 for (var, _) in &self.decision_levels {
385 if let Some(value) = self.assignment.get(*var) {
386 learned_literals.push(if value {
387 SatLiteral::negative(*var)
388 } else {
389 SatLiteral::positive(*var)
390 });
391 break;
392 }
393 }
394 }
395
396 Ok((
397 SatClause::new(learned_literals),
398 backtrack_level.saturating_sub(1),
399 ))
400 }
401
402 fn backtrack(&mut self, level: usize) {
404 let vars_to_remove: Vec<_> = self
406 .decision_levels
407 .iter()
408 .filter(|(_, &var_level)| var_level > level)
409 .map(|(var, _)| *var)
410 .collect();
411
412 for var in vars_to_remove {
413 self.assignment.unset(var);
414 self.decision_levels.remove(&var);
415 }
416
417 self.current_level = level;
418 }
419
420 pub fn get_assignment(&self) -> &Assignment {
422 &self.assignment
423 }
424}
425
426impl Default for CdclSolver {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432impl DependencySatSolver {
433 pub fn new() -> Self {
435 Self {
436 version_vars: HashMap::new(),
437 var_to_version: HashMap::new(),
438 next_var_id: 0,
439 solver: CdclSolver::new(),
440 available_versions: HashMap::new(),
441 }
442 }
443
444 fn get_or_create_variable(&mut self, package: &str, version: &str) -> SatVariable {
446 let key = (package.to_string(), version.to_string());
447 if let Some(&var) = self.version_vars.get(&key) {
448 return var;
449 }
450
451 let var = SatVariable(self.next_var_id);
452 self.next_var_id += 1;
453 self.version_vars.insert(key.clone(), var);
454 self.var_to_version.insert(var, key);
455 var
456 }
457
458 pub fn add_available_versions(&mut self, package: &str, versions: Vec<String>) {
460 self.available_versions
461 .insert(package.to_string(), versions.clone());
462
463 for version in &versions {
465 self.get_or_create_variable(package, version);
466 }
467
468 for i in 0..versions.len() {
471 for j in (i + 1)..versions.len() {
472 let var_i = self.get_or_create_variable(package, &versions[i]);
473 let var_j = self.get_or_create_variable(package, &versions[j]);
474
475 self.solver.add_clause(SatClause::new(vec![
477 SatLiteral::negative(var_i),
478 SatLiteral::negative(var_j),
479 ]));
480 }
481 }
482 }
483
484 pub fn add_dependency_constraint(
487 &mut self,
488 package: &str,
489 version: &str,
490 dep_spec: &DependencySpec,
491 ) -> Result<()> {
492 let package_var = self.get_or_create_variable(package, version);
493
494 let dep_versions = self
496 .available_versions
497 .get(&dep_spec.name)
498 .ok_or_else(|| {
499 TorshError::InvalidArgument(format!(
500 "No versions available for dependency: {}",
501 dep_spec.name
502 ))
503 })?
504 .clone();
505
506 let mut compatible_vars = Vec::new();
507 for dep_version in &dep_versions {
508 if dep_spec.is_satisfied_by(dep_version)? {
509 let dep_var = self.get_or_create_variable(&dep_spec.name, dep_version);
510 compatible_vars.push(dep_var);
511 }
512 }
513
514 if compatible_vars.is_empty() {
515 return Err(TorshError::InvalidArgument(format!(
516 "No compatible versions found for dependency: {} with requirement: {}",
517 dep_spec.name, dep_spec.version_req
518 )));
519 }
520
521 let mut clause_literals = vec![SatLiteral::negative(package_var)];
524 for dep_var in compatible_vars {
525 clause_literals.push(SatLiteral::positive(dep_var));
526 }
527
528 self.solver.add_clause(SatClause::new(clause_literals));
529 Ok(())
530 }
531
532 pub fn add_root_constraint(&mut self, package: &str) -> Result<()> {
534 let versions = self
535 .available_versions
536 .get(package)
537 .ok_or_else(|| {
538 TorshError::InvalidArgument(format!(
539 "No versions available for root package: {}",
540 package
541 ))
542 })?
543 .clone();
544
545 let clause_literals: Vec<_> = versions
547 .iter()
548 .map(|v| {
549 let var = self.get_or_create_variable(package, v);
550 SatLiteral::positive(var)
551 })
552 .collect();
553
554 self.solver.add_clause(SatClause::new(clause_literals));
555 Ok(())
556 }
557
558 pub fn solve(&mut self) -> Result<DependencySolution> {
560 let is_sat = self.solver.solve()?;
561
562 if !is_sat {
563 return Ok(DependencySolution {
564 selected_versions: HashMap::new(),
565 install_order: Vec::new(),
566 conflicts: vec!["Dependency constraints are unsatisfiable".to_string()],
567 });
568 }
569
570 let assignment = self.solver.get_assignment();
572 let mut selected_versions = HashMap::new();
573
574 for (var, &value) in &assignment.values {
575 if value {
576 if let Some((package, version)) = self.var_to_version.get(var) {
577 selected_versions.insert(package.clone(), version.clone());
578 }
579 }
580 }
581
582 let install_order = self.compute_install_order(&selected_versions)?;
584
585 Ok(DependencySolution {
586 selected_versions,
587 install_order,
588 conflicts: Vec::new(),
589 })
590 }
591
592 fn compute_install_order(&self, selected: &HashMap<String, String>) -> Result<Vec<String>> {
594 let mut order: Vec<_> = selected.keys().cloned().collect();
596 order.sort(); Ok(order)
598 }
599}
600
601impl Default for DependencySatSolver {
602 fn default() -> Self {
603 Self::new()
604 }
605}
606
607impl fmt::Display for DependencySolution {
608 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609 if !self.conflicts.is_empty() {
610 writeln!(f, "Dependency resolution failed:")?;
611 for conflict in &self.conflicts {
612 writeln!(f, " - {}", conflict)?;
613 }
614 return Ok(());
615 }
616
617 writeln!(f, "Dependency resolution successful:")?;
618 writeln!(f, "Selected versions:")?;
619 for (package, version) in &self.selected_versions {
620 writeln!(f, " {} = {}", package, version)?;
621 }
622 writeln!(f, "Installation order:")?;
623 for (i, package) in self.install_order.iter().enumerate() {
624 writeln!(f, " {}. {}", i + 1, package)?;
625 }
626 Ok(())
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_sat_literal() {
636 let var = SatVariable(0);
637 let pos = SatLiteral::positive(var);
638 let neg = SatLiteral::negative(var);
639
640 assert!(!pos.negated);
641 assert!(neg.negated);
642 assert_eq!(pos.negate(), neg);
643 }
644
645 #[test]
646 fn test_sat_clause_satisfaction() {
647 let var1 = SatVariable(0);
648 let var2 = SatVariable(1);
649
650 let clause = SatClause::new(vec![SatLiteral::positive(var1), SatLiteral::negative(var2)]);
651
652 let mut assignment = Assignment::new();
653 assignment.set(var1, true);
654 assignment.set(var2, false);
655
656 assert!(clause.is_satisfied(&assignment));
657 }
658
659 #[test]
660 fn test_simple_sat_solving() {
661 let mut solver = CdclSolver::new();
662
663 let var1 = SatVariable(0);
664 let var2 = SatVariable(1);
665
666 solver.add_clause(SatClause::new(vec![
668 SatLiteral::positive(var1),
669 SatLiteral::positive(var2),
670 ]));
671
672 solver.add_clause(SatClause::new(vec![
674 SatLiteral::negative(var1),
675 SatLiteral::positive(var2),
676 ]));
677
678 let result = solver.solve().unwrap();
679 assert!(result); }
681
682 #[test]
683 fn test_dependency_sat_solver() {
684 let mut solver = DependencySatSolver::new();
685
686 solver.add_available_versions("pkg-a", vec!["1.0.0".to_string(), "2.0.0".to_string()]);
688
689 solver.add_available_versions("pkg-b", vec!["1.0.0".to_string()]);
691
692 let dep_spec = DependencySpec::new("pkg-b".to_string(), "^1.0.0".to_string());
694 solver
695 .add_dependency_constraint("pkg-a", "1.0.0", &dep_spec)
696 .unwrap();
697
698 solver.add_root_constraint("pkg-a").unwrap();
700
701 let solution = solver.solve().unwrap();
702 assert!(solution.conflicts.is_empty());
703 assert!(solution.selected_versions.contains_key("pkg-a"));
704 }
705
706 #[test]
707 fn test_version_conflict_detection() {
708 let mut solver = DependencySatSolver::new();
709
710 solver.add_available_versions("pkg-a", vec!["1.0.0".to_string()]);
712
713 solver.add_available_versions("pkg-b", vec!["1.0.0".to_string()]);
715
716 let dep_spec = DependencySpec::new("pkg-a".to_string(), "^2.0.0".to_string());
717 let result = solver.add_dependency_constraint("pkg-b", "1.0.0", &dep_spec);
718
719 assert!(result.is_err());
721 }
722}