rust_overture/
curry.rs

1use std::sync::Arc;
2
3// Curry functions for Rust
4pub fn curry2<A1, A2, R, F>(function: F) -> impl Fn(A1) -> Arc<dyn Fn(A2) -> R + Send + Sync>
5where
6    F: Fn(A1, A2) -> R + Send + Sync + Copy + 'static,
7    A1: Clone + Send + Sync + 'static,
8    A2: Send + Sync + 'static,
9    R: Send + Sync + 'static,
10{
11    move |a1: A1| {
12        let a1_clone = a1.clone();
13        Arc::new(move |a2: A2| function(a1_clone.clone(), a2))
14    }
15}
16
17pub fn curry2_throwing<A1, A2, R, E, F>(
18    function: F,
19) -> impl Fn(A1) -> Arc<dyn Fn(A2) -> Result<R, E> + Send + Sync>
20where
21    F: Fn(A1, A2) -> Result<R, E> + Send + Sync + Copy + 'static,
22    A1: Clone + Send + Sync + 'static,
23    A2: Send + Sync + 'static,
24    R: Send + Sync + 'static,
25    E: Send + Sync + 'static,
26{
27    move |a1: A1| {
28        let a1_clone = a1.clone();
29        Arc::new(move |a2: A2| function(a1_clone.clone(), a2))
30    }
31}
32
33pub fn curry3<A1, A2, A3, R, F>(
34    function: F,
35) -> impl Fn(A1) -> Arc<dyn Fn(A2) -> Arc<dyn Fn(A3) -> R + Send + Sync> + Send + Sync>
36where
37    F: Fn(A1, A2, A3) -> R + Send + Sync + Copy + 'static,
38    A1: Clone + Send + Sync + 'static,
39    A2: Clone + Send + Sync + 'static,
40    A3: Send + Sync + 'static,
41    R: Send + Sync + 'static,
42{
43    move |a1: A1| {
44        let a1_clone = a1.clone();
45        Arc::new(move |a2: A2| {
46            let a1_clone = a1_clone.clone();
47            let a2_clone = a2.clone();
48            Arc::new(move |a3: A3| function(a1_clone.clone(), a2_clone.clone(), a3))
49        })
50    }
51}
52
53// Macro for higher arity functions - using Arc pattern
54macro_rules! curry {
55    ($name:ident, $($arg:ident),+) => {
56        pub fn $name<F, R, $($arg),+>(function: F) -> impl Fn($($arg),+) -> R
57        where
58            F: Fn($($arg),+) -> R + Copy + 'static,
59            $( $arg: Clone + 'static, )+
60            R: 'static,
61        {
62            move |$($arg),+| function($($arg.clone()),+)
63        }
64    };
65}
66
67// Generate curry functions using macro
68curry!(curry4, A1, A2, A3, A4);
69curry!(curry5, A1, A2, A3, A4, A5);
70curry!(curry6, A1, A2, A3, A4, A5, A6);
71curry!(curry7, A1, A2, A3, A4, A5, A6, A7);
72curry!(curry8, A1, A2, A3, A4, A5, A6, A7, A8);
73curry!(curry9, A1, A2, A3, A4, A5, A6, A7, A8, A9);
74curry!(curry10, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10);
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_curry2() {
82        let add = |a: i32, b: i32| a + b;
83        let curried = curry2(add);
84        let add2 = curried(2);
85        assert_eq!(add2(3), 5);
86        assert_eq!(add2(7), 9);
87    }
88
89    #[test]
90    fn test_curry2_throwing() {
91        let safe_divide = |a: f64, b: f64| {
92            if b == 0.0 {
93                Err("Division by zero".to_string())
94            } else {
95                Ok(a / b)
96            }
97        };
98        let curried = curry2_throwing(safe_divide);
99        let divide_by_2 = curried(10.0);
100
101        assert_eq!(divide_by_2(2.0), Ok(5.0));
102        assert_eq!(divide_by_2(0.0), Err("Division by zero".to_string()));
103    }
104
105    #[test]
106    fn test_curry3() {
107        let multiply_add = |a: i32, b: i32, c: i32| a * b + c;
108        let curried = curry3(multiply_add);
109        let multiply_by_2 = curried(2);
110        let multiply_by_2_add = multiply_by_2(3);
111        assert_eq!(multiply_by_2_add(4), 10); // 2*3 + 4 = 10
112    }
113
114    #[test]
115    fn test_curry4_macro() {
116        let complex_calc = |a: i32, b: i32, c: i32, d: i32| (a + b) * (c - d);
117        let result = curry4(complex_calc)(1, 2, 5, 3);
118        assert_eq!(result, 6); // (1+2)*(5-3) = 6
119    }
120
121    #[test]
122    fn test_curry5_macro() {
123        let fn5 = |a: i32, b: i32, c: i32, d: i32, e: i32| a + b + c + d + e;
124        let result = curry5(fn5)(1, 2, 3, 4, 5);
125        assert_eq!(result, 15);
126    }
127
128    #[test]
129    fn test_string_operations() {
130        let concat = |a: String, b: String| format!("{}-{}", a, b);
131        let curried = curry2(concat);
132        let hello_prefix = curried("hello".to_string());
133        let result = hello_prefix("world".to_string());
134        assert_eq!(result, "hello-world");
135    }
136
137    #[test]
138    fn test_partial_application() {
139        let add_three = |a: i32, b: i32, c: i32| a + b + c;
140        let curried = curry3(add_three);
141
142        // Partial application
143        let add_to_10 = curried(10);
144        let add_to_10_and_5 = add_to_10(5);
145
146        assert_eq!(add_to_10_and_5(3), 18); // 10 + 5 + 3 = 18
147        assert_eq!(add_to_10_and_5(7), 22); // 10 + 5 + 7 = 22
148    }
149
150    #[test]
151    fn test_different_types() {
152        let create_tuple = |a: i32, b: String, c: bool| (a, b, c);
153        let curried = curry3(create_tuple);
154        let with_number = curried(42);
155        let with_number_and_str = with_number("hello".to_string());
156        let result = with_number_and_str(true);
157        assert_eq!(result, (42, "hello".to_string(), true));
158    }
159
160    #[test]
161    fn test_curry6_macro() {
162        let fn6 = |a: i32, b: i32, c: i32, d: i32, e: i32, f: i32| a + b + c + d + e + f;
163        let result = curry6(fn6)(1, 2, 3, 4, 5, 6);
164        assert_eq!(result, 21);
165    }
166
167    #[test]
168    fn test_curry7_macro() {
169        let fn7 =
170            |a: i32, b: i32, c: i32, d: i32, e: i32, f: i32, g: i32| a + b + c + d + e + f + g;
171        let result = curry7(fn7)(1, 2, 3, 4, 5, 6, 7);
172        assert_eq!(result, 28);
173    }
174
175    #[test]
176    fn test_thread_safety() {
177        // Test that our curried functions can be sent between threads
178        let add = |a: i32, b: i32| a + b;
179        let curried = curry2(add);
180        let add5 = curried(5);
181
182        let handle = std::thread::spawn(move || add5(3));
183
184        assert_eq!(handle.join().unwrap(), 8);
185    }
186}