1use std::{cmp::Ordering, collections::*, mem::take};
2
3use crate::{Array, ArrayCmp, Boxed, Primitive, SigNode, Signature, Uiua, UiuaResult, Value};
4
5pub fn path(
6 neighbors: SigNode,
7 is_goal: SigNode,
8 heuristic: Option<SigNode>,
9 env: &mut Uiua,
10) -> UiuaResult {
11 path_impl(neighbors, is_goal, heuristic, PathMode::All, env)
12}
13
14pub fn path_first(
15 neighbors: SigNode,
16 is_goal: SigNode,
17 heuristic: Option<SigNode>,
18 env: &mut Uiua,
19) -> UiuaResult {
20 path_impl(neighbors, is_goal, heuristic, PathMode::First, env)
21}
22
23pub fn path_sign_len(
24 neighbors: SigNode,
25 is_goal: SigNode,
26 heuristic: Option<SigNode>,
27 env: &mut Uiua,
28) -> UiuaResult {
29 path_impl(neighbors, is_goal, heuristic, PathMode::Exists, env)
30}
31
32pub fn path_take(
33 neighbors: SigNode,
34 is_goal: SigNode,
35 heuristic: Option<SigNode>,
36 env: &mut Uiua,
37) -> UiuaResult {
38 let n = env
39 .pop("number of paths to take")?
40 .as_ints_or_infs(env, "Taken amount must be a list of integers or infinity")?;
41 if n.is_empty() {
42 path_impl(neighbors, is_goal, heuristic, PathMode::All, env)?;
43 } else {
44 match n.first() {
45 Some(Ok(n)) if *n >= 0 => path_impl(
46 neighbors,
47 is_goal,
48 heuristic,
49 PathMode::Take(*n as usize),
50 env,
51 )?,
52 _ => path_impl(neighbors, is_goal, heuristic, PathMode::All, env)?,
53 }
54 }
55 let path = env.pop("path")?.take_impl(&n, env)?;
56 env.push(path);
57 Ok(())
58}
59
60pub fn path_pop(
61 neighbors: SigNode,
62 is_goal: SigNode,
63 heuristic: Option<SigNode>,
64 env: &mut Uiua,
65) -> UiuaResult {
66 path_impl(neighbors, is_goal, heuristic, PathMode::CostOnly, env)
67}
68
69#[derive(Debug, Clone, Copy)]
70enum PathMode {
71 All,
72 First,
73 Exists,
74 CostOnly,
75 Take(usize),
76}
77
78fn path_impl(
79 neighbors: SigNode,
80 is_goal: SigNode,
81 heuristic: Option<SigNode>,
82 mode: PathMode,
83 env: &mut Uiua,
84) -> UiuaResult {
85 let start = env.pop("start")?;
86 let nei_sig = neighbors.sig;
87 let heu_sig = heuristic
88 .as_ref()
89 .map(|h| h.sig)
90 .unwrap_or_else(|| Signature::new(0, 1));
91 let isg_sig = is_goal.sig;
92 for (name, sig, req_out) in &[
93 ("neighbors", nei_sig, [1, 2].as_slice()),
94 ("goal", isg_sig, &[1]),
95 ("heuristic", heu_sig, &[1]),
96 ] {
97 if !req_out.contains(&sig.outputs()) {
98 let count = if req_out.len() == 1 {
99 "1"
100 } else {
101 "either 1 or 2"
102 };
103 return Err(env.error(format!(
104 "{} {name} function must return {count} outputs \
105 but its signature is {sig}",
106 Primitive::Path.format()
107 )));
108 }
109 }
110 let has_costs = nei_sig.outputs() == 2;
111 let arg_count = nei_sig
112 .args()
113 .max(heu_sig.args())
114 .max(isg_sig.args())
115 .saturating_sub(1);
116 let mut args = Vec::with_capacity(arg_count);
117 for i in 0..arg_count {
118 args.push(env.pop(i + 1)?);
119 }
120
121 struct NodeCost {
122 node: usize,
123 cost: f64,
124 }
125 impl PartialEq for NodeCost {
126 fn eq(&self, other: &Self) -> bool {
127 self.cost.array_eq(&other.cost)
128 }
129 }
130 impl Eq for NodeCost {}
131 impl PartialOrd for NodeCost {
132 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
133 Some(self.cmp(other))
134 }
135 }
136 impl Ord for NodeCost {
137 fn cmp(&self, other: &Self) -> Ordering {
138 self.cost.array_cmp(&other.cost).reverse()
139 }
140 }
141
142 struct PathEnv<'a> {
143 env: &'a mut Uiua,
144 neighbors: SigNode,
145 is_goal: SigNode,
146 heuristic: Option<SigNode>,
147 args: Vec<Value>,
148 }
149
150 impl PathEnv<'_> {
151 fn heuristic(&mut self, node: &Value) -> UiuaResult<f64> {
152 Ok(if let Some(heuristic) = &self.heuristic {
153 let heu_args = heuristic.sig.args();
154 for arg in (self.args.iter()).take(heu_args.saturating_sub(1)).rev() {
155 self.env.push(arg.clone());
156 }
157 if heu_args > 0 {
158 self.env.push(node.clone());
159 }
160 self.env.exec(heuristic.clone())?;
161 let h = (self.env)
162 .pop("heuristic")?
163 .as_num(self.env, "Heuristic must be a number")?;
164 if h < 0.0 {
165 return Err(self
166 .env
167 .error("Negative heuristic values are not allowed in A*"));
168 }
169 h
170 } else {
171 0.0
172 })
173 }
174 fn neighbors(&mut self, node: &Value) -> UiuaResult<Vec<(Value, f64)>> {
175 let nei_args = self.neighbors.sig.args();
176 for arg in (self.args.iter()).take(nei_args.saturating_sub(1)).rev() {
177 self.env.push(arg.clone());
178 }
179 if nei_args > 0 {
180 self.env.push(node.clone());
181 }
182 self.env.exec(self.neighbors.clone())?;
183 let (nodes, costs) = if self.neighbors.sig.outputs() == 2 {
184 let costs = self.env.pop("neighbors costs")?;
185 let costs_rank = costs.rank();
186 let mut costs = costs
187 .as_nums(self.env, "Costs must be a number or list of numbers")?
188 .into_owned();
189 let nodes = self.env.pop("neighbors nodes")?;
190 if costs.len() != nodes.row_count() {
191 if costs_rank == 0 {
192 costs.resize(nodes.row_count(), costs[0]);
193 } else {
194 return Err(self.env.error(format!(
195 "Number of nodes {} does not match number of costs {}",
196 nodes.row_count(),
197 costs.len(),
198 )));
199 }
200 }
201 if costs.iter().any(|&c| c < 0.0) {
202 return Err(self.env.error("Negative costs are not allowed in A*"));
203 }
204 (nodes, costs)
205 } else {
206 let nodes = self.env.pop("neighbors nodes")?;
207 let costs = vec![1.0; nodes.row_count()];
208 (nodes, costs)
209 };
210 Ok(nodes.into_rows().zip(costs).collect())
211 }
212 fn is_goal(&mut self, node: &Value) -> UiuaResult<bool> {
213 let isg_args = self.is_goal.sig.args();
214 for arg in (self.args.iter()).take(isg_args.saturating_sub(1)).rev() {
215 self.env.push(arg.clone());
216 }
217 if isg_args > 0 {
218 self.env.push(node.clone());
219 }
220 self.env.exec(self.is_goal.clone())?;
221 let is_goal = (self.env.pop("is_goal")?)
222 .as_bool(self.env, "path goal function must return a boolean")?;
223 Ok(is_goal)
224 }
225 }
226
227 let mut if_empty = start.clone();
228 if_empty.fix();
229 if_empty = if_empty.first_dim_zero();
230 if_empty.shape.prepend(0);
231
232 let mut to_see = BinaryHeap::new();
234 let mut backing = vec![start.clone()];
235 let mut indices: HashMap<Value, usize> = [(start, 0)].into();
236 to_see.push(NodeCost { node: 0, cost: 0.0 });
237
238 let mut came_from: HashMap<usize, BTreeSet<usize>> = HashMap::new();
239 let mut full_cost: HashMap<usize, f64> = [(0, 0.0)].into();
240
241 let mut shortest_cost = f64::INFINITY;
242 let mut ends = BTreeSet::new();
243
244 fn count_paths(ends: &BTreeSet<usize>, came_from: &HashMap<usize, BTreeSet<usize>>) -> usize {
245 let mut queue = VecDeque::new();
246 let mut count = 0;
247 for &end in ends {
248 queue.clear();
249 queue.push_back(end);
250 while let Some(curr) = queue.pop_front() {
251 if let Some(parents) = came_from.get(&curr) {
252 for &parent in parents {
253 queue.push_back(parent);
254 }
255 } else {
256 count += 1;
257 }
258 }
259 }
260 count
261 }
262
263 env.without_fill(|env| -> UiuaResult {
265 let mut env = PathEnv {
266 env,
267 neighbors,
268 heuristic,
269 is_goal,
270 args,
271 };
272
273 'outer: while let Some(NodeCost { node: curr, .. }) = to_see.pop() {
274 env.env.respect_execution_limit()?;
275 let curr_cost = full_cost[&curr];
276 if curr_cost > shortest_cost || ends.contains(&curr) {
278 continue;
279 }
280 if env.is_goal(&backing[curr])? {
282 ends.insert(curr);
283 shortest_cost = curr_cost;
284 match mode {
285 PathMode::All => continue,
286 PathMode::Take(n) if n <= 1 => break,
287 PathMode::Take(n) if count_paths(&ends, &came_from) >= n => break,
288 _ => break,
289 }
290 }
291 for (nei, nei_cost) in env.neighbors(&backing[curr])? {
293 let nei = if let Some(index) = indices.get(&nei) {
295 *index
296 } else {
297 let index = backing.len();
298 indices.insert(nei.clone(), index);
299 backing.push(nei);
300 index
301 };
302 let from_curr_nei_cost = curr_cost + nei_cost;
303 let curr_nei_cost = full_cost.get(&nei).copied().unwrap_or(f64::INFINITY);
304 if from_curr_nei_cost <= curr_nei_cost {
305 if let PathMode::Take(n) = mode {
306 if ends.contains(&nei) && count_paths(&ends, &came_from) >= n {
307 break 'outer;
308 }
309 }
310 let parents = came_from.entry(nei).or_default();
311 if from_curr_nei_cost < curr_nei_cost {
313 parents.clear();
315 full_cost.insert(nei, from_curr_nei_cost);
317 to_see.push(NodeCost {
319 cost: from_curr_nei_cost + env.heuristic(&backing[nei])?,
320 node: nei,
321 });
322 }
323 parents.insert(curr);
324 }
325 }
326 }
327 Ok(())
328 })?;
329
330 if has_costs {
331 env.push(shortest_cost);
332 }
333
334 let make_path = |path: Vec<usize>| {
335 if let Some(&[a, b]) = path
336 .windows(2)
337 .find(|w| backing[w[0]].shape != backing[w[1]].shape)
338 {
339 return Err(env.error(format!(
340 "Cannot make path from nodes with incompatible shapes {} and {}",
341 backing[a].shape, backing[b].shape
342 )));
343 }
344 let path: Vec<_> = path.into_iter().map(|i| backing[i].clone()).collect();
345 Value::from_row_values(path, env)
346 };
347
348 match mode {
349 PathMode::All | PathMode::Take(_) => {
350 let mut paths = Vec::new();
351 'outer: for end in ends {
352 let mut currs = vec![vec![end]];
353 let mut these_paths = Vec::new();
354 while !currs.is_empty() {
355 env.respect_execution_limit()?;
356 let mut new_paths = Vec::new();
357 currs.retain_mut(|path| {
358 let parents = came_from.get(path.last().unwrap());
359 match parents.map(|p| p.len()).unwrap_or(0) {
360 0 => {
361 these_paths.push(take(path));
362 false
363 }
364 1 => {
365 path.push(*parents.unwrap().iter().next().unwrap());
366 true
367 }
368 _ => {
369 for &parent in parents.unwrap().iter().skip(1) {
370 let mut path = path.clone();
371 path.push(parent);
372 new_paths.push(path);
373 }
374 path.push(*parents.unwrap().iter().next().unwrap());
375 true
376 }
377 }
378 });
379 currs.extend(new_paths);
380 }
381 for mut path in these_paths {
382 path.reverse();
383 let path_val = make_path(path)?;
384 paths.push(if has_costs {
385 Boxed(path_val).into()
386 } else {
387 path_val
388 });
389 if let PathMode::Take(n) = mode {
390 if paths.len() >= n {
391 break 'outer;
392 }
393 }
394 }
395 }
396 let path_count = paths.len();
397 let mut paths_val = Value::from_row_values(paths, env)?;
398 if path_count == 0 {
399 paths_val = if has_costs {
400 Array::<Boxed>::default().into()
401 } else {
402 if_empty
403 }
404 } else if let PathMode::Take(0) = mode {
405 if paths_val.row_count() > 0 {
406 paths_val.drop_n(1);
407 }
408 }
409 env.push(paths_val);
410 }
411 PathMode::First => {
412 if let Some(mut curr) = ends.into_iter().next() {
413 let mut path = vec![curr];
414 while let Some(from) = came_from.get(&curr) {
415 let from = *from.iter().next().unwrap();
416 path.push(from);
417 curr = from;
418 }
419 path.reverse();
420 let path_val = make_path(path)?;
421 env.push(if has_costs {
422 Boxed(path_val).into()
423 } else {
424 path_val
425 });
426 } else if let Some(val) = env.value_fill().map(|fv| &fv.value) {
427 if val.rank() == 0 {
428 env.push(if has_costs {
429 Array::<Boxed>::default().into()
430 } else {
431 Value::default()
432 });
433 } else {
434 return Err(env.error("No path found. A fill is set, but it is not scalar."));
435 }
436 } else {
437 return Err(env.error("No path found"));
438 }
439 }
440 PathMode::Exists => env.push(!ends.is_empty()),
441 PathMode::CostOnly => {}
442 }
443 Ok(())
444}