Skip to main content

webots_proto_resolver/
resolve.rs

1//! PROTO resolver for expanding EXTERNPROTO dependencies.
2//!
3//! This module provides functionality to resolve and expand EXTERNPROTO declarations
4//! in PROTO files, producing a fully expanded robot definition with no external references.
5
6use super::{ProtoError, ProtoResult};
7use std::collections::{HashMap, HashSet};
8use std::path::{Path, PathBuf};
9use webots_proto_ast::proto::ast::{
10    AstNode, AstNodeKind, FieldValue, NodeBodyElement, Proto, ProtoBodyItem,
11};
12use webots_proto_ast::proto::parser::Parser;
13use webots_proto_template::RenderOptions;
14
15#[derive(Clone)]
16struct ProtoExpansion {
17    root_node: AstNode,
18    interface_defaults: HashMap<String, FieldValue>,
19}
20
21/// Configuration options for PROTO resolution.
22#[derive(Debug, Clone, Default)]
23pub struct ResolveOptions {
24    /// Optional path to Webots projects directory for resolving webots:// URLs.
25    pub webots_projects_dir: Option<PathBuf>,
26    /// Maximum recursion depth to prevent infinite loops.
27    pub max_depth: usize,
28}
29
30impl ResolveOptions {
31    /// Creates a new ResolveOptions with default values.
32    pub fn new() -> Self {
33        Self {
34            webots_projects_dir: None,
35            max_depth: 10,
36        }
37    }
38
39    /// Sets the Webots projects directory for resolving webots:// URLs.
40    pub fn with_webots_projects_dir(mut self, path: PathBuf) -> Self {
41        self.webots_projects_dir = Some(path);
42        self
43    }
44
45    /// Sets the maximum recursion depth.
46    pub fn with_max_depth(mut self, depth: usize) -> Self {
47        self.max_depth = depth;
48        self
49    }
50}
51
52/// PROTO resolver for expanding EXTERNPROTO dependencies.
53pub struct ProtoResolver {
54    options: ResolveOptions,
55    /// Track visited files to detect circular dependencies.
56    visited: HashSet<PathBuf>,
57    /// Current recursion depth.
58    depth: usize,
59}
60
61impl ProtoResolver {
62    /// Creates a new ProtoResolver with the given options.
63    pub fn new(options: ResolveOptions) -> Self {
64        Self {
65            options,
66            visited: HashSet::new(),
67            depth: 0,
68        }
69    }
70
71    pub fn to_root_node(
72        &mut self,
73        input: &str,
74        base_path: Option<impl AsRef<Path>>,
75    ) -> ProtoResult<AstNode> {
76        let base_path = base_path.map(|p| p.as_ref().to_path_buf());
77        let mut parser = Parser::new(input);
78        let doc = parser
79            .parse_document()
80            .map_err(|e| ProtoError::ParseError(format!("{:?}", e)))?;
81
82        // Determine base path: use provided, or fall back to current directory
83        let base_path = base_path
84            .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
85
86        self.expand_document(&doc, &base_path)
87    }
88
89    /// Resolves an EXTERNPROTO URL to a local file path.
90    fn resolve_url(&self, url: &str, base_path: impl AsRef<Path>) -> ProtoResult<PathBuf> {
91        let base_path = base_path.as_ref();
92        // Handle webots:// URLs
93        if url.starts_with("webots://") {
94            if let Some(ref webots_dir) = self.options.webots_projects_dir {
95                let relative_path = url.strip_prefix("webots://").unwrap();
96                return Ok(webots_dir.join(relative_path));
97            } else {
98                return Err(ProtoError::ParseError(format!(
99                    "webots:// URL '{}' requires webots_projects_dir to be configured",
100                    url
101                )));
102            }
103        }
104
105        // Handle https:// URLs
106        if url.starts_with("https://") || url.starts_with("http://") {
107            return Err(ProtoError::ParseError(format!(
108                "Network URL '{}' is not supported by the resolver",
109                url
110            )));
111        }
112
113        // Handle local file paths (relative or absolute)
114        let path = Path::new(url);
115        if path.is_absolute() {
116            Ok(path.to_path_buf())
117        } else {
118            // Resolve relative to the base path
119            Ok(base_path.join(path))
120        }
121    }
122
123    /// Loads and parses a PROTO file from the given path.
124    fn load_proto(&mut self, path: impl AsRef<Path>) -> ProtoResult<Proto> {
125        let path = path.as_ref();
126        // Check for circular dependencies
127        let canonical_path = path
128            .canonicalize()
129            .map_err(|e| ProtoError::ParseError(format!("Failed to canonicalize path: {}", e)))?;
130
131        if self.visited.contains(&canonical_path) {
132            return Err(ProtoError::ParseError(format!(
133                "Circular dependency detected: {:?}",
134                canonical_path
135            )));
136        }
137
138        // Check recursion depth
139        if self.depth >= self.options.max_depth {
140            return Err(ProtoError::ParseError(format!(
141                "Maximum recursion depth ({}) exceeded",
142                self.options.max_depth
143            )));
144        }
145
146        self.visited.insert(canonical_path.clone());
147        self.depth += 1;
148
149        // Load and parse the file
150        let content = std::fs::read_to_string(&canonical_path).map_err(|e| {
151            ProtoError::ParseError(format!("Failed to read file {:?}: {}", canonical_path, e))
152        })?;
153
154        let mut parser = Parser::new(&content);
155        let doc = parser
156            .parse_document()
157            .map_err(|e| ProtoError::ParseError(format!("Parse error in {:?}: {:?}", path, e)))?;
158
159        self.depth -= 1;
160
161        Ok(doc)
162    }
163
164    /// Expands a PROTO document by resolving all EXTERNPROTO dependencies.
165    fn expand_document(
166        &mut self,
167        doc: &Proto,
168        base_path: impl AsRef<Path>,
169    ) -> ProtoResult<AstNode> {
170        let base_path = base_path.as_ref();
171        // First, render templates if this is a PROTO definition
172        let rendered_doc = if doc.proto.is_some() {
173            let rendered_content = webots_proto_template::render(doc, &RenderOptions::default())?;
174            let mut parser = Parser::new(&rendered_content);
175            parser.parse_document().map_err(|e| {
176                ProtoError::ParseError(format!("Failed to parse rendered template: {:?}", e))
177            })?
178        } else {
179            doc.clone()
180        };
181
182        // Build a map of EXTERNPROTO declarations: PROTO name -> expanded node + defaults.
183        let mut proto_definitions: HashMap<String, ProtoExpansion> = HashMap::new();
184
185        // Preserve EXTERNPROTO declarations from the original parsed document.
186        // Template rendering may output only PROTO body content and omit header directives.
187        for ext in &doc.externprotos {
188            let resolved_path = self.resolve_url(&ext.url, base_path)?;
189            let nested_doc = self.load_proto(&resolved_path)?;
190            let nested_base = resolved_path.parent().unwrap_or(Path::new("."));
191
192            // Recursively expand the nested document
193            let expanded_nested = self.expand_document(&nested_doc, nested_base)?;
194
195            // Get the PROTO name from the nested document
196            if let Some(proto) = &nested_doc.proto {
197                let mut interface_defaults = HashMap::new();
198                for field in &proto.fields {
199                    if let Some(default_value) = field.default_value.clone() {
200                        interface_defaults.insert(field.name.clone(), default_value);
201                    }
202                }
203                proto_definitions.insert(
204                    proto.name.clone(),
205                    ProtoExpansion {
206                        root_node: expanded_nested,
207                        interface_defaults,
208                    },
209                );
210            }
211        }
212
213        // Extract the root node from the PROTO body
214        let mut root_node = if let Some(proto) = &rendered_doc.proto {
215            // Find the first node in the PROTO body
216            let mut found_node = None;
217            for item in &proto.body {
218                if let ProtoBodyItem::Node(node) = item {
219                    found_node = Some(node.clone());
220                    break;
221                }
222            }
223
224            found_node.ok_or_else(|| {
225                ProtoError::ParseError("PROTO definition has no root node".to_string())
226            })?
227        } else {
228            // If not a PROTO, return the first root node
229            rendered_doc
230                .root_nodes
231                .first()
232                .ok_or_else(|| ProtoError::ParseError("Document has no root nodes".to_string()))?
233                .clone()
234        };
235
236        // Inline EXTERNPROTO references and resolve IS bindings
237        self.inline_proto_nodes(&mut root_node, &proto_definitions)?;
238        self.normalize_mesh_urls(&mut root_node, base_path);
239
240        Ok(root_node)
241    }
242
243    fn normalize_mesh_urls(&self, node: &mut AstNode, base_path: &Path) {
244        let AstNodeKind::Node {
245            type_name, fields, ..
246        } = &mut node.kind
247        else {
248            return;
249        };
250
251        if type_name == "CadShape" {
252            for element in fields.iter_mut() {
253                let NodeBodyElement::Field(field) = element else {
254                    continue;
255                };
256                if field.name != "url" {
257                    continue;
258                }
259                self.normalize_mesh_field_value(&mut field.value, base_path);
260            }
261        }
262
263        for element in fields.iter_mut() {
264            let NodeBodyElement::Field(field) = element else {
265                continue;
266            };
267            match &mut field.value {
268                FieldValue::Node(child_node) => self.normalize_mesh_urls(child_node, base_path),
269                FieldValue::Array(array) => {
270                    for item in &mut array.elements {
271                        if let FieldValue::Node(child_node) = &mut item.value {
272                            self.normalize_mesh_urls(child_node, base_path);
273                        }
274                    }
275                }
276                _ => {}
277            }
278        }
279    }
280
281    fn normalize_mesh_field_value(&self, field_value: &mut FieldValue, base_path: &Path) {
282        match field_value {
283            FieldValue::String(url) | FieldValue::Raw(url) => {
284                *url = normalize_local_url(url, base_path);
285            }
286            FieldValue::Array(array) => {
287                for element in &mut array.elements {
288                    if let FieldValue::String(url) | FieldValue::Raw(url) = &mut element.value {
289                        *url = normalize_local_url(url, base_path);
290                    }
291                }
292            }
293            _ => {}
294        }
295    }
296
297    /// Recursively inline EXTERNPROTO nodes and resolve IS bindings.
298    fn inline_proto_nodes(
299        &self,
300        node: &mut AstNode,
301        proto_definitions: &HashMap<String, ProtoExpansion>,
302    ) -> ProtoResult<()> {
303        if let AstNodeKind::Node {
304            type_name,
305            def_name,
306            fields,
307        } = &mut node.kind
308        {
309            // Check if this node type is a PROTO that needs inlining
310            if let Some(proto_expansion) = proto_definitions.get(type_name) {
311                // Clone the PROTO definition
312                let mut inlined_node = proto_expansion.root_node.clone();
313
314                // Resolve IS bindings in the inlined node with values from the current node
315                self.resolve_is_bindings(
316                    &mut inlined_node,
317                    fields,
318                    &proto_expansion.interface_defaults,
319                )?;
320
321                // Preserve caller DEF naming on the inlined root node so downstream
322                // frame/joint identifiers remain stable and human-readable.
323                if let Some(caller_def_name) = def_name.clone()
324                    && let AstNodeKind::Node {
325                        def_name: inlined_def_name,
326                        ..
327                    } = &mut inlined_node.kind
328                {
329                    *inlined_def_name = Some(caller_def_name);
330                }
331
332                // Replace the current node with the inlined version
333                *node = inlined_node;
334
335                // Continue processing the inlined node
336                if let AstNodeKind::Node { fields, .. } = &mut node.kind {
337                    // Recursively process child nodes in the inlined content
338                    for element in fields.iter_mut() {
339                        if let NodeBodyElement::Field(field) = element {
340                            if let FieldValue::Node(child_node) = &mut field.value {
341                                self.inline_proto_nodes(child_node.as_mut(), proto_definitions)?;
342                            } else if let FieldValue::Array(array) = &mut field.value {
343                                for item in &mut array.elements {
344                                    if let FieldValue::Node(child_node) = &mut item.value {
345                                        self.inline_proto_nodes(
346                                            child_node.as_mut(),
347                                            proto_definitions,
348                                        )?;
349                                    }
350                                }
351                            }
352                        }
353                    }
354                }
355            } else {
356                // Not a PROTO node, just recursively process children
357                for element in fields.iter_mut() {
358                    if let NodeBodyElement::Field(field) = element {
359                        if let FieldValue::Node(child_node) = &mut field.value {
360                            self.inline_proto_nodes(child_node.as_mut(), proto_definitions)?;
361                        } else if let FieldValue::Array(array) = &mut field.value {
362                            for item in &mut array.elements {
363                                if let FieldValue::Node(child_node) = &mut item.value {
364                                    self.inline_proto_nodes(
365                                        child_node.as_mut(),
366                                        proto_definitions,
367                                    )?;
368                                }
369                            }
370                        }
371                    }
372                }
373            }
374        }
375
376        Ok(())
377    }
378
379    /// Resolve IS bindings in a node by replacing them with actual values.
380    fn resolve_is_bindings(
381        &self,
382        node: &mut AstNode,
383        parent_fields: &[NodeBodyElement],
384        interface_defaults: &HashMap<String, FieldValue>,
385    ) -> ProtoResult<()> {
386        // Build a map of field name -> value from parent
387        let mut field_values: HashMap<String, FieldValue> = HashMap::new();
388        for element in parent_fields {
389            if let NodeBodyElement::Field(field) = element {
390                field_values.insert(field.name.clone(), field.value.clone());
391            }
392        }
393
394        // Recursively resolve IS bindings in this node
395        if let AstNodeKind::Node { fields, .. } = &mut node.kind {
396            for element in fields.iter_mut() {
397                if let NodeBodyElement::Field(field) = element {
398                    // Check if this field has an IS binding
399                    if let FieldValue::Is(ref field_name) = field.value {
400                        // Look up the value from the parent
401                        if let Some(value) = field_values.get(field_name) {
402                            field.value = value.clone();
403                        } else if let Some(default_value) = interface_defaults.get(field_name) {
404                            // If the caller omitted the field, use the PROTO interface default.
405                            field.value = default_value.clone();
406                        }
407                    }
408
409                    // Recursively process child nodes
410                    match &mut field.value {
411                        FieldValue::Node(child_node) => {
412                            self.resolve_is_bindings(
413                                child_node,
414                                parent_fields,
415                                interface_defaults,
416                            )?;
417                        }
418                        FieldValue::Array(array) => {
419                            for item in &mut array.elements {
420                                if let FieldValue::Node(child_node) = &mut item.value {
421                                    self.resolve_is_bindings(
422                                        child_node,
423                                        parent_fields,
424                                        interface_defaults,
425                                    )?;
426                                }
427                            }
428                        }
429                        _ => {}
430                    }
431                }
432            }
433        }
434
435        Ok(())
436    }
437}
438
439fn normalize_local_url(url: &str, base_path: &Path) -> String {
440    if url.contains("://") {
441        return url.to_string();
442    }
443
444    let as_path = Path::new(url);
445    if as_path.is_absolute() {
446        return as_path.to_string_lossy().to_string();
447    }
448
449    let absolute_base = if base_path.is_absolute() {
450        base_path.to_path_buf()
451    } else {
452        std::env::current_dir()
453            .unwrap_or_else(|_| PathBuf::from("."))
454            .join(base_path)
455    };
456
457    let candidate = absolute_base.join(as_path);
458    candidate
459        .canonicalize()
460        .unwrap_or(candidate)
461        .to_string_lossy()
462        .to_string()
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_resolve_options_builder() {
471        let options = ResolveOptions::new().with_max_depth(5);
472
473        assert_eq!(options.max_depth, 5);
474        assert!(options.webots_projects_dir.is_none());
475    }
476
477    #[test]
478    fn test_resolve_local_path() {
479        let resolver = ProtoResolver::new(ResolveOptions::new());
480        let base = Path::new("/base/path");
481
482        let result = resolver.resolve_url("Child.proto", base).unwrap();
483        assert_eq!(result, PathBuf::from("/base/path/Child.proto"));
484    }
485
486    #[test]
487    fn test_resolve_webots_url_without_config() {
488        let resolver = ProtoResolver::new(ResolveOptions::new());
489        let base = Path::new("/base/path");
490
491        let result = resolver.resolve_url("webots://projects/robots/Robot.proto", base);
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_resolve_webots_url_with_config() {
497        let options =
498            ResolveOptions::new().with_webots_projects_dir(PathBuf::from("/webots/assets"));
499        let resolver = ProtoResolver::new(options);
500        let base = Path::new("/base/path");
501
502        let result = resolver
503            .resolve_url("webots://projects/robots/Robot.proto", base)
504            .unwrap();
505        assert_eq!(
506            result,
507            PathBuf::from("/webots/assets/projects/robots/Robot.proto")
508        );
509    }
510
511    #[test]
512    fn test_reject_network_url_by_default() {
513        let resolver = ProtoResolver::new(ResolveOptions::new());
514        let base = Path::new("/base/path");
515
516        let result = resolver.resolve_url("https://example.com/Robot.proto", base);
517        assert!(result.is_err());
518    }
519
520    #[test]
521    fn test_normalize_local_url_keeps_remote_urls() {
522        let base = Path::new("/tmp");
523        assert_eq!(
524            normalize_local_url("https://example.com/mesh.obj", base),
525            "https://example.com/mesh.obj"
526        );
527        assert_eq!(
528            normalize_local_url("webots://projects/mesh.obj", base),
529            "webots://projects/mesh.obj"
530        );
531    }
532}