1use crate::jet::Jet;
4use std::{cmp, fmt};
5
6use crate::value::Word;
7#[cfg(feature = "elements")]
8use elements::encode::Encodable;
9#[cfg(feature = "serde")]
10use serde::Serialize;
11#[cfg(feature = "elements")]
12use std::{convert::TryFrom, io};
13
14#[cfg(feature = "bitcoin")]
19#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
20struct U32Weight(u32);
21
22#[cfg(feature = "bitcoin")]
23impl std::ops::Sub for U32Weight {
24 type Output = Self;
25
26 fn sub(self, rhs: Self) -> Self::Output {
27 Self(self.0.saturating_sub(rhs.0))
28 }
29}
30
31#[cfg(feature = "bitcoin")]
32impl From<bitcoin::Weight> for U32Weight {
33 fn from(value: bitcoin::Weight) -> Self {
34 Self(u32::try_from(value.to_wu()).unwrap_or(u32::MAX))
35 }
36}
37
38#[cfg(feature = "bitcoin")]
39impl From<U32Weight> for bitcoin::Weight {
40 fn from(value: U32Weight) -> Self {
41 bitcoin::Weight::from_wu(u64::from(value.0))
42 }
43}
44
45#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
70#[cfg_attr(feature = "serde", derive(Serialize))]
71pub struct Cost(u32);
72
73impl Cost {
74 const OVERHEAD: Self = Cost(100);
78
79 const NEVER_EXECUTED: Self = Cost(0);
83
84 pub const CONSENSUS_MAX: Self = Cost(4_000_050_000);
98
99 pub const fn of_type(bit_width: usize) -> Self {
101 Cost(bit_width as u32)
103 }
104
105 pub const fn from_milliweight(milliweight: u32) -> Self {
107 Cost(milliweight)
108 }
109
110 pub fn is_consensus_valid(self) -> bool {
115 self <= Self::CONSENSUS_MAX
116 }
117
118 #[cfg(feature = "elements")]
123 fn get_budget(script_witness: &Vec<Vec<u8>>) -> U32Weight {
124 let mut sink = io::sink();
125 let witness_stack_serialized_len = script_witness
126 .consensus_encode(&mut sink)
127 .expect("writing to sink never fails");
128 let budget = u32::try_from(witness_stack_serialized_len)
129 .expect("Serialized witness stack must be shorter than 2^32 elements")
130 .saturating_add(50);
131 U32Weight(budget)
132 }
133
134 #[cfg(feature = "elements")]
140 pub fn is_budget_valid(self, script_witness: &Vec<Vec<u8>>) -> bool {
141 let budget = Self::get_budget(script_witness);
142 self.0 <= budget.0.saturating_mul(1000)
143 }
144
145 #[cfg(feature = "elements")]
151 pub fn get_padding(self, script_witness: &Vec<Vec<u8>>) -> Option<Vec<u8>> {
152 let weight = U32Weight::from(self);
153 let budget = Self::get_budget(script_witness);
154 if weight <= budget {
155 return None;
156 }
157
158 let deficit = (weight - budget).0 as usize; let padding_len = match deficit {
170 0..=253 => deficit.saturating_sub(2),
172 254..=255 => 252,
175 256..=65538 => deficit - 4,
177 65539..=65540 => 65535,
179 _ => deficit - 6,
181 };
185 let annex_bytes: Vec<u8> = std::iter::once(0x50)
186 .chain(std::iter::repeat(0x00).take(padding_len))
187 .collect();
188
189 Some(annex_bytes)
190 }
191}
192
193impl fmt::Display for Cost {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 fmt::Display::fmt(&self.0, f)
196 }
197}
198
199impl std::ops::Add for Cost {
200 type Output = Self;
201
202 fn add(self, rhs: Self) -> Self::Output {
203 Cost(self.0.saturating_add(rhs.0))
204 }
205}
206
207#[cfg(feature = "bitcoin")]
208impl From<U32Weight> for Cost {
209 fn from(value: U32Weight) -> Self {
210 Self(value.0.saturating_mul(1000))
211 }
212}
213
214#[cfg(feature = "bitcoin")]
215impl From<Cost> for U32Weight {
216 fn from(value: Cost) -> Self {
217 Self(value.0.saturating_add(999) / 1000)
221 }
222}
223
224#[cfg(feature = "bitcoin")]
225impl From<bitcoin::Weight> for Cost {
226 fn from(value: bitcoin::Weight) -> Self {
227 Self(U32Weight::from(value).0.saturating_mul(1000))
228 }
229}
230
231#[cfg(feature = "bitcoin")]
232impl From<Cost> for bitcoin::Weight {
233 fn from(value: Cost) -> Self {
234 bitcoin::Weight::from_wu(u64::from(U32Weight::from(value).0))
235 }
236}
237
238#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
240#[cfg_attr(feature = "serde", derive(Serialize))]
241pub struct NodeBounds {
242 pub extra_cells: usize,
245 pub extra_frames: usize,
248 pub cost: Cost,
250}
251
252impl NodeBounds {
253 const NOP: Self = NodeBounds {
254 extra_cells: 0,
255 extra_frames: 0,
256 cost: Cost::OVERHEAD,
257 };
258 const NEVER_EXECUTED: Self = NodeBounds {
259 extra_cells: 0,
260 extra_frames: 0,
261 cost: Cost::NEVER_EXECUTED,
262 };
263
264 fn from_child(child: Self) -> Self {
265 NodeBounds {
266 extra_cells: child.extra_cells,
267 extra_frames: child.extra_frames,
268 cost: Cost::OVERHEAD + child.cost,
269 }
270 }
271
272 pub fn iden(target_type: usize) -> NodeBounds {
274 NodeBounds {
275 extra_cells: 0,
276 extra_frames: 0,
277 cost: Cost::OVERHEAD + Cost::of_type(target_type),
278 }
279 }
280
281 pub const fn unit() -> NodeBounds {
283 NodeBounds::NOP
284 }
285
286 pub fn injl(child: Self) -> NodeBounds {
288 Self::from_child(child)
289 }
290
291 pub fn injr(child: Self) -> NodeBounds {
293 Self::from_child(child)
294 }
295
296 pub fn take(child: Self) -> NodeBounds {
298 Self::from_child(child)
299 }
300
301 pub fn drop(child: Self) -> NodeBounds {
303 Self::from_child(child)
304 }
305
306 pub fn comp(left: Self, right: Self, mid_ty_bit_width: usize) -> NodeBounds {
308 NodeBounds {
309 extra_cells: mid_ty_bit_width + cmp::max(left.extra_cells, right.extra_cells),
310 extra_frames: 1 + cmp::max(left.extra_frames, right.extra_frames),
311 cost: Cost::OVERHEAD + Cost::of_type(mid_ty_bit_width) + left.cost + right.cost,
312 }
313 }
314
315 pub fn case(left: Self, right: Self) -> NodeBounds {
317 NodeBounds {
318 extra_cells: cmp::max(left.extra_cells, right.extra_cells),
319 extra_frames: cmp::max(left.extra_frames, right.extra_frames),
320 cost: Cost::OVERHEAD + cmp::max(left.cost, right.cost),
321 }
322 }
323
324 pub fn assertl(child: Self) -> NodeBounds {
326 Self::from_child(child)
327 }
328
329 pub fn assertr(child: Self) -> NodeBounds {
331 Self::from_child(child)
332 }
333
334 pub fn pair(left: Self, right: Self) -> NodeBounds {
336 NodeBounds {
337 extra_cells: cmp::max(left.extra_cells, right.extra_cells),
338 extra_frames: cmp::max(left.extra_frames, right.extra_frames),
339 cost: Cost::OVERHEAD + left.cost + right.cost,
340 }
341 }
342
343 pub fn disconnect(
346 left: Self,
347 right: Self,
348 left_target_b_bit_width: usize, left_source_bit_width: usize,
350 left_target_bit_width: usize,
351 ) -> NodeBounds {
352 NodeBounds {
353 extra_cells: left_source_bit_width
354 + left_target_bit_width
355 + cmp::max(left.extra_cells, right.extra_cells),
356 extra_frames: 2 + cmp::max(left.extra_frames, right.extra_frames),
357 cost: Cost::OVERHEAD
358 + Cost::of_type(left_source_bit_width)
359 + Cost::of_type(left_source_bit_width)
360 + Cost::of_type(left_target_bit_width)
361 + Cost::of_type(left_target_b_bit_width)
362 + left.cost
363 + right.cost,
364 }
365 }
366
367 pub fn witness(target_ty_bit_width: usize) -> NodeBounds {
369 NodeBounds {
370 extra_cells: target_ty_bit_width,
371 extra_frames: 0,
372 cost: Cost::OVERHEAD + Cost::of_type(target_ty_bit_width),
373 }
374 }
375
376 pub fn jet(jet: &dyn Jet) -> NodeBounds {
378 NodeBounds {
379 extra_cells: 0,
380 extra_frames: 0,
381 cost: Cost::OVERHEAD + jet.cost(),
382 }
383 }
384
385 pub fn const_word(word: &Word) -> NodeBounds {
387 NodeBounds {
388 extra_cells: 0,
389 extra_frames: 0,
390 cost: Cost::OVERHEAD + Cost::of_type(word.len()),
391 }
392 }
393
394 pub const fn fail() -> NodeBounds {
401 NodeBounds::NEVER_EXECUTED
402 }
403}
404
405pub(crate) const IO_EXTRA_FRAMES: usize = 2;
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use simplicity_sys::ffi::bounded::cost_overhead;
412
413 #[test]
414 fn test_overhead() {
415 assert_eq!(Cost::OVERHEAD.0, cost_overhead());
417 }
418
419 #[test]
420 #[cfg(feature = "bitcoin")]
421 fn cost_to_weight() {
422 let test_vectors = vec![
423 (Cost::NEVER_EXECUTED, 0),
424 (Cost::from_milliweight(1), 1),
425 (Cost::from_milliweight(999), 1),
426 (Cost::from_milliweight(1_000), 1),
427 (Cost::from_milliweight(1_001), 2),
428 (Cost::from_milliweight(1_999), 2),
429 (Cost::from_milliweight(2_000), 2),
430 (Cost::CONSENSUS_MAX, 4_000_050),
431 ];
432
433 for (cost, expected_weight) in test_vectors {
434 let converted_cost = U32Weight::from(cost);
435 let expected_weight = U32Weight(expected_weight);
436 assert_eq!(converted_cost, expected_weight);
437 }
438 }
439
440 #[test]
441 #[cfg(feature = "elements")]
442 fn test_get_padding() {
443 let empty = 51_000;
448
449 let test_vectors = vec![
451 (Cost::from_milliweight(0), vec![], None),
452 (Cost::from_milliweight(empty), vec![], None),
453 (Cost::from_milliweight(empty + 1), vec![], Some(1)),
454 (Cost::from_milliweight(empty + 2_000), vec![], Some(1)),
455 (Cost::from_milliweight(empty + 2_001), vec![], Some(2)),
456 (Cost::from_milliweight(empty + 3_000), vec![], Some(2)),
457 (Cost::from_milliweight(empty + 3_001), vec![], Some(3)),
458 (Cost::from_milliweight(empty + 4_000), vec![], Some(3)),
459 (Cost::from_milliweight(empty + 4_001), vec![], Some(4)),
460 (Cost::from_milliweight(empty + 50_000), vec![], Some(49)),
461 (Cost::from_milliweight(empty + 253_000), vec![], Some(252)),
464 (Cost::from_milliweight(empty + 254_000), vec![], Some(253)),
466 (Cost::from_milliweight(empty + 255_000), vec![], Some(253)),
468 (Cost::from_milliweight(empty + 256_000), vec![], Some(253)),
470 (Cost::from_milliweight(empty + 257_000), vec![], Some(254)),
472 (
474 Cost::from_milliweight(empty + 7_424_000),
475 vec![],
476 Some(7_421),
477 ),
478 (
480 Cost::from_milliweight(8_045_103),
481 vec![vec![], vec![0; 497], vec![0; 32], vec![0; 33]],
482 Some(7_424),
483 ),
484 (Cost::CONSENSUS_MAX, vec![], Some(3_999_994)),
486 ];
487
488 for (cost, mut witness, maybe_padding) in test_vectors {
489 match maybe_padding {
490 None => {
491 assert!(cost.is_budget_valid(&witness));
492 assert!(cost.get_padding(&witness).is_none());
493 }
494 Some(expected_annex_len) => {
495 assert!(!cost.is_budget_valid(&witness));
496
497 let annex_bytes = cost.get_padding(&witness).expect("not enough budget");
498 assert_eq!(expected_annex_len, annex_bytes.len());
499 witness.extend(std::iter::once(annex_bytes));
500 assert!(cost.is_budget_valid(&witness));
501
502 witness.pop();
503 assert!(!cost.is_budget_valid(&witness), "Padding must be minimal");
504 }
505 }
506 }
507 }
508}