Skip to main content

tw_storage_extra/cow/
conditional_multi_index.rs

1use cosmwasm_std::{StdResult, Storage};
2use cw_storage_plus::{Index, Map, Prefix, Prefixer, PrimaryKey};
3use serde::{de::DeserializeOwned, Serialize};
4use std::borrow::Cow;
5
6use super::{helpers::deserialize_multi_kv, DeserializeFn};
7
8#[derive(Clone)]
9pub struct ConditionalMultiIndex<'a, K, T> {
10    pub(crate) idx_namespace: Cow<'a, str>,
11    pub(crate) pk_namespace: Cow<'a, str>,
12    idx_fn: fn(&T, Vec<u8>) -> K,
13    cond_fn: fn(&T) -> bool,
14    dese_fn: Option<DeserializeFn<T>>,
15}
16
17impl<'a, K, T> ConditionalMultiIndex<'a, K, T> {
18    /// Only if result of `cond_fn` is `true`, data will be added to this `ConditionalMultiIndex`.
19    ///
20    /// Result of `cond_fn` **must be constant**, otherwise might raise unexpected behavior.
21    pub const fn new_ref(
22        idx_fn: fn(&T, Vec<u8>) -> K,
23        cond_fn: fn(&T) -> bool,
24        dese_fn: Option<DeserializeFn<T>>,
25        pk_namespace: &'a str,
26        idx_namespace: &'a str,
27    ) -> Self {
28        Self {
29            idx_fn,
30            cond_fn,
31            dese_fn,
32            idx_namespace: Cow::Borrowed(idx_namespace),
33            pk_namespace: Cow::Borrowed(pk_namespace),
34        }
35    }
36
37    /// Only if result of `cond_fn` is `true`, data will be added to this `ConditionalMultiIndex`.
38    ///
39    /// Result of `cond_fn` **must be constant**, otherwise might raise unexpected behavior.
40    pub const fn new_owned(
41        idx_fn: fn(&T, Vec<u8>) -> K,
42        cond_fn: fn(&T) -> bool,
43        dese_fn: Option<DeserializeFn<T>>,
44        pk_namespace: String,
45        idx_namespace: String,
46    ) -> Self {
47        Self {
48            idx_fn,
49            cond_fn,
50            dese_fn,
51            idx_namespace: Cow::Owned(idx_namespace),
52            pk_namespace: Cow::Owned(pk_namespace),
53        }
54    }
55}
56
57impl<'a, K, T> Index<T> for ConditionalMultiIndex<'a, K, T>
58where
59    T: Serialize + DeserializeOwned + Clone,
60    K: for<'key> PrimaryKey<'key>,
61{
62    fn save(&self, store: &mut dyn Storage, pk: &[u8], data: &T) -> StdResult<()> {
63        if (self.cond_fn)(data) {
64            let idx = (self.idx_fn)(data, pk.to_vec());
65            self.idx_map().save(store, idx, &(pk.len() as u32))?;
66        }
67
68        Ok(())
69    }
70
71    fn remove(&self, store: &mut dyn Storage, pk: &[u8], old_data: &T) -> StdResult<()> {
72        if (self.cond_fn)(old_data) {
73            let idx = (self.idx_fn)(old_data, pk.to_vec());
74            self.idx_map().remove(store, idx);
75        };
76
77        Ok(())
78    }
79}
80
81impl<'a, K, T> ConditionalMultiIndex<'a, K, T>
82where
83    T: Serialize + DeserializeOwned + Clone,
84    K: for<'key> PrimaryKey<'key>,
85{
86    fn idx_map(&self) -> Map<'_, K, u32> {
87        Map::new(&self.idx_namespace)
88    }
89
90    pub fn prefix(&self, p: <K as PrimaryKey<'_>>::Prefix) -> Prefix<T> {
91        Prefix::with_deserialization_function(
92            self.idx_namespace.as_bytes(),
93            &p.prefix(),
94            self.pk_namespace.as_bytes(),
95            match self.dese_fn {
96                Some(f) => f,
97                None => deserialize_multi_kv,
98            },
99        )
100    }
101
102    pub fn sub_prefix(&self, p: <K as PrimaryKey<'_>>::SubPrefix) -> Prefix<T> {
103        Prefix::with_deserialization_function(
104            self.idx_namespace.as_bytes(),
105            &p.prefix(),
106            self.pk_namespace.as_bytes(),
107            match self.dese_fn {
108                Some(f) => f,
109                None => deserialize_multi_kv,
110            },
111        )
112    }
113
114    pub fn index_key(&self, k: K) -> Vec<u8> {
115        k.joined_key()
116    }
117}
118
119#[cfg(test)]
120mod test {
121    use cosmwasm_std::{testing::MockStorage, Uint128};
122    use cw_storage_plus::{Index, IndexList, IndexedMap, MultiIndex, PrimaryKey, U128Key, U64Key};
123    use serde::{Deserialize, Serialize};
124
125    use crate::cow::deserialize_multi_kv_custom_pk;
126
127    use super::ConditionalMultiIndex;
128
129    #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, PartialOrd)]
130    struct Test {
131        id: u64,
132        val: Uint128,
133    }
134
135    struct TestIndexes<'a> {
136        val: ConditionalMultiIndex<'a, (U128Key, Vec<u8>), Test>,
137        val_inv: ConditionalMultiIndex<'a, (U128Key, Vec<u8>), Test>,
138        val_n: MultiIndex<'a, (U128Key, Vec<u8>), Test>,
139    }
140
141    impl IndexList<Test> for TestIndexes<'_> {
142        fn get_indexes(&'_ self) -> Box<dyn Iterator<Item = &'_ dyn Index<Test>> + '_> {
143            let v: Vec<&dyn Index<Test>> = vec![&self.val, &self.val_n, &self.val_inv];
144            Box::new(v.into_iter())
145        }
146    }
147
148    fn idm<'a>() -> IndexedMap<'a, U64Key, Test, TestIndexes<'a>> {
149        IndexedMap::new(
150            "test",
151            TestIndexes {
152                val: ConditionalMultiIndex::new_ref(
153                    |t, k| (t.val.u128().into(), k),
154                    // only add to val if t.val > 100
155                    |t| t.val.u128() > 100,
156                    None,
157                    "test",
158                    "test__val",
159                ),
160                val_inv: ConditionalMultiIndex::new_ref(
161                    |t, _| {
162                        (
163                            t.val.u128().into(),
164                            U64Key::new(u64::max_value() - t.id).joined_key(),
165                        )
166                    },
167                    // only add to val if t.val > 100
168                    |t| t.val.u128() > 100,
169                    Some(|s, pk, kv| {
170                        deserialize_multi_kv_custom_pk(s, pk, kv, |old_kv| {
171                            U64Key::new(
172                                u64::max_value()
173                                    - u64::from_be_bytes(old_kv.as_slice().try_into().unwrap()),
174                            )
175                            .joined_key()
176                        })
177                    }),
178                    "test",
179                    "test__inv",
180                ),
181                val_n: MultiIndex::new(|t, k| (t.val.u128().into(), k), "test", "test__normal"),
182            },
183        )
184    }
185
186    #[test]
187    fn correct_namespace() {
188        let idm = idm();
189
190        assert_eq!(idm.idx.val.pk_namespace, "test");
191        assert_eq!(idm.idx.val.idx_namespace, "test__val");
192    }
193
194    #[test]
195    fn correctly_add_to_index() {
196        let mut storage = MockStorage::new();
197        idm()
198            .save(
199                &mut storage,
200                0.into(),
201                &Test {
202                    id: 0,
203                    val: Uint128::from(101u64),
204                },
205            )
206            .unwrap();
207
208        idm()
209            .save(
210                &mut storage,
211                1.into(),
212                &Test {
213                    id: 1,
214                    val: Uint128::from(100u64),
215                },
216            )
217            .unwrap();
218
219        idm()
220            .save(
221                &mut storage,
222                2.into(),
223                &Test {
224                    id: 2,
225                    val: Uint128::from(101u64),
226                },
227            )
228            .unwrap();
229
230        let v = idm()
231            .idx
232            .val
233            .sub_prefix(())
234            .range(&storage, None, None, cosmwasm_std::Order::Descending)
235            .map(|e| e.map(|(_, i)| (i.id, i.val.u128())).unwrap())
236            .collect::<Vec<_>>();
237
238        assert_eq!(v, vec![(2, 101), (0, 101)]);
239
240        let v_n = idm()
241            .idx
242            .val_n
243            .sub_prefix(())
244            .range(&storage, None, None, cosmwasm_std::Order::Descending)
245            .map(|e| e.map(|(_, i)| (i.id, i.val.u128())).unwrap())
246            .collect::<Vec<_>>();
247
248        assert_eq!(v_n, vec![(2, 101), (0, 101), (1, 100),]);
249    }
250
251    #[test]
252    fn correctly_add_to_index_custom_dese() {
253        let mut storage = MockStorage::new();
254        idm()
255            .save(
256                &mut storage,
257                0.into(),
258                &Test {
259                    id: 0,
260                    val: Uint128::from(101u64),
261                },
262            )
263            .unwrap();
264
265        idm()
266            .save(
267                &mut storage,
268                1.into(),
269                &Test {
270                    id: 1,
271                    val: Uint128::from(100u64),
272                },
273            )
274            .unwrap();
275
276        idm()
277            .save(
278                &mut storage,
279                2.into(),
280                &Test {
281                    id: 2,
282                    val: Uint128::from(101u64),
283                },
284            )
285            .unwrap();
286
287        let v_inv = idm()
288            .idx
289            .val_inv
290            .sub_prefix(())
291            .range(&storage, None, None, cosmwasm_std::Order::Descending)
292            .map(|e| e.map(|(_, i)| (i.id, i.val.u128())).unwrap())
293            .collect::<Vec<_>>();
294
295        assert_eq!(v_inv, vec![(0, 101), (2, 101)]);
296
297        let v_n = idm()
298            .idx
299            .val_n
300            .sub_prefix(())
301            .range(&storage, None, None, cosmwasm_std::Order::Descending)
302            .map(|e| e.map(|(_, i)| (i.id, i.val.u128())).unwrap())
303            .collect::<Vec<_>>();
304
305        assert_eq!(v_n, vec![(2, 101), (0, 101), (1, 100),]);
306    }
307}