rust2fun/
flatmap.rs

1//! FlatMap.
2
3use core::marker::PhantomData;
4
5use crate::combinator::id;
6use crate::constant1;
7use crate::functor::Functor;
8use crate::higher::Higher;
9
10/// Gives access to the `flat_map` method. The motivation for separating this out of
11/// [Monad](super::monad::Monad) is that there are situations where `flat_map` can be implemented
12/// but not `pure`.
13pub trait FlatMap<B>: Higher {
14    /// Maps a function over a value in the context and flattens the resulting nested context.
15    /// This is the same  as `self.map(f).flatten()`.
16    /// This is also known as `bind` or `>>=` in other languages.
17    ///
18    /// # Examples
19    ///
20    /// ```
21    /// use rust2fun::prelude::*;
22    ///
23    /// let x = Some(1);
24    /// let actual = x.flat_map(|x| Some(x.to_string()));
25    /// assert_eq!(Some("1".to_string()), actual);
26    /// ```
27    fn flat_map<F>(self, f: F) -> Self::Target<B>
28    where
29        F: FnMut(Self::Param) -> Self::Target<B>;
30
31    /// Flattens a nested structure.
32    /// This is a convenience method for `flat_map(id)`.
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// use rust2fun::prelude::*;
38    ///
39    /// let actual = Some(Some(1)).flatten();
40    /// assert_eq!(Some(1), actual);
41    /// ```
42    #[inline]
43    fn flatten(self) -> Self::Target<B>
44    where
45        Self: FlatMap<B, Param = <Self as Higher>::Target<B>> + Sized,
46    {
47        self.flat_map(id)
48    }
49
50    /// Pair up the value with the result of applying the function to the value.
51    ///
52    /// # Examples
53    /// ```
54    /// use rust2fun::prelude::*;
55    ///
56    /// let x = Some(1);
57    /// let actual = x.m_product(|x| Some(x.to_string()));
58    /// assert_eq!(Some((1, "1".to_string())), actual);
59    /// ```
60    fn m_product<F>(self, mut f: F) -> Self::Target<(Self::Param, B)>
61    where
62        F: FnMut(Self::Param) -> Self::Target<B>,
63        Self: FlatMap<(<Self as Higher>::Param, B)> + Sized,
64        Self::Param: Copy,
65        Self::Target<B>:
66            Functor<(Self::Param, B), Target<(Self::Param, B)> = Self::Target<(Self::Param, B)>>,
67    {
68        self.flat_map(|a| f(a).map(|b| (a, b)))
69    }
70
71    /// `if` lifted into monad.
72    ///
73    /// # Examples
74    /// ```
75    /// use rust2fun::prelude::*;
76    ///
77    /// let x = Some(true);
78    /// let actual = x.if_m(constant!(Some(1)), constant!(Some(0)));
79    /// assert_eq!(Some(1), actual);
80    /// ```
81    #[inline]
82    fn if_m<T, F>(self, mut if_true: T, mut if_false: F) -> Self::Target<B>
83    where
84        T: FnMut() -> Self::Target<B>,
85        F: FnMut() -> Self::Target<B>,
86        Self: FlatMap<B, Param = bool> + Sized,
87    {
88        self.flat_map(|x| if x { if_true() } else { if_false() })
89    }
90
91    /// Apply a monadic function and discard the result while keeping the effect.
92    ///
93    /// # Examples
94    /// ```
95    /// use rust2fun::prelude::*;
96    ///
97    /// let x = Some(1);
98    /// let actual = x.flat_tap(|x| Some(x.to_string()));
99    /// assert_eq!(Some(1), actual);
100    /// ```
101    fn flat_tap<F>(self, mut f: F) -> Self
102    where
103        F: FnMut(Self::Param) -> Self::Target<B>,
104        Self: FlatMap<<Self as Higher>::Param, Target<<Self as Higher>::Param> = Self> + Sized,
105        Self::Param: Copy,
106        Self::Target<B>: Functor<Self::Param, Target<Self::Param> = Self>,
107    {
108        #[inline]
109        fn internal<FA: FlatMap<<FA as Higher>::Param, Target<<FA as Higher>::Param> = FA>>(
110            fa: FA,
111            g: impl FnMut(FA::Param) -> FA,
112        ) -> FA {
113            fa.flat_map(g)
114        }
115
116        internal(self, |a| f(a).map(constant1!(a)))
117    }
118}
119
120/// Macro to implement [FlatMap] for types with [Iterator] support.
121#[macro_export]
122macro_rules! flatmap_iter {
123    ($name:ident) => {
124        impl<A, B> $crate::flatmap::FlatMap<B> for $name<A>
125        {
126            #[inline]
127            fn flat_map<F>(self, f: F) -> Self::Target<B>
128            where
129                F: FnMut(A) -> Self::Target<B>,
130            {
131                self.into_iter().flat_map(f).collect::<$name<B>>()
132            }
133        }
134    };
135    ($name:ident, $ct:tt $(+ $dt:tt )*) => {
136        impl<A, B: $ct $(+ $dt )*> $crate::flatmap::FlatMap<B> for $name<A> {
137            #[inline]
138            fn flat_map<F>(self, f: F) -> Self::Target<B>
139            where
140                F: FnMut(A) -> Self::Target<B>,
141            {
142                self.into_iter().flat_map(f).collect::<$name<B>>()
143            }
144        }
145    };
146}
147
148impl<A, B> FlatMap<B> for PhantomData<A> {
149    #[inline]
150    fn flat_map<F>(self, _f: F) -> PhantomData<B>
151    where
152        F: FnMut(A) -> PhantomData<B>,
153    {
154        PhantomData
155    }
156}
157
158impl<A, B> FlatMap<B> for Option<A> {
159    #[inline]
160    fn flat_map<F>(self, f: F) -> Option<B>
161    where
162        F: FnMut(A) -> Option<B>,
163    {
164        self.and_then(f)
165    }
166}
167
168impl<A, B, E> FlatMap<B> for Result<A, E> {
169    #[inline]
170    fn flat_map<F>(self, f: F) -> Result<B, E>
171    where
172        F: FnMut(A) -> Result<B, E>,
173    {
174        self.and_then(f)
175    }
176}
177
178if_std! {
179    use std::boxed::Box;
180    use std::collections::*;
181    use std::hash::Hash;
182    use std::vec::Vec;
183
184    impl<A, B> FlatMap<B> for Box<A> {
185        #[inline]
186        fn flat_map<F>(self, mut f: F) -> Box<B>
187        where
188            F: FnMut(A) -> Box<B>,
189        {
190            f(*self)
191        }
192    }
193
194    flatmap_iter!(Vec);
195    flatmap_iter!(LinkedList);
196    flatmap_iter!(VecDeque);
197    flatmap_iter!(BinaryHeap, Ord);
198    flatmap_iter!(BTreeSet, Ord);
199    flatmap_iter!(HashSet, Eq + Hash);
200
201    impl<A, B, K: Eq + Hash> FlatMap<B> for HashMap<K, A> {
202        #[inline]
203        fn flat_map<F>(self, mut f: F) -> HashMap<K, B>
204        where
205            F: FnMut(A) -> HashMap<K, B>,
206        {
207            self.into_iter().flat_map(|(_, v)|  f(v)).collect()
208        }
209    }
210}