tachyon_math_library/functions/
function_logic.rs

1use crate::{FunctionData, FunctionDataAccessors, FunctionType, Interpolation, ValueCode};
2use anchor_lang::prelude::*;
3use num_traits::FromPrimitive;
4
5use crate::error::ErrorCode;
6use crate::math::interpolation;
7use rust_decimal::Decimal;
8
9pub trait FunctionLogic {
10    const FUNCTION_TYPE: FunctionType;
11
12    fn validate_load(x_in: Decimal, y_in: Decimal) -> Result<(Decimal, ValueCode)>;
13    fn eval(fd: &FunctionData, x: Decimal, interp: Interpolation, saturating: bool) -> Result<Decimal>;
14
15    fn proportion_difference(a: Decimal, b: Decimal) -> Result<Decimal> {
16        if a.is_zero() || b.is_zero() {
17            return Ok((a - b).abs());
18        }
19
20        Ok(((a / b) - Decimal::ONE).abs())
21    }
22
23    fn interpolate(fd: &FunctionData, x: Decimal, interp: Interpolation) -> Result<Decimal> {
24        // get indices for the x value
25        let (lower_index, upper_index, x_index_decimal) = fd.get_index_bounds(x)?;
26
27        let (y, code) = if lower_index == upper_index {
28            // if x lands exactly on an index
29
30            let y = fd.get_value(lower_index)?;
31            let code = fd.reduce_value_codes_from_indices(Vec::from([lower_index]))?;
32
33            (y, code)
34        } else {
35            // get the data using the indices
36            let lower_val = fd.get_value(lower_index)?;
37            let upper_val = fd.get_value(upper_index)?;
38
39            let point_a = (Decimal::from_u32(lower_index).unwrap(), lower_val);
40            let point_b = (Decimal::from_u32(upper_index).unwrap(), upper_val);
41
42            match interp {
43                Interpolation::Linear => {
44                    let value_code = fd.reduce_value_codes_from_indices(Vec::from([lower_index, upper_index]))?;
45
46                    (interpolation::linear(point_a, point_b, x_index_decimal)?, value_code)
47                }
48                Interpolation::Quadratic => {
49                    // determine if we can grab the point before the index pair or the point after
50                    // if the index is 1 we still want to grab the 3rd index on the right side to avoid the NaN value at 0
51                    let (point_c, code) = if lower_index < 2_u32 {
52                        let next_index = upper_index.checked_add(1u32).unwrap();
53                        let next_val = fd.get_value(next_index)?;
54
55                        let value_code = fd.reduce_value_codes_from_indices(Vec::from([lower_index, upper_index, next_index]))?;
56
57                        ((Decimal::from_u32(next_index).unwrap(), next_val), value_code)
58                    } else {
59                        let prev_index = lower_index.checked_sub(1u32).unwrap();
60                        let prev_val = fd.get_value(prev_index)?;
61
62                        let value_code = fd.reduce_value_codes_from_indices(Vec::from([lower_index, upper_index, prev_index]))?;
63
64                        ((Decimal::from_u32(prev_index).unwrap(), prev_val), value_code)
65                    };
66
67                    (interpolation::quadratic(point_a, point_b, point_c, x_index_decimal)?, code)
68                }
69            }
70        };
71
72        if code == ValueCode::Empty {
73            return err!(ErrorCode::EmptyData);
74        }
75
76        if code == ValueCode::NaN {
77            return err!(ErrorCode::NaNData);
78        }
79
80        Ok(y)
81    }
82}