zng_layout/unit/length/
expr.rs

1use std::fmt;
2
3use zng_unit::{ByteLength, ByteUnits as _, Factor, Px};
4use zng_var::animation::Transitionable as _;
5
6use crate::{
7    context::LayoutMask,
8    unit::{Layout1d, LayoutAxis, Length, ParseCompositeError},
9};
10
11/// Represents an unresolved [`Length`] expression.
12#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
13#[non_exhaustive]
14pub enum LengthExpr {
15    /// Sums the both layout length.
16    Add(Length, Length),
17    /// Subtracts the first layout length from the second.
18    Sub(Length, Length),
19    /// Multiplies the layout length by the factor.
20    Mul(Length, Factor),
21    /// Divide the layout length by the factor.
22    Div(Length, Factor),
23    /// Maximum layout length.
24    Max(Length, Length),
25    /// Minimum layout length.
26    Min(Length, Length),
27    /// Computes the absolute layout length.
28    Abs(Length),
29    /// Negate the layout length.
30    Neg(Length),
31    /// Linear interpolate between lengths by factor.
32    Lerp(Length, Length, Factor),
33}
34impl LengthExpr {
35    /// Gets the total memory allocated by this length expression.
36    ///
37    /// This includes the sum of all nested [`Length::Expr`] heap memory.
38    pub fn memory_used(&self) -> ByteLength {
39        use LengthExpr::*;
40        std::mem::size_of::<LengthExpr>().bytes()
41            + match self {
42                Add(a, b) => a.heap_memory_used() + b.heap_memory_used(),
43                Sub(a, b) => a.heap_memory_used() + b.heap_memory_used(),
44                Mul(a, _) => a.heap_memory_used(),
45                Div(a, _) => a.heap_memory_used(),
46                Max(a, b) => a.heap_memory_used() + b.heap_memory_used(),
47                Min(a, b) => a.heap_memory_used() + b.heap_memory_used(),
48                Abs(a) => a.heap_memory_used(),
49                Neg(a) => a.heap_memory_used(),
50                Lerp(a, b, _) => a.heap_memory_used() + b.heap_memory_used(),
51            }
52    }
53
54    /// Convert to [`Length::Expr`], logs warning for memory use above 1kB, logs error for use > 20kB and collapses to [`Length::zero`].
55    ///
56    /// Every length expression created using the [`std::ops`] uses this method to check the constructed expression. Some operations
57    /// like iterator fold can cause an *expression explosion* where two lengths of different units that cannot
58    /// be evaluated immediately start an expression that subsequently is wrapped in a new expression for each operation done on it.
59    pub fn to_length_checked(self) -> Length {
60        let bytes = self.memory_used();
61        if bytes > 20.kibibytes() {
62            tracing::error!(target: "to_length_checked", "length alloc > 20kB, replaced with zero");
63            return Length::zero();
64        }
65        Length::Expr(Box::new(self))
66    }
67
68    /// If contains a [`Length::Default`] value.
69    pub fn has_default(&self) -> bool {
70        match self {
71            LengthExpr::Add(a, b) | LengthExpr::Sub(a, b) | LengthExpr::Max(a, b) | LengthExpr::Min(a, b) | LengthExpr::Lerp(a, b, _) => {
72                a.has_default() || b.has_default()
73            }
74            LengthExpr::Mul(a, _) | LengthExpr::Div(a, _) | LengthExpr::Abs(a) | LengthExpr::Neg(a) => a.has_default(),
75        }
76    }
77
78    /// Replace all [`Length::Default`] values with `overwrite`.
79    pub fn replace_default(&mut self, overwrite: &Length) {
80        match self {
81            LengthExpr::Add(a, b) | LengthExpr::Sub(a, b) | LengthExpr::Max(a, b) | LengthExpr::Min(a, b) | LengthExpr::Lerp(a, b, _) => {
82                a.replace_default(overwrite);
83                b.replace_default(overwrite);
84            }
85            LengthExpr::Mul(a, _) | LengthExpr::Div(a, _) | LengthExpr::Abs(a) | LengthExpr::Neg(a) => a.replace_default(overwrite),
86        }
87    }
88
89    /// Convert [`PxF32`] to [`Px`] and [`DipF32`] to [`Dip`].
90    ///
91    /// [`PxF32`]: Length::PxF32
92    /// [`Px`]: Length::Px
93    /// [`DipF32`]: Length::DipF32
94    /// [`Dip`]: Length::Dip
95    pub fn round_exact(&mut self) {
96        match self {
97            LengthExpr::Add(a, b) | LengthExpr::Sub(a, b) | LengthExpr::Max(a, b) | LengthExpr::Min(a, b) | LengthExpr::Lerp(a, b, _) => {
98                a.round_exact();
99                b.round_exact();
100            }
101            LengthExpr::Mul(a, _) | LengthExpr::Div(a, _) | LengthExpr::Abs(a) | LengthExpr::Neg(a) => a.round_exact(),
102        }
103    }
104}
105impl Layout1d for LengthExpr {
106    fn layout_dft(&self, axis: LayoutAxis, default: Px) -> Px {
107        let l = self.layout_f32_dft(axis, default.0 as f32);
108        Px(l.round() as i32)
109    }
110
111    fn layout_f32_dft(&self, axis: LayoutAxis, default: f32) -> f32 {
112        use LengthExpr::*;
113        match self {
114            Add(a, b) => a.layout_f32_dft(axis, default) + b.layout_f32_dft(axis, default),
115            Sub(a, b) => a.layout_f32_dft(axis, default) - b.layout_f32_dft(axis, default),
116            Mul(l, s) => l.layout_f32_dft(axis, default) * s.0,
117            Div(l, s) => l.layout_f32_dft(axis, default) / s.0,
118            Max(a, b) => {
119                let a = a.layout_f32_dft(axis, default);
120                let b = b.layout_f32_dft(axis, default);
121                a.max(b)
122            }
123            Min(a, b) => {
124                let a = a.layout_f32_dft(axis, default);
125                let b = b.layout_f32_dft(axis, default);
126                a.min(b)
127            }
128            Abs(e) => e.layout_f32_dft(axis, default).abs(),
129            Neg(e) => -e.layout_f32_dft(axis, default),
130            Lerp(a, b, f) => a.layout_f32_dft(axis, default).lerp(&b.layout_f32_dft(axis, default), *f),
131        }
132    }
133
134    fn affect_mask(&self) -> LayoutMask {
135        use LengthExpr::*;
136        match self {
137            Add(a, b) => a.affect_mask() | b.affect_mask(),
138            Sub(a, b) => a.affect_mask() | b.affect_mask(),
139            Mul(a, _) => a.affect_mask(),
140            Div(a, _) => a.affect_mask(),
141            Max(a, b) => a.affect_mask() | b.affect_mask(),
142            Min(a, b) => a.affect_mask() | b.affect_mask(),
143            Abs(a) => a.affect_mask(),
144            Neg(a) => a.affect_mask(),
145            Lerp(a, b, _) => a.affect_mask() | b.affect_mask(),
146        }
147    }
148}
149impl fmt::Debug for LengthExpr {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        use LengthExpr::*;
152        if f.alternate() {
153            match self {
154                Add(a, b) => f.debug_tuple("LengthExpr::Add").field(a).field(b).finish(),
155                Sub(a, b) => f.debug_tuple("LengthExpr::Sub").field(a).field(b).finish(),
156                Mul(l, s) => f.debug_tuple("LengthExpr::Mul").field(l).field(s).finish(),
157                Div(l, s) => f.debug_tuple("LengthExpr::Div").field(l).field(s).finish(),
158                Max(a, b) => f.debug_tuple("LengthExpr::Max").field(a).field(b).finish(),
159                Min(a, b) => f.debug_tuple("LengthExpr::Min").field(a).field(b).finish(),
160                Abs(e) => f.debug_tuple("LengthExpr::Abs").field(e).finish(),
161                Neg(e) => f.debug_tuple("LengthExpr::Neg").field(e).finish(),
162                Lerp(a, b, n) => f.debug_tuple("LengthExpr::Lerp").field(a).field(b).field(n).finish(),
163            }
164        } else {
165            match self {
166                Add(a, b) => write!(f, "({a:.p$?} + {b:.p$?})", p = f.precision().unwrap_or(0)),
167                Sub(a, b) => write!(f, "({a:.p$?} - {b:.p$?})", p = f.precision().unwrap_or(0)),
168                Mul(l, s) => write!(f, "({l:.p$?} * {:.p$?}.pct())", s.0 * 100.0, p = f.precision().unwrap_or(0)),
169                Div(l, s) => write!(f, "({l:.p$?} / {:.p$?}.pct())", s.0 * 100.0, p = f.precision().unwrap_or(0)),
170                Max(a, b) => write!(f, "max({a:.p$?}, {b:.p$?})", p = f.precision().unwrap_or(0)),
171                Min(a, b) => write!(f, "min({a:.p$?}, {b:.p$?})", p = f.precision().unwrap_or(0)),
172                Abs(e) => write!(f, "abs({e:.p$?})", p = f.precision().unwrap_or(0)),
173                Neg(e) => write!(f, "-({e:.p$?})", p = f.precision().unwrap_or(0)),
174                Lerp(a, b, n) => write!(f, "lerp({a:.p$?}, {b:.p$?}, {n:.p$?})", p = f.precision().unwrap_or(0)),
175            }
176        }
177    }
178}
179impl fmt::Display for LengthExpr {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        use LengthExpr::*;
182        match self {
183            Add(a, b) => write!(f, "({a:.p$} + {b:.p$})", p = f.precision().unwrap_or(0)),
184            Sub(a, b) => write!(f, "({a:.p$} - {b:.p$})", p = f.precision().unwrap_or(0)),
185            Mul(l, s) => write!(f, "({l:.p$} * {:.p$}%)", s.0 * 100.0, p = f.precision().unwrap_or(0)),
186            Div(l, s) => write!(f, "({l:.p$} / {:.p$}%)", s.0 * 100.0, p = f.precision().unwrap_or(0)),
187            Max(a, b) => write!(f, "max({a:.p$}, {b:.p$})", p = f.precision().unwrap_or(0)),
188            Min(a, b) => write!(f, "min({a:.p$}, {b:.p$})", p = f.precision().unwrap_or(0)),
189            Abs(e) => write!(f, "abs({e:.p$})", p = f.precision().unwrap_or(0)),
190            Neg(e) => write!(f, "-({e:.p$})", p = f.precision().unwrap_or(0)),
191            Lerp(a, b, n) => write!(f, "lerp({a:.p$}, {b:.p$}, {n:.p$})", p = f.precision().unwrap_or(0)),
192        }
193    }
194}
195impl std::str::FromStr for LengthExpr {
196    type Err = ParseCompositeError;
197
198    fn from_str(s: &str) -> Result<Self, Self::Err> {
199        let expr = Parser::new(s).parse()?;
200        match Length::try_from(expr)? {
201            Length::Expr(expr) => Ok(*expr),
202            _ => Err(ParseCompositeError::MissingComponent),
203        }
204    }
205}
206
207impl<'a> TryFrom<Expr<'a>> for Length {
208    type Error = ParseCompositeError;
209
210    fn try_from(value: Expr) -> Result<Self, Self::Error> {
211        match value {
212            Expr::Value(l) => l.parse(),
213            Expr::UnaryOp { op, rhs } => match op {
214                '-' => Ok(LengthExpr::Neg(Length::try_from(*rhs)?).into()),
215                '+' => Length::try_from(*rhs),
216                _ => Err(ParseCompositeError::UnknownFormat),
217            },
218            Expr::BinaryOp { op, lhs, rhs } => match op {
219                '+' => Ok(LengthExpr::Add(Length::try_from(*lhs)?, Length::try_from(*rhs)?).into()),
220                '-' => Ok(LengthExpr::Sub(Length::try_from(*lhs)?, Length::try_from(*rhs)?).into()),
221                '*' => Ok(LengthExpr::Mul(Length::try_from(*lhs)?, try_into_scale(*rhs)?).into()),
222                '/' => Ok(LengthExpr::Div(Length::try_from(*lhs)?, try_into_scale(*rhs)?).into()),
223                _ => Err(ParseCompositeError::UnknownFormat),
224            },
225            Expr::Call { name, mut args } => match name {
226                "max" => {
227                    let [a, b] = try_args(args)?;
228                    Ok(LengthExpr::Max(a, b).into())
229                }
230                "min" => {
231                    let [a, b] = try_args(args)?;
232                    Ok(LengthExpr::Min(a, b).into())
233                }
234                "abs" => {
235                    let [a] = try_args(args)?;
236                    Ok(LengthExpr::Abs(a).into())
237                }
238                "lerp" => {
239                    let s = args.pop().ok_or(ParseCompositeError::MissingComponent)?;
240                    let [a, b] = try_args(args)?;
241                    let s = try_into_scale(s)?;
242                    Ok(LengthExpr::Lerp(a, b, s).into())
243                }
244                _ => Err(ParseCompositeError::UnknownFormat),
245            },
246        }
247    }
248}
249fn try_into_scale(rhs: Expr) -> Result<Factor, ParseCompositeError> {
250    if let Length::Factor(f) = Length::try_from(rhs)? {
251        Ok(f)
252    } else {
253        Err(ParseCompositeError::UnknownFormat)
254    }
255}
256fn try_args<const N: usize>(args: Vec<Expr>) -> Result<[Length; N], ParseCompositeError> {
257    match args.len().cmp(&N) {
258        std::cmp::Ordering::Less => Err(ParseCompositeError::MissingComponent),
259        std::cmp::Ordering::Equal => Ok(args
260            .into_iter()
261            .map(Length::try_from)
262            .collect::<Result<Vec<Length>, ParseCompositeError>>()?
263            .try_into()
264            .unwrap()),
265        std::cmp::Ordering::Greater => Err(ParseCompositeError::ExtraComponent),
266    }
267}
268
269/// Basic string representation of `lengthExpr`, without validating functions and Length values.
270#[derive(Debug, PartialEq)]
271enum Expr<'a> {
272    #[allow(unused)]
273    Value(&'a str),
274    UnaryOp {
275        op: char,
276        rhs: Box<Expr<'a>>,
277    },
278    BinaryOp {
279        op: char,
280        lhs: Box<Expr<'a>>,
281        rhs: Box<Expr<'a>>,
282    },
283    Call {
284        name: &'a str,
285        args: Vec<Expr<'a>>,
286    },
287}
288
289struct Parser<'a> {
290    input: &'a str,
291    pos: usize,
292    len: usize,
293}
294impl<'a> Parser<'a> {
295    pub fn new(input: &'a str) -> Self {
296        Self {
297            input,
298            pos: 0,
299            len: input.len(),
300        }
301    }
302
303    fn peek_char(&self) -> Option<char> {
304        self.input[self.pos..].chars().next()
305    }
306
307    fn next_char(&mut self) -> Option<char> {
308        if self.pos >= self.len {
309            return None;
310        }
311        let ch = self.peek_char()?;
312        self.pos += ch.len_utf8();
313        Some(ch)
314    }
315
316    fn consume_whitespace(&mut self) {
317        while let Some(ch) = self.peek_char() {
318            if ch.is_whitespace() {
319                self.next_char();
320            } else {
321                break;
322            }
323        }
324    }
325
326    fn starts_with_nonop(&self, ch: char) -> bool {
327        !ch.is_whitespace() && !matches!(ch, '+' | '-' | '*' | '/' | '(' | ')' | ',')
328    }
329
330    fn parse_value_token(&mut self) -> Result<&'a str, ParseCompositeError> {
331        self.consume_whitespace();
332        let start = self.pos;
333        while let Some(ch) = self.peek_char() {
334            if self.starts_with_nonop(ch) {
335                self.next_char();
336            } else {
337                break;
338            }
339        }
340        let s = &self.input[start..self.pos];
341        if s.is_empty() {
342            Err(ParseCompositeError::MissingComponent)
343        } else {
344            Ok(s)
345        }
346    }
347
348    pub fn parse(&mut self) -> Result<Expr<'a>, ParseCompositeError> {
349        self.consume_whitespace();
350        let expr = self.parse_expr_bp(0)?;
351        self.consume_whitespace();
352        if self.pos < self.len {
353            Err(ParseCompositeError::ExtraComponent)
354        } else {
355            Ok(expr)
356        }
357    }
358
359    fn infix_binding_power(op: char) -> Option<(u32, u32)> {
360        match op {
361            '+' | '-' => Some((10, 11)), // low precedence
362            '*' | '/' => Some((20, 21)), // higher precedence
363            _ => None,
364        }
365    }
366
367    fn parse_expr_bp(&mut self, min_bp: u32) -> Result<Expr<'a>, ParseCompositeError> {
368        self.consume_whitespace();
369
370        // --- prefix / primary ---
371        let mut lhs = match self.peek_char() {
372            Some('-') => {
373                // unary -
374                self.next_char();
375                let rhs = self.parse_expr_bp(100)?; // high precedence for unary
376                Expr::UnaryOp {
377                    op: '-',
378                    rhs: Box::new(rhs),
379                }
380            }
381            Some('(') => {
382                // parenthesized expression
383                self.next_char(); // consume '('
384                let inner = self.parse_expr_bp(0)?;
385                self.consume_whitespace();
386                match self.next_char() {
387                    Some(')') => inner,
388                    _ => return Err(ParseCompositeError::MissingComponent),
389                }
390            }
391            Some(ch) if self.starts_with_nonop(ch) => {
392                // value token or function call
393                let token = self.parse_value_token()?;
394                // check if function call: next non-space char is '('
395                self.consume_whitespace();
396                if let Some('(') = self.peek_char() {
397                    // function call: name(token) (must have at least one arg)
398                    let name = token;
399                    self.next_char(); // consume '('
400                    let mut args = Vec::new();
401                    self.consume_whitespace();
402                    if let Some(')') = self.peek_char() {
403                        return Err(ParseCompositeError::MissingComponent);
404                    }
405                    // parse first arg
406                    loop {
407                        self.consume_whitespace();
408                        let arg = self.parse_expr_bp(0)?;
409                        args.push(arg);
410                        self.consume_whitespace();
411                        match self.peek_char() {
412                            Some(',') => {
413                                self.next_char();
414                                continue;
415                            }
416                            Some(')') => {
417                                self.next_char();
418                                break;
419                            }
420                            Some(_) => return Err(ParseCompositeError::ExtraComponent),
421                            None => return Err(ParseCompositeError::MissingComponent),
422                        }
423                    }
424                    Expr::Call { name, args }
425                } else {
426                    Expr::Value(token)
427                }
428            }
429            Some(_) => return Err(ParseCompositeError::ExtraComponent),
430            None => return Err(ParseCompositeError::MissingComponent),
431        };
432
433        // --- infix loop: while there's an operator with precedence >= min_bp ---
434        loop {
435            self.consume_whitespace();
436            let op = match self.peek_char() {
437                Some(c) if matches!(c, '+' | '-' | '*' | '/') => c,
438                _ => break,
439            };
440
441            if let Some((l_bp, r_bp)) = Self::infix_binding_power(op) {
442                if l_bp < min_bp {
443                    break;
444                }
445                // consume operator
446                self.next_char();
447                // parse rhs with r_bp
448                let rhs = self.parse_expr_bp(r_bp)?;
449                lhs = Expr::BinaryOp {
450                    op,
451                    lhs: Box::new(lhs),
452                    rhs: Box::new(rhs),
453                };
454            } else {
455                break;
456            }
457        }
458
459        Ok(lhs)
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    fn parse_ok(s: &str) -> Expr<'_> {
468        let mut p = Parser::new(s);
469        p.parse().unwrap()
470    }
471
472    #[test]
473    fn test_values() {
474        assert_eq!(parse_ok("default"), Expr::Value("default"));
475        assert_eq!(parse_ok("3.14"), Expr::Value("3.14"));
476        assert_eq!(parse_ok("abc.def"), Expr::Value("abc.def"));
477    }
478
479    #[test]
480    fn test_unary() {
481        assert_eq!(
482            parse_ok("-x"),
483            Expr::UnaryOp {
484                op: '-',
485                rhs: Box::new(Expr::Value("x"))
486            }
487        );
488        assert_eq!(
489            parse_ok("--3"),
490            Expr::UnaryOp {
491                op: '-',
492                rhs: Box::new(Expr::UnaryOp {
493                    op: '-',
494                    rhs: Box::new(Expr::Value("3"))
495                })
496            }
497        );
498    }
499
500    #[test]
501    fn test_binary_prec() {
502        // 1 + 2 * 3 => 1 + (2 * 3)
503        let e = parse_ok("1 + 2 * 3");
504        assert_eq!(
505            e,
506            Expr::BinaryOp {
507                op: '+',
508                lhs: Box::new(Expr::Value("1")),
509                rhs: Box::new(Expr::BinaryOp {
510                    op: '*',
511                    lhs: Box::new(Expr::Value("2")),
512                    rhs: Box::new(Expr::Value("3")),
513                })
514            }
515        );
516
517        // (1 + 2) * 3
518        let e = parse_ok("(1 + 2) * 3");
519        assert_eq!(
520            e,
521            Expr::BinaryOp {
522                op: '*',
523                lhs: Box::new(Expr::BinaryOp {
524                    op: '+',
525                    lhs: Box::new(Expr::Value("1")),
526                    rhs: Box::new(Expr::Value("2")),
527                }),
528                rhs: Box::new(Expr::Value("3"))
529            }
530        );
531    }
532
533    #[test]
534    fn test_call() {
535        let e = parse_ok("f(a, b + 2, -3)");
536        assert_eq!(
537            e,
538            Expr::Call {
539                name: "f",
540                args: vec![
541                    Expr::Value("a"),
542                    Expr::BinaryOp {
543                        op: '+',
544                        lhs: Box::new(Expr::Value("b")),
545                        rhs: Box::new(Expr::Value("2")),
546                    },
547                    Expr::UnaryOp {
548                        op: '-',
549                        rhs: Box::new(Expr::Value("3"))
550                    },
551                ],
552            }
553        );
554    }
555}