Skip to main content

webots_proto_template/
render.rs

1use std::collections::HashMap;
2
3use derive_new::new;
4use derive_setters::Setters;
5use webots_proto_ast::proto::ast::{
6    ArrayValue, FieldType, FieldValue, NumberSequence, Proto, ProtoBodyItem,
7};
8use webots_proto_ast::proto::writer::ProtoWriter;
9
10use crate::template::types::{
11    TemplateContext, TemplateField, TemplateFieldBinding, TemplateWebotsVersion,
12};
13use crate::{TemplateError, TemplateEvaluator};
14
15#[derive(Debug, Clone, Default, Setters)]
16#[setters(prefix = "with_", strip_option, into)]
17pub struct RenderContext {
18    pub world: Option<String>,
19    pub proto: Option<String>,
20    pub project_path: Option<String>,
21    pub webots_version: Option<RenderWebotsVersion>,
22    pub webots_home: Option<String>,
23    pub temporary_files_path: Option<String>,
24    pub os: Option<String>,
25    pub id: Option<String>,
26    pub coordinate_system: Option<String>,
27}
28
29#[derive(Debug, Clone, Default, new, Setters)]
30#[setters(prefix = "with_", strip_option, into)]
31pub struct RenderWebotsVersion {
32    pub major: String,
33    pub revision: String,
34}
35
36#[derive(Debug, Clone, Default, new, Setters)]
37#[setters(prefix = "with_", strip_option)]
38pub struct RenderOptions {
39    pub field_overrides: HashMap<String, TemplateField>,
40    pub context: RenderContext,
41}
42
43impl RenderOptions {
44    pub fn with_field_overrides_from<I, K, V>(mut self, overrides: I) -> Self
45    where
46        I: IntoIterator<Item = (K, V)>,
47        K: Into<String>,
48        V: Into<TemplateField>,
49    {
50        self.field_overrides = overrides
51            .into_iter()
52            .map(|(key, value)| (key.into(), value.into()))
53            .collect();
54        self
55    }
56}
57
58pub fn render(proto: &Proto, options: &RenderOptions) -> Result<String, TemplateError> {
59    let Some(proto_definition) = &proto.proto else {
60        return Err(TemplateError::ValidationError(
61            "Document does not contain a PROTO definition".to_string(),
62        ));
63    };
64
65    let mut context_fields = HashMap::new();
66    for field in &proto_definition.fields {
67        let default_field_value = field
68            .default_value
69            .as_ref()
70            .map(|default_value| {
71                convert_to_template_field(default_value, &field.field_type).map_err(|error| {
72                    TemplateError::ValidationError(format!("Field '{}': {}", field.name, error))
73                })
74            })
75            .transpose()?;
76
77        if let Some(override_field_value) = options.field_overrides.get(&field.name) {
78            validate_override_type(&field.name, override_field_value, &field.field_type)?;
79            let default_value_for_binding =
80                default_field_value.unwrap_or_else(|| override_field_value.clone());
81            context_fields.insert(
82                field.name.clone(),
83                TemplateFieldBinding::new(override_field_value.clone(), default_value_for_binding),
84            );
85            continue;
86        }
87
88        if let Some(default_value) = default_field_value {
89            context_fields.insert(
90                field.name.clone(),
91                TemplateFieldBinding::new(default_value.clone(), default_value),
92            );
93        } else {
94            return Err(TemplateError::ValidationError(format!(
95                "Field '{}' has no default and no override value",
96                field.name
97            )));
98        }
99    }
100
101    for override_field_name in options.field_overrides.keys() {
102        if proto_definition
103            .fields
104            .iter()
105            .all(|field| field.name != *override_field_name)
106        {
107            return Err(TemplateError::ValidationError(format!(
108                "Unknown field override '{}': no matching PROTO interface field",
109                override_field_name
110            )));
111        }
112    }
113
114    let body_content = if let Some(source) = &proto.source_content {
115        if let (Some(first_item), Some(last_item)) =
116            (proto_definition.body.first(), proto_definition.body.last())
117        {
118            let start = match first_item {
119                ProtoBodyItem::Node(node) => node.span.start,
120                ProtoBodyItem::Template(template) => template.span.start,
121            };
122            let end = match last_item {
123                ProtoBodyItem::Node(node) => node.span.end,
124                ProtoBodyItem::Template(template) => template.span.end,
125            };
126            source.get(start..end).unwrap_or("").to_string()
127        } else {
128            String::new()
129        }
130    } else {
131        let writer = ProtoWriter::new();
132        let mut body_content = String::new();
133        for item in &proto_definition.body {
134            match item {
135                ProtoBodyItem::Node(node) => writer
136                    .write_node(&mut body_content, node)
137                    .map_err(|error| TemplateError::ValidationError(error.to_string()))?,
138                ProtoBodyItem::Template(template) => writer
139                    .write_template(&mut body_content, template)
140                    .map_err(|error| TemplateError::ValidationError(error.to_string()))?,
141            }
142        }
143        body_content
144    };
145
146    let template_context = convert_render_context(&options.context);
147    TemplateEvaluator::with_context(template_context)
148        .evaluate_with_environment(&body_content, &context_fields)
149}
150
151fn convert_render_context(context: &RenderContext) -> TemplateContext {
152    TemplateContext {
153        world: context.world.clone(),
154        proto: context.proto.clone(),
155        project_path: context.project_path.clone(),
156        webots_home: context.webots_home.clone(),
157        temporary_files_path: context.temporary_files_path.clone(),
158        os: context.os.clone(),
159        id: context.id.clone(),
160        coordinate_system: context.coordinate_system.clone(),
161        webots_version: context.webots_version.as_ref().map(|version| {
162            TemplateWebotsVersion::new(version.major.clone(), version.revision.clone())
163        }),
164    }
165}
166
167fn convert_to_template_field(
168    value: &FieldValue,
169    field_type: &FieldType,
170) -> Result<TemplateField, String> {
171    if let (FieldValue::Bool(boolean), FieldType::SFBool) = (value, field_type) {
172        return Ok(TemplateField::SFBool(*boolean));
173    }
174    if let (FieldValue::Int(integer, _), FieldType::SFInt32) = (value, field_type) {
175        return Ok(TemplateField::SFInt32(convert_i64_to_i32(*integer)?));
176    }
177    if let (FieldValue::Int(integer, _), FieldType::SFFloat) = (value, field_type) {
178        return Ok(TemplateField::SFFloat(*integer as f64));
179    }
180    if let (FieldValue::Float(float, _), FieldType::SFFloat) = (value, field_type) {
181        return Ok(TemplateField::SFFloat(*float));
182    }
183    if let (FieldValue::String(string), FieldType::SFString) = (value, field_type) {
184        return Ok(TemplateField::SFString(string.clone()));
185    }
186    if let (FieldValue::Vec2f(vector), FieldType::SFVec2f) = (value, field_type) {
187        return Ok(TemplateField::SFVec2f(vector[0], vector[1]));
188    }
189    if let (FieldValue::Vec3f(vector), FieldType::SFVec3f) = (value, field_type) {
190        return Ok(TemplateField::SFVec3f(vector[0], vector[1], vector[2]));
191    }
192    if let (FieldValue::Rotation(vector), FieldType::SFRotation) = (value, field_type) {
193        return Ok(TemplateField::SFRotation(
194            vector[0], vector[1], vector[2], vector[3],
195        ));
196    }
197    if let (FieldValue::Color(vector), FieldType::SFColor) = (value, field_type) {
198        return Ok(TemplateField::SFColor(vector[0], vector[1], vector[2]));
199    }
200    if let (FieldValue::Node(node), FieldType::SFNode) = (value, field_type) {
201        let mut content = String::new();
202        ProtoWriter::new()
203            .write_node(&mut content, node)
204            .map_err(|_| "Failed to serialize node".to_string())?;
205        return Ok(TemplateField::SFNode(content));
206    }
207    if let (FieldValue::Null, FieldType::SFNode) = (value, field_type) {
208        return Ok(TemplateField::SFNode("NULL".into()));
209    }
210
211    if let (FieldValue::NumberSequence(sequence), FieldType::SFVec2f) = (value, field_type)
212        && sequence.elements.len() == 2
213    {
214        let numbers = extract_numbers_as_vec(sequence)?;
215        return Ok(TemplateField::SFVec2f(numbers[0], numbers[1]));
216    }
217    if let (FieldValue::NumberSequence(sequence), FieldType::SFVec3f) = (value, field_type)
218        && sequence.elements.len() == 3
219    {
220        let numbers = extract_numbers_as_vec(sequence)?;
221        return Ok(TemplateField::SFVec3f(numbers[0], numbers[1], numbers[2]));
222    }
223    if let (FieldValue::NumberSequence(sequence), FieldType::SFColor) = (value, field_type)
224        && sequence.elements.len() == 3
225    {
226        let numbers = extract_numbers_as_vec(sequence)?;
227        return Ok(TemplateField::SFColor(numbers[0], numbers[1], numbers[2]));
228    }
229    if let (FieldValue::NumberSequence(sequence), FieldType::SFRotation) = (value, field_type)
230        && sequence.elements.len() == 4
231    {
232        let numbers = extract_numbers_as_vec(sequence)?;
233        return Ok(TemplateField::SFRotation(
234            numbers[0], numbers[1], numbers[2], numbers[3],
235        ));
236    }
237    if let (FieldValue::NumberSequence(sequence), FieldType::SFFloat) = (value, field_type)
238        && sequence.elements.len() == 1
239    {
240        let numbers = extract_numbers_as_vec(sequence)?;
241        return Ok(TemplateField::SFFloat(numbers[0]));
242    }
243    if let (FieldValue::NumberSequence(sequence), FieldType::SFInt32) = (value, field_type)
244        && sequence.elements.len() == 1
245    {
246        let numbers = extract_numbers_as_int(sequence)?;
247        return Ok(TemplateField::SFInt32(numbers[0]));
248    }
249
250    if let (FieldValue::Array(array), FieldType::MFBool) = (value, field_type) {
251        return Ok(TemplateField::MFBool(extract_mf_values(
252            array,
253            |field_value| {
254                if let FieldValue::Bool(boolean) = field_value {
255                    Ok(*boolean)
256                } else {
257                    Err("Expected Bool".to_string())
258                }
259            },
260        )?));
261    }
262    if let (FieldValue::Array(array), FieldType::MFInt32) = (value, field_type) {
263        return Ok(TemplateField::MFInt32(extract_mf_values(
264            array,
265            |field_value| {
266                if let FieldValue::Int(integer, _) = field_value {
267                    Ok(convert_i64_to_i32(*integer)?)
268                } else {
269                    Err("Expected Int".to_string())
270                }
271            },
272        )?));
273    }
274    if let (FieldValue::Array(array), FieldType::MFFloat) = (value, field_type) {
275        return Ok(TemplateField::MFFloat(extract_mf_values(
276            array,
277            |field_value| {
278                if let FieldValue::Float(float, _) = field_value {
279                    Ok(*float)
280                } else if let FieldValue::Int(integer, _) = field_value {
281                    Ok(*integer as f64)
282                } else {
283                    Err("Expected Float".to_string())
284                }
285            },
286        )?));
287    }
288    if let (FieldValue::Array(array), FieldType::MFString) = (value, field_type) {
289        return Ok(TemplateField::MFString(extract_mf_values(
290            array,
291            |field_value| {
292                if let FieldValue::String(string) = field_value {
293                    Ok(string.clone())
294                } else {
295                    Err("Expected String".to_string())
296                }
297            },
298        )?));
299    }
300    if let (FieldValue::Array(array), FieldType::MFVec2f) = (value, field_type) {
301        return Ok(TemplateField::MFVec2f(
302            extract_grouped_vectors(array, 2)?
303                .into_iter()
304                .map(|values| (values[0], values[1]))
305                .collect(),
306        ));
307    }
308    if let (FieldValue::Array(array), FieldType::MFVec3f) = (value, field_type) {
309        return Ok(TemplateField::MFVec3f(
310            extract_grouped_vectors(array, 3)?
311                .into_iter()
312                .map(|values| (values[0], values[1], values[2]))
313                .collect(),
314        ));
315    }
316    if let (FieldValue::Array(array), FieldType::MFRotation) = (value, field_type) {
317        return Ok(TemplateField::MFRotation(extract_mf_values(
318            array,
319            |field_value| {
320                if let FieldValue::Rotation(vector) = field_value {
321                    Ok((vector[0], vector[1], vector[2], vector[3]))
322                } else {
323                    Err("Expected Rotation".to_string())
324                }
325            },
326        )?));
327    }
328    if let (FieldValue::Array(array), FieldType::MFColor) = (value, field_type) {
329        return Ok(TemplateField::MFColor(extract_mf_values(
330            array,
331            |field_value| {
332                if let FieldValue::Color(vector) = field_value {
333                    Ok((vector[0], vector[1], vector[2]))
334                } else {
335                    Err("Expected Color".to_string())
336                }
337            },
338        )?));
339    }
340    if let (FieldValue::Array(array), FieldType::MFNode) = (value, field_type) {
341        return Ok(TemplateField::MFNode(extract_mf_values(
342            array,
343            |field_value| {
344                if let FieldValue::Node(node) = field_value {
345                    let mut content = String::new();
346                    ProtoWriter::new()
347                        .write_node(&mut content, node)
348                        .map_err(|_| "Failed to serialize node".to_string())?;
349                    Ok(content)
350                } else if matches!(field_value, FieldValue::Null) {
351                    Ok("NULL".to_string())
352                } else {
353                    Err("Expected Node".to_string())
354                }
355            },
356        )?));
357    }
358
359    Err(format!(
360        "Unsupported field conversion from {:?} to {:?}",
361        value, field_type
362    ))
363}
364
365fn validate_override_type(
366    field_name: &str,
367    value: &TemplateField,
368    field_type: &FieldType,
369) -> Result<(), TemplateError> {
370    let type_matches = matches!(
371        (value, field_type),
372        (TemplateField::SFBool(_), FieldType::SFBool)
373            | (TemplateField::SFInt32(_), FieldType::SFInt32)
374            | (TemplateField::SFFloat(_), FieldType::SFFloat)
375            | (TemplateField::SFString(_), FieldType::SFString)
376            | (TemplateField::SFVec2f(_, _), FieldType::SFVec2f)
377            | (TemplateField::SFVec3f(_, _, _), FieldType::SFVec3f)
378            | (TemplateField::SFRotation(_, _, _, _), FieldType::SFRotation)
379            | (TemplateField::SFColor(_, _, _), FieldType::SFColor)
380            | (TemplateField::SFNode(_), FieldType::SFNode)
381            | (TemplateField::MFBool(_), FieldType::MFBool)
382            | (TemplateField::MFInt32(_), FieldType::MFInt32)
383            | (TemplateField::MFFloat(_), FieldType::MFFloat)
384            | (TemplateField::MFString(_), FieldType::MFString)
385            | (TemplateField::MFVec2f(_), FieldType::MFVec2f)
386            | (TemplateField::MFVec3f(_), FieldType::MFVec3f)
387            | (TemplateField::MFRotation(_), FieldType::MFRotation)
388            | (TemplateField::MFColor(_), FieldType::MFColor)
389            | (TemplateField::MFNode(_), FieldType::MFNode)
390    );
391
392    if type_matches {
393        Ok(())
394    } else {
395        Err(TemplateError::ValidationError(format!(
396            "Field '{}' override type mismatch: expected {:?}",
397            field_name, field_type
398        )))
399    }
400}
401
402fn extract_numbers_as_vec(sequence: &NumberSequence) -> Result<Vec<f64>, String> {
403    sequence
404        .elements
405        .iter()
406        .map(|element| match &element.value {
407            FieldValue::Float(value, _) => Ok(*value),
408            FieldValue::Int(value, _) => Ok(*value as f64),
409            other => Err(format!("Expected numeric element, got {:?}", other)),
410        })
411        .collect()
412}
413
414fn extract_numbers_as_int(sequence: &NumberSequence) -> Result<Vec<i32>, String> {
415    sequence
416        .elements
417        .iter()
418        .map(|element| match &element.value {
419            FieldValue::Int(value, _) => convert_i64_to_i32(*value),
420            other => Err(format!("Expected integer element, got {:?}", other)),
421        })
422        .collect()
423}
424
425fn extract_mf_values<T>(
426    array: &ArrayValue,
427    converter: impl Fn(&FieldValue) -> Result<T, String>,
428) -> Result<Vec<T>, String> {
429    array
430        .elements
431        .iter()
432        .map(|element| converter(&element.value))
433        .collect()
434}
435
436fn extract_grouped_vectors(array: &ArrayValue, width: usize) -> Result<Vec<Vec<f64>>, String> {
437    let mut numbers = Vec::new();
438    for element in &array.elements {
439        match &element.value {
440            FieldValue::Vec2f(vector) if width == 2 => numbers.push(vec![vector[0], vector[1]]),
441            FieldValue::Vec3f(vector) if width == 3 => {
442                numbers.push(vec![vector[0], vector[1], vector[2]])
443            }
444            FieldValue::NumberSequence(sequence) => {
445                let values = extract_numbers_as_vec(sequence)?;
446                if values.len() != width {
447                    return Err(format!("Expected Vec{width}f"));
448                }
449                numbers.push(values);
450            }
451            FieldValue::Int(value, _) => {
452                if let Some(last) = numbers.last_mut()
453                    && last.len() < width
454                {
455                    last.push(*value as f64);
456                } else {
457                    numbers.push(vec![*value as f64]);
458                }
459            }
460            FieldValue::Float(value, _) => {
461                if let Some(last) = numbers.last_mut()
462                    && last.len() < width
463                {
464                    last.push(*value);
465                } else {
466                    numbers.push(vec![*value]);
467                }
468            }
469            _ => return Err(format!("Expected Vec{width}f")),
470        }
471    }
472
473    if numbers.iter().any(|group| group.len() != width) {
474        return Err(format!("Expected Vec{width}f"));
475    }
476
477    Ok(numbers)
478}
479
480fn convert_i64_to_i32(value: i64) -> Result<i32, String> {
481    i32::try_from(value).map_err(|_| format!("Integer {} does not fit in Int32", value))
482}