sp_runtime/traits/vers_tx_ext/
at_vers.rs1use crate::{
21 traits::{
22 AsTransactionAuthorizedOrigin, DecodeWithVersion, DecodeWithVersionWithMemTracking,
23 DispatchInfoOf, DispatchOriginOf, DispatchTransaction, Dispatchable, Pipeline,
24 PipelineMetadataBuilder, PipelineVersion, PostDispatchInfoOf, TransactionExtension,
25 },
26 transaction_validity::{TransactionSource, TransactionValidityError, ValidTransaction},
27};
28use codec::{Decode, DecodeWithMemTracking, Encode};
29use core::fmt::Debug;
30use scale_info::TypeInfo;
31use sp_weights::Weight;
32
33#[derive(Encode, Clone, Debug, TypeInfo, PartialEq, Eq)]
35pub struct PipelineAtVers<const VERSION: u8, Extension> {
36 pub extension: Extension,
38}
39
40impl<const VERSION: u8, Extension> PipelineAtVers<VERSION, Extension> {
41 pub fn new(extension: Extension) -> Self {
43 Self { extension }
44 }
45}
46
47impl<const VERSION: u8, Extension: Decode> DecodeWithVersion
48 for PipelineAtVers<VERSION, Extension>
49{
50 fn decode_with_version<I: codec::Input>(
51 extension_version: u8,
52 input: &mut I,
53 ) -> Result<Self, codec::Error> {
54 if extension_version == VERSION {
55 Ok(PipelineAtVers { extension: Extension::decode(input)? })
56 } else {
57 Err(codec::Error::from("Invalid extension version"))
58 }
59 }
60}
61
62impl<const VERSION: u8, Extension: DecodeWithMemTracking> DecodeWithVersionWithMemTracking
63 for PipelineAtVers<VERSION, Extension>
64{
65}
66
67impl<const VERSION: u8, Extension> PipelineVersion for PipelineAtVers<VERSION, Extension> {
68 fn version(&self) -> u8 {
69 VERSION
70 }
71}
72
73impl<const VERSION: u8, Call, Extension> Pipeline<Call> for PipelineAtVers<VERSION, Extension>
74where
75 Call: Dispatchable<RuntimeOrigin: AsTransactionAuthorizedOrigin> + Encode,
76 Extension: TransactionExtension<Call>,
77{
78 fn build_metadata(builder: &mut PipelineMetadataBuilder) {
79 builder.push_versioned_extension(VERSION, Extension::metadata());
80 }
81 fn validate_only(
82 &self,
83 origin: DispatchOriginOf<Call>,
84 call: &Call,
85 info: &DispatchInfoOf<Call>,
86 len: usize,
87 source: TransactionSource,
88 ) -> Result<ValidTransaction, TransactionValidityError> {
89 self.extension
90 .validate_only(origin, call, info, len, source, VERSION)
91 .map(|x| x.0)
92 }
93 fn dispatch_transaction(
94 self,
95 origin: DispatchOriginOf<Call>,
96 call: Call,
97 info: &DispatchInfoOf<Call>,
98 len: usize,
99 ) -> crate::ApplyExtrinsicResultWithInfo<PostDispatchInfoOf<Call>> {
100 self.extension.dispatch_transaction(origin, call, info, len, VERSION)
101 }
102 fn weight(&self, call: &Call) -> Weight {
103 self.extension.weight(call)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::{
111 traits::{
112 Dispatchable, Implication, TransactionExtension, TransactionSource, ValidateResult,
113 },
114 transaction_validity::{InvalidTransaction, TransactionValidityError, ValidTransaction},
115 DispatchError,
116 };
117 use codec::{Decode, DecodeWithMemTracking, Encode};
118 use sp_weights::Weight;
119
120 #[derive(Clone, Debug, Encode, Decode, PartialEq, Eq)]
124 pub struct MockCall(pub u64);
125 #[derive(Debug)]
126 pub struct MockOrigin(pub u64);
127
128 impl AsTransactionAuthorizedOrigin for MockOrigin {
129 fn is_transaction_authorized(&self) -> bool {
130 true
131 }
132 }
133
134 impl Dispatchable for MockCall {
135 type RuntimeOrigin = MockOrigin;
136 type Config = ();
137 type Info = ();
138 type PostInfo = ();
139
140 fn dispatch(
141 self,
142 origin: Self::RuntimeOrigin,
143 ) -> crate::DispatchResultWithInfo<Self::PostInfo> {
144 if origin.0 == 0 {
145 return Err(DispatchError::Other("origin is 0").into());
146 }
147 Ok(Default::default())
148 }
149 }
150
151 #[derive(PartialEq, Eq, Clone, Debug, Encode, Decode, DecodeWithMemTracking, TypeInfo)]
154 pub struct SimpleExtension {
155 pub token: u32,
157 pub w: u64,
159 }
160
161 impl TransactionExtension<MockCall> for SimpleExtension {
162 const IDENTIFIER: &'static str = "SimpleExtension";
163
164 type Implicit = ();
165 fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
166 Ok(())
167 }
168
169 type Val = ();
170 type Pre = ();
171
172 fn weight(&self, _call: &MockCall) -> Weight {
173 Weight::from_parts(self.w, 0)
174 }
175
176 fn validate(
177 &self,
178 origin: MockOrigin,
179 _call: &MockCall,
180 _info: &DispatchInfoOf<MockCall>,
181 _len: usize,
182 _self_implicit: Self::Implicit,
183 _inherited_implication: &impl Implication,
184 _source: TransactionSource,
185 ) -> ValidateResult<Self::Val, MockCall> {
186 if self.token == 0 {
188 Err(InvalidTransaction::Custom(1).into())
189 } else {
190 Ok((ValidTransaction::default(), (), origin))
191 }
192 }
193
194 fn prepare(
195 self,
196 _val: Self::Val,
197 _origin: &MockOrigin,
198 _call: &MockCall,
199 _info: &DispatchInfoOf<MockCall>,
200 _len: usize,
201 ) -> Result<Self::Pre, TransactionValidityError> {
202 Ok(())
203 }
204 }
205
206 pub type ExtV3 = PipelineAtVers<3, SimpleExtension>;
208
209 pub type ExtV10 = PipelineAtVers<10, SimpleExtension>;
211
212 #[test]
215 fn decode_with_correct_version_succeeds() {
216 let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
217 let encoded = ext_v3.encode();
218
219 let decoded = <ExtV3 as DecodeWithVersion>::decode_with_version(3, &mut &encoded[..])
220 .expect("should decode fine with matching version");
221 assert_eq!(decoded.extension.token, 55);
222 assert_eq!(decoded.extension.w, 1234);
223 }
224
225 #[test]
226 fn decode_with_incorrect_version_fails() {
227 let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
228 let encoded = ext_v3.encode();
229
230 let decode_err = <ExtV3 as DecodeWithVersion>::decode_with_version(10, &mut &encoded[..])
232 .expect_err("should fail decode due to invalid version");
233 let decode_err_str = format!("{}", decode_err);
234 assert!(decode_err_str.contains("Invalid extension version"));
235 }
236
237 #[test]
238 fn version_is_correct() {
239 let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
240 assert_eq!(ext_v3.version(), 3);
241
242 let ext_v10 = ExtV10 { extension: SimpleExtension { token: 1, w: 1 } };
243 assert_eq!(ext_v10.version(), 10);
244 }
245
246 #[test]
247 fn pipeline_functions_work() {
248 let ext_v3 = ExtV3 { extension: SimpleExtension { token: 999, w: 50 } };
249
250 let call = MockCall(0x_f00);
252 assert_eq!(ext_v3.weight(&call).ref_time(), 50);
253
254 {
256 let invalid_ext_v3 = ExtV3 { extension: SimpleExtension { token: 0, w: 10 } };
258 let validity = invalid_ext_v3.validate_only(
259 MockOrigin(1),
260 &call,
261 &Default::default(),
262 0,
263 TransactionSource::External,
264 );
265 assert_eq!(
266 validity,
267 Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(1)))
268 );
269 }
270
271 let validity_ok = ext_v3.validate_only(
273 MockOrigin(2),
274 &call,
275 &Default::default(),
276 0,
277 TransactionSource::Local,
278 );
279 assert!(validity_ok.is_ok());
280 let valid = validity_ok.unwrap();
281 assert_eq!(valid, ValidTransaction::default());
282 }
283
284 #[test]
285 fn dispatch_transaction_works() {
286 let ext_v3 = ExtV3 { extension: SimpleExtension { token: 1, w: 10 } };
288 let call = MockCall(123);
289 let info = Default::default();
290 let len = 0usize;
291
292 ext_v3
294 .clone()
295 .dispatch_transaction(MockOrigin(1), call.clone(), &info, len)
296 .expect("valid dispatch")
297 .expect("should be OK");
298
299 let res_fail = ext_v3.dispatch_transaction(MockOrigin(0), call, &info, len);
301 let block_err = res_fail.expect("valid").expect_err("should fail");
302 assert_eq!(block_err.error, DispatchError::Other("origin is 0"));
303 }
304}