1use std::collections::{HashSet, VecDeque};
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use indexmap::IndexMap;
9use tracing::{debug, instrument};
10
11pub trait Graph: Debug {
13 type Node: Clone + Debug + Hash + Eq;
15
16 type Error: std::error::Error;
18
19 fn is_ancestor(
27 &self,
28 ancestor: Self::Node,
29 descendant: Self::Node,
30 ) -> Result<bool, Self::Error>;
31
32 #[instrument]
39 fn simplify_success_bounds(
40 &self,
41 nodes: HashSet<Self::Node>,
42 ) -> Result<HashSet<Self::Node>, Self::Error> {
43 Ok(nodes)
44 }
45
46 #[instrument]
53 fn simplify_failure_bounds(
54 &self,
55 nodes: HashSet<Self::Node>,
56 ) -> Result<HashSet<Self::Node>, Self::Error> {
57 Ok(nodes)
58 }
59}
60
61#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
63pub enum Status {
64 Untested,
66
67 Success,
71
72 Failure,
76
77 Indeterminate,
80}
81
82#[derive(Debug, Eq, PartialEq)]
84pub struct Bounds<Node: Debug + Eq + Hash> {
85 pub success: HashSet<Node>,
88
89 pub failure: HashSet<Node>,
92}
93
94impl<Node: Debug + Eq + Hash> Default for Bounds<Node> {
95 fn default() -> Self {
96 Bounds {
97 success: Default::default(),
98 failure: Default::default(),
99 }
100 }
101}
102
103pub trait Strategy<G: Graph>: Debug {
105 type Error: std::error::Error;
107
108 fn midpoint(
122 &self,
123 graph: &G,
124 bounds: &Bounds<G::Node>,
125 statuses: &IndexMap<G::Node, Status>,
126 ) -> Result<Option<G::Node>, Self::Error>;
127}
128
129pub struct LazySolution<'a, TNode: Debug + Eq + Hash + 'a, TError> {
131 pub bounds: Bounds<TNode>,
133
134 pub next_to_search: Box<dyn Iterator<Item = Result<TNode, TError>> + 'a>,
142}
143
144impl<'a, TNode: Debug + Eq + Hash + 'a, TError> LazySolution<'a, TNode, TError> {
145 pub fn into_eager(self) -> Result<EagerSolution<TNode>, TError> {
147 let LazySolution {
148 bounds,
149 next_to_search,
150 } = self;
151 Ok(EagerSolution {
152 bounds,
153 next_to_search: next_to_search.collect::<Result<Vec<_>, TError>>()?,
154 })
155 }
156}
157
158#[derive(Debug, Eq, PartialEq)]
161pub struct EagerSolution<Node: Debug + Hash + Eq> {
162 pub(crate) bounds: Bounds<Node>,
163 pub(crate) next_to_search: Vec<Node>,
164}
165
166#[allow(missing_docs)]
167#[derive(Debug, thiserror::Error)]
168pub enum SearchError<TNode, TGraphError, TStrategyError> {
169 #[error("node {node:?} has already been classified as a {status:?} node, but was returned as a new midpoint to search; this would loop indefinitely")]
170 AlreadySearchedMidpoint { node: TNode, status: Status },
171
172 #[error(transparent)]
173 Graph(TGraphError),
174
175 #[error(transparent)]
176 Strategy(TStrategyError),
177}
178
179#[allow(missing_docs)]
181#[derive(Debug, thiserror::Error)]
182pub enum NotifyError<TNode, TGraphError> {
183 #[error("inconsistent state transition: {ancestor_node:?} ({ancestor_status:?}) was marked as an ancestor of {descendant_node:?} ({descendant_status:?}")]
184 InconsistentStateTransition {
185 ancestor_node: TNode,
186 ancestor_status: Status,
187 descendant_node: TNode,
188 descendant_status: Status,
189 },
190
191 #[error("illegal state transition for {node:?}: {from:?} -> {to:?}")]
192 IllegalStateTransition {
193 node: TNode,
194 from: Status,
195 to: Status,
196 },
197
198 #[error(transparent)]
199 Graph(TGraphError),
200}
201
202#[derive(Clone, Debug)]
204pub struct Search<G: Graph> {
205 graph: G,
206 nodes: IndexMap<G::Node, Status>,
207}
208
209impl<G: Graph> Search<G> {
210 pub fn new(graph: G, search_nodes: impl IntoIterator<Item = G::Node>) -> Self {
219 let nodes = search_nodes
220 .into_iter()
221 .map(|node| (node, Status::Untested))
222 .collect();
223 Self { graph, nodes }
224 }
225
226 #[instrument]
230 pub fn success_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
231 let success_nodes = self
232 .nodes
233 .iter()
234 .filter_map(|(node, status)| match status {
235 Status::Success => Some(node.clone()),
236 Status::Untested | Status::Failure | Status::Indeterminate => None,
237 })
238 .collect::<HashSet<_>>();
239 let success_bounds = self.graph.simplify_success_bounds(success_nodes)?;
240 Ok(success_bounds)
241 }
242
243 #[instrument]
247 pub fn failure_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
248 let failure_nodes = self
249 .nodes
250 .iter()
251 .filter_map(|(node, status)| match status {
252 Status::Failure => Some(node.clone()),
253 Status::Untested | Status::Success | Status::Indeterminate => None,
254 })
255 .collect::<HashSet<_>>();
256 let failure_bounds = self.graph.simplify_failure_bounds(failure_nodes)?;
257 Ok(failure_bounds)
258 }
259
260 #[instrument]
263 #[allow(clippy::type_complexity)]
264 pub fn search<'a, S: Strategy<G>>(
265 &'a self,
266 strategy: &'a S,
267 ) -> Result<
268 LazySolution<G::Node, SearchError<G::Node, G::Error, S::Error>>,
269 SearchError<G::Node, G::Error, S::Error>,
270 > {
271 let success_bounds = self.success_bounds().map_err(SearchError::Graph)?;
272 let failure_bounds = self.failure_bounds().map_err(SearchError::Graph)?;
273
274 #[derive(Debug)]
275 struct State<G: Graph> {
276 bounds: Bounds<G::Node>,
277 statuses: IndexMap<G::Node, Status>,
278 }
279
280 struct Iter<'a, G: Graph, S: Strategy<G>> {
281 graph: &'a G,
282 strategy: &'a S,
283 seen: HashSet<G::Node>,
284 states: VecDeque<State<G>>,
285 }
286
287 impl<'a, G: Graph, S: Strategy<G>> Iterator for Iter<'a, G, S> {
288 type Item = Result<G::Node, SearchError<G::Node, G::Error, S::Error>>;
289
290 fn next(&mut self) -> Option<Self::Item> {
291 while let Some(state) = self.states.pop_front() {
292 debug!(?state, "Popped speculation state");
293 let State { bounds, statuses } = state;
294
295 let node = match self.strategy.midpoint(self.graph, &bounds, &statuses) {
296 Ok(Some(node)) => node,
297 Ok(None) => continue,
298 Err(err) => return Some(Err(SearchError::Strategy(err))),
299 };
300
301 let Bounds { success, failure } = bounds;
302 for success_node in success.iter() {
303 match self.graph.is_ancestor(node.clone(), success_node.clone()) {
304 Ok(true) => {
305 return Some(Err(SearchError::AlreadySearchedMidpoint {
306 node,
307 status: Status::Success,
308 }));
309 }
310 Ok(false) => (),
311 Err(err) => return Some(Err(SearchError::Graph(err))),
312 }
313 }
314 for failure_node in failure.iter() {
315 match self.graph.is_ancestor(failure_node.clone(), node.clone()) {
316 Ok(true) => {
317 return Some(Err(SearchError::AlreadySearchedMidpoint {
318 node,
319 status: Status::Failure,
320 }));
321 }
322 Ok(false) => (),
323 Err(err) => return Some(Err(SearchError::Graph(err))),
324 }
325 }
326
327 self.states.push_back(State {
329 bounds: Bounds {
330 success: success.clone(),
331 failure: {
332 let mut failure_bounds = failure.clone();
333 failure_bounds.insert(node.clone());
334 match self.graph.simplify_failure_bounds(failure_bounds) {
335 Ok(bounds) => bounds,
336 Err(err) => return Some(Err(SearchError::Graph(err))),
337 }
338 },
339 },
340 statuses: {
341 let mut statuses = statuses.clone();
342 statuses.insert(node.clone(), Status::Failure);
343 statuses
344 },
345 });
346
347 self.states.push_back(State {
349 bounds: Bounds {
350 success: {
351 let mut success_bounds = success.clone();
352 success_bounds.insert(node.clone());
353 match self.graph.simplify_success_bounds(success_bounds) {
354 Ok(bounds) => bounds,
355 Err(err) => return Some(Err(SearchError::Graph(err))),
356 }
357 },
358 failure: failure.clone(),
359 },
360 statuses: {
361 let mut statuses = statuses.clone();
362 statuses.insert(node.clone(), Status::Success);
363 statuses
364 },
365 });
366
367 if self.seen.insert(node.clone()) {
368 return Some(Ok(node));
369 }
370 }
371 None
372 }
373 }
374
375 let initial_state = State {
376 bounds: Bounds {
377 success: success_bounds.clone(),
378 failure: failure_bounds.clone(),
379 },
380 statuses: self.nodes.clone(),
381 };
382 let iter = Iter {
383 graph: &self.graph,
384 strategy,
385 seen: Default::default(),
386 states: [initial_state].into_iter().collect(),
387 };
388
389 Ok(LazySolution {
390 bounds: Bounds {
391 success: success_bounds,
392 failure: failure_bounds,
393 },
394 next_to_search: Box::new(iter),
395 })
396 }
397
398 #[instrument]
400 pub fn notify(
401 &mut self,
402 node: G::Node,
403 status: Status,
404 ) -> Result<(), NotifyError<G::Node, G::Error>> {
405 match self.nodes.get(&node) {
406 Some(existing_status @ (Status::Success | Status::Failure))
407 if existing_status != &status =>
408 {
409 return Err(NotifyError::IllegalStateTransition {
410 node,
411 from: *existing_status,
412 to: status,
413 })
414 }
415 _ => {}
416 }
417
418 match status {
419 Status::Untested | Status::Indeterminate => {}
420
421 Status::Success => {
422 for failure_node in self.failure_bounds().map_err(NotifyError::Graph)? {
423 if self
424 .graph
425 .is_ancestor(failure_node.clone(), node.clone())
426 .map_err(NotifyError::Graph)?
427 {
428 return Err(NotifyError::InconsistentStateTransition {
429 ancestor_node: failure_node,
430 ancestor_status: Status::Failure,
431 descendant_node: node,
432 descendant_status: Status::Success,
433 });
434 }
435 }
436 }
437
438 Status::Failure => {
439 for success_node in self.success_bounds().map_err(NotifyError::Graph)? {
440 if self
441 .graph
442 .is_ancestor(node.clone(), success_node.clone())
443 .map_err(NotifyError::Graph)?
444 {
445 return Err(NotifyError::InconsistentStateTransition {
446 ancestor_node: node,
447 ancestor_status: Status::Failure,
448 descendant_node: success_node,
449 descendant_status: Status::Success,
450 });
451 }
452 }
453 }
454 }
455
456 self.nodes.insert(node, status);
457 Ok(())
458 }
459}