1#![no_std]
2extern crate alloc;
3
4use alloc::boxed::Box;
5use core::marker::PhantomData;
6
7pub mod fields;
8pub mod ops;
9
10pub use fields::{constant, function, identity, wrap};
11
12pub trait ScalarField {
13 type Point;
14 type Scalar;
15 fn value(&self, point: Self::Point) -> Self::Scalar;
16
17 fn map<S, F>(self, func: F) -> MapField<Self, F>
18 where
19 Self: Sized,
20 F: Fn(Self::Scalar) -> S,
21 {
22 MapField { inner: self, func }
23 }
24
25 fn transform<P, F>(self, func: F) -> TransformField<Self, P, F>
26 where
27 Self: Sized,
28 F: Fn(P) -> Self::Point,
29 {
30 TransformField {
31 inner: self,
32 func,
33 _point: PhantomData,
34 }
35 }
36}
37
38#[derive(Copy, Clone, Debug)]
39pub struct MapField<I, F> {
40 inner: I,
41 func: F,
42}
43
44impl<P, S> ScalarField for Box<dyn ScalarField<Point = P, Scalar = S>> {
45 type Point = P;
46 type Scalar = S;
47
48 fn value(&self, point: Self::Point) -> Self::Scalar {
49 (**self).value(point)
50 }
51}
52
53impl<I, S, F> ScalarField for MapField<I, F>
54where
55 I: ScalarField,
56 F: Fn(I::Scalar) -> S,
57{
58 type Point = I::Point;
59 type Scalar = S;
60
61 fn value(&self, point: Self::Point) -> Self::Scalar {
62 (self.func)(self.inner.value(point))
63 }
64}
65
66#[derive(Copy, Clone, Debug)]
67pub struct TransformField<I, P, F> {
68 inner: I,
69 func: F,
70 _point: PhantomData<P>,
71}
72
73impl<I, P, F> ScalarField for TransformField<I, P, F>
74where
75 I: ScalarField,
76 F: Fn(P) -> I::Point,
77{
78 type Point = P;
79 type Scalar = I::Scalar;
80
81 fn value(&self, point: Self::Point) -> Self::Scalar {
82 self.inner.value((self.func)(point))
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use crate::*;
89
90 #[test]
91 fn test_constant() {
92 assert_eq!(
93 ((constant(1) + constant(5) * constant(3)) / constant(4)).value(()),
94 4,
95 );
96 }
97
98 #[test]
99 fn test_function() {
100 let field = function(|(x, y, z)| x + y - z);
101 assert_eq!(field.value((4, 5, 6)), 3);
102 assert_eq!(field.value((10, 11, 12)), 9);
103 assert_eq!(field.value((-2, -3, 4)), -9);
104 }
105
106 #[test]
107 fn test_identity() {
108 assert_eq!(identity().value(1), 1);
109 assert_eq!(identity().value(-2), -2);
110 assert_eq!(identity().value([1, 2, 3]), [1, 2, 3]);
111 }
112
113 #[test]
114 fn test_map() {
115 assert_eq!(constant(3).map(|x| x * 2).value(()), 6);
116 assert_eq!(identity().map(|(x, y)| x + y).value((-4, 11)), 7);
117 }
118
119 #[test]
120 fn test_transform() {
121 assert_eq!(constant(3).transform(|p| p * 2).value(4), 3);
122 assert_eq!(
123 identity()
124 .map(|(x, y)| (x * 2.0, y / 2.0))
125 .value((3.0, 4.0)),
126 (6.0, 2.0)
127 );
128 }
129}