1use std::collections::{HashSet, VecDeque};
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use indexmap::IndexMap;
9use tracing::{debug, instrument};
10
11pub trait Graph: Debug {
13 type Node: Clone + Debug + Hash + Eq;
15
16 type Error: std::error::Error;
18
19 fn is_ancestor(
27 &self,
28 ancestor: Self::Node,
29 descendant: Self::Node,
30 ) -> Result<bool, Self::Error>;
31
32 #[instrument]
39 fn simplify_success_bounds(
40 &self,
41 nodes: HashSet<Self::Node>,
42 ) -> Result<HashSet<Self::Node>, Self::Error> {
43 Ok(nodes)
44 }
45
46 #[instrument]
53 fn simplify_failure_bounds(
54 &self,
55 nodes: HashSet<Self::Node>,
56 ) -> Result<HashSet<Self::Node>, Self::Error> {
57 Ok(nodes)
58 }
59}
60
61#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
63pub enum Status {
64 Untested,
66
67 Success,
71
72 Failure,
76
77 Indeterminate,
80}
81
82#[derive(Debug, Eq, PartialEq)]
84pub struct Bounds<Node: Debug + Eq + Hash> {
85 pub success: HashSet<Node>,
88
89 pub failure: HashSet<Node>,
92}
93
94impl<Node: Debug + Eq + Hash> Default for Bounds<Node> {
95 fn default() -> Self {
96 Bounds {
97 success: Default::default(),
98 failure: Default::default(),
99 }
100 }
101}
102
103pub trait Strategy<G: Graph>: Debug {
105 type Error: std::error::Error;
107
108 fn midpoint(
122 &self,
123 graph: &G,
124 bounds: &Bounds<G::Node>,
125 statuses: &IndexMap<G::Node, Status>,
126 ) -> Result<Option<G::Node>, Self::Error>;
127}
128
129pub struct LazySolution<'a, TNode: Debug + Eq + Hash + 'a, TError> {
131 pub bounds: Bounds<TNode>,
133
134 pub next_to_search: Box<dyn Iterator<Item = Result<TNode, TError>> + 'a>,
142}
143
144impl<'a, TNode: Debug + Eq + Hash + 'a, TError> LazySolution<'a, TNode, TError> {
145 pub fn into_eager(self) -> Result<EagerSolution<TNode>, TError> {
147 let LazySolution {
148 bounds,
149 next_to_search,
150 } = self;
151 Ok(EagerSolution {
152 bounds,
153 next_to_search: next_to_search.collect::<Result<Vec<_>, TError>>()?,
154 })
155 }
156}
157
158#[derive(Debug, Eq, PartialEq)]
161pub struct EagerSolution<Node: Debug + Hash + Eq> {
162 pub(crate) bounds: Bounds<Node>,
163 pub(crate) next_to_search: Vec<Node>,
164}
165
166#[allow(missing_docs)]
167#[derive(Debug, thiserror::Error)]
168pub enum SearchError<TNode, TGraphError, TStrategyError> {
169 #[error(
170 "node {node:?} has already been classified as a {status:?} node, but was returned as a new midpoint to search; this would loop indefinitely"
171 )]
172 AlreadySearchedMidpoint { node: TNode, status: Status },
173
174 #[error(transparent)]
175 Graph(TGraphError),
176
177 #[error(transparent)]
178 Strategy(TStrategyError),
179}
180
181#[allow(missing_docs)]
183#[derive(Debug, thiserror::Error)]
184pub enum NotifyError<TNode, TGraphError> {
185 #[error(
186 "inconsistent state transition: {ancestor_node:?} ({ancestor_status:?}) was marked as an ancestor of {descendant_node:?} ({descendant_status:?}"
187 )]
188 InconsistentStateTransition {
189 ancestor_node: TNode,
190 ancestor_status: Status,
191 descendant_node: TNode,
192 descendant_status: Status,
193 },
194
195 #[error("illegal state transition for {node:?}: {from:?} -> {to:?}")]
196 IllegalStateTransition {
197 node: TNode,
198 from: Status,
199 to: Status,
200 },
201
202 #[error(transparent)]
203 Graph(TGraphError),
204}
205
206#[derive(Clone, Debug)]
208pub struct Search<G: Graph> {
209 graph: G,
210 nodes: IndexMap<G::Node, Status>,
211}
212
213impl<G: Graph> Search<G> {
214 pub fn new(graph: G, search_nodes: impl IntoIterator<Item = G::Node>) -> Self {
223 let nodes = search_nodes
224 .into_iter()
225 .map(|node| (node, Status::Untested))
226 .collect();
227 Self { graph, nodes }
228 }
229
230 #[instrument]
234 pub fn success_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
235 let success_nodes = self
236 .nodes
237 .iter()
238 .filter_map(|(node, status)| match status {
239 Status::Success => Some(node.clone()),
240 Status::Untested | Status::Failure | Status::Indeterminate => None,
241 })
242 .collect::<HashSet<_>>();
243 let success_bounds = self.graph.simplify_success_bounds(success_nodes)?;
244 Ok(success_bounds)
245 }
246
247 #[instrument]
251 pub fn failure_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
252 let failure_nodes = self
253 .nodes
254 .iter()
255 .filter_map(|(node, status)| match status {
256 Status::Failure => Some(node.clone()),
257 Status::Untested | Status::Success | Status::Indeterminate => None,
258 })
259 .collect::<HashSet<_>>();
260 let failure_bounds = self.graph.simplify_failure_bounds(failure_nodes)?;
261 Ok(failure_bounds)
262 }
263
264 #[instrument]
267 #[allow(clippy::type_complexity)]
268 pub fn search<'a, S: Strategy<G>>(
269 &'a self,
270 strategy: &'a S,
271 ) -> Result<
272 LazySolution<'a, G::Node, SearchError<G::Node, G::Error, S::Error>>,
273 SearchError<G::Node, G::Error, S::Error>,
274 > {
275 let success_bounds = self.success_bounds().map_err(SearchError::Graph)?;
276 let failure_bounds = self.failure_bounds().map_err(SearchError::Graph)?;
277
278 #[derive(Debug)]
279 struct State<G: Graph> {
280 bounds: Bounds<G::Node>,
281 statuses: IndexMap<G::Node, Status>,
282 }
283
284 struct Iter<'a, G: Graph, S: Strategy<G>> {
285 graph: &'a G,
286 strategy: &'a S,
287 seen: HashSet<G::Node>,
288 states: VecDeque<State<G>>,
289 }
290
291 impl<G: Graph, S: Strategy<G>> Iterator for Iter<'_, G, S> {
292 type Item = Result<G::Node, SearchError<G::Node, G::Error, S::Error>>;
293
294 fn next(&mut self) -> Option<Self::Item> {
295 while let Some(state) = self.states.pop_front() {
296 debug!(?state, "Popped speculation state");
297 let State { bounds, statuses } = state;
298
299 let node = match self.strategy.midpoint(self.graph, &bounds, &statuses) {
300 Ok(Some(node)) => node,
301 Ok(None) => continue,
302 Err(err) => return Some(Err(SearchError::Strategy(err))),
303 };
304
305 let Bounds { success, failure } = bounds;
306 for success_node in success.iter() {
307 match self.graph.is_ancestor(node.clone(), success_node.clone()) {
308 Ok(true) => {
309 return Some(Err(SearchError::AlreadySearchedMidpoint {
310 node,
311 status: Status::Success,
312 }));
313 }
314 Ok(false) => (),
315 Err(err) => return Some(Err(SearchError::Graph(err))),
316 }
317 }
318 for failure_node in failure.iter() {
319 match self.graph.is_ancestor(failure_node.clone(), node.clone()) {
320 Ok(true) => {
321 return Some(Err(SearchError::AlreadySearchedMidpoint {
322 node,
323 status: Status::Failure,
324 }));
325 }
326 Ok(false) => (),
327 Err(err) => return Some(Err(SearchError::Graph(err))),
328 }
329 }
330
331 self.states.push_back(State {
333 bounds: Bounds {
334 success: success.clone(),
335 failure: {
336 let mut failure_bounds = failure.clone();
337 failure_bounds.insert(node.clone());
338 match self.graph.simplify_failure_bounds(failure_bounds) {
339 Ok(bounds) => bounds,
340 Err(err) => return Some(Err(SearchError::Graph(err))),
341 }
342 },
343 },
344 statuses: {
345 let mut statuses = statuses.clone();
346 statuses.insert(node.clone(), Status::Failure);
347 statuses
348 },
349 });
350
351 self.states.push_back(State {
353 bounds: Bounds {
354 success: {
355 let mut success_bounds = success.clone();
356 success_bounds.insert(node.clone());
357 match self.graph.simplify_success_bounds(success_bounds) {
358 Ok(bounds) => bounds,
359 Err(err) => return Some(Err(SearchError::Graph(err))),
360 }
361 },
362 failure: failure.clone(),
363 },
364 statuses: {
365 let mut statuses = statuses.clone();
366 statuses.insert(node.clone(), Status::Success);
367 statuses
368 },
369 });
370
371 if self.seen.insert(node.clone()) {
372 return Some(Ok(node));
373 }
374 }
375 None
376 }
377 }
378
379 let initial_state = State {
380 bounds: Bounds {
381 success: success_bounds.clone(),
382 failure: failure_bounds.clone(),
383 },
384 statuses: self.nodes.clone(),
385 };
386 let iter = Iter {
387 graph: &self.graph,
388 strategy,
389 seen: Default::default(),
390 states: [initial_state].into_iter().collect(),
391 };
392
393 Ok(LazySolution {
394 bounds: Bounds {
395 success: success_bounds,
396 failure: failure_bounds,
397 },
398 next_to_search: Box::new(iter),
399 })
400 }
401
402 #[instrument]
404 pub fn notify(
405 &mut self,
406 node: G::Node,
407 status: Status,
408 ) -> Result<(), NotifyError<G::Node, G::Error>> {
409 match self.nodes.get(&node) {
410 Some(existing_status @ (Status::Success | Status::Failure))
411 if existing_status != &status =>
412 {
413 return Err(NotifyError::IllegalStateTransition {
414 node,
415 from: *existing_status,
416 to: status,
417 });
418 }
419 _ => {}
420 }
421
422 match status {
423 Status::Untested | Status::Indeterminate => {}
424
425 Status::Success => {
426 for failure_node in self.failure_bounds().map_err(NotifyError::Graph)? {
427 if self
428 .graph
429 .is_ancestor(failure_node.clone(), node.clone())
430 .map_err(NotifyError::Graph)?
431 {
432 return Err(NotifyError::InconsistentStateTransition {
433 ancestor_node: failure_node,
434 ancestor_status: Status::Failure,
435 descendant_node: node,
436 descendant_status: Status::Success,
437 });
438 }
439 }
440 }
441
442 Status::Failure => {
443 for success_node in self.success_bounds().map_err(NotifyError::Graph)? {
444 if self
445 .graph
446 .is_ancestor(node.clone(), success_node.clone())
447 .map_err(NotifyError::Graph)?
448 {
449 return Err(NotifyError::InconsistentStateTransition {
450 ancestor_node: node,
451 ancestor_status: Status::Failure,
452 descendant_node: success_node,
453 descendant_status: Status::Success,
454 });
455 }
456 }
457 }
458 }
459
460 self.nodes.insert(node, status);
461 Ok(())
462 }
463}