Skip to main content

sp_runtime/traits/vers_tx_ext/
multi.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Types and trait to aggregate multiple versioned transaction extension pipelines.
19
20use crate::{
21	traits::{
22		DecodeWithVersion, DecodeWithVersionWithMemTracking, DispatchInfoOf, DispatchOriginOf,
23		Dispatchable, InvalidVersion, Pipeline, PipelineAtVers, PipelineMetadataBuilder,
24		PipelineVersion, PostDispatchInfoOf,
25	},
26	transaction_validity::{TransactionSource, TransactionValidityError, ValidTransaction},
27};
28use alloc::vec::Vec;
29use codec::Encode;
30use core::fmt::Debug;
31use scale_info::TypeInfo;
32use sp_weights::Weight;
33
34/// An item in [`MultiVersion`]. It represents a transaction extension pipeline of a specific
35/// single version.
36pub trait MultiVersionItem {
37	/// The version of the transaction extension pipeline.
38	///
39	/// `None` means that the item has no version and can't be decoded.
40	const VERSION: Option<u8>;
41}
42
43impl MultiVersionItem for InvalidVersion {
44	const VERSION: Option<u8> = None;
45}
46
47impl<const VERSION: u8, Extension> MultiVersionItem for PipelineAtVers<VERSION, Extension> {
48	const VERSION: Option<u8> = Some(VERSION);
49}
50
51macro_rules! declare_multi_version_enum {
52	($( $variant:tt, )*) => {
53
54		/// An implementation of [`Pipeline`] that aggregates multiple versioned transaction
55		/// extension pipeline.
56		///
57		/// It is an enum where each variant has its own version, duplicated version must be
58		/// avoided, only the first used version will be effective other duplicated version will be
59		/// ignored.
60		///
61		/// Versioned transaction extension pipelines are configured using the generic parameters.
62		///
63		/// # Example
64		///
65		/// ```
66		/// use sp_runtime::traits::{MultiVersion, PipelineAtVers};
67		///
68		/// struct PaymentExt;
69		/// struct PaymentExtV2;
70		/// struct NonceExt;
71		///
72		/// type ExtV1 = PipelineAtVers<1, (NonceExt, PaymentExt)>;
73		/// type ExtV4 = PipelineAtVers<4, (NonceExt, PaymentExtV2)>;
74		///
75		/// /// The transaction extension pipeline that supports both version 1 and 4.
76		/// type TransactionExtension = MultiVersion<ExtV1, ExtV4>;
77		/// ```
78		#[allow(private_interfaces)]
79		#[derive(PartialEq, Eq, Clone, Debug, TypeInfo)]
80		pub enum MultiVersion<
81			$(
82				$variant = InvalidVersion,
83			)*
84		> {
85			$(
86				/// The transaction extension pipeline of a specific version.
87				$variant($variant),
88			)*
89		}
90
91		impl<$( $variant: PipelineVersion, )*> PipelineVersion for MultiVersion<$( $variant, )*> {
92			fn version(&self) -> u8 {
93				match self {
94					$(
95						MultiVersion::$variant(v) => v.version(),
96					)*
97				}
98			}
99		}
100
101		// It encodes without the variant index.
102		impl<$( $variant: Encode, )*> Encode for MultiVersion<$( $variant, )*> {
103			fn size_hint(&self) -> usize {
104				match self {
105					$(
106						MultiVersion::$variant(v) => v.size_hint(),
107					)*
108				}
109			}
110			fn encode(&self) -> Vec<u8> {
111				match self {
112					$(
113						MultiVersion::$variant(v) => v.encode(),
114					)*
115				}
116			}
117			fn encode_to<CodecOutput: codec::Output + ?Sized>(&self, dest: &mut CodecOutput) {
118				match self {
119					$(
120						MultiVersion::$variant(v) => v.encode_to(dest),
121					)*
122				}
123			}
124			fn encoded_size(&self) -> usize {
125				match self {
126					$(
127						MultiVersion::$variant(v) => v.encoded_size(),
128					)*
129				}
130			}
131			fn using_encoded<FunctionResult, Function: FnOnce(&[u8]) -> FunctionResult>(
132				&self,
133				f: Function
134			) -> FunctionResult {
135				match self {
136					$(
137						MultiVersion::$variant(v) => v.using_encoded(f),
138					)*
139				}
140			}
141		}
142
143		// It decodes from a specified version.
144		impl<$( $variant: DecodeWithVersion + MultiVersionItem, )*>
145			DecodeWithVersion for MultiVersion<$( $variant, )*>
146		{
147			fn decode_with_version<CodecInput: codec::Input>(
148				extension_version: u8,
149				input: &mut CodecInput,
150			) -> Result<Self, codec::Error> {
151				$(
152					// Here we could try all variants without checking for the version,
153					// but the error would be less informative.
154					// Otherwise we could change the trait `DecodeWithVersion` to return an enum of
155					// 3 variants: ok, error and invalid_version.
156					if $variant::VERSION == Some(extension_version) {
157						return Ok(MultiVersion::$variant($variant::decode_with_version(extension_version, input)?));
158					}
159				)*
160
161				Err(codec::Error::from("Invalid extension version"))
162			}
163		}
164
165		impl<$( $variant: DecodeWithVersionWithMemTracking + MultiVersionItem, )*>
166			DecodeWithVersionWithMemTracking for MultiVersion<$( $variant, )*>
167		{}
168
169		impl<$( $variant: Pipeline<Call> + MultiVersionItem, )* Call: Dispatchable>
170			Pipeline<Call> for MultiVersion<$( $variant, )*>
171		{
172			fn build_metadata(builder: &mut PipelineMetadataBuilder) {
173				$(
174					$variant::build_metadata(builder);
175				)*
176			}
177			fn validate_only(
178				&self,
179				origin: DispatchOriginOf<Call>,
180				call: &Call,
181				info: &DispatchInfoOf<Call>,
182				len: usize,
183				source: TransactionSource,
184			) -> Result<ValidTransaction, TransactionValidityError> {
185				match self {
186					$(
187						MultiVersion::$variant(v) => v.validate_only(origin, call, info, len, source),
188					)*
189				}
190			}
191			fn dispatch_transaction(
192				self,
193				origin: DispatchOriginOf<Call>,
194				call: Call,
195				info: &DispatchInfoOf<Call>,
196				len: usize,
197			) -> crate::ApplyExtrinsicResultWithInfo<PostDispatchInfoOf<Call>> {
198				match self {
199					$(
200						MultiVersion::$variant(v) => v.dispatch_transaction(origin, call, info, len),
201					)*
202				}
203			}
204			fn weight(&self, call: &Call) -> Weight {
205				match self {
206					$(
207						MultiVersion::$variant(v) => v.weight(call),
208					)*
209				}
210			}
211		}
212	};
213}
214
215declare_multi_version_enum! {
216	A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z,
217}
218
219#[cfg(test)]
220mod tests {
221	use super::*;
222	use crate::{
223		traits::{
224			AsTransactionAuthorizedOrigin, DecodeWithVersion, DispatchInfoOf, Dispatchable,
225			Implication, Pipeline, PipelineVersion, TransactionExtension, TransactionSource,
226			ValidateResult,
227		},
228		transaction_validity::{InvalidTransaction, TransactionValidityError, ValidTransaction},
229		DispatchError,
230	};
231	use codec::{Decode, DecodeWithMemTracking, Encode};
232	use core::fmt::Debug;
233	use scale_info::TypeInfo;
234	use sp_weights::Weight;
235
236	// --------------------------------------------------------
237	// A mock call type and origin used for testing
238	// --------------------------------------------------------
239	#[derive(Clone, Debug, Encode, Decode, PartialEq, Eq, TypeInfo)]
240	pub struct MockCall(pub u32);
241
242	#[derive(Clone, Debug)]
243	pub struct MockOrigin(pub u8);
244
245	impl AsTransactionAuthorizedOrigin for MockOrigin {
246		fn is_transaction_authorized(&self) -> bool {
247			// Let's say any origin != 0 is authorized
248			self.0 != 0
249		}
250	}
251
252	impl Dispatchable for MockCall {
253		type RuntimeOrigin = MockOrigin;
254		type Config = ();
255		type Info = ();
256		type PostInfo = ();
257
258		fn dispatch(
259			self,
260			origin: Self::RuntimeOrigin,
261		) -> crate::DispatchResultWithInfo<Self::PostInfo> {
262			// If the origin is 0, dispatch fails.
263			// Also, if the call is 0, dispatch fails.
264			if origin.0 == 0 {
265				return Err(DispatchError::Other("Unauthorized origin=0").into());
266			}
267			if self.0 == 0 {
268				return Err(DispatchError::Other("call=0").into());
269			}
270			Ok(Default::default())
271		}
272	}
273
274	// --------------------------------------------------------
275	// Let's define two single-version pipelines with versions 4 and 7
276	// --------------------------------------------------------
277
278	// A single-version extension pipeline that "succeeds" only if token != 0
279	#[derive(Clone, Debug, Encode, Decode, DecodeWithMemTracking, PartialEq, Eq, TypeInfo)]
280	pub struct SimpleExtensionV4 {
281		pub token: u8,
282		pub declared_weight: u64,
283	}
284
285	impl TransactionExtension<MockCall> for SimpleExtensionV4 {
286		const IDENTIFIER: &'static str = "SimpleExtV4";
287		type Implicit = ();
288		type Val = ();
289		type Pre = ();
290
291		fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
292			Ok(())
293		}
294
295		fn weight(&self, _call: &MockCall) -> Weight {
296			Weight::from_parts(self.declared_weight, 0)
297		}
298
299		fn validate(
300			&self,
301			origin: MockOrigin,
302			_call: &MockCall,
303			_info: &DispatchInfoOf<MockCall>,
304			_len: usize,
305			_self_implicit: Self::Implicit,
306			_inherited_implication: &impl Implication,
307			_source: TransactionSource,
308		) -> ValidateResult<Self::Val, MockCall> {
309			if self.token == 0 {
310				Err(InvalidTransaction::Custom(44).into())
311			} else {
312				Ok((ValidTransaction::default(), (), origin))
313			}
314		}
315
316		fn prepare(
317			self,
318			_val: Self::Val,
319			_origin: &MockOrigin,
320			_call: &MockCall,
321			_info: &DispatchInfoOf<MockCall>,
322			_len: usize,
323		) -> Result<Self::Pre, TransactionValidityError> {
324			Ok(())
325		}
326	}
327
328	pub type PipelineV4 = PipelineAtVers<4, SimpleExtensionV4>;
329
330	// Another single-version extension pipeline, version=7
331	#[derive(Clone, Debug, Encode, Decode, DecodeWithMemTracking, PartialEq, Eq, TypeInfo)]
332	pub struct SimpleExtensionV7 {
333		pub token: u8,
334		pub declared_weight: u64,
335	}
336
337	impl TransactionExtension<MockCall> for SimpleExtensionV7 {
338		const IDENTIFIER: &'static str = "SimpleExtV7";
339		type Implicit = ();
340		type Val = ();
341		type Pre = ();
342
343		fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
344			Ok(())
345		}
346
347		fn weight(&self, _call: &MockCall) -> Weight {
348			Weight::from_parts(self.declared_weight, 0)
349		}
350
351		fn validate(
352			&self,
353			origin: MockOrigin,
354			_call: &MockCall,
355			_info: &DispatchInfoOf<MockCall>,
356			_len: usize,
357			_self_implicit: Self::Implicit,
358			_inherited_implication: &impl Implication,
359			_source: TransactionSource,
360		) -> ValidateResult<Self::Val, MockCall> {
361			if self.token == 0 {
362				Err(InvalidTransaction::Custom(77).into())
363			} else {
364				Ok((ValidTransaction::default(), (), origin))
365			}
366		}
367
368		fn prepare(
369			self,
370			_val: Self::Val,
371			_origin: &MockOrigin,
372			_call: &MockCall,
373			_info: &DispatchInfoOf<MockCall>,
374			_len: usize,
375		) -> Result<Self::Pre, TransactionValidityError> {
376			Ok(())
377		}
378	}
379
380	pub type PipelineV7 = PipelineAtVers<7, SimpleExtensionV7>;
381
382	// --------------------------------------------------------
383	// Our MultiVersion definition under test
384	// --------------------------------------------------------
385
386	pub type MyMultiExt = MultiVersion<PipelineV4, PipelineV7>;
387
388	// --------------------------------------------------------
389	// Actual tests
390	// --------------------------------------------------------
391
392	#[test]
393	fn decode_with_version_works_for_known_versions() {
394		// Build a pipeline for version=4
395		let pipeline_v4 = PipelineV4::new(SimpleExtensionV4 { token: 99, declared_weight: 123 });
396		let encoded_v4 = pipeline_v4.encode();
397		let decoded_v4 = MyMultiExt::decode_with_version(4, &mut &encoded_v4[..])
398			.expect("decode with version=4");
399		let expected_v4 = MultiVersion::A(pipeline_v4);
400		assert_eq!(decoded_v4, expected_v4);
401
402		// Build a pipeline for version=7
403		let pipeline_v7 = PipelineV7::new(SimpleExtensionV7 { token: 55, declared_weight: 777 });
404		let encoded_v7 = pipeline_v7.encode();
405		let decoded_v7 = MyMultiExt::decode_with_version(7, &mut &encoded_v7[..])
406			.expect("decode with version=7");
407		let expected_v7 = MultiVersion::B(pipeline_v7);
408		assert_eq!(decoded_v7, expected_v7);
409	}
410
411	#[test]
412	fn decode_with_unknown_version_fails() {
413		let pipeline_v4 = PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 100 });
414		let encoded_v4 = pipeline_v4.encode();
415
416		// Attempt decode with version=123 => fails
417		let decode_err = MyMultiExt::decode_with_version(123, &mut &encoded_v4[..])
418			.expect_err("decode must fail with unknown version=123");
419		assert!(format!("{}", decode_err).contains("Invalid extension version"));
420	}
421
422	#[test]
423	fn version_is_correct() {
424		// The variant "A" is always the first in our MultiVersion and is version=4
425		let multi_a =
426			MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 10 }));
427		assert_eq!(multi_a.version(), 4);
428
429		// The variant "B" is version=7
430		let multi_b =
431			MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 2, declared_weight: 20 }));
432		assert_eq!(multi_b.version(), 7);
433	}
434
435	#[test]
436	fn weight_check_works() {
437		let multi_a =
438			MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 500 }));
439		let multi_b =
440			MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 1, declared_weight: 999 }));
441
442		let call = MockCall(0);
443		assert_eq!(multi_a.weight(&call).ref_time(), 500);
444		assert_eq!(multi_b.weight(&call).ref_time(), 999);
445	}
446
447	#[test]
448	fn validate_only_logic_works() {
449		// A with token=0 => invalid
450		let invalid_a =
451			MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 0, declared_weight: 123 }));
452		let call = MockCall(42);
453		let validity = invalid_a.validate_only(
454			MockOrigin(42),
455			&call,
456			&Default::default(),
457			0,
458			TransactionSource::Local,
459		);
460		assert_eq!(
461			validity,
462			Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(44)))
463		);
464
465		// B with token=0 => invalid
466		let invalid_b =
467			MyMultiExt::B(PipelineV7::new(SimpleExtensionV7 { token: 0, declared_weight: 456 }));
468		let validity_b = invalid_b.validate_only(
469			MockOrigin(42),
470			&call,
471			&Default::default(),
472			0,
473			TransactionSource::Local,
474		);
475		assert_eq!(
476			validity_b,
477			Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(77)))
478		);
479
480		// A with token=some => ok
481		let valid_a =
482			MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 55, declared_weight: 10 }));
483		let result_ok_a = valid_a.validate_only(
484			MockOrigin(1),
485			&call,
486			&Default::default(),
487			0,
488			TransactionSource::External,
489		);
490		assert!(result_ok_a.is_ok(), "valid scenario for pipeline A");
491	}
492
493	#[test]
494	fn dispatch_transaction_works() {
495		// "A" with token != 0 => valid
496		let pipeline_a = PipelineV4::new(SimpleExtensionV4 { token: 33, declared_weight: 1 });
497		let multi_a = MyMultiExt::A(pipeline_a);
498		let call_good = MockCall(42);
499		multi_a
500			.dispatch_transaction(MockOrigin(9), call_good.clone(), &Default::default(), 0)
501			.expect("Should not fail validity")
502			.expect("Success");
503
504		// but call=0 => dispatch fails
505		let fail_res =
506			MyMultiExt::A(PipelineV4::new(SimpleExtensionV4 { token: 1, declared_weight: 10 }))
507				.dispatch_transaction(MockOrigin(9), MockCall(0), &Default::default(), 0)
508				.expect("Should be a valid transaction from viewpoint of extension");
509		let block_err = fail_res.expect_err("actual dispatch error");
510		assert_eq!(block_err.error, DispatchError::Other("call=0"));
511
512		// "B" scenario
513		let pipeline_b = PipelineV7::new(SimpleExtensionV7 { token: 2, declared_weight: 99 });
514		let multi_b = MyMultiExt::B(pipeline_b);
515		let outcome_ok = multi_b
516			.dispatch_transaction(MockOrigin(1), call_good, &Default::default(), 0)
517			.expect("Should pass validity");
518		assert!(outcome_ok.is_ok());
519	}
520}