1use alloc::vec;
8use alloc::vec::Vec;
9use codec::{Compact, Input};
10use core::marker::PhantomData;
11use scale_bits::{
12 Bits,
13 scale::format::{Format, OrderFormat, StoreFormat},
14};
15use scale_decode::{IntoVisitor, TypeResolver};
16
17pub trait BitStore {
22 const FORMAT: StoreFormat;
24 const BITS: u32;
26}
27macro_rules! impl_store {
28 ($ty:ident, $wrapped:ty) => {
29 impl BitStore for $wrapped {
30 const FORMAT: StoreFormat = StoreFormat::$ty;
31 const BITS: u32 = <$wrapped>::BITS;
32 }
33 };
34}
35impl_store!(U8, u8);
36impl_store!(U16, u16);
37impl_store!(U32, u32);
38impl_store!(U64, u64);
39
40pub trait BitOrder {
45 const FORMAT: OrderFormat;
47}
48macro_rules! impl_order {
49 ($ty:ident) => {
50 #[doc = concat!("Type-level value that corresponds to `scale_bits::OrderFormat::", stringify!($ty), "` at run-time")]
51 #[doc = concat!(" and `bitvec::order::BitOrder::", stringify!($ty), "` at the type level.")]
52 #[derive(Clone, Debug, PartialEq, Eq)]
53 pub enum $ty {}
54 impl BitOrder for $ty {
55 const FORMAT: OrderFormat = OrderFormat::$ty;
56 }
57 };
58}
59impl_order!(Lsb0);
60impl_order!(Msb0);
61
62fn bit_format<Store: BitStore, Order: BitOrder>() -> Format {
64 Format {
65 order: Order::FORMAT,
66 store: Store::FORMAT,
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
73pub struct DecodedBits<Store, Order> {
74 bits: Bits,
75 _marker: PhantomData<(Store, Order)>,
76}
77
78impl<Store, Order> DecodedBits<Store, Order> {
79 pub fn into_bits(self) -> Bits {
81 self.bits
82 }
83
84 pub fn as_bits(&self) -> &Bits {
86 &self.bits
87 }
88}
89
90impl<Store, Order> core::iter::FromIterator<bool> for DecodedBits<Store, Order> {
91 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
92 DecodedBits {
93 bits: Bits::from_iter(iter),
94 _marker: PhantomData,
95 }
96 }
97}
98
99impl<Store: BitStore, Order: BitOrder> codec::Decode for DecodedBits<Store, Order> {
100 fn decode<I: Input>(input: &mut I) -> Result<Self, codec::Error> {
101 const ARCH32BIT_BITSLICE_MAX_BITS: u32 = 0x1fff_ffff;
103
104 let Compact(bits) = <Compact<u32>>::decode(input)?;
105 if bits > ARCH32BIT_BITSLICE_MAX_BITS {
107 return Err("Attempt to decode a BitVec with too many bits".into());
108 }
109 let elements = (bits / Store::BITS) + u32::from(bits % Store::BITS != 0);
111 let bytes_in_elem = Store::BITS.saturating_div(u8::BITS);
112 let bytes_needed = (elements * bytes_in_elem) as usize;
113
114 let mut storage = codec::Encode::encode(&Compact(bits));
118 let prefix_len = storage.len();
119 storage.reserve_exact(bytes_needed);
120 storage.extend(vec![0; bytes_needed]);
121 input.read(&mut storage[prefix_len..])?;
122
123 let decoder = scale_bits::decode_using_format_from(&storage, bit_format::<Store, Order>())?;
124 let bits = decoder.collect::<Result<Vec<_>, _>>()?;
125 let bits = Bits::from_iter(bits);
126
127 Ok(DecodedBits {
128 bits,
129 _marker: PhantomData,
130 })
131 }
132}
133
134impl<Store: BitStore, Order: BitOrder> codec::Encode for DecodedBits<Store, Order> {
135 fn size_hint(&self) -> usize {
136 self.bits.size_hint()
137 }
138
139 fn encoded_size(&self) -> usize {
140 self.bits.encoded_size()
141 }
142
143 fn encode(&self) -> Vec<u8> {
144 scale_bits::encode_using_format(self.bits.iter(), bit_format::<Store, Order>())
145 }
146}
147
148#[doc(hidden)]
149pub struct DecodedBitsVisitor<S, O, R: TypeResolver>(core::marker::PhantomData<(S, O, R)>);
150
151impl<Store, Order, R: TypeResolver> scale_decode::Visitor for DecodedBitsVisitor<Store, Order, R> {
152 type Value<'scale, 'info> = DecodedBits<Store, Order>;
153 type Error = scale_decode::Error;
154 type TypeResolver = R;
155
156 fn unchecked_decode_as_type<'scale, 'info>(
157 self,
158 input: &mut &'scale [u8],
159 type_id: R::TypeId,
160 types: &'info R,
161 ) -> scale_decode::visitor::DecodeAsTypeResult<
162 Self,
163 Result<Self::Value<'scale, 'info>, Self::Error>,
164 > {
165 let res =
166 scale_decode::visitor::decode_with_visitor(input, type_id, types, Bits::into_visitor())
167 .map(|bits| DecodedBits {
168 bits,
169 _marker: PhantomData,
170 });
171 scale_decode::visitor::DecodeAsTypeResult::Decoded(res)
172 }
173}
174impl<Store, Order> scale_decode::IntoVisitor for DecodedBits<Store, Order> {
175 type AnyVisitor<R: scale_decode::TypeResolver> = DecodedBitsVisitor<Store, Order, R>;
176 fn into_visitor<R: TypeResolver>() -> DecodedBitsVisitor<Store, Order, R> {
177 DecodedBitsVisitor(PhantomData)
178 }
179}
180
181impl<Store, Order> scale_encode::EncodeAsType for DecodedBits<Store, Order> {
182 fn encode_as_type_to<R: TypeResolver>(
183 &self,
184 type_id: R::TypeId,
185 types: &R,
186 out: &mut Vec<u8>,
187 ) -> Result<(), scale_encode::Error> {
188 self.bits.encode_as_type_to(type_id, types, out)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 use core::fmt::Debug;
197
198 use bitvec::vec::BitVec;
199 use codec::Decode as _;
200
201 trait ToBitVec {
206 type Order: bitvec::order::BitOrder;
207 }
208 impl ToBitVec for Lsb0 {
209 type Order = bitvec::order::Lsb0;
210 }
211 impl ToBitVec for Msb0 {
212 type Order = bitvec::order::Msb0;
213 }
214
215 fn scales_like_bitvec_and_roundtrips<
216 'a,
217 Store: BitStore + bitvec::store::BitStore + PartialEq,
218 Order: BitOrder + ToBitVec + Debug + PartialEq,
219 >(
220 input: impl IntoIterator<Item = &'a bool>,
221 ) where
222 BitVec<Store, <Order as ToBitVec>::Order>: codec::Encode + codec::Decode,
223 {
224 let input: Vec<_> = input.into_iter().copied().collect();
225
226 let decoded_bits = DecodedBits::<Store, Order>::from_iter(input.clone());
227 let bitvec = BitVec::<Store, <Order as ToBitVec>::Order>::from_iter(input);
228
229 let decoded_bits_encoded = codec::Encode::encode(&decoded_bits);
230 let bitvec_encoded = codec::Encode::encode(&bitvec);
231 assert_eq!(decoded_bits_encoded, bitvec_encoded);
232
233 let decoded_bits_decoded =
234 DecodedBits::<Store, Order>::decode(&mut &decoded_bits_encoded[..])
235 .expect("SCALE-encoding DecodedBits to roundtrip");
236 let bitvec_decoded =
237 BitVec::<Store, <Order as ToBitVec>::Order>::decode(&mut &bitvec_encoded[..])
238 .expect("SCALE-encoding BitVec to roundtrip");
239 assert_eq!(decoded_bits, decoded_bits_decoded);
240 assert_eq!(bitvec, bitvec_decoded);
241 }
242
243 #[test]
244 fn decoded_bitvec_scales_and_roundtrips() {
245 let test_cases = [
246 vec![],
247 vec![true],
248 vec![false],
249 vec![true, false, true],
250 vec![true, false, true, false, false, false, false, false, true],
251 [vec![true; 5], vec![false; 5], vec![true; 1], vec![false; 3]].concat(),
252 [vec![true; 9], vec![false; 9], vec![true; 9], vec![false; 9]].concat(),
253 ];
254
255 for test_case in &test_cases {
256 scales_like_bitvec_and_roundtrips::<u8, Lsb0>(test_case);
257 scales_like_bitvec_and_roundtrips::<u16, Lsb0>(test_case);
258 scales_like_bitvec_and_roundtrips::<u32, Lsb0>(test_case);
259 scales_like_bitvec_and_roundtrips::<u64, Lsb0>(test_case);
260 scales_like_bitvec_and_roundtrips::<u8, Msb0>(test_case);
261 scales_like_bitvec_and_roundtrips::<u16, Msb0>(test_case);
262 scales_like_bitvec_and_roundtrips::<u32, Msb0>(test_case);
263 scales_like_bitvec_and_roundtrips::<u64, Msb0>(test_case);
264 }
265 }
266}