Skip to main content

sp_runtime/traits/vers_tx_ext/
at_vers.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Type to define a versioned transaction extension pipeline for a specific version.
19
20use crate::{
21	traits::{
22		AsTransactionAuthorizedOrigin, DecodeWithVersion, DecodeWithVersionWithMemTracking,
23		DispatchInfoOf, DispatchOriginOf, DispatchTransaction, Dispatchable, Pipeline,
24		PipelineMetadataBuilder, PipelineVersion, PostDispatchInfoOf, TransactionExtension,
25	},
26	transaction_validity::{TransactionSource, TransactionValidityError, ValidTransaction},
27};
28use codec::{Decode, DecodeWithMemTracking, Encode};
29use core::fmt::Debug;
30use scale_info::TypeInfo;
31use sp_weights::Weight;
32
33/// A transaction extension pipeline defined for a single version.
34#[derive(Encode, Clone, Debug, TypeInfo, PartialEq, Eq)]
35pub struct PipelineAtVers<const VERSION: u8, Extension> {
36	/// The transaction extension pipeline for the version `VERSION`.
37	pub extension: Extension,
38}
39
40impl<const VERSION: u8, Extension> PipelineAtVers<VERSION, Extension> {
41	/// Create a new versioned extension.
42	pub fn new(extension: Extension) -> Self {
43		Self { extension }
44	}
45}
46
47impl<const VERSION: u8, Extension: Decode> DecodeWithVersion
48	for PipelineAtVers<VERSION, Extension>
49{
50	fn decode_with_version<I: codec::Input>(
51		extension_version: u8,
52		input: &mut I,
53	) -> Result<Self, codec::Error> {
54		if extension_version == VERSION {
55			Ok(PipelineAtVers { extension: Extension::decode(input)? })
56		} else {
57			Err(codec::Error::from("Invalid extension version"))
58		}
59	}
60}
61
62impl<const VERSION: u8, Extension: DecodeWithMemTracking> DecodeWithVersionWithMemTracking
63	for PipelineAtVers<VERSION, Extension>
64{
65}
66
67impl<const VERSION: u8, Extension> PipelineVersion for PipelineAtVers<VERSION, Extension> {
68	fn version(&self) -> u8 {
69		VERSION
70	}
71}
72
73impl<const VERSION: u8, Call, Extension> Pipeline<Call> for PipelineAtVers<VERSION, Extension>
74where
75	Call: Dispatchable<RuntimeOrigin: AsTransactionAuthorizedOrigin> + Encode,
76	Extension: TransactionExtension<Call>,
77{
78	fn build_metadata(builder: &mut PipelineMetadataBuilder) {
79		builder.push_versioned_extension(VERSION, Extension::metadata());
80	}
81	fn validate_only(
82		&self,
83		origin: DispatchOriginOf<Call>,
84		call: &Call,
85		info: &DispatchInfoOf<Call>,
86		len: usize,
87		source: TransactionSource,
88	) -> Result<ValidTransaction, TransactionValidityError> {
89		self.extension
90			.validate_only(origin, call, info, len, source, VERSION)
91			.map(|x| x.0)
92	}
93	fn dispatch_transaction(
94		self,
95		origin: DispatchOriginOf<Call>,
96		call: Call,
97		info: &DispatchInfoOf<Call>,
98		len: usize,
99	) -> crate::ApplyExtrinsicResultWithInfo<PostDispatchInfoOf<Call>> {
100		self.extension.dispatch_transaction(origin, call, info, len, VERSION)
101	}
102	fn weight(&self, call: &Call) -> Weight {
103		self.extension.weight(call)
104	}
105}
106
107#[cfg(test)]
108mod tests {
109	use super::*;
110	use crate::{
111		traits::{
112			Dispatchable, Implication, TransactionExtension, TransactionSource, ValidateResult,
113		},
114		transaction_validity::{InvalidTransaction, TransactionValidityError, ValidTransaction},
115		DispatchError,
116	};
117	use codec::{Decode, DecodeWithMemTracking, Encode};
118	use sp_weights::Weight;
119
120	// --- Mock types ---
121
122	/// A mock call type implementing Dispatchable
123	#[derive(Clone, Debug, Encode, Decode, PartialEq, Eq)]
124	pub struct MockCall(pub u64);
125	#[derive(Debug)]
126	pub struct MockOrigin(pub u64);
127
128	impl AsTransactionAuthorizedOrigin for MockOrigin {
129		fn is_transaction_authorized(&self) -> bool {
130			true
131		}
132	}
133
134	impl Dispatchable for MockCall {
135		type RuntimeOrigin = MockOrigin;
136		type Config = ();
137		type Info = ();
138		type PostInfo = ();
139
140		fn dispatch(
141			self,
142			origin: Self::RuntimeOrigin,
143		) -> crate::DispatchResultWithInfo<Self::PostInfo> {
144			if origin.0 == 0 {
145				return Err(DispatchError::Other("origin is 0").into());
146			}
147			Ok(Default::default())
148		}
149	}
150
151	// A trivial extension that sets a known weight and does minimal logic.
152	// We simply store an integer "token" and do check logic on it.
153	#[derive(PartialEq, Eq, Clone, Debug, Encode, Decode, DecodeWithMemTracking, TypeInfo)]
154	pub struct SimpleExtension {
155		/// The token for validation logic
156		pub token: u32,
157		/// The "weight" that this extension claims to cost.
158		pub w: u64,
159	}
160
161	impl TransactionExtension<MockCall> for SimpleExtension {
162		const IDENTIFIER: &'static str = "SimpleExtension";
163
164		type Implicit = ();
165		fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
166			Ok(())
167		}
168
169		type Val = ();
170		type Pre = ();
171
172		fn weight(&self, _call: &MockCall) -> Weight {
173			Weight::from_parts(self.w, 0)
174		}
175
176		fn validate(
177			&self,
178			origin: MockOrigin,
179			_call: &MockCall,
180			_info: &DispatchInfoOf<MockCall>,
181			_len: usize,
182			_self_implicit: Self::Implicit,
183			_inherited_implication: &impl Implication,
184			_source: TransactionSource,
185		) -> ValidateResult<Self::Val, MockCall> {
186			// any origin is permitted, but if `token == 0` => invalid
187			if self.token == 0 {
188				Err(InvalidTransaction::Custom(1).into())
189			} else {
190				Ok((ValidTransaction::default(), (), origin))
191			}
192		}
193
194		fn prepare(
195			self,
196			_val: Self::Val,
197			_origin: &MockOrigin,
198			_call: &MockCall,
199			_info: &DispatchInfoOf<MockCall>,
200			_len: usize,
201		) -> Result<Self::Pre, TransactionValidityError> {
202			Ok(())
203		}
204	}
205
206	// This type represents the versioned extension pipeline for version=3.
207	pub type ExtV3 = PipelineAtVers<3, SimpleExtension>;
208
209	// This type represents the versioned extension pipeline for version=10.
210	pub type ExtV10 = PipelineAtVers<10, SimpleExtension>;
211
212	// --- Tests ---
213
214	#[test]
215	fn decode_with_correct_version_succeeds() {
216		let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
217		let encoded = ext_v3.encode();
218
219		let decoded = <ExtV3 as DecodeWithVersion>::decode_with_version(3, &mut &encoded[..])
220			.expect("should decode fine with matching version");
221		assert_eq!(decoded.extension.token, 55);
222		assert_eq!(decoded.extension.w, 1234);
223	}
224
225	#[test]
226	fn decode_with_incorrect_version_fails() {
227		let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
228		let encoded = ext_v3.encode();
229
230		// Attempt decode with version=10
231		let decode_err = <ExtV3 as DecodeWithVersion>::decode_with_version(10, &mut &encoded[..])
232			.expect_err("should fail decode due to invalid version");
233		let decode_err_str = format!("{}", decode_err);
234		assert!(decode_err_str.contains("Invalid extension version"));
235	}
236
237	#[test]
238	fn version_is_correct() {
239		let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
240		assert_eq!(ext_v3.version(), 3);
241
242		let ext_v10 = ExtV10 { extension: SimpleExtension { token: 1, w: 1 } };
243		assert_eq!(ext_v10.version(), 10);
244	}
245
246	#[test]
247	fn pipeline_functions_work() {
248		let ext_v3 = ExtV3 { extension: SimpleExtension { token: 999, w: 50 } };
249
250		// test "weight" function
251		let call = MockCall(0x_f00);
252		assert_eq!(ext_v3.weight(&call).ref_time(), 50);
253
254		// test validating logic
255		{
256			// token = 0 => invalid
257			let invalid_ext_v3 = ExtV3 { extension: SimpleExtension { token: 0, w: 10 } };
258			let validity = invalid_ext_v3.validate_only(
259				MockOrigin(1),
260				&call,
261				&Default::default(),
262				0,
263				TransactionSource::External,
264			);
265			assert_eq!(
266				validity,
267				Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(1)))
268			);
269		}
270
271		// ok scenario: token != 0 => OK
272		let validity_ok = ext_v3.validate_only(
273			MockOrigin(2),
274			&call,
275			&Default::default(),
276			0,
277			TransactionSource::Local,
278		);
279		assert!(validity_ok.is_ok());
280		let valid = validity_ok.unwrap();
281		assert_eq!(valid, ValidTransaction::default());
282	}
283
284	#[test]
285	fn dispatch_transaction_works() {
286		// This extension is valid => token=1
287		let ext_v3 = ExtV3 { extension: SimpleExtension { token: 1, w: 10 } };
288		let call = MockCall(123);
289		let info = Default::default();
290		let len = 0usize;
291
292		// dispatch => OK
293		ext_v3
294			.clone()
295			.dispatch_transaction(MockOrigin(1), call.clone(), &info, len)
296			.expect("valid dispatch")
297			.expect("should be OK");
298
299		// but if origin is None => the underlying call fails
300		let res_fail = ext_v3.dispatch_transaction(MockOrigin(0), call, &info, len);
301		let block_err = res_fail.expect("valid").expect_err("should fail");
302		assert_eq!(block_err.error, DispatchError::Other("origin is 0"));
303	}
304}