1use crate::{
21 traits::{
22 DecodeWithVersion, DecodeWithVersionWithMemTracking, DispatchInfoOf, DispatchOriginOf,
23 Dispatchable, InvalidVersion, Pipeline, PipelineAtVers, PipelineMetadataBuilder,
24 PipelineVersion, PostDispatchInfoOf,
25 },
26 transaction_validity::{TransactionSource, TransactionValidityError, ValidTransaction},
27};
28use alloc::vec::Vec;
29use codec::Encode;
30use core::fmt::Debug;
31use scale_info::TypeInfo;
32use sp_weights::Weight;
33
34pub trait MultiVersionItem {
37 const VERSION: Option<u8>;
41}
42
43impl MultiVersionItem for InvalidVersion {
44 const VERSION: Option<u8> = None;
45}
46
47impl<const VERSION: u8, Extension> MultiVersionItem for PipelineAtVers<VERSION, Extension> {
48 const VERSION: Option<u8> = Some(VERSION);
49}
50
51macro_rules! declare_multi_version_enum {
52 ($( $variant:tt, )*) => {
53
54 #[allow(private_interfaces)]
79 #[derive(PartialEq, Eq, Clone, Debug, TypeInfo)]
80 pub enum MultiVersion<
81 $(
82 $variant = InvalidVersion,
83 )*
84 > {
85 $(
86 $variant($variant),
88 )*
89 }
90
91 impl<$( $variant: PipelineVersion, )*> PipelineVersion for MultiVersion<$( $variant, )*> {
92 fn version(&self) -> u8 {
93 match self {
94 $(
95 MultiVersion::$variant(v) => v.version(),
96 )*
97 }
98 }
99 }
100
101 impl<$( $variant: Encode, )*> Encode for MultiVersion<$( $variant, )*> {
103 fn size_hint(&self) -> usize {
104 match self {
105 $(
106 MultiVersion::$variant(v) => v.size_hint(),
107 )*
108 }
109 }
110 fn encode(&self) -> Vec<u8> {
111 match self {
112 $(
113 MultiVersion::$variant(v) => v.encode(),
114 )*
115 }
116 }
117 fn encode_to<CodecOutput: codec::Output + ?Sized>(&self, dest: &mut CodecOutput) {
118 match self {
119 $(
120 MultiVersion::$variant(v) => v.encode_to(dest),
121 )*
122 }
123 }
124 fn encoded_size(&self) -> usize {
125 match self {
126 $(
127 MultiVersion::$variant(v) => v.encoded_size(),
128 )*
129 }
130 }
131 fn using_encoded<FunctionResult, Function: FnOnce(&[u8]) -> FunctionResult>(
132 &self,
133 f: Function
134 ) -> FunctionResult {
135 match self {
136 $(
137 MultiVersion::$variant(v) => v.using_encoded(f),
138 )*
139 }
140 }
141 }
142
143 impl<$( $variant: DecodeWithVersion + MultiVersionItem, )*>
145 DecodeWithVersion for MultiVersion<$( $variant, )*>
146 {
147 fn decode_with_version<CodecInput: codec::Input>(
148 extension_version: u8,
149 input: &mut CodecInput,
150 ) -> Result<Self, codec::Error> {
151 $(
152 if $variant::VERSION == Some(extension_version) {
157 return Ok(MultiVersion::$variant($variant::decode_with_version(extension_version, input)?));
158 }
159 )*
160
161 Err(codec::Error::from("Invalid extension version"))
162 }
163 }
164
165 impl<$( $variant: DecodeWithVersionWithMemTracking + MultiVersionItem, )*>
166 DecodeWithVersionWithMemTracking for MultiVersion<$( $variant, )*>
167 {}
168
169 impl<$( $variant: Pipeline<Call> + MultiVersionItem, )* Call: Dispatchable>
170 Pipeline<Call> for MultiVersion<$( $variant, )*>
171 {
172 fn build_metadata(builder: &mut PipelineMetadataBuilder) {
173 $(
174 $variant::build_metadata(builder);
175 )*
176 }
177 fn validate_only(
178 &self,
179 origin: DispatchOriginOf<Call>,
180 call: &Call,
181 info: &DispatchInfoOf<Call>,
182 len: usize,
183 source: TransactionSource,
184 ) -> Result<ValidTransaction, TransactionValidityError> {
185 match self {
186 $(
187 MultiVersion::$variant(v) => v.validate_only(origin, call, info, len, source),
188 )*
189 }
190 }
191 fn dispatch_transaction(
192 self,
193 origin: DispatchOriginOf<Call>,
194 call: Call,
195 info: &DispatchInfoOf<Call>,
196 len: usize,
197 ) -> crate::ApplyExtrinsicResultWithInfo<PostDispatchInfoOf<Call>> {
198 match self {
199 $(
200 MultiVersion::$variant(v) => v.dispatch_transaction(origin, call, info, len),
201 )*
202 }
203 }
204 fn weight(&self, call: &Call) -> Weight {
205 match self {
206 $(
207 MultiVersion::$variant(v) => v.weight(call),
208 )*
209 }
210 }
211 }
212 };
213}
214
215declare_multi_version_enum! {
216 A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z,
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::{
223 traits::{
224 AsTransactionAuthorizedOrigin, DecodeWithVersion, DispatchInfoOf, Dispatchable,
225 Implication, Pipeline, PipelineVersion, TransactionExtension, TransactionSource,
226 ValidateResult,
227 },
228 transaction_validity::{InvalidTransaction, TransactionValidityError, ValidTransaction},
229 DispatchError,
230 };
231 use codec::{Decode, DecodeWithMemTracking, Encode};
232 use core::fmt::Debug;
233 use scale_info::TypeInfo;
234 use sp_weights::Weight;
235
236 #[derive(Clone, Debug, Encode, Decode, PartialEq, Eq, TypeInfo)]
240 pub struct MockCall(pub u32);
241
242 #[derive(Clone, Debug)]
243 pub struct MockOrigin(pub u8);
244
245 impl AsTransactionAuthorizedOrigin for MockOrigin {
246 fn is_transaction_authorized(&self) -> bool {
247 self.0 != 0
249 }
250 }
251
252 impl Dispatchable for MockCall {
253 type RuntimeOrigin = MockOrigin;
254 type Config = ();
255 type Info = ();
256 type PostInfo = ();
257
258 fn dispatch(
259 self,
260 origin: Self::RuntimeOrigin,
261 ) -> crate::DispatchResultWithInfo<Self::PostInfo> {
262 if origin.0 == 0 {
265 return Err(DispatchError::Other("Unauthorized origin=0").into());
266 }
267 if self.0 == 0 {
268 return Err(DispatchError::Other("call=0").into());
269 }
270 Ok(Default::default())
271 }
272 }
273
274 #[derive(Clone, Debug, Encode, Decode, DecodeWithMemTracking, PartialEq, Eq, TypeInfo)]
280 pub struct SimpleExtensionV4 {
281 pub token: u8,
282 pub declared_weight: u64,
283 }
284
285 impl TransactionExtension<MockCall> for SimpleExtensionV4 {
286 const IDENTIFIER: &'static str = "SimpleExtV4";
287 type Implicit = ();
288 type Val = ();
289 type Pre = ();
290
291 fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
292 Ok(())
293 }
294
295 fn weight(&self, _call: &MockCall) -> Weight {
296 Weight::from_parts(self.declared_weight, 0)
297 }
298
299 fn validate(
300 &self,
301 origin: MockOrigin,
302 _call: &MockCall,
303 _info: &DispatchInfoOf<MockCall>,
304 _len: usize,
305 _self_implicit: Self::Implicit,
306 _inherited_implication: &impl Implication,
307 _source: TransactionSource,
308 ) -> ValidateResult<Self::Val, MockCall> {
309 if self.token == 0 {
310 Err(InvalidTransaction::Custom(44).into())
311 } else {
312 Ok((ValidTransaction::default(), (), origin))
313 }
314 }
315
316 fn prepare(
317 self,
318 _val: Self::Val,
319 _origin: &MockOrigin,
320 _call: &MockCall,
321 _info: &DispatchInfoOf<MockCall>,
322 _len: usize,
323 ) -> Result<Self::Pre, TransactionValidityError> {
324 Ok(())
325 }
326 }
327
328 pub type PipelineV4 = PipelineAtVers<4, SimpleExtensionV4>;
329
330 #[derive(Clone, Debug, Encode, Decode, DecodeWithMemTracking, PartialEq, Eq, TypeInfo)]
332 pub struct SimpleExtensionV7 {
333 pub token: u8,
334 pub declared_weight: u64,
335 }
336
337 impl TransactionExtension<MockCall> for SimpleExtensionV7 {
338 const IDENTIFIER: &'static str = "SimpleExtV7";
339 type Implicit = ();
340 type Val = ();
341 type Pre = ();
342
343 fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
344 Ok(())
345 }
346
347 fn weight(&self, _call: &MockCall) -> Weight {
348 Weight::from_parts(self.declared_weight, 0)
349 }
350
351 fn validate(
352 &self,
353 origin: MockOrigin,
354 _call: &MockCall,
355 _info: &DispatchInfoOf<MockCall>,
356 _len: usize,
357 _self_implicit: Self::Implicit,
358 _inherited_implication: &impl Implication,
359 _source: TransactionSource,
360 ) -> ValidateResult<Self::Val, MockCall> {
361 if self.token == 0 {
362 Err(InvalidTransaction::Custom(77).into())
363 } else {
364 Ok((ValidTransaction::default(), (), origin))
365 }
366 }
367
368 fn prepare(
369 self,
370 _val: Self::Val,
371 _origin: &MockOrigin,
372 _call: &MockCall,
373 _info: &DispatchInfoOf<MockCall>,
374 _len: usize,
375 ) -> Result<Self::Pre, TransactionValidityError> {
376 Ok(())
377 }
378 }
379
380 pub type PipelineV7 = PipelineAtVers<7, SimpleExtensionV7>;
381
382 pub type MyMultiExt = MultiVersion<PipelineV4, PipelineV7>;
387
388 #[test]
393 fn decode_with_version_works_for_known_versions() {
394 let pipeline_v4 = PipelineV4::new(SimpleExtensionV4 { token: 99, declared_weight: 123 });
396 let encoded_v4 = pipeline_v4.encode();
397 let decoded_v4 = MyMultiExt::decode_with_version(4, &mut &encoded_v4[..])
398 .expect("decode with version=4");
399 let expected_v4 = MultiVersion::A(pipeline_v4);
400 assert_eq!(decoded_v4, expected_v4);
401
402 let pipeline_v7 = PipelineV7::new(SimpleExtensionV7 { token: 55, declared_weight: 777 });
404 let encoded_v7 = pipeline_v7.encode();
405 let decoded_v7 = MyMultiExt::decode_with_version(7, &mut &encoded_v7[..])
406 .expect("decode with version=7");
407 let expected_v7 = MultiVersion::B(pipeline_v7);
408 assert_eq!(decoded_v7, expected_v7);
409 }
410
411 #[test]
412 fn decode_with_unknown_version_fails() {
413 let pipeline_v4 = PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 100 });
414 let encoded_v4 = pipeline_v4.encode();
415
416 let decode_err = MyMultiExt::decode_with_version(123, &mut &encoded_v4[..])
418 .expect_err("decode must fail with unknown version=123");
419 assert!(format!("{}", decode_err).contains("Invalid extension version"));
420 }
421
422 #[test]
423 fn version_is_correct() {
424 let multi_a =
426 MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 10 }));
427 assert_eq!(multi_a.version(), 4);
428
429 let multi_b =
431 MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 2, declared_weight: 20 }));
432 assert_eq!(multi_b.version(), 7);
433 }
434
435 #[test]
436 fn weight_check_works() {
437 let multi_a =
438 MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 500 }));
439 let multi_b =
440 MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 1, declared_weight: 999 }));
441
442 let call = MockCall(0);
443 assert_eq!(multi_a.weight(&call).ref_time(), 500);
444 assert_eq!(multi_b.weight(&call).ref_time(), 999);
445 }
446
447 #[test]
448 fn validate_only_logic_works() {
449 let invalid_a =
451 MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 0, declared_weight: 123 }));
452 let call = MockCall(42);
453 let validity = invalid_a.validate_only(
454 MockOrigin(42),
455 &call,
456 &Default::default(),
457 0,
458 TransactionSource::Local,
459 );
460 assert_eq!(
461 validity,
462 Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(44)))
463 );
464
465 let invalid_b =
467 MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 0, declared_weight: 456 }));
468 let validity_b = invalid_b.validate_only(
469 MockOrigin(42),
470 &call,
471 &Default::default(),
472 0,
473 TransactionSource::Local,
474 );
475 assert_eq!(
476 validity_b,
477 Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(77)))
478 );
479
480 let valid_a =
482 MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 55, declared_weight: 10 }));
483 let result_ok_a = valid_a.validate_only(
484 MockOrigin(1),
485 &call,
486 &Default::default(),
487 0,
488 TransactionSource::External,
489 );
490 assert!(result_ok_a.is_ok(), "valid scenario for pipeline A");
491 }
492
493 #[test]
494 fn dispatch_transaction_works() {
495 let pipeline_a = PipelineV4::new(SimpleExtensionV4 { token: 33, declared_weight: 1 });
497 let multi_a = MyMultiExt::A(pipeline_a);
498 let call_good = MockCall(42);
499 multi_a
500 .dispatch_transaction(MockOrigin(9), call_good.clone(), &Default::default(), 0)
501 .expect("Should not fail validity")
502 .expect("Success");
503
504 let fail_res =
506 MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 10 }))
507 .dispatch_transaction(MockOrigin(9), MockCall(0), &Default::default(), 0)
508 .expect("Should be a valid transaction from viewpoint of extension");
509 let block_err = fail_res.expect_err("actual dispatch error");
510 assert_eq!(block_err.error, DispatchError::Other("call=0"));
511
512 let pipeline_b = PipelineV7::new(SimpleExtensionV7 { token: 2, declared_weight: 99 });
514 let multi_b = MyMultiExt::B(pipeline_b);
515 let outcome_ok = multi_b
516 .dispatch_transaction(MockOrigin(1), call_good, &Default::default(), 0)
517 .expect("Should pass validity");
518 assert!(outcome_ok.is_ok());
519 }
520}