wdl_engine/
eval.rs

1//! Module for evaluation.
2
3use std::borrow::Cow;
4use std::collections::BTreeMap;
5use std::collections::HashMap;
6use std::fs;
7use std::io::BufRead;
8use std::path::Component;
9use std::path::Path;
10use std::path::PathBuf;
11
12use anyhow::Context;
13use anyhow::Result;
14use anyhow::bail;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use rev_buf_reader::RevBufReader;
18use wdl_analysis::Document;
19use wdl_analysis::document::Task;
20use wdl_analysis::types::Type;
21use wdl_ast::Diagnostic;
22use wdl_ast::Span;
23use wdl_ast::SupportedVersion;
24use wdl_ast::v1::TASK_REQUIREMENT_RETURN_CODES;
25use wdl_ast::v1::TASK_REQUIREMENT_RETURN_CODES_ALIAS;
26
27use crate::CompoundValue;
28use crate::Outputs;
29use crate::PrimitiveValue;
30use crate::TaskExecutionResult;
31use crate::Value;
32use crate::http::Downloader;
33use crate::http::Location;
34use crate::path::EvaluationPath;
35use crate::stdlib::download_file;
36
37pub mod v1;
38
39/// The maximum number of stderr lines to display in error messages.
40const MAX_STDERR_LINES: usize = 10;
41
42/// Represents the location of a call in an evaluation error.
43#[derive(Debug, Clone)]
44pub struct CallLocation {
45    /// The document containing the call statement.
46    pub document: Document,
47    /// The span of the call statement.
48    pub span: Span,
49}
50
51/// Represents an error that originates from WDL source.
52#[derive(Debug)]
53pub struct SourceError {
54    /// The document originating the diagnostic.
55    pub document: Document,
56    /// The evaluation diagnostic.
57    pub diagnostic: Diagnostic,
58    /// The call backtrace for the error.
59    ///
60    /// An empty backtrace denotes that the error was encountered outside of
61    /// a call.
62    ///
63    /// The call locations are stored as most recent to least recent.
64    pub backtrace: Vec<CallLocation>,
65}
66
67/// Represents an error that may occur when evaluating a workflow or task.
68#[derive(Debug)]
69pub enum EvaluationError {
70    /// The error came from WDL source evaluation.
71    Source(Box<SourceError>),
72    /// The error came from another source.
73    Other(anyhow::Error),
74}
75
76impl EvaluationError {
77    /// Creates a new evaluation error from the given document and diagnostic.
78    pub fn new(document: Document, diagnostic: Diagnostic) -> Self {
79        Self::Source(Box::new(SourceError {
80            document,
81            diagnostic,
82            backtrace: Default::default(),
83        }))
84    }
85
86    /// Helper for tests for converting an evaluation error to a string.
87    #[cfg(feature = "codespan-reporting")]
88    #[allow(clippy::inherent_to_string)]
89    pub fn to_string(&self) -> String {
90        use codespan_reporting::diagnostic::Label;
91        use codespan_reporting::diagnostic::LabelStyle;
92        use codespan_reporting::files::SimpleFiles;
93        use codespan_reporting::term::Config;
94        use codespan_reporting::term::termcolor::Buffer;
95        use codespan_reporting::term::{self};
96        use wdl_ast::AstNode;
97
98        match self {
99            Self::Source(e) => {
100                let mut files = SimpleFiles::new();
101                let mut map = HashMap::new();
102
103                let file_id = files.add(e.document.path(), e.document.root().text().to_string());
104
105                let diagnostic =
106                    e.diagnostic
107                        .to_codespan(file_id)
108                        .with_labels_iter(e.backtrace.iter().map(|l| {
109                            let id = l.document.id();
110                            let file_id = *map.entry(id).or_insert_with(|| {
111                                files.add(l.document.path(), l.document.root().text().to_string())
112                            });
113
114                            Label {
115                                style: LabelStyle::Secondary,
116                                file_id,
117                                range: l.span.start()..l.span.end(),
118                                message: "called from this location".into(),
119                            }
120                        }));
121
122                let mut buffer = Buffer::no_color();
123                term::emit(&mut buffer, &Config::default(), &files, &diagnostic)
124                    .expect("failed to emit diagnostic");
125
126                String::from_utf8(buffer.into_inner()).expect("should be UTF-8")
127            }
128            Self::Other(e) => format!("{e:?}"),
129        }
130    }
131}
132
133impl From<anyhow::Error> for EvaluationError {
134    fn from(e: anyhow::Error) -> Self {
135        Self::Other(e)
136    }
137}
138
139/// Represents a result from evaluating a workflow or task.
140pub type EvaluationResult<T> = Result<T, EvaluationError>;
141
142/// Represents context to an expression evaluator.
143pub trait EvaluationContext: Send + Sync {
144    /// Gets the supported version of the document being evaluated.
145    fn version(&self) -> SupportedVersion;
146
147    /// Gets the value of the given name in scope.
148    fn resolve_name(&self, name: &str, span: Span) -> Result<Value, Diagnostic>;
149
150    /// Resolves a type name to a type.
151    fn resolve_type_name(&self, name: &str, span: Span) -> Result<Type, Diagnostic>;
152
153    /// Gets the working directory for the evaluation.
154    ///
155    /// Returns `None` if the task execution hasn't occurred yet.
156    fn work_dir(&self) -> Option<&EvaluationPath>;
157
158    /// Gets the temp directory for the evaluation.
159    fn temp_dir(&self) -> &Path;
160
161    /// Gets the value to return for a call to the `stdout` function.
162    ///
163    /// This is `Some` only when evaluating task outputs.
164    fn stdout(&self) -> Option<&Value>;
165
166    /// Gets the value to return for a call to the `stderr` function.
167    ///
168    /// This is `Some` only when evaluating task outputs.
169    fn stderr(&self) -> Option<&Value>;
170
171    /// Gets the task associated with the evaluation context.
172    ///
173    /// This is only `Some` when evaluating task hints sections.
174    fn task(&self) -> Option<&Task>;
175
176    /// Translates a host path to a guest path.
177    ///
178    /// Returns `None` if no translation is available.
179    fn translate_path(&self, path: &str) -> Option<Cow<'_, Path>>;
180
181    /// Gets the downloader to use for evaluating expressions.
182    fn downloader(&self) -> &dyn Downloader;
183}
184
185/// Represents an index of a scope in a collection of scopes.
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
187pub struct ScopeIndex(usize);
188
189impl ScopeIndex {
190    /// Constructs a new scope index from a raw index.
191    pub const fn new(index: usize) -> Self {
192        Self(index)
193    }
194}
195
196impl From<usize> for ScopeIndex {
197    fn from(index: usize) -> Self {
198        Self(index)
199    }
200}
201
202impl From<ScopeIndex> for usize {
203    fn from(index: ScopeIndex) -> Self {
204        index.0
205    }
206}
207
208/// Represents an evaluation scope in a WDL document.
209#[derive(Default, Debug)]
210pub struct Scope {
211    /// The index of the parent scope.
212    ///
213    /// This is `None` for the root scopes.
214    parent: Option<ScopeIndex>,
215    /// The map of names in scope to their values.
216    names: IndexMap<String, Value>,
217}
218
219impl Scope {
220    /// Creates a new scope given the parent scope.
221    pub fn new(parent: ScopeIndex) -> Self {
222        Self {
223            parent: Some(parent),
224            names: Default::default(),
225        }
226    }
227
228    /// Inserts a name into the scope.
229    pub fn insert(&mut self, name: impl Into<String>, value: impl Into<Value>) {
230        let prev = self.names.insert(name.into(), value.into());
231        assert!(prev.is_none(), "conflicting name in scope");
232    }
233
234    /// Iterates over the local names and values in the scope.
235    pub fn local(&self) -> impl Iterator<Item = (&str, &Value)> + use<'_> {
236        self.names.iter().map(|(k, v)| (k.as_str(), v))
237    }
238
239    /// Gets a mutable reference to an existing name in scope.
240    pub(crate) fn get_mut(&mut self, name: &str) -> Option<&mut Value> {
241        self.names.get_mut(name)
242    }
243
244    /// Clears the scope.
245    pub(crate) fn clear(&mut self) {
246        self.parent = None;
247        self.names.clear();
248    }
249
250    /// Sets the scope's parent.
251    pub(crate) fn set_parent(&mut self, parent: ScopeIndex) {
252        self.parent = Some(parent);
253    }
254}
255
256impl From<Scope> for IndexMap<String, Value> {
257    fn from(scope: Scope) -> Self {
258        scope.names
259    }
260}
261
262/// Represents a reference to a scope.
263#[derive(Debug, Clone, Copy)]
264pub struct ScopeRef<'a> {
265    /// The reference to the scopes collection.
266    scopes: &'a [Scope],
267    /// The index of the scope in the collection.
268    index: ScopeIndex,
269}
270
271impl<'a> ScopeRef<'a> {
272    /// Creates a new scope reference given the scope index.
273    pub fn new(scopes: &'a [Scope], index: impl Into<ScopeIndex>) -> Self {
274        Self {
275            scopes,
276            index: index.into(),
277        }
278    }
279
280    /// Gets the parent scope.
281    ///
282    /// Returns `None` if there is no parent scope.
283    pub fn parent(&self) -> Option<Self> {
284        self.scopes[self.index.0].parent.map(|p| Self {
285            scopes: self.scopes,
286            index: p,
287        })
288    }
289
290    /// Gets all of the name and values available at this scope.
291    pub fn names(&self) -> impl Iterator<Item = (&str, &Value)> + use<'_> {
292        self.scopes[self.index.0]
293            .names
294            .iter()
295            .map(|(n, name)| (n.as_str(), name))
296    }
297
298    /// Iterates over each name and value visible to the scope and calls the
299    /// provided callback.
300    ///
301    /// Stops iterating and returns an error if the callback returns an error.
302    pub fn for_each(&self, mut cb: impl FnMut(&str, &Value) -> Result<()>) -> Result<()> {
303        let mut current = Some(self.index);
304
305        while let Some(index) = current {
306            for (n, v) in self.scopes[index.0].local() {
307                cb(n, v)?;
308            }
309
310            current = self.scopes[index.0].parent;
311        }
312
313        Ok(())
314    }
315
316    /// Gets the value of a name local to this scope.
317    ///
318    /// Returns `None` if a name local to this scope was not found.
319    pub fn local(&self, name: &str) -> Option<&Value> {
320        self.scopes[self.index.0].names.get(name)
321    }
322
323    /// Lookups a name in the scope.
324    ///
325    /// Returns `None` if the name is not available in the scope.
326    pub fn lookup(&self, name: &str) -> Option<&Value> {
327        let mut current = Some(self.index);
328
329        while let Some(index) = current {
330            if let Some(name) = self.scopes[index.0].names.get(name) {
331                return Some(name);
332            }
333
334            current = self.scopes[index.0].parent;
335        }
336
337        None
338    }
339}
340
341/// Represents an evaluated task.
342#[derive(Debug)]
343pub struct EvaluatedTask {
344    /// The task attempt directory.
345    attempt_dir: PathBuf,
346    /// The task execution result.
347    result: TaskExecutionResult,
348    /// The evaluated outputs of the task.
349    ///
350    /// This is `Ok` when the task executes successfully and all of the task's
351    /// outputs evaluated without error.
352    ///
353    /// Otherwise, this contains the error that occurred while attempting to
354    /// evaluate the task's outputs.
355    outputs: EvaluationResult<Outputs>,
356}
357
358impl EvaluatedTask {
359    /// Constructs a new evaluated task.
360    ///
361    /// Returns an error if the stdout or stderr paths are not UTF-8.
362    fn new(attempt_dir: PathBuf, result: TaskExecutionResult) -> anyhow::Result<Self> {
363        Ok(Self {
364            result,
365            attempt_dir,
366            outputs: Ok(Default::default()),
367        })
368    }
369
370    /// Gets the exit code of the evaluated task.
371    pub fn exit_code(&self) -> i32 {
372        self.result.exit_code
373    }
374
375    /// Gets the attempt directory of the task.
376    pub fn attempt_dir(&self) -> &Path {
377        &self.attempt_dir
378    }
379
380    /// Gets the inputs that were given to the task.
381    pub fn inputs(&self) -> &[Input] {
382        &self.result.inputs
383    }
384
385    /// Gets the working directory of the evaluated task.
386    pub fn work_dir(&self) -> &EvaluationPath {
387        &self.result.work_dir
388    }
389
390    /// Gets the stdout value of the evaluated task.
391    pub fn stdout(&self) -> &Value {
392        &self.result.stdout
393    }
394
395    /// Gets the stderr value of the evaluated task.
396    pub fn stderr(&self) -> &Value {
397        &self.result.stderr
398    }
399
400    /// Gets the outputs of the evaluated task.
401    ///
402    /// This is `Ok` when the task executes successfully and all of the task's
403    /// outputs evaluated without error.
404    ///
405    /// Otherwise, this contains the error that occurred while attempting to
406    /// evaluate the task's outputs.
407    pub fn outputs(&self) -> &EvaluationResult<Outputs> {
408        &self.outputs
409    }
410
411    /// Converts the evaluated task into an evaluation result.
412    ///
413    /// Returns `Ok(_)` if the task outputs were evaluated.
414    ///
415    /// Returns `Err(_)` if the task outputs could not be evaluated.
416    pub fn into_result(self) -> EvaluationResult<Outputs> {
417        self.outputs
418    }
419
420    /// Handles the exit of a task execution.
421    ///
422    /// Returns an error if the task failed.
423    async fn handle_exit(
424        &self,
425        requirements: &HashMap<String, Value>,
426        downloader: &dyn Downloader,
427    ) -> anyhow::Result<()> {
428        let mut error = true;
429        if let Some(return_codes) = requirements
430            .get(TASK_REQUIREMENT_RETURN_CODES)
431            .or_else(|| requirements.get(TASK_REQUIREMENT_RETURN_CODES_ALIAS))
432        {
433            match return_codes {
434                Value::Primitive(PrimitiveValue::String(s)) if s.as_ref() == "*" => {
435                    error = false;
436                }
437                Value::Primitive(PrimitiveValue::String(s)) => {
438                    bail!(
439                        "invalid return code value `{s}`: only `*` is accepted when the return \
440                         code is specified as a string"
441                    );
442                }
443                Value::Primitive(PrimitiveValue::Integer(ok)) => {
444                    if self.result.exit_code == i32::try_from(*ok).unwrap_or_default() {
445                        error = false;
446                    }
447                }
448                Value::Compound(CompoundValue::Array(codes)) => {
449                    error = !codes.as_slice().iter().any(|v| {
450                        v.as_integer()
451                            .map(|i| i32::try_from(i).unwrap_or_default() == self.result.exit_code)
452                            .unwrap_or(false)
453                    });
454                }
455                _ => unreachable!("unexpected return codes value"),
456            }
457        } else {
458            error = self.result.exit_code != 0;
459        }
460
461        if error {
462            // Read the last `MAX_STDERR_LINES` number of lines from stderr
463            // If there's a problem reading stderr, don't output it
464            let stderr = download_file(downloader, None, self.stderr().as_file().unwrap())
465                .await
466                .ok()
467                .and_then(|l| {
468                    fs::File::open(l).ok().map(|f| {
469                        // Buffer the last N number of lines
470                        let reader = RevBufReader::new(f);
471                        let lines: Vec<_> = reader
472                            .lines()
473                            .take(MAX_STDERR_LINES)
474                            .map_while(|l| l.ok())
475                            .collect();
476
477                        // Iterate the lines in reverse order as we read them in reverse
478                        lines
479                            .iter()
480                            .rev()
481                            .format_with("\n", |l, f| f(&format_args!("  {l}")))
482                            .to_string()
483                    })
484                })
485                .unwrap_or_default();
486
487            // If the work directory is remote,
488            bail!(
489                "process terminated with exit code {code}: see `{stdout_path}` and \
490                 `{stderr_path}` for task output and the related files in \
491                 `{dir}`{header}{stderr}{trailer}",
492                code = self.result.exit_code,
493                dir = self.attempt_dir().display(),
494                stdout_path = self.stdout().as_file().expect("must be file"),
495                stderr_path = self.stderr().as_file().expect("must be file"),
496                header = if stderr.is_empty() {
497                    Cow::Borrowed("")
498                } else {
499                    format!("\n\ntask stderr output (last {MAX_STDERR_LINES} lines):\n\n").into()
500                },
501                trailer = if stderr.is_empty() { "" } else { "\n" }
502            );
503        }
504
505        Ok(())
506    }
507}
508
509/// Gets the kind of an input.
510#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
511pub enum InputKind {
512    /// The input is a single file.
513    File,
514    /// The input is a directory.
515    Directory,
516}
517
518impl From<InputKind> for crankshaft::engine::task::input::Type {
519    fn from(value: InputKind) -> Self {
520        match value {
521            InputKind::File => Self::File,
522            InputKind::Directory => Self::Directory,
523        }
524    }
525}
526
527/// Represents a `File` or `Directory` input to a task.
528#[derive(Debug, Clone)]
529pub struct Input {
530    /// The input kind.
531    kind: InputKind,
532    /// The path for the input.
533    path: EvaluationPath,
534    /// The download location for the input.
535    ///
536    /// This is `Some` if the input has been downloaded to a known location.
537    location: Option<Location<'static>>,
538    /// The guest path for the input.
539    guest_path: Option<String>,
540}
541
542impl Input {
543    /// Creates a new input with the given path and access.
544    pub fn new(kind: InputKind, path: EvaluationPath) -> Self {
545        Self {
546            kind,
547            path,
548            location: None,
549            guest_path: None,
550        }
551    }
552
553    /// Creates an input from a primitive value.
554    pub fn from_primitive(value: &PrimitiveValue) -> Result<Self> {
555        let (kind, path) = match value {
556            PrimitiveValue::File(path) => (InputKind::File, path),
557            PrimitiveValue::Directory(path) => (InputKind::Directory, path),
558            _ => bail!("value is not a `File` or `Directory`"),
559        };
560
561        Ok(Self {
562            kind,
563            path: path.parse()?,
564            location: None,
565            guest_path: None,
566        })
567    }
568
569    /// Gets the kind of the input.
570    pub fn kind(&self) -> InputKind {
571        self.kind
572    }
573
574    /// Gets the path to the input.
575    pub fn path(&self) -> &EvaluationPath {
576        &self.path
577    }
578
579    /// Gets the location of the input if it has been downloaded.
580    pub fn location(&self) -> Option<&Path> {
581        self.location.as_deref()
582    }
583
584    /// Sets the location of the input.
585    pub fn set_location(&mut self, location: Location<'static>) {
586        self.location = Some(location);
587    }
588
589    /// Gets the guest path for the input.
590    pub fn guest_path(&self) -> Option<&str> {
591        self.guest_path.as_deref()
592    }
593
594    /// Sets the guest path for the input.
595    pub fn set_guest_path(&mut self, path: impl Into<String>) {
596        self.guest_path = Some(path.into());
597    }
598}
599
600/// Represents a node in an input trie.
601#[derive(Debug)]
602struct InputTrieNode<'a> {
603    /// The children of this node.
604    ///
605    /// A `BTreeMap` is used here to get a consistent walk of the tree.
606    children: BTreeMap<&'a str, Self>,
607    /// The identifier of the node in the trie.
608    ///
609    /// A node's identifier is used when formatting guest paths of children.
610    id: usize,
611    /// The input represented by this node.
612    ///
613    /// This is `Some` only for terminal nodes in the trie.
614    ///
615    /// The first element in the tuple is the index of the input.
616    input: Option<(usize, &'a Input)>,
617}
618
619impl InputTrieNode<'_> {
620    /// Constructs a new input trie node with the given component.
621    fn new(id: usize) -> Self {
622        Self {
623            children: Default::default(),
624            id,
625            input: None,
626        }
627    }
628
629    /// Calculates the guest path for all terminal nodes in the trie.
630    fn calculate_guest_paths(
631        &self,
632        root: &str,
633        parent_id: usize,
634        paths: &mut Vec<(usize, String)>,
635    ) -> Result<()> {
636        // Invoke the callback for any terminal node in the trie
637        if let Some((index, input)) = self.input {
638            let file_name = input.path.file_name()?.unwrap_or("");
639
640            // If the file name is empty, it means this is a root URL
641            let guest_path = if file_name.is_empty() {
642                format!(
643                    "{root}{sep}{parent_id}/.root",
644                    root = root,
645                    sep = if root.as_bytes().last() == Some(&b'/') {
646                        ""
647                    } else {
648                        "/"
649                    }
650                )
651            } else {
652                format!(
653                    "{root}{sep}{parent_id}/{file_name}",
654                    root = root,
655                    sep = if root.as_bytes().last() == Some(&b'/') {
656                        ""
657                    } else {
658                        "/"
659                    },
660                )
661            };
662
663            paths.push((index, guest_path));
664        }
665
666        // Traverse into the children
667        for child in self.children.values() {
668            child.calculate_guest_paths(root, self.id, paths)?;
669        }
670
671        Ok(())
672    }
673}
674
675/// Represents a prefix trie based on input paths.
676///
677/// This is used to determine guest paths for inputs.
678///
679/// From the root to a terminal node represents a unique input.
680#[derive(Debug)]
681pub struct InputTrie<'a> {
682    /// The URL path children of the tree.
683    ///
684    /// The key in the map is the scheme of each URL.
685    ///
686    /// A `BTreeMap` is used here to get a consistent walk of the tree.
687    urls: BTreeMap<&'a str, InputTrieNode<'a>>,
688    /// The local path children of the tree.
689    ///
690    /// The key in the map is the first component of each path.
691    ///
692    /// A `BTreeMap` is used here to get a consistent walk of the tree.
693    paths: BTreeMap<&'a str, InputTrieNode<'a>>,
694    /// The next node identifier.
695    next_id: usize,
696    /// The number of inputs in the trie.
697    count: usize,
698}
699
700impl<'a> InputTrie<'a> {
701    /// Inserts a new input into the trie.
702    pub fn insert(&mut self, input: &'a Input) -> Result<()> {
703        let node = match &input.path {
704            EvaluationPath::Local(path) => {
705                // Don't both inserting anything into the trie for relative paths
706                // We still consider the input part of the trie, but it will never have a guest
707                // path
708                if path.is_relative() {
709                    self.count += 1;
710                    return Ok(());
711                }
712
713                let mut components = path.components();
714
715                let component = components
716                    .next()
717                    .context("input path cannot be empty")?
718                    .as_os_str()
719                    .to_str()
720                    .with_context(|| {
721                        format!("input path `{path}` is not UTF-8", path = path.display())
722                    })?;
723                let mut node = self.paths.entry(component).or_insert_with(|| {
724                    let node = InputTrieNode::new(self.next_id);
725                    self.next_id += 1;
726                    node
727                });
728
729                for component in components {
730                    match component {
731                        Component::CurDir | Component::ParentDir => {
732                            bail!(
733                                "input path `{path}` may not contain `.` or `..`",
734                                path = path.display()
735                            );
736                        }
737                        _ => {}
738                    }
739
740                    let component = component.as_os_str().to_str().with_context(|| {
741                        format!("input path `{path}` is not UTF-8", path = path.display())
742                    })?;
743                    node = node.children.entry(component).or_insert_with(|| {
744                        let node = InputTrieNode::new(self.next_id);
745                        self.next_id += 1;
746                        node
747                    });
748                }
749
750                node
751            }
752            EvaluationPath::Remote(url) => {
753                // Insert for scheme
754                let mut node = self.urls.entry(url.scheme()).or_insert_with(|| {
755                    let node = InputTrieNode::new(self.next_id);
756                    self.next_id += 1;
757                    node
758                });
759
760                // Insert the authority
761                node = node.children.entry(url.authority()).or_insert_with(|| {
762                    let node = InputTrieNode::new(self.next_id);
763                    self.next_id += 1;
764                    node
765                });
766
767                // Insert the path segments
768                if let Some(segments) = url.path_segments() {
769                    for segment in segments {
770                        node = node.children.entry(segment).or_insert_with(|| {
771                            let node = InputTrieNode::new(self.next_id);
772                            self.next_id += 1;
773                            node
774                        });
775                    }
776                }
777
778                // Ignore query parameters and fragments
779                node
780            }
781        };
782
783        node.input = Some((self.count, input));
784        self.count += 1;
785        Ok(())
786    }
787
788    /// Calculates guest paths for the inputs in the trie.
789    ///
790    /// Returns a collection of input insertion index paired with the calculated
791    /// guest path.
792    pub fn calculate_guest_paths(&self, root: &str) -> Result<Vec<(usize, String)>> {
793        let mut paths = Vec::with_capacity(self.count);
794        for child in self.urls.values() {
795            child.calculate_guest_paths(root, 0, &mut paths)?;
796        }
797
798        for child in self.paths.values() {
799            child.calculate_guest_paths(root, 0, &mut paths)?;
800        }
801
802        Ok(paths)
803    }
804}
805
806impl Default for InputTrie<'_> {
807    fn default() -> Self {
808        Self {
809            urls: Default::default(),
810            paths: Default::default(),
811            // The first id starts at 1 as 0 is considered the "virtual root" of the trie
812            next_id: 1,
813            count: 0,
814        }
815    }
816}
817
818#[cfg(test)]
819mod test {
820    use pretty_assertions::assert_eq;
821
822    use super::*;
823
824    #[test]
825    fn empty_trie() {
826        let empty = InputTrie::default();
827        let paths = empty.calculate_guest_paths("/mnt/").unwrap();
828        assert!(paths.is_empty());
829    }
830
831    #[cfg(unix)]
832    #[test]
833    fn non_empty_trie_unix() {
834        let mut trie = InputTrie::default();
835        let inputs = [
836            Input::new(InputKind::Directory, "/".parse().unwrap()),
837            Input::new(InputKind::File, "/foo/bar/foo.txt".parse().unwrap()),
838            Input::new(InputKind::File, "/foo/bar/bar.txt".parse().unwrap()),
839            Input::new(InputKind::File, "/foo/baz/foo.txt".parse().unwrap()),
840            Input::new(InputKind::File, "/foo/baz/bar.txt".parse().unwrap()),
841            Input::new(InputKind::File, "/bar/foo/foo.txt".parse().unwrap()),
842            Input::new(InputKind::File, "/bar/foo/bar.txt".parse().unwrap()),
843            Input::new(InputKind::Directory, "/baz".parse().unwrap()),
844            Input::new(InputKind::File, "https://example.com/".parse().unwrap()),
845            Input::new(
846                InputKind::File,
847                "https://example.com/foo/bar/foo.txt".parse().unwrap(),
848            ),
849            Input::new(
850                InputKind::File,
851                "https://example.com/foo/bar/bar.txt".parse().unwrap(),
852            ),
853            Input::new(
854                InputKind::File,
855                "https://example.com/foo/baz/foo.txt".parse().unwrap(),
856            ),
857            Input::new(
858                InputKind::File,
859                "https://example.com/foo/baz/bar.txt".parse().unwrap(),
860            ),
861            Input::new(
862                InputKind::File,
863                "https://example.com/bar/foo/foo.txt".parse().unwrap(),
864            ),
865            Input::new(
866                InputKind::File,
867                "https://example.com/bar/foo/bar.txt".parse().unwrap(),
868            ),
869            Input::new(InputKind::File, "https://foo.com/bar".parse().unwrap()),
870        ];
871
872        for input in &inputs {
873            trie.insert(input).unwrap();
874        }
875
876        // The important part of the guest paths are:
877        // 1) The guest file name should be the same (or `.root` if the path is
878        //    considered to be root)
879        // 2) Paths with the same parent should have the same guest parent
880        let paths = trie.calculate_guest_paths("/mnt/").unwrap();
881        let paths: Vec<_> = paths
882            .iter()
883            .map(|(index, guest)| (inputs[*index].path().to_str().unwrap(), guest.as_str()))
884            .collect();
885
886        assert_eq!(
887            paths,
888            [
889                ("https://example.com/", "/mnt/15/.root"),
890                ("https://example.com/bar/foo/bar.txt", "/mnt/25/bar.txt"),
891                ("https://example.com/bar/foo/foo.txt", "/mnt/25/foo.txt"),
892                ("https://example.com/foo/bar/bar.txt", "/mnt/18/bar.txt"),
893                ("https://example.com/foo/bar/foo.txt", "/mnt/18/foo.txt"),
894                ("https://example.com/foo/baz/bar.txt", "/mnt/21/bar.txt"),
895                ("https://example.com/foo/baz/foo.txt", "/mnt/21/foo.txt"),
896                ("https://foo.com/bar", "/mnt/28/bar"),
897                ("/", "/mnt/0/.root"),
898                ("/bar/foo/bar.txt", "/mnt/10/bar.txt"),
899                ("/bar/foo/foo.txt", "/mnt/10/foo.txt"),
900                ("/baz", "/mnt/1/baz"),
901                ("/foo/bar/bar.txt", "/mnt/3/bar.txt"),
902                ("/foo/bar/foo.txt", "/mnt/3/foo.txt"),
903                ("/foo/baz/bar.txt", "/mnt/6/bar.txt"),
904                ("/foo/baz/foo.txt", "/mnt/6/foo.txt"),
905            ]
906        );
907    }
908
909    #[cfg(windows)]
910    #[test]
911    fn non_empty_trie_windows() {
912        let mut trie = InputTrie::default();
913        let inputs = [
914            Input::new(InputKind::Directory, "C:\\".parse().unwrap()),
915            Input::new(InputKind::File, "C:\\foo\\bar\\foo.txt".parse().unwrap()),
916            Input::new(InputKind::File, "C:\\foo\\bar\\bar.txt".parse().unwrap()),
917            Input::new(InputKind::File, "C:\\foo\\baz\\foo.txt".parse().unwrap()),
918            Input::new(InputKind::File, "C:\\foo\\baz\\bar.txt".parse().unwrap()),
919            Input::new(InputKind::File, "C:\\bar\\foo\\foo.txt".parse().unwrap()),
920            Input::new(InputKind::File, "C:\\bar\\foo\\bar.txt".parse().unwrap()),
921            Input::new(InputKind::Directory, "C:\\baz".parse().unwrap()),
922            Input::new(InputKind::File, "https://example.com/".parse().unwrap()),
923            Input::new(
924                InputKind::File,
925                "https://example.com/foo/bar/foo.txt".parse().unwrap(),
926            ),
927            Input::new(
928                InputKind::File,
929                "https://example.com/foo/bar/bar.txt".parse().unwrap(),
930            ),
931            Input::new(
932                InputKind::File,
933                "https://example.com/foo/baz/foo.txt".parse().unwrap(),
934            ),
935            Input::new(
936                InputKind::File,
937                "https://example.com/foo/baz/bar.txt".parse().unwrap(),
938            ),
939            Input::new(
940                InputKind::File,
941                "https://example.com/bar/foo/foo.txt".parse().unwrap(),
942            ),
943            Input::new(
944                InputKind::File,
945                "https://example.com/bar/foo/bar.txt".parse().unwrap(),
946            ),
947            Input::new(InputKind::File, "https://foo.com/bar".parse().unwrap()),
948        ];
949
950        for input in &inputs {
951            trie.insert(input).unwrap();
952        }
953
954        // The important part of the guest paths are:
955        // 1) The guest file name should be the same (or `.root` if the path is
956        //    considered to be root)
957        // 2) Paths with the same parent should have the same guest parent
958        let paths = trie.calculate_guest_paths("/mnt/").unwrap();
959        let paths: Vec<_> = paths
960            .iter()
961            .map(|(index, guest)| (inputs[*index].path().to_str().unwrap(), guest.as_str()))
962            .collect();
963
964        assert_eq!(
965            paths,
966            [
967                ("https://example.com/", "/mnt/16/.root"),
968                ("https://example.com/bar/foo/bar.txt", "/mnt/26/bar.txt"),
969                ("https://example.com/bar/foo/foo.txt", "/mnt/26/foo.txt"),
970                ("https://example.com/foo/bar/bar.txt", "/mnt/19/bar.txt"),
971                ("https://example.com/foo/bar/foo.txt", "/mnt/19/foo.txt"),
972                ("https://example.com/foo/baz/bar.txt", "/mnt/22/bar.txt"),
973                ("https://example.com/foo/baz/foo.txt", "/mnt/22/foo.txt"),
974                ("https://foo.com/bar", "/mnt/29/bar"),
975                ("C:\\", "/mnt/1/.root"),
976                ("C:\\bar\\foo\\bar.txt", "/mnt/11/bar.txt"),
977                ("C:\\bar\\foo\\foo.txt", "/mnt/11/foo.txt"),
978                ("C:\\baz", "/mnt/2/baz"),
979                ("C:\\foo\\bar\\bar.txt", "/mnt/4/bar.txt"),
980                ("C:\\foo\\bar\\foo.txt", "/mnt/4/foo.txt"),
981                ("C:\\foo\\baz\\bar.txt", "/mnt/7/bar.txt"),
982                ("C:\\foo\\baz\\foo.txt", "/mnt/7/foo.txt"),
983            ]
984        );
985    }
986}