uiua/algorithm/
path.rs

1use std::{cmp::Ordering, collections::*, mem::take};
2
3use crate::{Array, ArrayCmp, Boxed, Primitive, SigNode, Signature, Uiua, UiuaResult, Value};
4
5pub fn path(
6    neighbors: SigNode,
7    is_goal: SigNode,
8    heuristic: Option<SigNode>,
9    env: &mut Uiua,
10) -> UiuaResult {
11    path_impl(neighbors, is_goal, heuristic, PathMode::All, env)
12}
13
14pub fn path_first(
15    neighbors: SigNode,
16    is_goal: SigNode,
17    heuristic: Option<SigNode>,
18    env: &mut Uiua,
19) -> UiuaResult {
20    path_impl(neighbors, is_goal, heuristic, PathMode::First, env)
21}
22
23pub fn path_sign_len(
24    neighbors: SigNode,
25    is_goal: SigNode,
26    heuristic: Option<SigNode>,
27    env: &mut Uiua,
28) -> UiuaResult {
29    path_impl(neighbors, is_goal, heuristic, PathMode::Exists, env)
30}
31
32pub fn path_take(
33    neighbors: SigNode,
34    is_goal: SigNode,
35    heuristic: Option<SigNode>,
36    env: &mut Uiua,
37) -> UiuaResult {
38    let n = env
39        .pop("number of paths to take")?
40        .as_ints_or_infs(env, "Taken amount must be a list of integers or infinity")?;
41    if n.is_empty() {
42        path_impl(neighbors, is_goal, heuristic, PathMode::All, env)?;
43    } else {
44        match n.first() {
45            Some(Ok(n)) if *n >= 0 => path_impl(
46                neighbors,
47                is_goal,
48                heuristic,
49                PathMode::Take(*n as usize),
50                env,
51            )?,
52            _ => path_impl(neighbors, is_goal, heuristic, PathMode::All, env)?,
53        }
54    }
55    let path = env.pop("path")?.take_impl(&n, env)?;
56    env.push(path);
57    Ok(())
58}
59
60pub fn path_pop(
61    neighbors: SigNode,
62    is_goal: SigNode,
63    heuristic: Option<SigNode>,
64    env: &mut Uiua,
65) -> UiuaResult {
66    path_impl(neighbors, is_goal, heuristic, PathMode::CostOnly, env)
67}
68
69#[derive(Debug, Clone, Copy)]
70enum PathMode {
71    All,
72    First,
73    Exists,
74    CostOnly,
75    Take(usize),
76}
77
78fn path_impl(
79    neighbors: SigNode,
80    is_goal: SigNode,
81    heuristic: Option<SigNode>,
82    mode: PathMode,
83    env: &mut Uiua,
84) -> UiuaResult {
85    let start = env.pop("start")?;
86    let nei_sig = neighbors.sig;
87    let heu_sig = heuristic
88        .as_ref()
89        .map(|h| h.sig)
90        .unwrap_or_else(|| Signature::new(0, 1));
91    let isg_sig = is_goal.sig;
92    for (name, sig, req_out) in &[
93        ("neighbors", nei_sig, [1, 2].as_slice()),
94        ("goal", isg_sig, &[1]),
95        ("heuristic", heu_sig, &[1]),
96    ] {
97        if !req_out.contains(&sig.outputs()) {
98            let count = if req_out.len() == 1 {
99                "1"
100            } else {
101                "either 1 or 2"
102            };
103            return Err(env.error(format!(
104                "{} {name} function must return {count} outputs \
105                but its signature is {sig}",
106                Primitive::Path.format()
107            )));
108        }
109    }
110    let has_costs = nei_sig.outputs() == 2;
111    let arg_count = nei_sig
112        .args()
113        .max(heu_sig.args())
114        .max(isg_sig.args())
115        .saturating_sub(1);
116    let mut args = Vec::with_capacity(arg_count);
117    for i in 0..arg_count {
118        args.push(env.pop(i + 1)?);
119    }
120
121    struct NodeCost {
122        node: usize,
123        cost: f64,
124    }
125    impl PartialEq for NodeCost {
126        fn eq(&self, other: &Self) -> bool {
127            self.cost.array_eq(&other.cost)
128        }
129    }
130    impl Eq for NodeCost {}
131    impl PartialOrd for NodeCost {
132        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
133            Some(self.cmp(other))
134        }
135    }
136    impl Ord for NodeCost {
137        fn cmp(&self, other: &Self) -> Ordering {
138            self.cost.array_cmp(&other.cost).reverse()
139        }
140    }
141
142    struct PathEnv<'a> {
143        env: &'a mut Uiua,
144        neighbors: SigNode,
145        is_goal: SigNode,
146        heuristic: Option<SigNode>,
147        args: Vec<Value>,
148    }
149
150    impl PathEnv<'_> {
151        fn heuristic(&mut self, node: &Value) -> UiuaResult<f64> {
152            Ok(if let Some(heuristic) = &self.heuristic {
153                let heu_args = heuristic.sig.args();
154                for arg in (self.args.iter()).take(heu_args.saturating_sub(1)).rev() {
155                    self.env.push(arg.clone());
156                }
157                if heu_args > 0 {
158                    self.env.push(node.clone());
159                }
160                self.env.exec(heuristic.clone())?;
161                let h = (self.env)
162                    .pop("heuristic")?
163                    .as_num(self.env, "Heuristic must be a number")?;
164                if h < 0.0 {
165                    return Err(self
166                        .env
167                        .error("Negative heuristic values are not allowed in A*"));
168                }
169                h
170            } else {
171                0.0
172            })
173        }
174        fn neighbors(&mut self, node: &Value) -> UiuaResult<Vec<(Value, f64)>> {
175            let nei_args = self.neighbors.sig.args();
176            for arg in (self.args.iter()).take(nei_args.saturating_sub(1)).rev() {
177                self.env.push(arg.clone());
178            }
179            if nei_args > 0 {
180                self.env.push(node.clone());
181            }
182            self.env.exec(self.neighbors.clone())?;
183            let (nodes, costs) = if self.neighbors.sig.outputs() == 2 {
184                let costs = self.env.pop("neighbors costs")?;
185                let costs_rank = costs.rank();
186                let mut costs = costs
187                    .as_nums(self.env, "Costs must be a number or list of numbers")?
188                    .into_owned();
189                let nodes = self.env.pop("neighbors nodes")?;
190                if costs.len() != nodes.row_count() {
191                    if costs_rank == 0 {
192                        costs.resize(nodes.row_count(), costs[0]);
193                    } else {
194                        return Err(self.env.error(format!(
195                            "Number of nodes {} does not match number of costs {}",
196                            nodes.row_count(),
197                            costs.len(),
198                        )));
199                    }
200                }
201                if costs.iter().any(|&c| c < 0.0) {
202                    return Err(self.env.error("Negative costs are not allowed in A*"));
203                }
204                (nodes, costs)
205            } else {
206                let nodes = self.env.pop("neighbors nodes")?;
207                let costs = vec![1.0; nodes.row_count()];
208                (nodes, costs)
209            };
210            Ok(nodes.into_rows().zip(costs).collect())
211        }
212        fn is_goal(&mut self, node: &Value) -> UiuaResult<bool> {
213            let isg_args = self.is_goal.sig.args();
214            for arg in (self.args.iter()).take(isg_args.saturating_sub(1)).rev() {
215                self.env.push(arg.clone());
216            }
217            if isg_args > 0 {
218                self.env.push(node.clone());
219            }
220            self.env.exec(self.is_goal.clone())?;
221            let is_goal = (self.env.pop("is_goal")?)
222                .as_bool(self.env, "path goal function must return a boolean")?;
223            Ok(is_goal)
224        }
225    }
226
227    let mut if_empty = start.clone();
228    if_empty.fix();
229    if_empty = if_empty.first_dim_zero();
230    if_empty.shape.prepend(0);
231
232    // Initialize state
233    let mut to_see = BinaryHeap::new();
234    let mut backing = vec![start.clone()];
235    let mut indices: HashMap<Value, usize> = [(start, 0)].into();
236    to_see.push(NodeCost { node: 0, cost: 0.0 });
237
238    let mut came_from: HashMap<usize, BTreeSet<usize>> = HashMap::new();
239    let mut full_cost: HashMap<usize, f64> = [(0, 0.0)].into();
240
241    let mut shortest_cost = f64::INFINITY;
242    let mut ends = BTreeSet::new();
243
244    fn count_paths(ends: &BTreeSet<usize>, came_from: &HashMap<usize, BTreeSet<usize>>) -> usize {
245        let mut queue = VecDeque::new();
246        let mut count = 0;
247        for &end in ends {
248            queue.clear();
249            queue.push_back(end);
250            while let Some(curr) = queue.pop_front() {
251                if let Some(parents) = came_from.get(&curr) {
252                    for &parent in parents {
253                        queue.push_back(parent);
254                    }
255                } else {
256                    count += 1;
257                }
258            }
259        }
260        count
261    }
262
263    // Main pathing loop
264    env.without_fill(|env| -> UiuaResult {
265        let mut env = PathEnv {
266            env,
267            neighbors,
268            heuristic,
269            is_goal,
270            args,
271        };
272
273        'outer: while let Some(NodeCost { node: curr, .. }) = to_see.pop() {
274            env.env.respect_execution_limit()?;
275            let curr_cost = full_cost[&curr];
276            // Early exit if found a shorter path
277            if curr_cost > shortest_cost || ends.contains(&curr) {
278                continue;
279            }
280            // Check if reached a goal
281            if env.is_goal(&backing[curr])? {
282                ends.insert(curr);
283                shortest_cost = curr_cost;
284                match mode {
285                    PathMode::All => continue,
286                    PathMode::Take(n) if n <= 1 => break,
287                    PathMode::Take(n) if count_paths(&ends, &came_from) >= n => break,
288                    _ => break,
289                }
290            }
291            // Check neighbors
292            for (nei, nei_cost) in env.neighbors(&backing[curr])? {
293                // Add to backing if needed
294                let nei = if let Some(index) = indices.get(&nei) {
295                    *index
296                } else {
297                    let index = backing.len();
298                    indices.insert(nei.clone(), index);
299                    backing.push(nei);
300                    index
301                };
302                let from_curr_nei_cost = curr_cost + nei_cost;
303                let curr_nei_cost = full_cost.get(&nei).copied().unwrap_or(f64::INFINITY);
304                if from_curr_nei_cost <= curr_nei_cost {
305                    if let PathMode::Take(n) = mode {
306                        if ends.contains(&nei) && count_paths(&ends, &came_from) >= n {
307                            break 'outer;
308                        }
309                    }
310                    let parents = came_from.entry(nei).or_default();
311                    // If a better path was found we...
312                    if from_curr_nei_cost < curr_nei_cost {
313                        // 1. Clear the parents
314                        parents.clear();
315                        // 2. Update the known cost
316                        full_cost.insert(nei, from_curr_nei_cost);
317                        // 3. Add to to see
318                        to_see.push(NodeCost {
319                            cost: from_curr_nei_cost + env.heuristic(&backing[nei])?,
320                            node: nei,
321                        });
322                    }
323                    parents.insert(curr);
324                }
325            }
326        }
327        Ok(())
328    })?;
329
330    if has_costs {
331        env.push(shortest_cost);
332    }
333
334    let make_path = |path: Vec<usize>| {
335        if let Some(&[a, b]) = path
336            .windows(2)
337            .find(|w| backing[w[0]].shape != backing[w[1]].shape)
338        {
339            return Err(env.error(format!(
340                "Cannot make path from nodes with incompatible shapes {} and {}",
341                backing[a].shape, backing[b].shape
342            )));
343        }
344        let path: Vec<_> = path.into_iter().map(|i| backing[i].clone()).collect();
345        Value::from_row_values(path, env)
346    };
347
348    match mode {
349        PathMode::All | PathMode::Take(_) => {
350            let mut paths = Vec::new();
351            'outer: for end in ends {
352                let mut currs = vec![vec![end]];
353                let mut these_paths = Vec::new();
354                while !currs.is_empty() {
355                    env.respect_execution_limit()?;
356                    let mut new_paths = Vec::new();
357                    currs.retain_mut(|path| {
358                        let parents = came_from.get(path.last().unwrap());
359                        match parents.map(|p| p.len()).unwrap_or(0) {
360                            0 => {
361                                these_paths.push(take(path));
362                                false
363                            }
364                            1 => {
365                                path.push(*parents.unwrap().iter().next().unwrap());
366                                true
367                            }
368                            _ => {
369                                for &parent in parents.unwrap().iter().skip(1) {
370                                    let mut path = path.clone();
371                                    path.push(parent);
372                                    new_paths.push(path);
373                                }
374                                path.push(*parents.unwrap().iter().next().unwrap());
375                                true
376                            }
377                        }
378                    });
379                    currs.extend(new_paths);
380                }
381                for mut path in these_paths {
382                    path.reverse();
383                    let path_val = make_path(path)?;
384                    paths.push(if has_costs {
385                        Boxed(path_val).into()
386                    } else {
387                        path_val
388                    });
389                    if let PathMode::Take(n) = mode {
390                        if paths.len() >= n {
391                            break 'outer;
392                        }
393                    }
394                }
395            }
396            let path_count = paths.len();
397            let mut paths_val = Value::from_row_values(paths, env)?;
398            if path_count == 0 {
399                paths_val = if has_costs {
400                    Array::<Boxed>::default().into()
401                } else {
402                    if_empty
403                }
404            } else if let PathMode::Take(0) = mode {
405                if paths_val.row_count() > 0 {
406                    paths_val.drop_n(1);
407                }
408            }
409            env.push(paths_val);
410        }
411        PathMode::First => {
412            if let Some(mut curr) = ends.into_iter().next() {
413                let mut path = vec![curr];
414                while let Some(from) = came_from.get(&curr) {
415                    let from = *from.iter().next().unwrap();
416                    path.push(from);
417                    curr = from;
418                }
419                path.reverse();
420                let path_val = make_path(path)?;
421                env.push(if has_costs {
422                    Boxed(path_val).into()
423                } else {
424                    path_val
425                });
426            } else if let Some(val) = env.value_fill().map(|fv| &fv.value) {
427                if val.rank() == 0 {
428                    env.push(if has_costs {
429                        Array::<Boxed>::default().into()
430                    } else {
431                        Value::default()
432                    });
433                } else {
434                    return Err(env.error("No path found. A fill is set, but it is not scalar."));
435                }
436            } else {
437                return Err(env.error("No path found"));
438            }
439        }
440        PathMode::Exists => env.push(!ends.is_empty()),
441        PathMode::CostOnly => {}
442    }
443    Ok(())
444}