scalar_field/
lib.rs

1#![no_std]
2extern crate alloc;
3
4use alloc::boxed::Box;
5use core::marker::PhantomData;
6
7pub mod fields;
8pub mod ops;
9
10pub use fields::{constant, function, identity, wrap};
11
12pub trait ScalarField {
13	type Point;
14	type Scalar;
15	fn value(&self, point: Self::Point) -> Self::Scalar;
16
17	fn map<S, F>(self, func: F) -> MapField<Self, F>
18	where
19		Self: Sized,
20		F: Fn(Self::Scalar) -> S,
21	{
22		MapField { inner: self, func }
23	}
24
25	fn transform<P, F>(self, func: F) -> TransformField<Self, P, F>
26	where
27		Self: Sized,
28		F: Fn(P) -> Self::Point,
29	{
30		TransformField {
31			inner: self,
32			func,
33			_point: PhantomData,
34		}
35	}
36}
37
38#[derive(Copy, Clone, Debug)]
39pub struct MapField<I, F> {
40	inner: I,
41	func: F,
42}
43
44impl<P, S> ScalarField for Box<dyn ScalarField<Point = P, Scalar = S>> {
45	type Point = P;
46	type Scalar = S;
47
48	fn value(&self, point: Self::Point) -> Self::Scalar {
49		(**self).value(point)
50	}
51}
52
53impl<I, S, F> ScalarField for MapField<I, F>
54where
55	I: ScalarField,
56	F: Fn(I::Scalar) -> S,
57{
58	type Point = I::Point;
59	type Scalar = S;
60
61	fn value(&self, point: Self::Point) -> Self::Scalar {
62		(self.func)(self.inner.value(point))
63	}
64}
65
66#[derive(Copy, Clone, Debug)]
67pub struct TransformField<I, P, F> {
68	inner: I,
69	func: F,
70	_point: PhantomData<P>,
71}
72
73impl<I, P, F> ScalarField for TransformField<I, P, F>
74where
75	I: ScalarField,
76	F: Fn(P) -> I::Point,
77{
78	type Point = P;
79	type Scalar = I::Scalar;
80
81	fn value(&self, point: Self::Point) -> Self::Scalar {
82		self.inner.value((self.func)(point))
83	}
84}
85
86#[cfg(test)]
87mod tests {
88	use crate::*;
89
90	#[test]
91	fn test_constant() {
92		assert_eq!(
93			((constant(1) + constant(5) * constant(3)) / constant(4)).value(()),
94			4,
95		);
96	}
97
98	#[test]
99	fn test_function() {
100		let field = function(|(x, y, z)| x + y - z);
101		assert_eq!(field.value((4, 5, 6)), 3);
102		assert_eq!(field.value((10, 11, 12)), 9);
103		assert_eq!(field.value((-2, -3, 4)), -9);
104	}
105
106	#[test]
107	fn test_identity() {
108		assert_eq!(identity().value(1), 1);
109		assert_eq!(identity().value(-2), -2);
110		assert_eq!(identity().value([1, 2, 3]), [1, 2, 3]);
111	}
112
113	#[test]
114	fn test_map() {
115		assert_eq!(constant(3).map(|x| x * 2).value(()), 6);
116		assert_eq!(identity().map(|(x, y)| x + y).value((-4, 11)), 7);
117	}
118
119	#[test]
120	fn test_transform() {
121		assert_eq!(constant(3).transform(|p| p * 2).value(4), 3);
122		assert_eq!(
123			identity()
124				.map(|(x, y)| (x * 2.0, y / 2.0))
125				.value((3.0, 4.0)),
126			(6.0, 2.0)
127		);
128	}
129}