1use std::collections::HashMap;
2
3use derive_new::new;
4use derive_setters::Setters;
5use webots_proto_ast::proto::ast::{
6 ArrayValue, FieldType, FieldValue, NumberSequence, Proto, ProtoBodyItem,
7};
8use webots_proto_ast::proto::writer::ProtoWriter;
9
10use crate::template::types::{
11 TemplateContext, TemplateField, TemplateFieldBinding, TemplateWebotsVersion,
12};
13use crate::{TemplateError, TemplateEvaluator};
14
15#[derive(Debug, Clone, Default, Setters)]
16#[setters(prefix = "with_", strip_option, into)]
17pub struct RenderContext {
18 pub world: Option<String>,
19 pub proto: Option<String>,
20 pub project_path: Option<String>,
21 pub webots_version: Option<RenderWebotsVersion>,
22 pub webots_home: Option<String>,
23 pub temporary_files_path: Option<String>,
24 pub os: Option<String>,
25 pub id: Option<String>,
26 pub coordinate_system: Option<String>,
27}
28
29#[derive(Debug, Clone, Default, new, Setters)]
30#[setters(prefix = "with_", strip_option, into)]
31pub struct RenderWebotsVersion {
32 pub major: String,
33 pub revision: String,
34}
35
36#[derive(Debug, Clone, Default, new, Setters)]
37#[setters(prefix = "with_", strip_option)]
38pub struct RenderOptions {
39 pub field_overrides: HashMap<String, TemplateField>,
40 pub context: RenderContext,
41}
42
43impl RenderOptions {
44 pub fn with_field_overrides_from<I, K, V>(mut self, overrides: I) -> Self
45 where
46 I: IntoIterator<Item = (K, V)>,
47 K: Into<String>,
48 V: Into<TemplateField>,
49 {
50 self.field_overrides = overrides
51 .into_iter()
52 .map(|(key, value)| (key.into(), value.into()))
53 .collect();
54 self
55 }
56}
57
58pub fn render(proto: &Proto, options: &RenderOptions) -> Result<String, TemplateError> {
59 let Some(proto_definition) = &proto.proto else {
60 return Err(TemplateError::ValidationError(
61 "Document does not contain a PROTO definition".to_string(),
62 ));
63 };
64
65 let mut context_fields = HashMap::new();
66 for field in &proto_definition.fields {
67 let default_field_value = field
68 .default_value
69 .as_ref()
70 .map(|default_value| {
71 convert_to_template_field(default_value, &field.field_type).map_err(|error| {
72 TemplateError::ValidationError(format!("Field '{}': {}", field.name, error))
73 })
74 })
75 .transpose()?;
76
77 if let Some(override_field_value) = options.field_overrides.get(&field.name) {
78 validate_override_type(&field.name, override_field_value, &field.field_type)?;
79 let default_value_for_binding =
80 default_field_value.unwrap_or_else(|| override_field_value.clone());
81 context_fields.insert(
82 field.name.clone(),
83 TemplateFieldBinding::new(override_field_value.clone(), default_value_for_binding),
84 );
85 continue;
86 }
87
88 if let Some(default_value) = default_field_value {
89 context_fields.insert(
90 field.name.clone(),
91 TemplateFieldBinding::new(default_value.clone(), default_value),
92 );
93 } else {
94 return Err(TemplateError::ValidationError(format!(
95 "Field '{}' has no default and no override value",
96 field.name
97 )));
98 }
99 }
100
101 for override_field_name in options.field_overrides.keys() {
102 if proto_definition
103 .fields
104 .iter()
105 .all(|field| field.name != *override_field_name)
106 {
107 return Err(TemplateError::ValidationError(format!(
108 "Unknown field override '{}': no matching PROTO interface field",
109 override_field_name
110 )));
111 }
112 }
113
114 let body_content = if let Some(source) = &proto.source_content {
115 if let (Some(first_item), Some(last_item)) =
116 (proto_definition.body.first(), proto_definition.body.last())
117 {
118 let start = match first_item {
119 ProtoBodyItem::Node(node) => node.span.start,
120 ProtoBodyItem::Template(template) => template.span.start,
121 };
122 let end = match last_item {
123 ProtoBodyItem::Node(node) => node.span.end,
124 ProtoBodyItem::Template(template) => template.span.end,
125 };
126 source.get(start..end).unwrap_or("").to_string()
127 } else {
128 String::new()
129 }
130 } else {
131 let writer = ProtoWriter::new();
132 let mut body_content = String::new();
133 for item in &proto_definition.body {
134 match item {
135 ProtoBodyItem::Node(node) => writer
136 .write_node(&mut body_content, node)
137 .map_err(|error| TemplateError::ValidationError(error.to_string()))?,
138 ProtoBodyItem::Template(template) => writer
139 .write_template(&mut body_content, template)
140 .map_err(|error| TemplateError::ValidationError(error.to_string()))?,
141 }
142 }
143 body_content
144 };
145
146 let template_context = convert_render_context(&options.context);
147 TemplateEvaluator::with_context(template_context)
148 .evaluate_with_environment(&body_content, &context_fields)
149}
150
151fn convert_render_context(context: &RenderContext) -> TemplateContext {
152 TemplateContext {
153 world: context.world.clone(),
154 proto: context.proto.clone(),
155 project_path: context.project_path.clone(),
156 webots_home: context.webots_home.clone(),
157 temporary_files_path: context.temporary_files_path.clone(),
158 os: context.os.clone(),
159 id: context.id.clone(),
160 coordinate_system: context.coordinate_system.clone(),
161 webots_version: context.webots_version.as_ref().map(|version| {
162 TemplateWebotsVersion::new(version.major.clone(), version.revision.clone())
163 }),
164 }
165}
166
167fn convert_to_template_field(
168 value: &FieldValue,
169 field_type: &FieldType,
170) -> Result<TemplateField, String> {
171 if let (FieldValue::Bool(boolean), FieldType::SFBool) = (value, field_type) {
172 return Ok(TemplateField::SFBool(*boolean));
173 }
174 if let (FieldValue::Int(integer, _), FieldType::SFInt32) = (value, field_type) {
175 return Ok(TemplateField::SFInt32(convert_i64_to_i32(*integer)?));
176 }
177 if let (FieldValue::Int(integer, _), FieldType::SFFloat) = (value, field_type) {
178 return Ok(TemplateField::SFFloat(*integer as f64));
179 }
180 if let (FieldValue::Float(float, _), FieldType::SFFloat) = (value, field_type) {
181 return Ok(TemplateField::SFFloat(*float));
182 }
183 if let (FieldValue::String(string), FieldType::SFString) = (value, field_type) {
184 return Ok(TemplateField::SFString(string.clone()));
185 }
186 if let (FieldValue::Vec2f(vector), FieldType::SFVec2f) = (value, field_type) {
187 return Ok(TemplateField::SFVec2f(vector[0], vector[1]));
188 }
189 if let (FieldValue::Vec3f(vector), FieldType::SFVec3f) = (value, field_type) {
190 return Ok(TemplateField::SFVec3f(vector[0], vector[1], vector[2]));
191 }
192 if let (FieldValue::Rotation(vector), FieldType::SFRotation) = (value, field_type) {
193 return Ok(TemplateField::SFRotation(
194 vector[0], vector[1], vector[2], vector[3],
195 ));
196 }
197 if let (FieldValue::Color(vector), FieldType::SFColor) = (value, field_type) {
198 return Ok(TemplateField::SFColor(vector[0], vector[1], vector[2]));
199 }
200 if let (FieldValue::Node(node), FieldType::SFNode) = (value, field_type) {
201 let mut content = String::new();
202 ProtoWriter::new()
203 .write_node(&mut content, node)
204 .map_err(|_| "Failed to serialize node".to_string())?;
205 return Ok(TemplateField::SFNode(content));
206 }
207 if let (FieldValue::Null, FieldType::SFNode) = (value, field_type) {
208 return Ok(TemplateField::SFNode("NULL".into()));
209 }
210
211 if let (FieldValue::NumberSequence(sequence), FieldType::SFVec2f) = (value, field_type)
212 && sequence.elements.len() == 2
213 {
214 let numbers = extract_numbers_as_vec(sequence)?;
215 return Ok(TemplateField::SFVec2f(numbers[0], numbers[1]));
216 }
217 if let (FieldValue::NumberSequence(sequence), FieldType::SFVec3f) = (value, field_type)
218 && sequence.elements.len() == 3
219 {
220 let numbers = extract_numbers_as_vec(sequence)?;
221 return Ok(TemplateField::SFVec3f(numbers[0], numbers[1], numbers[2]));
222 }
223 if let (FieldValue::NumberSequence(sequence), FieldType::SFColor) = (value, field_type)
224 && sequence.elements.len() == 3
225 {
226 let numbers = extract_numbers_as_vec(sequence)?;
227 return Ok(TemplateField::SFColor(numbers[0], numbers[1], numbers[2]));
228 }
229 if let (FieldValue::NumberSequence(sequence), FieldType::SFRotation) = (value, field_type)
230 && sequence.elements.len() == 4
231 {
232 let numbers = extract_numbers_as_vec(sequence)?;
233 return Ok(TemplateField::SFRotation(
234 numbers[0], numbers[1], numbers[2], numbers[3],
235 ));
236 }
237 if let (FieldValue::NumberSequence(sequence), FieldType::SFFloat) = (value, field_type)
238 && sequence.elements.len() == 1
239 {
240 let numbers = extract_numbers_as_vec(sequence)?;
241 return Ok(TemplateField::SFFloat(numbers[0]));
242 }
243 if let (FieldValue::NumberSequence(sequence), FieldType::SFInt32) = (value, field_type)
244 && sequence.elements.len() == 1
245 {
246 let numbers = extract_numbers_as_int(sequence)?;
247 return Ok(TemplateField::SFInt32(numbers[0]));
248 }
249
250 if let (FieldValue::Array(array), FieldType::MFBool) = (value, field_type) {
251 return Ok(TemplateField::MFBool(extract_mf_values(
252 array,
253 |field_value| {
254 if let FieldValue::Bool(boolean) = field_value {
255 Ok(*boolean)
256 } else {
257 Err("Expected Bool".to_string())
258 }
259 },
260 )?));
261 }
262 if let (FieldValue::Array(array), FieldType::MFInt32) = (value, field_type) {
263 return Ok(TemplateField::MFInt32(extract_mf_values(
264 array,
265 |field_value| {
266 if let FieldValue::Int(integer, _) = field_value {
267 Ok(convert_i64_to_i32(*integer)?)
268 } else {
269 Err("Expected Int".to_string())
270 }
271 },
272 )?));
273 }
274 if let (FieldValue::Array(array), FieldType::MFFloat) = (value, field_type) {
275 return Ok(TemplateField::MFFloat(extract_mf_values(
276 array,
277 |field_value| {
278 if let FieldValue::Float(float, _) = field_value {
279 Ok(*float)
280 } else if let FieldValue::Int(integer, _) = field_value {
281 Ok(*integer as f64)
282 } else {
283 Err("Expected Float".to_string())
284 }
285 },
286 )?));
287 }
288 if let (FieldValue::Array(array), FieldType::MFString) = (value, field_type) {
289 return Ok(TemplateField::MFString(extract_mf_values(
290 array,
291 |field_value| {
292 if let FieldValue::String(string) = field_value {
293 Ok(string.clone())
294 } else {
295 Err("Expected String".to_string())
296 }
297 },
298 )?));
299 }
300 if let (FieldValue::Array(array), FieldType::MFVec2f) = (value, field_type) {
301 return Ok(TemplateField::MFVec2f(
302 extract_grouped_vectors(array, 2)?
303 .into_iter()
304 .map(|values| (values[0], values[1]))
305 .collect(),
306 ));
307 }
308 if let (FieldValue::Array(array), FieldType::MFVec3f) = (value, field_type) {
309 return Ok(TemplateField::MFVec3f(
310 extract_grouped_vectors(array, 3)?
311 .into_iter()
312 .map(|values| (values[0], values[1], values[2]))
313 .collect(),
314 ));
315 }
316 if let (FieldValue::Array(array), FieldType::MFRotation) = (value, field_type) {
317 return Ok(TemplateField::MFRotation(extract_mf_values(
318 array,
319 |field_value| {
320 if let FieldValue::Rotation(vector) = field_value {
321 Ok((vector[0], vector[1], vector[2], vector[3]))
322 } else {
323 Err("Expected Rotation".to_string())
324 }
325 },
326 )?));
327 }
328 if let (FieldValue::Array(array), FieldType::MFColor) = (value, field_type) {
329 return Ok(TemplateField::MFColor(extract_mf_values(
330 array,
331 |field_value| {
332 if let FieldValue::Color(vector) = field_value {
333 Ok((vector[0], vector[1], vector[2]))
334 } else {
335 Err("Expected Color".to_string())
336 }
337 },
338 )?));
339 }
340 if let (FieldValue::Array(array), FieldType::MFNode) = (value, field_type) {
341 return Ok(TemplateField::MFNode(extract_mf_values(
342 array,
343 |field_value| {
344 if let FieldValue::Node(node) = field_value {
345 let mut content = String::new();
346 ProtoWriter::new()
347 .write_node(&mut content, node)
348 .map_err(|_| "Failed to serialize node".to_string())?;
349 Ok(content)
350 } else if matches!(field_value, FieldValue::Null) {
351 Ok("NULL".to_string())
352 } else {
353 Err("Expected Node".to_string())
354 }
355 },
356 )?));
357 }
358
359 Err(format!(
360 "Unsupported field conversion from {:?} to {:?}",
361 value, field_type
362 ))
363}
364
365fn validate_override_type(
366 field_name: &str,
367 value: &TemplateField,
368 field_type: &FieldType,
369) -> Result<(), TemplateError> {
370 let type_matches = matches!(
371 (value, field_type),
372 (TemplateField::SFBool(_), FieldType::SFBool)
373 | (TemplateField::SFInt32(_), FieldType::SFInt32)
374 | (TemplateField::SFFloat(_), FieldType::SFFloat)
375 | (TemplateField::SFString(_), FieldType::SFString)
376 | (TemplateField::SFVec2f(_, _), FieldType::SFVec2f)
377 | (TemplateField::SFVec3f(_, _, _), FieldType::SFVec3f)
378 | (TemplateField::SFRotation(_, _, _, _), FieldType::SFRotation)
379 | (TemplateField::SFColor(_, _, _), FieldType::SFColor)
380 | (TemplateField::SFNode(_), FieldType::SFNode)
381 | (TemplateField::MFBool(_), FieldType::MFBool)
382 | (TemplateField::MFInt32(_), FieldType::MFInt32)
383 | (TemplateField::MFFloat(_), FieldType::MFFloat)
384 | (TemplateField::MFString(_), FieldType::MFString)
385 | (TemplateField::MFVec2f(_), FieldType::MFVec2f)
386 | (TemplateField::MFVec3f(_), FieldType::MFVec3f)
387 | (TemplateField::MFRotation(_), FieldType::MFRotation)
388 | (TemplateField::MFColor(_), FieldType::MFColor)
389 | (TemplateField::MFNode(_), FieldType::MFNode)
390 );
391
392 if type_matches {
393 Ok(())
394 } else {
395 Err(TemplateError::ValidationError(format!(
396 "Field '{}' override type mismatch: expected {:?}",
397 field_name, field_type
398 )))
399 }
400}
401
402fn extract_numbers_as_vec(sequence: &NumberSequence) -> Result<Vec<f64>, String> {
403 sequence
404 .elements
405 .iter()
406 .map(|element| match &element.value {
407 FieldValue::Float(value, _) => Ok(*value),
408 FieldValue::Int(value, _) => Ok(*value as f64),
409 other => Err(format!("Expected numeric element, got {:?}", other)),
410 })
411 .collect()
412}
413
414fn extract_numbers_as_int(sequence: &NumberSequence) -> Result<Vec<i32>, String> {
415 sequence
416 .elements
417 .iter()
418 .map(|element| match &element.value {
419 FieldValue::Int(value, _) => convert_i64_to_i32(*value),
420 other => Err(format!("Expected integer element, got {:?}", other)),
421 })
422 .collect()
423}
424
425fn extract_mf_values<T>(
426 array: &ArrayValue,
427 converter: impl Fn(&FieldValue) -> Result<T, String>,
428) -> Result<Vec<T>, String> {
429 array
430 .elements
431 .iter()
432 .map(|element| converter(&element.value))
433 .collect()
434}
435
436fn extract_grouped_vectors(array: &ArrayValue, width: usize) -> Result<Vec<Vec<f64>>, String> {
437 let mut numbers = Vec::new();
438 for element in &array.elements {
439 match &element.value {
440 FieldValue::Vec2f(vector) if width == 2 => numbers.push(vec![vector[0], vector[1]]),
441 FieldValue::Vec3f(vector) if width == 3 => {
442 numbers.push(vec![vector[0], vector[1], vector[2]])
443 }
444 FieldValue::NumberSequence(sequence) => {
445 let values = extract_numbers_as_vec(sequence)?;
446 if values.len() != width {
447 return Err(format!("Expected Vec{width}f"));
448 }
449 numbers.push(values);
450 }
451 FieldValue::Int(value, _) => {
452 if let Some(last) = numbers.last_mut()
453 && last.len() < width
454 {
455 last.push(*value as f64);
456 } else {
457 numbers.push(vec![*value as f64]);
458 }
459 }
460 FieldValue::Float(value, _) => {
461 if let Some(last) = numbers.last_mut()
462 && last.len() < width
463 {
464 last.push(*value);
465 } else {
466 numbers.push(vec![*value]);
467 }
468 }
469 _ => return Err(format!("Expected Vec{width}f")),
470 }
471 }
472
473 if numbers.iter().any(|group| group.len() != width) {
474 return Err(format!("Expected Vec{width}f"));
475 }
476
477 Ok(numbers)
478}
479
480fn convert_i64_to_i32(value: i64) -> Result<i32, String> {
481 i32::try_from(value).map_err(|_| format!("Integer {} does not fit in Int32", value))
482}