solverforge_core/solver/
manager.rs1use std::collections::HashMap;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use crate::error::{SolverForgeError, SolverForgeResult};
11use crate::solver::{
12 ListAccessorDto, SolveHandle, SolveRequest, SolveState, SolverConfig, SolverService,
13 TerminationConfig,
14};
15use crate::traits::PlanningSolution;
16use crate::wasm::{Expression, HostFunctionRegistry, WasmModuleBuilder};
17
18pub struct SolverManager<S: PlanningSolution, ProblemId: Eq + Hash + Clone> {
41 active_solves: HashMap<ProblemId, ManagedSolve<S>>,
42 service: Arc<dyn SolverService>,
43 config: SolverConfig,
44 cascading_expressions: Vec<(String, String, Expression)>,
46}
47
48struct ManagedSolve<S: PlanningSolution> {
49 handle: SolveHandle,
50 _phantom: std::marker::PhantomData<S>,
51}
52
53impl<S, I> SolverManager<S, I>
54where
55 S: PlanningSolution + Clone,
56 I: Eq + Hash + Clone,
57{
58 pub fn new(service: Arc<dyn SolverService>) -> Self {
60 Self {
61 active_solves: HashMap::new(),
62 service,
63 config: SolverConfig::default(),
64 cascading_expressions: Vec::new(),
65 }
66 }
67
68 pub fn with_config(mut self, config: SolverConfig) -> Self {
70 self.config = config;
71 self
72 }
73
74 pub fn with_termination(mut self, termination: TerminationConfig) -> Self {
76 self.config.termination = Some(termination);
77 self
78 }
79
80 pub fn with_cascading_expression(
90 mut self,
91 class_name: impl Into<String>,
92 field_name: impl Into<String>,
93 expression: Expression,
94 ) -> Self {
95 self.cascading_expressions
96 .push((class_name.into(), field_name.into(), expression));
97 self
98 }
99
100 pub fn solve(&mut self, id: I, problem: S) -> SolverForgeResult<()> {
104 if self.active_solves.contains_key(&id) {
105 return Err(SolverForgeError::Solver(
106 "Solve already in progress for this problem ID".into(),
107 ));
108 }
109
110 let mut domain_model = S::domain_model();
112 let constraints = S::constraints();
113
114 for (class_name, field_name, expression) in &self.cascading_expressions {
116 domain_model.set_cascading_expression(class_name, field_name, expression.clone())?;
117 }
118
119 let predicates = constraints.extract_predicates();
121
122 let mut builder = WasmModuleBuilder::new()
124 .with_host_functions(HostFunctionRegistry::with_standard_functions())
125 .with_domain_model(domain_model.clone());
126
127 for predicate in predicates {
129 builder = builder.add_predicate(predicate);
130 }
131
132 let wasm_base64 = builder.build_base64()?;
133
134 let domain_dto = domain_model.to_dto();
136 let constraints_dto = constraints.to_dto();
137 let problem_json = problem.to_json()?;
138
139 let list_accessor = ListAccessorDto::new(
140 "newList", "getItem", "setItem", "size", "append", "insert", "remove", "dealloc",
141 );
142
143 let mut request = SolveRequest::new(
144 domain_dto,
145 constraints_dto,
146 wasm_base64,
147 "alloc".to_string(),
148 "dealloc".to_string(),
149 list_accessor,
150 problem_json,
151 );
152
153 if let Some(mode) = &self.config.environment_mode {
154 request = request.with_environment_mode(format!("{:?}", mode).to_uppercase());
155 }
156
157 if let Some(termination) = &self.config.termination {
158 request = request.with_termination(termination.clone());
159 }
160
161 let handle = self.service.solve_async(&request)?;
163 self.active_solves.insert(
164 id,
165 ManagedSolve {
166 handle,
167 _phantom: std::marker::PhantomData,
168 },
169 );
170
171 Ok(())
172 }
173
174 pub fn get_best_solution(&self, id: &I) -> SolverForgeResult<Option<S>> {
176 let managed = self
177 .active_solves
178 .get(id)
179 .ok_or_else(|| SolverForgeError::Solver("No solve found for this problem ID".into()))?;
180
181 if let Ok(Some(response)) = self.service.get_best_solution(&managed.handle) {
183 let parsed = S::from_json(&response.solution)?;
184 return Ok(Some(parsed));
185 }
186
187 Ok(None)
189 }
190
191 pub fn terminate(&mut self, id: &I) -> SolverForgeResult<()> {
193 let managed = self
194 .active_solves
195 .get(id)
196 .ok_or_else(|| SolverForgeError::Solver("No solve found for this problem ID".into()))?;
197
198 self.service.stop(&managed.handle)?;
199 self.active_solves.remove(id);
200
201 Ok(())
202 }
203
204 pub fn terminate_all(&mut self) {
206 let ids: Vec<_> = self.active_solves.keys().cloned().collect();
207 for id in ids {
208 let _ = self.terminate(&id);
209 }
210 }
211
212 pub fn is_solving(&self, id: &I) -> bool {
214 if let Some(managed) = self.active_solves.get(id) {
215 if let Ok(status) = self.service.get_status(&managed.handle) {
216 return status.state == SolveState::Running;
217 }
218 }
219 false
220 }
221
222 pub fn active_solve_count(&self) -> usize {
224 self.active_solves.len()
225 }
226
227 pub fn cleanup_completed(&mut self) -> SolverForgeResult<()> {
229 let mut completed = Vec::new();
230
231 for (id, managed) in &self.active_solves {
232 let status = self.service.get_status(&managed.handle)?;
233 if status.state != SolveState::Running {
234 completed.push(id.clone());
235 }
236 }
237
238 for id in completed {
239 self.active_solves.remove(&id);
240 }
241
242 Ok(())
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 #[test]
249 fn test_solver_manager_types_compile() {
250 }
253}