1use std::collections::BTreeMap;
2
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4
5use crate::{
6 reduce::{Apply, ArgValue},
7 Node, Visitor,
8};
9
10#[derive(Debug, thiserror::Error)]
11pub enum Error {
12 #[error("unknown TIR version: {0}")]
13 UnknownTirVersion(String),
14
15 #[error("deprecated TIR version: {0}")]
16 DeprecatedTirVersion(String),
17
18 #[error("TIR deserialize error: {0}")]
19 TirDeserializeError(String),
20}
21
22#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
23#[serde(rename_all = "lowercase")]
24pub enum TirVersion {
25 V1Alpha8,
26 V1Beta0,
27}
28
29pub const MIN_SUPPORTED_VERSION: TirVersion = TirVersion::V1Beta0;
30
31impl std::fmt::Display for TirVersion {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 let version = match self {
34 TirVersion::V1Alpha8 => "v1alpha8",
35 TirVersion::V1Beta0 => "v1beta0",
36 };
37
38 write!(f, "{}", version)
39 }
40}
41
42impl TryFrom<&str> for TirVersion {
43 type Error = Error;
44
45 fn try_from(value: &str) -> Result<Self, Self::Error> {
46 match value {
47 "v1alpha8" => Ok(TirVersion::V1Alpha8),
48 "v1beta0" => Ok(TirVersion::V1Beta0),
49 x => Err(Error::UnknownTirVersion(x.to_string())),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub enum AnyTir {
56 V1Beta0(crate::model::v1beta0::Tx),
57}
58
59impl AnyTir {
60 pub fn version(&self) -> TirVersion {
61 match self {
62 AnyTir::V1Beta0(_) => TirVersion::V1Beta0,
63 }
64 }
65}
66
67impl Node for AnyTir {
68 fn apply<V: Visitor>(self, visitor: &mut V) -> Result<Self, crate::reduce::Error> {
69 match self {
70 AnyTir::V1Beta0(tx) => Ok(AnyTir::V1Beta0(tx.apply(visitor)?)),
71 }
72 }
73}
74
75impl Apply for AnyTir {
76 fn apply_args(self, args: &BTreeMap<String, ArgValue>) -> Result<Self, crate::reduce::Error> {
77 match self {
78 AnyTir::V1Beta0(tx) => Ok(AnyTir::V1Beta0(tx.apply_args(args)?)),
79 }
80 }
81
82 fn apply_inputs(
83 self,
84 args: &BTreeMap<String, std::collections::HashSet<crate::model::core::Utxo>>,
85 ) -> Result<Self, crate::reduce::Error> {
86 match self {
87 AnyTir::V1Beta0(tx) => Ok(AnyTir::V1Beta0(tx.apply_inputs(args)?)),
88 }
89 }
90
91 fn apply_fees(self, fees: u64) -> Result<Self, crate::reduce::Error> {
92 match self {
93 AnyTir::V1Beta0(tx) => Ok(AnyTir::V1Beta0(tx.apply_fees(fees)?)),
94 }
95 }
96
97 fn is_constant(&self) -> bool {
98 match self {
99 AnyTir::V1Beta0(tx) => tx.is_constant(),
100 }
101 }
102
103 fn params(&self) -> BTreeMap<String, crate::model::core::Type> {
104 match self {
105 AnyTir::V1Beta0(tx) => tx.params(),
106 }
107 }
108
109 fn queries(&self) -> BTreeMap<String, crate::model::v1beta0::InputQuery> {
110 match self {
111 AnyTir::V1Beta0(tx) => tx.queries(),
112 }
113 }
114
115 fn reduce(self) -> Result<Self, crate::reduce::Error> {
116 match self {
117 AnyTir::V1Beta0(tx) => Ok(AnyTir::V1Beta0(tx.reduce()?)),
118 }
119 }
120}
121
122pub trait TirRoot: Serialize + DeserializeOwned {
123 const VERSION: TirVersion;
124}
125
126pub fn to_bytes<T: TirRoot>(tx: &T) -> (Vec<u8>, TirVersion) {
127 let mut buffer = Vec::new();
128 ciborium::into_writer(tx, &mut buffer).unwrap(); (buffer, T::VERSION)
130}
131
132fn decode_root<T: TirRoot>(bytes: &[u8]) -> Result<T, Error> {
133 let root: T =
134 ciborium::from_reader(bytes).map_err(|e| Error::TirDeserializeError(e.to_string()))?;
135 Ok(root)
136}
137
138pub fn from_bytes(bytes: &[u8], version: TirVersion) -> Result<AnyTir, Error> {
139 match version {
140 TirVersion::V1Beta0 => {
141 let tx: crate::model::v1beta0::Tx = decode_root(bytes)?;
142 Ok(AnyTir::V1Beta0(tx))
143 }
144 x => Err(Error::DeprecatedTirVersion(x.to_string())),
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 const BACKWARDS_SUPPORTED_VERSIONS: &[&str] = &["v1alpha9"];
153
154 fn decode_version_snapshot(version: &str) {
155 let manifest_dir = env!("CARGO_MANIFEST_DIR");
156
157 let path = format!(
158 "{}/../../test_data/backwards/{version}.tir.hex",
159 manifest_dir
160 );
161
162 let bytes = std::fs::read_to_string(path).unwrap();
163 let bytes = hex::decode(bytes).unwrap();
164
165 _ = from_bytes(&bytes, TirVersion::V1Beta0).unwrap();
167 }
168
169 #[test]
170 fn test_decoding_is_backward_compatible() {
171 for version in BACKWARDS_SUPPORTED_VERSIONS {
172 decode_version_snapshot(version);
173 }
174 }
175}