1use super::{ProtoError, ProtoResult};
7use std::collections::{HashMap, HashSet};
8use std::path::{Path, PathBuf};
9use webots_proto_ast::proto::ast::{
10 AstNode, AstNodeKind, FieldValue, NodeBodyElement, Proto, ProtoBodyItem,
11};
12use webots_proto_ast::proto::parser::Parser;
13use webots_proto_template::RenderOptions;
14
15#[derive(Clone)]
16struct ProtoExpansion {
17 root_node: AstNode,
18 interface_defaults: HashMap<String, FieldValue>,
19}
20
21#[derive(Debug, Clone, Default)]
23pub struct ResolveOptions {
24 pub webots_projects_dir: Option<PathBuf>,
26 pub max_depth: usize,
28}
29
30impl ResolveOptions {
31 pub fn new() -> Self {
33 Self {
34 webots_projects_dir: None,
35 max_depth: 10,
36 }
37 }
38
39 pub fn with_webots_projects_dir(mut self, path: PathBuf) -> Self {
41 self.webots_projects_dir = Some(path);
42 self
43 }
44
45 pub fn with_max_depth(mut self, depth: usize) -> Self {
47 self.max_depth = depth;
48 self
49 }
50}
51
52pub struct ProtoResolver {
54 options: ResolveOptions,
55 visited: HashSet<PathBuf>,
57 depth: usize,
59}
60
61impl ProtoResolver {
62 pub fn new(options: ResolveOptions) -> Self {
64 Self {
65 options,
66 visited: HashSet::new(),
67 depth: 0,
68 }
69 }
70
71 pub fn to_root_node(
72 &mut self,
73 input: &str,
74 base_path: Option<impl AsRef<Path>>,
75 ) -> ProtoResult<AstNode> {
76 let base_path = base_path.map(|p| p.as_ref().to_path_buf());
77 let mut parser = Parser::new(input);
78 let doc = parser
79 .parse_document()
80 .map_err(|e| ProtoError::ParseError(format!("{:?}", e)))?;
81
82 let base_path = base_path
84 .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
85
86 self.expand_document(&doc, &base_path)
87 }
88
89 fn resolve_url(&self, url: &str, base_path: impl AsRef<Path>) -> ProtoResult<PathBuf> {
91 let base_path = base_path.as_ref();
92 if url.starts_with("webots://") {
94 if let Some(ref webots_dir) = self.options.webots_projects_dir {
95 let relative_path = url.strip_prefix("webots://").unwrap();
96 return Ok(webots_dir.join(relative_path));
97 } else {
98 return Err(ProtoError::ParseError(format!(
99 "webots:// URL '{}' requires webots_projects_dir to be configured",
100 url
101 )));
102 }
103 }
104
105 if url.starts_with("https://") || url.starts_with("http://") {
107 return Err(ProtoError::ParseError(format!(
108 "Network URL '{}' is not supported by the resolver",
109 url
110 )));
111 }
112
113 let path = Path::new(url);
115 if path.is_absolute() {
116 Ok(path.to_path_buf())
117 } else {
118 Ok(base_path.join(path))
120 }
121 }
122
123 fn load_proto(&mut self, path: impl AsRef<Path>) -> ProtoResult<Proto> {
125 let path = path.as_ref();
126 let canonical_path = path
128 .canonicalize()
129 .map_err(|e| ProtoError::ParseError(format!("Failed to canonicalize path: {}", e)))?;
130
131 if self.visited.contains(&canonical_path) {
132 return Err(ProtoError::ParseError(format!(
133 "Circular dependency detected: {:?}",
134 canonical_path
135 )));
136 }
137
138 if self.depth >= self.options.max_depth {
140 return Err(ProtoError::ParseError(format!(
141 "Maximum recursion depth ({}) exceeded",
142 self.options.max_depth
143 )));
144 }
145
146 self.visited.insert(canonical_path.clone());
147 self.depth += 1;
148
149 let content = std::fs::read_to_string(&canonical_path).map_err(|e| {
151 ProtoError::ParseError(format!("Failed to read file {:?}: {}", canonical_path, e))
152 })?;
153
154 let mut parser = Parser::new(&content);
155 let doc = parser
156 .parse_document()
157 .map_err(|e| ProtoError::ParseError(format!("Parse error in {:?}: {:?}", path, e)))?;
158
159 self.depth -= 1;
160
161 Ok(doc)
162 }
163
164 fn expand_document(
166 &mut self,
167 doc: &Proto,
168 base_path: impl AsRef<Path>,
169 ) -> ProtoResult<AstNode> {
170 let base_path = base_path.as_ref();
171 let rendered_doc = if doc.proto.is_some() {
173 let rendered_content = webots_proto_template::render(doc, &RenderOptions::default())?;
174 let mut parser = Parser::new(&rendered_content);
175 parser.parse_document().map_err(|e| {
176 ProtoError::ParseError(format!("Failed to parse rendered template: {:?}", e))
177 })?
178 } else {
179 doc.clone()
180 };
181
182 let mut proto_definitions: HashMap<String, ProtoExpansion> = HashMap::new();
184
185 for ext in &doc.externprotos {
188 let resolved_path = self.resolve_url(&ext.url, base_path)?;
189 let nested_doc = self.load_proto(&resolved_path)?;
190 let nested_base = resolved_path.parent().unwrap_or(Path::new("."));
191
192 let expanded_nested = self.expand_document(&nested_doc, nested_base)?;
194
195 if let Some(proto) = &nested_doc.proto {
197 let mut interface_defaults = HashMap::new();
198 for field in &proto.fields {
199 if let Some(default_value) = field.default_value.clone() {
200 interface_defaults.insert(field.name.clone(), default_value);
201 }
202 }
203 proto_definitions.insert(
204 proto.name.clone(),
205 ProtoExpansion {
206 root_node: expanded_nested,
207 interface_defaults,
208 },
209 );
210 }
211 }
212
213 let mut root_node = if let Some(proto) = &rendered_doc.proto {
215 let mut found_node = None;
217 for item in &proto.body {
218 if let ProtoBodyItem::Node(node) = item {
219 found_node = Some(node.clone());
220 break;
221 }
222 }
223
224 found_node.ok_or_else(|| {
225 ProtoError::ParseError("PROTO definition has no root node".to_string())
226 })?
227 } else {
228 rendered_doc
230 .root_nodes
231 .first()
232 .ok_or_else(|| ProtoError::ParseError("Document has no root nodes".to_string()))?
233 .clone()
234 };
235
236 self.inline_proto_nodes(&mut root_node, &proto_definitions)?;
238 self.normalize_mesh_urls(&mut root_node, base_path);
239
240 Ok(root_node)
241 }
242
243 fn normalize_mesh_urls(&self, node: &mut AstNode, base_path: &Path) {
244 let AstNodeKind::Node {
245 type_name, fields, ..
246 } = &mut node.kind
247 else {
248 return;
249 };
250
251 if type_name == "CadShape" {
252 for element in fields.iter_mut() {
253 let NodeBodyElement::Field(field) = element else {
254 continue;
255 };
256 if field.name != "url" {
257 continue;
258 }
259 self.normalize_mesh_field_value(&mut field.value, base_path);
260 }
261 }
262
263 for element in fields.iter_mut() {
264 let NodeBodyElement::Field(field) = element else {
265 continue;
266 };
267 match &mut field.value {
268 FieldValue::Node(child_node) => self.normalize_mesh_urls(child_node, base_path),
269 FieldValue::Array(array) => {
270 for item in &mut array.elements {
271 if let FieldValue::Node(child_node) = &mut item.value {
272 self.normalize_mesh_urls(child_node, base_path);
273 }
274 }
275 }
276 _ => {}
277 }
278 }
279 }
280
281 fn normalize_mesh_field_value(&self, field_value: &mut FieldValue, base_path: &Path) {
282 match field_value {
283 FieldValue::String(url) | FieldValue::Raw(url) => {
284 *url = normalize_local_url(url, base_path);
285 }
286 FieldValue::Array(array) => {
287 for element in &mut array.elements {
288 if let FieldValue::String(url) | FieldValue::Raw(url) = &mut element.value {
289 *url = normalize_local_url(url, base_path);
290 }
291 }
292 }
293 _ => {}
294 }
295 }
296
297 fn inline_proto_nodes(
299 &self,
300 node: &mut AstNode,
301 proto_definitions: &HashMap<String, ProtoExpansion>,
302 ) -> ProtoResult<()> {
303 if let AstNodeKind::Node {
304 type_name,
305 def_name,
306 fields,
307 } = &mut node.kind
308 {
309 if let Some(proto_expansion) = proto_definitions.get(type_name) {
311 let mut inlined_node = proto_expansion.root_node.clone();
313
314 self.resolve_is_bindings(
316 &mut inlined_node,
317 fields,
318 &proto_expansion.interface_defaults,
319 )?;
320
321 if let Some(caller_def_name) = def_name.clone()
324 && let AstNodeKind::Node {
325 def_name: inlined_def_name,
326 ..
327 } = &mut inlined_node.kind
328 {
329 *inlined_def_name = Some(caller_def_name);
330 }
331
332 *node = inlined_node;
334
335 if let AstNodeKind::Node { fields, .. } = &mut node.kind {
337 for element in fields.iter_mut() {
339 if let NodeBodyElement::Field(field) = element {
340 if let FieldValue::Node(child_node) = &mut field.value {
341 self.inline_proto_nodes(child_node.as_mut(), proto_definitions)?;
342 } else if let FieldValue::Array(array) = &mut field.value {
343 for item in &mut array.elements {
344 if let FieldValue::Node(child_node) = &mut item.value {
345 self.inline_proto_nodes(
346 child_node.as_mut(),
347 proto_definitions,
348 )?;
349 }
350 }
351 }
352 }
353 }
354 }
355 } else {
356 for element in fields.iter_mut() {
358 if let NodeBodyElement::Field(field) = element {
359 if let FieldValue::Node(child_node) = &mut field.value {
360 self.inline_proto_nodes(child_node.as_mut(), proto_definitions)?;
361 } else if let FieldValue::Array(array) = &mut field.value {
362 for item in &mut array.elements {
363 if let FieldValue::Node(child_node) = &mut item.value {
364 self.inline_proto_nodes(
365 child_node.as_mut(),
366 proto_definitions,
367 )?;
368 }
369 }
370 }
371 }
372 }
373 }
374 }
375
376 Ok(())
377 }
378
379 fn resolve_is_bindings(
381 &self,
382 node: &mut AstNode,
383 parent_fields: &[NodeBodyElement],
384 interface_defaults: &HashMap<String, FieldValue>,
385 ) -> ProtoResult<()> {
386 let mut field_values: HashMap<String, FieldValue> = HashMap::new();
388 for element in parent_fields {
389 if let NodeBodyElement::Field(field) = element {
390 field_values.insert(field.name.clone(), field.value.clone());
391 }
392 }
393
394 if let AstNodeKind::Node { fields, .. } = &mut node.kind {
396 for element in fields.iter_mut() {
397 if let NodeBodyElement::Field(field) = element {
398 if let FieldValue::Is(ref field_name) = field.value {
400 if let Some(value) = field_values.get(field_name) {
402 field.value = value.clone();
403 } else if let Some(default_value) = interface_defaults.get(field_name) {
404 field.value = default_value.clone();
406 }
407 }
408
409 match &mut field.value {
411 FieldValue::Node(child_node) => {
412 self.resolve_is_bindings(
413 child_node,
414 parent_fields,
415 interface_defaults,
416 )?;
417 }
418 FieldValue::Array(array) => {
419 for item in &mut array.elements {
420 if let FieldValue::Node(child_node) = &mut item.value {
421 self.resolve_is_bindings(
422 child_node,
423 parent_fields,
424 interface_defaults,
425 )?;
426 }
427 }
428 }
429 _ => {}
430 }
431 }
432 }
433 }
434
435 Ok(())
436 }
437}
438
439fn normalize_local_url(url: &str, base_path: &Path) -> String {
440 if url.contains("://") {
441 return url.to_string();
442 }
443
444 let as_path = Path::new(url);
445 if as_path.is_absolute() {
446 return as_path.to_string_lossy().to_string();
447 }
448
449 let absolute_base = if base_path.is_absolute() {
450 base_path.to_path_buf()
451 } else {
452 std::env::current_dir()
453 .unwrap_or_else(|_| PathBuf::from("."))
454 .join(base_path)
455 };
456
457 let candidate = absolute_base.join(as_path);
458 candidate
459 .canonicalize()
460 .unwrap_or(candidate)
461 .to_string_lossy()
462 .to_string()
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_resolve_options_builder() {
471 let options = ResolveOptions::new().with_max_depth(5);
472
473 assert_eq!(options.max_depth, 5);
474 assert!(options.webots_projects_dir.is_none());
475 }
476
477 #[test]
478 fn test_resolve_local_path() {
479 let resolver = ProtoResolver::new(ResolveOptions::new());
480 let base = Path::new("/base/path");
481
482 let result = resolver.resolve_url("Child.proto", base).unwrap();
483 assert_eq!(result, PathBuf::from("/base/path/Child.proto"));
484 }
485
486 #[test]
487 fn test_resolve_webots_url_without_config() {
488 let resolver = ProtoResolver::new(ResolveOptions::new());
489 let base = Path::new("/base/path");
490
491 let result = resolver.resolve_url("webots://projects/robots/Robot.proto", base);
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_resolve_webots_url_with_config() {
497 let options =
498 ResolveOptions::new().with_webots_projects_dir(PathBuf::from("/webots/assets"));
499 let resolver = ProtoResolver::new(options);
500 let base = Path::new("/base/path");
501
502 let result = resolver
503 .resolve_url("webots://projects/robots/Robot.proto", base)
504 .unwrap();
505 assert_eq!(
506 result,
507 PathBuf::from("/webots/assets/projects/robots/Robot.proto")
508 );
509 }
510
511 #[test]
512 fn test_reject_network_url_by_default() {
513 let resolver = ProtoResolver::new(ResolveOptions::new());
514 let base = Path::new("/base/path");
515
516 let result = resolver.resolve_url("https://example.com/Robot.proto", base);
517 assert!(result.is_err());
518 }
519
520 #[test]
521 fn test_normalize_local_url_keeps_remote_urls() {
522 let base = Path::new("/tmp");
523 assert_eq!(
524 normalize_local_url("https://example.com/mesh.obj", base),
525 "https://example.com/mesh.obj"
526 );
527 assert_eq!(
528 normalize_local_url("webots://projects/mesh.obj", base),
529 "webots://projects/mesh.obj"
530 );
531 }
532}