solverforge_core/solver/
manager.rs

1//! SolverManager for managing multiple concurrent solves.
2//!
3//! Provides a high-level API for managing multiple planning problems simultaneously,
4//! similar to Timefold's SolverManager.
5
6use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use crate::error::{SolverForgeError, SolverForgeResult};
11use crate::solver::{
12    ListAccessorDto, SolveHandle, SolveRequest, SolveState, SolverConfig, SolverService,
13    TerminationConfig,
14};
15use crate::traits::PlanningSolution;
16use crate::wasm::{Expression, HostFunctionRegistry, WasmModuleBuilder};
17
18/// Manages multiple concurrent solves for planning problems.
19///
20/// Each solve is identified by a unique `ProblemId` that allows tracking
21/// and managing individual solves independently.
22///
23/// # Example
24///
25/// ```ignore
26/// let service = Arc::new(HttpSolverService::new("http://localhost:8080"));
27/// let mut manager = SolverManager::<Timetable, String>::new(service)
28///     .with_termination(TerminationConfig::new().with_spent_limit("PT5M"));
29///
30/// manager.solve("problem-1".to_string(), problem1)?;
31/// manager.solve("problem-2".to_string(), problem2)?;
32///
33/// // Check solutions later
34/// if let Some(solution) = manager.get_best_solution(&"problem-1".to_string())? {
35///     println!("Best score: {:?}", solution.score());
36/// }
37///
38/// manager.terminate_all();
39/// ```
40pub struct SolverManager<S: PlanningSolution, ProblemId: Eq + Hash + Clone> {
41    active_solves: HashMap<ProblemId, ManagedSolve<S>>,
42    service: Arc<dyn SolverService>,
43    config: SolverConfig,
44    /// Cascading update expressions: (class_name, field_name, expression)
45    cascading_expressions: Vec<(String, String, Expression)>,
46}
47
48struct ManagedSolve<S: PlanningSolution> {
49    handle: SolveHandle,
50    _phantom: std::marker::PhantomData<S>,
51}
52
53impl<S, I> SolverManager<S, I>
54where
55    S: PlanningSolution + Clone,
56    I: Eq + Hash + Clone,
57{
58    /// Creates a new SolverManager with the given solver service.
59    pub fn new(service: Arc<dyn SolverService>) -> Self {
60        Self {
61            active_solves: HashMap::new(),
62            service,
63            config: SolverConfig::default(),
64            cascading_expressions: Vec::new(),
65        }
66    }
67
68    /// Sets the solver configuration.
69    pub fn with_config(mut self, config: SolverConfig) -> Self {
70        self.config = config;
71        self
72    }
73
74    /// Sets the termination configuration.
75    pub fn with_termination(mut self, termination: TerminationConfig) -> Self {
76        self.config.termination = Some(termination);
77        self
78    }
79
80    /// Registers a cascading update expression for a shadow variable.
81    ///
82    /// This expression will be compiled to WASM and called by the solver
83    /// when the shadow variable needs to be recomputed.
84    ///
85    /// # Arguments
86    /// * `class_name` - The entity class name (e.g., "Visit")
87    /// * `field_name` - The field with the cascading update shadow variable
88    /// * `expression` - The expression to compute the shadow value
89    pub fn with_cascading_expression(
90        mut self,
91        class_name: impl Into<String>,
92        field_name: impl Into<String>,
93        expression: Expression,
94    ) -> Self {
95        self.cascading_expressions
96            .push((class_name.into(), field_name.into(), expression));
97        self
98    }
99
100    /// Starts solving a problem with the given ID.
101    ///
102    /// Returns an error if a solve with this ID is already in progress.
103    pub fn solve(&mut self, id: I, problem: S) -> SolverForgeResult<()> {
104        if self.active_solves.contains_key(&id) {
105            return Err(SolverForgeError::Solver(
106                "Solve already in progress for this problem ID".into(),
107            ));
108        }
109
110        // Build request using the same logic as TypedSolver
111        let mut domain_model = S::domain_model();
112        let constraints = S::constraints();
113
114        // Apply registered cascading update expressions to the domain model
115        for (class_name, field_name, expression) in &self.cascading_expressions {
116            domain_model.set_cascading_expression(class_name, field_name, expression.clone())?;
117        }
118
119        // Extract predicates from constraints (expressions that need to be compiled to WASM)
120        let predicates = constraints.extract_predicates();
121
122        // Build WASM module (with standard host functions for list operations)
123        let mut builder = WasmModuleBuilder::new()
124            .with_host_functions(HostFunctionRegistry::with_standard_functions())
125            .with_domain_model(domain_model.clone());
126
127        // Add all constraint predicates to the WASM module
128        for predicate in predicates {
129            builder = builder.add_predicate(predicate);
130        }
131
132        let wasm_base64 = builder.build_base64()?;
133
134        // Build the solve request (same as TypedSolver::solve)
135        let domain_dto = domain_model.to_dto();
136        let constraints_dto = constraints.to_dto();
137        let problem_json = problem.to_json()?;
138
139        let list_accessor = ListAccessorDto::new(
140            "newList", "getItem", "setItem", "size", "append", "insert", "remove", "dealloc",
141        );
142
143        let mut request = SolveRequest::new(
144            domain_dto,
145            constraints_dto,
146            wasm_base64,
147            "alloc".to_string(),
148            "dealloc".to_string(),
149            list_accessor,
150            problem_json,
151        );
152
153        if let Some(mode) = &self.config.environment_mode {
154            request = request.with_environment_mode(format!("{:?}", mode).to_uppercase());
155        }
156
157        if let Some(termination) = &self.config.termination {
158            request = request.with_termination(termination.clone());
159        }
160
161        // Start async solve
162        let handle = self.service.solve_async(&request)?;
163        self.active_solves.insert(
164            id,
165            ManagedSolve {
166                handle,
167                _phantom: std::marker::PhantomData,
168            },
169        );
170
171        Ok(())
172    }
173
174    /// Gets the best solution for a problem, if available.
175    pub fn get_best_solution(&self, id: &I) -> SolverForgeResult<Option<S>> {
176        let managed = self
177            .active_solves
178            .get(id)
179            .ok_or_else(|| SolverForgeError::Solver("No solve found for this problem ID".into()))?;
180
181        // Try to get the best solution from the handle
182        if let Ok(Some(response)) = self.service.get_best_solution(&managed.handle) {
183            let parsed = S::from_json(&response.solution)?;
184            return Ok(Some(parsed));
185        }
186
187        // No solution available yet
188        Ok(None)
189    }
190
191    /// Terminates a solve early.
192    pub fn terminate(&mut self, id: &I) -> SolverForgeResult<()> {
193        let managed = self
194            .active_solves
195            .get(id)
196            .ok_or_else(|| SolverForgeError::Solver("No solve found for this problem ID".into()))?;
197
198        self.service.stop(&managed.handle)?;
199        self.active_solves.remove(id);
200
201        Ok(())
202    }
203
204    /// Terminates all active solves.
205    pub fn terminate_all(&mut self) {
206        let ids: Vec<_> = self.active_solves.keys().cloned().collect();
207        for id in ids {
208            let _ = self.terminate(&id);
209        }
210    }
211
212    /// Checks if a solve is currently in progress for the given ID.
213    pub fn is_solving(&self, id: &I) -> bool {
214        if let Some(managed) = self.active_solves.get(id) {
215            if let Ok(status) = self.service.get_status(&managed.handle) {
216                return status.state == SolveState::Running;
217            }
218        }
219        false
220    }
221
222    /// Gets the number of active solves.
223    pub fn active_solve_count(&self) -> usize {
224        self.active_solves.len()
225    }
226
227    /// Removes completed solves from tracking.
228    pub fn cleanup_completed(&mut self) -> SolverForgeResult<()> {
229        let mut completed = Vec::new();
230
231        for (id, managed) in &self.active_solves {
232            let status = self.service.get_status(&managed.handle)?;
233            if status.state != SolveState::Running {
234                completed.push(id.clone());
235            }
236        }
237
238        for id in completed {
239            self.active_solves.remove(&id);
240        }
241
242        Ok(())
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    #[test]
249    fn test_solver_manager_types_compile() {
250        // Verify the generic types work correctly
251        // Actual tests require a mock SolverService
252    }
253}