wynd_utils/
curve.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use thiserror::Error;
5
6use cosmwasm_std::Uint128;
7
8#[derive(Error, Debug, PartialEq)]
9pub enum CurveError {
10    #[error("Curve isn't monotonic")]
11    NotMonotonic,
12
13    #[error("Curve is monotonic increasing")]
14    MonotonicIncreasing,
15
16    #[error("Curve is monotonic decreasing")]
17    MonotonicDecreasing,
18
19    #[error("Later point must have higher X than previous point")]
20    PointsOutOfOrder,
21
22    #[error("No steps defined")]
23    MissingSteps,
24}
25
26#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
27#[serde(rename_all = "snake_case")]
28pub enum Curve {
29    Constant { y: Uint128 },
30    SaturatingLinear(SaturatingLinear),
31    PiecewiseLinear(PiecewiseLinear),
32}
33
34impl Curve {
35    pub fn saturating_linear((min_x, min_y): (u64, u128), (max_x, max_y): (u64, u128)) -> Self {
36        Curve::SaturatingLinear(SaturatingLinear {
37            min_x,
38            min_y: min_y.into(),
39            max_x,
40            max_y: max_y.into(),
41        })
42    }
43
44    pub fn constant(y: u128) -> Self {
45        Curve::Constant { y: Uint128::new(y) }
46    }
47}
48
49impl Curve {
50    /// provides y = f(x) evaluation
51    pub fn value(&self, x: u64) -> Uint128 {
52        match self {
53            Curve::Constant { y } => *y,
54            Curve::SaturatingLinear(s) => s.value(x),
55            Curve::PiecewiseLinear(p) => p.value(x),
56        }
57    }
58
59    /// general sanity checks on input values to ensure this is valid.
60    /// these checks should be included by the other validate_* functions
61    pub fn validate(&self) -> Result<(), CurveError> {
62        match self {
63            Curve::Constant { .. } => Ok(()),
64            Curve::SaturatingLinear(s) => s.validate(),
65            Curve::PiecewiseLinear(p) => p.validate(),
66        }
67    }
68
69    /// returns an error if there is ever x2 > x1 such that value(x2) < value(x1)
70    pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
71        match self {
72            Curve::Constant { .. } => Ok(()),
73            Curve::SaturatingLinear(s) => s.validate_monotonic_increasing(),
74            Curve::PiecewiseLinear(p) => p.validate_monotonic_increasing(),
75        }
76    }
77
78    /// returns an error if there is ever x2 > x1 such that value(x1) < value(x2)
79    pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
80        match self {
81            Curve::Constant { .. } => Ok(()),
82            Curve::SaturatingLinear(s) => s.validate_monotonic_decreasing(),
83            Curve::PiecewiseLinear(p) => p.validate_monotonic_decreasing(),
84        }
85    }
86
87    /// return (min, max) that can ever be returned from value. These could potentially be u128::MIN and u128::MAX
88    pub fn range(&self) -> (u128, u128) {
89        match self {
90            Curve::Constant { y } => (y.u128(), y.u128()),
91            Curve::SaturatingLinear(sat) => sat.range(),
92            Curve::PiecewiseLinear(p) => p.range(),
93        }
94    }
95}
96
97/// min_y for all x <= min_x, max_y for all x >= max_x, linear in between
98#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
99pub struct SaturatingLinear {
100    pub min_x: u64,
101    // I would use Uint128, but those cause parse error, which was fixed in https://github.com/CosmWasm/serde-json-wasm/pull/37
102    // but not yet released in serde-wasm-json v0.4.0
103    pub min_y: Uint128,
104    pub max_x: u64,
105    pub max_y: Uint128,
106}
107
108impl SaturatingLinear {
109    /// provides y = f(x) evaluation
110    pub fn value(&self, x: u64) -> Uint128 {
111        match (x < self.min_x, x > self.max_x) {
112            (true, _) => self.min_y,
113            (_, true) => self.max_y,
114            _ => interpolate((self.min_x, self.min_y), (self.max_x, self.max_y), x),
115        }
116    }
117
118    /// general sanity checks on input values to ensure this is valid.
119    /// these checks should be included by the other validate_* functions
120    pub fn validate(&self) -> Result<(), CurveError> {
121        if self.max_x <= self.min_x {
122            return Err(CurveError::PointsOutOfOrder);
123        }
124        Ok(())
125    }
126
127    /// returns an error if there is ever x2 > x1 such that value(x2) < value(x1)
128    pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
129        self.validate()?;
130        if self.max_y < self.min_y {
131            return Err(CurveError::MonotonicDecreasing);
132        }
133        Ok(())
134    }
135
136    /// returns an error if there is ever x2 > x1 such that value(x1) < value(x2)
137    pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
138        self.validate()?;
139        if self.max_y > self.min_y {
140            return Err(CurveError::MonotonicIncreasing);
141        }
142        Ok(())
143    }
144
145    /// return (min, max) that can ever be returned from value. These could potentially be 0 and u64::MAX
146    pub fn range(&self) -> (u128, u128) {
147        if self.max_y > self.min_y {
148            (self.min_y.u128(), self.max_y.u128())
149        } else {
150            (self.max_y.u128(), self.min_y.u128())
151        }
152    }
153}
154
155// this requires min_x < x < max_x to have been previously validated
156fn interpolate((min_x, min_y): (u64, Uint128), (max_x, max_y): (u64, Uint128), x: u64) -> Uint128 {
157    if max_y > min_y {
158        min_y + (max_y - min_y) * Uint128::from(x - min_x) / Uint128::from(max_x - min_x)
159    } else {
160        min_y - (min_y - max_y) * Uint128::from(x - min_x) / Uint128::from(max_x - min_x)
161    }
162}
163
164/// This is a generalization of SaturatingLinear, steps must be arranged with increasing time (u64).
165/// Any point before first step gets the first value, after last step the last value.
166/// Otherwise, it is a linear interpolation between the two closest points.
167/// Vec of length 1 -> Constant
168/// Vec of length 2 -> SaturatingLinear
169#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
170pub struct PiecewiseLinear {
171    pub steps: Vec<(u64, Uint128)>,
172}
173
174impl PiecewiseLinear {
175    /// provides y = f(x) evaluation
176    pub fn value(&self, x: u64) -> Uint128 {
177        // figure out the pair of points it lies between
178        let (mut prev, mut next): (Option<&(u64, Uint128)>, _) = (None, &self.steps[0]);
179        for step in &self.steps[1..] {
180            // only break if x is not above prev
181            if x >= next.0 {
182                prev = Some(next);
183                next = step;
184            } else {
185                break;
186            }
187        }
188        // at this time:
189        // prev may be None (this was lower than first point)
190        // x may equal prev.0 (use this value)
191        // x may be greater than next (if higher than last item)
192        // OR x may be between prev and next (interpolate)
193        if let Some(last) = prev {
194            if x == last.0 {
195                // this handles exact match with low end
196                last.1
197            } else if x >= next.0 {
198                // this handles both higher than all and exact match
199                next.1
200            } else {
201                // here we do linear interpolation
202                interpolate(*last, *next, x)
203            }
204        } else {
205            // lower than all, use first
206            next.1
207        }
208    }
209
210    /// general sanity checks on input values to ensure this is valid.
211    /// these checks should be included by the other validate_* functions
212    pub fn validate(&self) -> Result<(), CurveError> {
213        if self.steps.is_empty() {
214            return Err(CurveError::MissingSteps);
215        }
216        self.steps.iter().fold(Ok(0u64), |acc, (x, _)| {
217            acc.and_then(|last| {
218                if *x > last {
219                    Ok(*x)
220                } else {
221                    Err(CurveError::PointsOutOfOrder)
222                }
223            })
224        })?;
225        Ok(())
226    }
227
228    /// returns an error if there is ever x2 > x1 such that value(x2) < value(x1)
229    pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
230        self.validate()?;
231        match self.classify_curve() {
232            Shape::NotMonotonic => Err(CurveError::NotMonotonic),
233            Shape::MonotonicDecreasing => Err(CurveError::MonotonicDecreasing),
234            _ => Ok(()),
235        }
236    }
237
238    /// returns an error if there is ever x2 > x1 such that value(x1) < value(x2)
239    pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
240        self.validate()?;
241        match self.classify_curve() {
242            Shape::NotMonotonic => Err(CurveError::NotMonotonic),
243            Shape::MonotonicIncreasing => Err(CurveError::MonotonicIncreasing),
244            _ => Ok(()),
245        }
246    }
247
248    // Gives monotonic info. Requires there be at least one item in steps
249    fn classify_curve(&self) -> Shape {
250        let mut iter = self.steps.iter();
251        let (_, first) = iter.next().unwrap();
252        let (_, shape) = iter.fold((*first, Shape::Constant), |(last, shape), (_, y)| {
253            let shape = match (shape, y.cmp(&last)) {
254                (Shape::NotMonotonic, _) => Shape::NotMonotonic,
255                (Shape::MonotonicDecreasing, Ordering::Greater) => Shape::NotMonotonic,
256                (Shape::MonotonicDecreasing, _) => Shape::MonotonicDecreasing,
257                (Shape::MonotonicIncreasing, Ordering::Less) => Shape::NotMonotonic,
258                (Shape::MonotonicIncreasing, _) => Shape::MonotonicIncreasing,
259                (Shape::Constant, Ordering::Greater) => Shape::MonotonicIncreasing,
260                (Shape::Constant, Ordering::Less) => Shape::MonotonicDecreasing,
261                (Shape::Constant, Ordering::Equal) => Shape::Constant,
262            };
263            (*y, shape)
264        });
265        shape
266    }
267
268    /// return (min, max) that can ever be returned from value. These could potentially be 0 and u64::MAX
269    pub fn range(&self) -> (u128, u128) {
270        let low = self.steps.iter().map(|(_, y)| *y).min().unwrap().u128();
271        let high = self.steps.iter().map(|(_, y)| *y).max().unwrap().u128();
272        (low, high)
273    }
274}
275
276enum Shape {
277    // If there is only one point, or all have same value
278    Constant,
279    MonotonicIncreasing,
280    MonotonicDecreasing,
281    NotMonotonic,
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_constant() {
290        let y = 524;
291        let curve = Curve::constant(y);
292
293        // always valid
294        curve.validate().unwrap();
295        curve.validate_monotonic_increasing().unwrap();
296        curve.validate_monotonic_decreasing().unwrap();
297
298        // always returns same value
299        assert_eq!(curve.value(1).u128(), y);
300        assert_eq!(curve.value(1000000).u128(), y);
301
302        // range is constant
303        assert_eq!(curve.range(), (y, y));
304    }
305
306    #[test]
307    fn test_increasing_linear() {
308        let low = (100, 0);
309        let high = (200, 50);
310        let curve = Curve::saturating_linear(low, high);
311
312        // validly increasing
313        curve.validate().unwrap();
314        curve.validate_monotonic_increasing().unwrap();
315        // but not decreasing
316        let err = curve.validate_monotonic_decreasing().unwrap_err();
317        assert_eq!(err, CurveError::MonotonicIncreasing);
318
319        // check extremes
320        assert_eq!(curve.value(1).u128(), low.1);
321        assert_eq!(curve.value(1000000).u128(), high.1);
322        // check linear portion
323        assert_eq!(curve.value(150).u128(), 25);
324        // and rounding
325        assert_eq!(curve.value(103).u128(), 1);
326
327        // range is min to max
328        assert_eq!(curve.range(), (low.1, high.1));
329    }
330
331    #[test]
332    fn test_decreasing_linear() {
333        let low = (1700, 500);
334        let high = (2000, 200);
335        let curve = Curve::saturating_linear(low, high);
336
337        // validly decreasing
338        curve.validate().unwrap();
339        curve.validate_monotonic_decreasing().unwrap();
340        // but not increasing
341        let err = curve.validate_monotonic_increasing().unwrap_err();
342        assert_eq!(err, CurveError::MonotonicDecreasing);
343
344        // check extremes
345        assert_eq!(curve.value(low.0 - 5).u128(), low.1);
346        assert_eq!(curve.value(high.0 + 5).u128(), high.1);
347        // check linear portion
348        assert_eq!(curve.value(1800).u128(), 400);
349        assert_eq!(curve.value(1997).u128(), 203);
350
351        // range is min to max
352        assert_eq!(curve.range(), (high.1, low.1));
353    }
354
355    #[test]
356    fn test_invalid_linear() {
357        let low = (15000, 100);
358        let high = (12000, 200);
359        let curve = Curve::saturating_linear(low, high);
360
361        // always invalid
362        let err = curve.validate().unwrap_err();
363        assert_eq!(CurveError::PointsOutOfOrder, err);
364        let err = curve.validate_monotonic_decreasing().unwrap_err();
365        assert_eq!(CurveError::PointsOutOfOrder, err);
366        let err = curve.validate_monotonic_increasing().unwrap_err();
367        assert_eq!(CurveError::PointsOutOfOrder, err);
368    }
369
370    #[test]
371    fn test_piecewise_one_step() {
372        let y = 524;
373        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
374            steps: vec![(12345, Uint128::new(y))],
375        });
376
377        // always valid
378        curve.validate().unwrap();
379        curve.validate_monotonic_increasing().unwrap();
380        curve.validate_monotonic_decreasing().unwrap();
381
382        // always returns same value
383        assert_eq!(curve.value(1).u128(), y);
384        assert_eq!(curve.value(1000000).u128(), y);
385
386        // range is constant
387        assert_eq!(curve.range(), (y, y));
388    }
389
390    #[test]
391    fn test_piecewise_two_point_increasing() {
392        let low = (100, Uint128::new(0));
393        let high = (200, Uint128::new(50));
394        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
395            steps: vec![low, high],
396        });
397
398        // validly increasing
399        curve.validate().unwrap();
400        curve.validate_monotonic_increasing().unwrap();
401        // but not decreasing
402        let err = curve.validate_monotonic_decreasing().unwrap_err();
403        assert_eq!(err, CurveError::MonotonicIncreasing);
404
405        // check extremes
406        assert_eq!(curve.value(1), low.1);
407        assert_eq!(curve.value(1000000), high.1);
408        // check linear portion
409        assert_eq!(curve.value(150).u128(), 25);
410        // and rounding
411        assert_eq!(curve.value(103).u128(), 1);
412        // check both edges
413        assert_eq!(curve.value(low.0), low.1);
414        assert_eq!(curve.value(high.0), high.1);
415
416        // range is min to max
417        assert_eq!(curve.range(), (low.1.u128(), high.1.u128()));
418    }
419
420    #[test]
421    fn test_piecewise_two_point_decreasing() {
422        let low = (1700, Uint128::new(500));
423        let high = (2000, Uint128::new(200));
424        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
425            steps: vec![low, high],
426        });
427
428        // validly decreasing
429        curve.validate().unwrap();
430        curve.validate_monotonic_decreasing().unwrap();
431        // but not increasing
432        let err = curve.validate_monotonic_increasing().unwrap_err();
433        assert_eq!(err, CurveError::MonotonicDecreasing);
434
435        // check extremes
436        assert_eq!(curve.value(low.0 - 5), low.1);
437        assert_eq!(curve.value(high.0 + 5), high.1);
438        // check linear portion
439        assert_eq!(curve.value(1800).u128(), 400);
440        assert_eq!(curve.value(1997).u128(), 203);
441        // check edge matches
442        assert_eq!(curve.value(low.0), low.1);
443        assert_eq!(curve.value(high.0), high.1);
444
445        // range is min to max
446        assert_eq!(curve.range(), (high.1.u128(), low.1.u128()));
447    }
448
449    #[test]
450    fn test_piecewise_two_point_invalid() {
451        let low = (15000, 100);
452        let high = (12000, 200);
453        let curve = Curve::saturating_linear(low, high);
454
455        // always invalid
456        let err = curve.validate().unwrap_err();
457        assert_eq!(CurveError::PointsOutOfOrder, err);
458        let err = curve.validate_monotonic_decreasing().unwrap_err();
459        assert_eq!(CurveError::PointsOutOfOrder, err);
460        let err = curve.validate_monotonic_increasing().unwrap_err();
461        assert_eq!(CurveError::PointsOutOfOrder, err);
462    }
463
464    #[test]
465    fn test_piecewise_three_point_increasing() {
466        let low = (100, Uint128::new(0));
467        let mid = (200, Uint128::new(100));
468        let high = (300, Uint128::new(400));
469        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
470            steps: vec![low, mid, high],
471        });
472
473        // validly increasing
474        curve.validate().unwrap();
475        curve.validate_monotonic_increasing().unwrap();
476        // but not decreasing
477        let err = curve.validate_monotonic_decreasing().unwrap_err();
478        assert_eq!(err, CurveError::MonotonicIncreasing);
479
480        // check extremes
481        assert_eq!(curve.value(1), low.1);
482        assert_eq!(curve.value(1000000), high.1);
483
484        // check first portion
485        assert_eq!(curve.value(172).u128(), 72);
486        // check second portion (100 + 3 * 40) = 220
487        assert_eq!(curve.value(240).u128(), 220);
488
489        // check all exact matches
490        assert_eq!(curve.value(low.0), low.1);
491        assert_eq!(curve.value(mid.0), mid.1);
492        assert_eq!(curve.value(high.0), high.1);
493
494        // range is min to max
495        assert_eq!(curve.range(), (low.1.u128(), high.1.u128()));
496    }
497
498    #[test]
499    fn test_piecewise_three_point_decreasing() {
500        let low = (100, Uint128::new(400));
501        let mid = (200, Uint128::new(100));
502        let high = (300, Uint128::new(0));
503        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
504            steps: vec![low, mid, high],
505        });
506
507        // validly decreasing
508        curve.validate().unwrap();
509        curve.validate_monotonic_decreasing().unwrap();
510        // but not increasing
511        let err = curve.validate_monotonic_increasing().unwrap_err();
512        assert_eq!(err, CurveError::MonotonicDecreasing);
513
514        // check extremes
515        assert_eq!(curve.value(1), low.1);
516        assert_eq!(curve.value(1000000), high.1);
517
518        // check first portion (400 - 72 * 3 = 184)
519        assert_eq!(curve.value(172).u128(), 184);
520        // check second portion (100 + 45) = 55
521        assert_eq!(curve.value(245).u128(), 55);
522
523        // check all exact matches
524        assert_eq!(curve.value(low.0), low.1);
525        assert_eq!(curve.value(mid.0), mid.1);
526        assert_eq!(curve.value(high.0), high.1);
527
528        // range is min to max
529        assert_eq!(curve.range(), (high.1.u128(), low.1.u128()));
530    }
531
532    #[test]
533    fn test_piecewise_three_point_invalid_not_monotonic() {
534        let low = (100, Uint128::new(400));
535        let mid = (200, Uint128::new(100));
536        let high = (300, Uint128::new(300));
537        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
538            steps: vec![low, mid, high],
539        });
540
541        // validly order
542        curve.validate().unwrap();
543        // not monotonic
544        let err = curve.validate_monotonic_increasing().unwrap_err();
545        assert_eq!(err, CurveError::NotMonotonic);
546        // not increasing
547        let err = curve.validate_monotonic_decreasing().unwrap_err();
548        assert_eq!(err, CurveError::NotMonotonic);
549    }
550
551    #[test]
552    fn test_piecewise_three_point_invalid_out_of_order() {
553        let low = (100, Uint128::new(400));
554        let mid = (200, Uint128::new(100));
555        let high = (300, Uint128::new(300));
556        let curve = Curve::PiecewiseLinear(PiecewiseLinear {
557            steps: vec![low, high, mid],
558        });
559
560        // validly order
561        let err = curve.validate().unwrap_err();
562        assert_eq!(err, CurveError::PointsOutOfOrder);
563        // not monotonic
564        let err = curve.validate_monotonic_increasing().unwrap_err();
565        assert_eq!(err, CurveError::PointsOutOfOrder);
566        // not increasing
567        let err = curve.validate_monotonic_decreasing().unwrap_err();
568        assert_eq!(err, CurveError::PointsOutOfOrder);
569    }
570
571    // TODO: multi-step bad
572}