1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3use std::cmp::Ordering;
4use thiserror::Error;
5
6use cosmwasm_std::Uint128;
7
8#[derive(Error, Debug, PartialEq)]
9pub enum CurveError {
10 #[error("Curve isn't monotonic")]
11 NotMonotonic,
12
13 #[error("Curve is monotonic increasing")]
14 MonotonicIncreasing,
15
16 #[error("Curve is monotonic decreasing")]
17 MonotonicDecreasing,
18
19 #[error("Later point must have higher X than previous point")]
20 PointsOutOfOrder,
21
22 #[error("No steps defined")]
23 MissingSteps,
24}
25
26#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
27#[serde(rename_all = "snake_case")]
28pub enum Curve {
29 Constant { y: Uint128 },
30 SaturatingLinear(SaturatingLinear),
31 PiecewiseLinear(PiecewiseLinear),
32}
33
34impl Curve {
35 pub fn saturating_linear((min_x, min_y): (u64, u128), (max_x, max_y): (u64, u128)) -> Self {
36 Curve::SaturatingLinear(SaturatingLinear {
37 min_x,
38 min_y: min_y.into(),
39 max_x,
40 max_y: max_y.into(),
41 })
42 }
43
44 pub fn constant(y: u128) -> Self {
45 Curve::Constant { y: Uint128::new(y) }
46 }
47}
48
49impl Curve {
50 pub fn value(&self, x: u64) -> Uint128 {
52 match self {
53 Curve::Constant { y } => *y,
54 Curve::SaturatingLinear(s) => s.value(x),
55 Curve::PiecewiseLinear(p) => p.value(x),
56 }
57 }
58
59 pub fn validate(&self) -> Result<(), CurveError> {
62 match self {
63 Curve::Constant { .. } => Ok(()),
64 Curve::SaturatingLinear(s) => s.validate(),
65 Curve::PiecewiseLinear(p) => p.validate(),
66 }
67 }
68
69 pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
71 match self {
72 Curve::Constant { .. } => Ok(()),
73 Curve::SaturatingLinear(s) => s.validate_monotonic_increasing(),
74 Curve::PiecewiseLinear(p) => p.validate_monotonic_increasing(),
75 }
76 }
77
78 pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
80 match self {
81 Curve::Constant { .. } => Ok(()),
82 Curve::SaturatingLinear(s) => s.validate_monotonic_decreasing(),
83 Curve::PiecewiseLinear(p) => p.validate_monotonic_decreasing(),
84 }
85 }
86
87 pub fn range(&self) -> (u128, u128) {
89 match self {
90 Curve::Constant { y } => (y.u128(), y.u128()),
91 Curve::SaturatingLinear(sat) => sat.range(),
92 Curve::PiecewiseLinear(p) => p.range(),
93 }
94 }
95}
96
97#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
99pub struct SaturatingLinear {
100 pub min_x: u64,
101 pub min_y: Uint128,
104 pub max_x: u64,
105 pub max_y: Uint128,
106}
107
108impl SaturatingLinear {
109 pub fn value(&self, x: u64) -> Uint128 {
111 match (x < self.min_x, x > self.max_x) {
112 (true, _) => self.min_y,
113 (_, true) => self.max_y,
114 _ => interpolate((self.min_x, self.min_y), (self.max_x, self.max_y), x),
115 }
116 }
117
118 pub fn validate(&self) -> Result<(), CurveError> {
121 if self.max_x <= self.min_x {
122 return Err(CurveError::PointsOutOfOrder);
123 }
124 Ok(())
125 }
126
127 pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
129 self.validate()?;
130 if self.max_y < self.min_y {
131 return Err(CurveError::MonotonicDecreasing);
132 }
133 Ok(())
134 }
135
136 pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
138 self.validate()?;
139 if self.max_y > self.min_y {
140 return Err(CurveError::MonotonicIncreasing);
141 }
142 Ok(())
143 }
144
145 pub fn range(&self) -> (u128, u128) {
147 if self.max_y > self.min_y {
148 (self.min_y.u128(), self.max_y.u128())
149 } else {
150 (self.max_y.u128(), self.min_y.u128())
151 }
152 }
153}
154
155fn interpolate((min_x, min_y): (u64, Uint128), (max_x, max_y): (u64, Uint128), x: u64) -> Uint128 {
157 if max_y > min_y {
158 min_y + (max_y - min_y) * Uint128::from(x - min_x) / Uint128::from(max_x - min_x)
159 } else {
160 min_y - (min_y - max_y) * Uint128::from(x - min_x) / Uint128::from(max_x - min_x)
161 }
162}
163
164#[derive(Serialize, Deserialize, JsonSchema, Debug, Clone, PartialEq)]
170pub struct PiecewiseLinear {
171 pub steps: Vec<(u64, Uint128)>,
172}
173
174impl PiecewiseLinear {
175 pub fn value(&self, x: u64) -> Uint128 {
177 let (mut prev, mut next): (Option<&(u64, Uint128)>, _) = (None, &self.steps[0]);
179 for step in &self.steps[1..] {
180 if x >= next.0 {
182 prev = Some(next);
183 next = step;
184 } else {
185 break;
186 }
187 }
188 if let Some(last) = prev {
194 if x == last.0 {
195 last.1
197 } else if x >= next.0 {
198 next.1
200 } else {
201 interpolate(*last, *next, x)
203 }
204 } else {
205 next.1
207 }
208 }
209
210 pub fn validate(&self) -> Result<(), CurveError> {
213 if self.steps.is_empty() {
214 return Err(CurveError::MissingSteps);
215 }
216 self.steps.iter().fold(Ok(0u64), |acc, (x, _)| {
217 acc.and_then(|last| {
218 if *x > last {
219 Ok(*x)
220 } else {
221 Err(CurveError::PointsOutOfOrder)
222 }
223 })
224 })?;
225 Ok(())
226 }
227
228 pub fn validate_monotonic_increasing(&self) -> Result<(), CurveError> {
230 self.validate()?;
231 match self.classify_curve() {
232 Shape::NotMonotonic => Err(CurveError::NotMonotonic),
233 Shape::MonotonicDecreasing => Err(CurveError::MonotonicDecreasing),
234 _ => Ok(()),
235 }
236 }
237
238 pub fn validate_monotonic_decreasing(&self) -> Result<(), CurveError> {
240 self.validate()?;
241 match self.classify_curve() {
242 Shape::NotMonotonic => Err(CurveError::NotMonotonic),
243 Shape::MonotonicIncreasing => Err(CurveError::MonotonicIncreasing),
244 _ => Ok(()),
245 }
246 }
247
248 fn classify_curve(&self) -> Shape {
250 let mut iter = self.steps.iter();
251 let (_, first) = iter.next().unwrap();
252 let (_, shape) = iter.fold((*first, Shape::Constant), |(last, shape), (_, y)| {
253 let shape = match (shape, y.cmp(&last)) {
254 (Shape::NotMonotonic, _) => Shape::NotMonotonic,
255 (Shape::MonotonicDecreasing, Ordering::Greater) => Shape::NotMonotonic,
256 (Shape::MonotonicDecreasing, _) => Shape::MonotonicDecreasing,
257 (Shape::MonotonicIncreasing, Ordering::Less) => Shape::NotMonotonic,
258 (Shape::MonotonicIncreasing, _) => Shape::MonotonicIncreasing,
259 (Shape::Constant, Ordering::Greater) => Shape::MonotonicIncreasing,
260 (Shape::Constant, Ordering::Less) => Shape::MonotonicDecreasing,
261 (Shape::Constant, Ordering::Equal) => Shape::Constant,
262 };
263 (*y, shape)
264 });
265 shape
266 }
267
268 pub fn range(&self) -> (u128, u128) {
270 let low = self.steps.iter().map(|(_, y)| *y).min().unwrap().u128();
271 let high = self.steps.iter().map(|(_, y)| *y).max().unwrap().u128();
272 (low, high)
273 }
274}
275
276enum Shape {
277 Constant,
279 MonotonicIncreasing,
280 MonotonicDecreasing,
281 NotMonotonic,
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_constant() {
290 let y = 524;
291 let curve = Curve::constant(y);
292
293 curve.validate().unwrap();
295 curve.validate_monotonic_increasing().unwrap();
296 curve.validate_monotonic_decreasing().unwrap();
297
298 assert_eq!(curve.value(1).u128(), y);
300 assert_eq!(curve.value(1000000).u128(), y);
301
302 assert_eq!(curve.range(), (y, y));
304 }
305
306 #[test]
307 fn test_increasing_linear() {
308 let low = (100, 0);
309 let high = (200, 50);
310 let curve = Curve::saturating_linear(low, high);
311
312 curve.validate().unwrap();
314 curve.validate_monotonic_increasing().unwrap();
315 let err = curve.validate_monotonic_decreasing().unwrap_err();
317 assert_eq!(err, CurveError::MonotonicIncreasing);
318
319 assert_eq!(curve.value(1).u128(), low.1);
321 assert_eq!(curve.value(1000000).u128(), high.1);
322 assert_eq!(curve.value(150).u128(), 25);
324 assert_eq!(curve.value(103).u128(), 1);
326
327 assert_eq!(curve.range(), (low.1, high.1));
329 }
330
331 #[test]
332 fn test_decreasing_linear() {
333 let low = (1700, 500);
334 let high = (2000, 200);
335 let curve = Curve::saturating_linear(low, high);
336
337 curve.validate().unwrap();
339 curve.validate_monotonic_decreasing().unwrap();
340 let err = curve.validate_monotonic_increasing().unwrap_err();
342 assert_eq!(err, CurveError::MonotonicDecreasing);
343
344 assert_eq!(curve.value(low.0 - 5).u128(), low.1);
346 assert_eq!(curve.value(high.0 + 5).u128(), high.1);
347 assert_eq!(curve.value(1800).u128(), 400);
349 assert_eq!(curve.value(1997).u128(), 203);
350
351 assert_eq!(curve.range(), (high.1, low.1));
353 }
354
355 #[test]
356 fn test_invalid_linear() {
357 let low = (15000, 100);
358 let high = (12000, 200);
359 let curve = Curve::saturating_linear(low, high);
360
361 let err = curve.validate().unwrap_err();
363 assert_eq!(CurveError::PointsOutOfOrder, err);
364 let err = curve.validate_monotonic_decreasing().unwrap_err();
365 assert_eq!(CurveError::PointsOutOfOrder, err);
366 let err = curve.validate_monotonic_increasing().unwrap_err();
367 assert_eq!(CurveError::PointsOutOfOrder, err);
368 }
369
370 #[test]
371 fn test_piecewise_one_step() {
372 let y = 524;
373 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
374 steps: vec![(12345, Uint128::new(y))],
375 });
376
377 curve.validate().unwrap();
379 curve.validate_monotonic_increasing().unwrap();
380 curve.validate_monotonic_decreasing().unwrap();
381
382 assert_eq!(curve.value(1).u128(), y);
384 assert_eq!(curve.value(1000000).u128(), y);
385
386 assert_eq!(curve.range(), (y, y));
388 }
389
390 #[test]
391 fn test_piecewise_two_point_increasing() {
392 let low = (100, Uint128::new(0));
393 let high = (200, Uint128::new(50));
394 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
395 steps: vec![low, high],
396 });
397
398 curve.validate().unwrap();
400 curve.validate_monotonic_increasing().unwrap();
401 let err = curve.validate_monotonic_decreasing().unwrap_err();
403 assert_eq!(err, CurveError::MonotonicIncreasing);
404
405 assert_eq!(curve.value(1), low.1);
407 assert_eq!(curve.value(1000000), high.1);
408 assert_eq!(curve.value(150).u128(), 25);
410 assert_eq!(curve.value(103).u128(), 1);
412 assert_eq!(curve.value(low.0), low.1);
414 assert_eq!(curve.value(high.0), high.1);
415
416 assert_eq!(curve.range(), (low.1.u128(), high.1.u128()));
418 }
419
420 #[test]
421 fn test_piecewise_two_point_decreasing() {
422 let low = (1700, Uint128::new(500));
423 let high = (2000, Uint128::new(200));
424 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
425 steps: vec![low, high],
426 });
427
428 curve.validate().unwrap();
430 curve.validate_monotonic_decreasing().unwrap();
431 let err = curve.validate_monotonic_increasing().unwrap_err();
433 assert_eq!(err, CurveError::MonotonicDecreasing);
434
435 assert_eq!(curve.value(low.0 - 5), low.1);
437 assert_eq!(curve.value(high.0 + 5), high.1);
438 assert_eq!(curve.value(1800).u128(), 400);
440 assert_eq!(curve.value(1997).u128(), 203);
441 assert_eq!(curve.value(low.0), low.1);
443 assert_eq!(curve.value(high.0), high.1);
444
445 assert_eq!(curve.range(), (high.1.u128(), low.1.u128()));
447 }
448
449 #[test]
450 fn test_piecewise_two_point_invalid() {
451 let low = (15000, 100);
452 let high = (12000, 200);
453 let curve = Curve::saturating_linear(low, high);
454
455 let err = curve.validate().unwrap_err();
457 assert_eq!(CurveError::PointsOutOfOrder, err);
458 let err = curve.validate_monotonic_decreasing().unwrap_err();
459 assert_eq!(CurveError::PointsOutOfOrder, err);
460 let err = curve.validate_monotonic_increasing().unwrap_err();
461 assert_eq!(CurveError::PointsOutOfOrder, err);
462 }
463
464 #[test]
465 fn test_piecewise_three_point_increasing() {
466 let low = (100, Uint128::new(0));
467 let mid = (200, Uint128::new(100));
468 let high = (300, Uint128::new(400));
469 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
470 steps: vec![low, mid, high],
471 });
472
473 curve.validate().unwrap();
475 curve.validate_monotonic_increasing().unwrap();
476 let err = curve.validate_monotonic_decreasing().unwrap_err();
478 assert_eq!(err, CurveError::MonotonicIncreasing);
479
480 assert_eq!(curve.value(1), low.1);
482 assert_eq!(curve.value(1000000), high.1);
483
484 assert_eq!(curve.value(172).u128(), 72);
486 assert_eq!(curve.value(240).u128(), 220);
488
489 assert_eq!(curve.value(low.0), low.1);
491 assert_eq!(curve.value(mid.0), mid.1);
492 assert_eq!(curve.value(high.0), high.1);
493
494 assert_eq!(curve.range(), (low.1.u128(), high.1.u128()));
496 }
497
498 #[test]
499 fn test_piecewise_three_point_decreasing() {
500 let low = (100, Uint128::new(400));
501 let mid = (200, Uint128::new(100));
502 let high = (300, Uint128::new(0));
503 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
504 steps: vec![low, mid, high],
505 });
506
507 curve.validate().unwrap();
509 curve.validate_monotonic_decreasing().unwrap();
510 let err = curve.validate_monotonic_increasing().unwrap_err();
512 assert_eq!(err, CurveError::MonotonicDecreasing);
513
514 assert_eq!(curve.value(1), low.1);
516 assert_eq!(curve.value(1000000), high.1);
517
518 assert_eq!(curve.value(172).u128(), 184);
520 assert_eq!(curve.value(245).u128(), 55);
522
523 assert_eq!(curve.value(low.0), low.1);
525 assert_eq!(curve.value(mid.0), mid.1);
526 assert_eq!(curve.value(high.0), high.1);
527
528 assert_eq!(curve.range(), (high.1.u128(), low.1.u128()));
530 }
531
532 #[test]
533 fn test_piecewise_three_point_invalid_not_monotonic() {
534 let low = (100, Uint128::new(400));
535 let mid = (200, Uint128::new(100));
536 let high = (300, Uint128::new(300));
537 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
538 steps: vec![low, mid, high],
539 });
540
541 curve.validate().unwrap();
543 let err = curve.validate_monotonic_increasing().unwrap_err();
545 assert_eq!(err, CurveError::NotMonotonic);
546 let err = curve.validate_monotonic_decreasing().unwrap_err();
548 assert_eq!(err, CurveError::NotMonotonic);
549 }
550
551 #[test]
552 fn test_piecewise_three_point_invalid_out_of_order() {
553 let low = (100, Uint128::new(400));
554 let mid = (200, Uint128::new(100));
555 let high = (300, Uint128::new(300));
556 let curve = Curve::PiecewiseLinear(PiecewiseLinear {
557 steps: vec![low, high, mid],
558 });
559
560 let err = curve.validate().unwrap_err();
562 assert_eq!(err, CurveError::PointsOutOfOrder);
563 let err = curve.validate_monotonic_increasing().unwrap_err();
565 assert_eq!(err, CurveError::PointsOutOfOrder);
566 let err = curve.validate_monotonic_decreasing().unwrap_err();
568 assert_eq!(err, CurveError::PointsOutOfOrder);
569 }
570
571 }