Skip to main content

soroban_spec_rust/
lib.rs

1mod syn_ext;
2pub mod r#trait;
3pub mod types;
4
5use std::borrow::Cow;
6use std::{fs, io};
7
8use proc_macro2::TokenStream;
9use quote::quote;
10use sha2::{Digest, Sha256};
11use stellar_xdr::curr as stellar_xdr;
12use stellar_xdr::{ScSpecEntry, ScSpecTypeDef, ScSpecTypeUdt, ScSpecUdtUnionCaseV0};
13use syn::Error;
14
15use soroban_spec::read::{from_wasm, FromWasmError};
16
17use types::{
18    generate_enum_with_options, generate_error_enum_with_options, generate_event_with_options,
19    generate_struct_with_options, generate_union_with_options,
20};
21pub use types::{GenerateError, GenerateOptions};
22
23// IMPORTANT: The "docs" fields of spec entries are not output in Rust token
24// streams as rustdocs, because rustdocs can contain Rust code, and that code
25// will be executed. Generated code may be generated from untrusted Wasm
26// containing untrusted spec docs.
27
28#[derive(thiserror::Error, Debug)]
29pub enum GenerateFromFileError {
30    #[error("reading file: {0}")]
31    Io(io::Error),
32    #[error("sha256 does not match, expected: {expected}")]
33    VerifySha256 { expected: String },
34    #[error("parsing contract spec: {0}")]
35    Parse(stellar_xdr::Error),
36    #[error("getting contract spec: {0}")]
37    GetSpec(FromWasmError),
38    #[error("generating code: {0}")]
39    Generate(GenerateError),
40}
41
42pub fn generate_from_file(
43    file: &str,
44    verify_sha256: Option<&str>,
45) -> Result<TokenStream, GenerateFromFileError> {
46    // Read file.
47    let wasm = fs::read(file).map_err(GenerateFromFileError::Io)?;
48
49    // Generate code.
50    let code = generate_from_wasm(&wasm, file, verify_sha256)?;
51    Ok(code)
52}
53
54pub fn generate_from_wasm(
55    wasm: &[u8],
56    file: &str,
57    verify_sha256: Option<&str>,
58) -> Result<TokenStream, GenerateFromFileError> {
59    generate_from_wasm_with_options(wasm, file, verify_sha256, &GenerateOptions::default())
60}
61
62pub fn generate_from_wasm_with_options(
63    wasm: &[u8],
64    file: &str,
65    verify_sha256: Option<&str>,
66    opts: &GenerateOptions,
67) -> Result<TokenStream, GenerateFromFileError> {
68    let sha256 = Sha256::digest(wasm);
69    let sha256 = format!("{:x}", sha256);
70    if let Some(verify_sha256) = verify_sha256 {
71        if verify_sha256 != sha256 {
72            return Err(GenerateFromFileError::VerifySha256 { expected: sha256 });
73        }
74    }
75
76    let spec = from_wasm(wasm).map_err(GenerateFromFileError::GetSpec)?;
77    let code = generate_with_options(&spec, file, &sha256, opts)
78        .map_err(GenerateFromFileError::Generate)?;
79    Ok(code)
80}
81
82pub fn generate(
83    specs: &[ScSpecEntry],
84    file: &str,
85    sha256: &str,
86) -> Result<TokenStream, GenerateError> {
87    generate_with_options(specs, file, sha256, &GenerateOptions::default())
88}
89
90pub fn generate_with_options(
91    specs: &[ScSpecEntry],
92    file: &str,
93    sha256: &str,
94    opts: &GenerateOptions,
95) -> Result<TokenStream, GenerateError> {
96    let generated = generate_without_file_with_options(specs, opts)?;
97    Ok(quote! {
98        pub const WASM: &[u8] = soroban_sdk::contractfile!(file = #file, sha256 = #sha256);
99        #generated
100    })
101}
102
103pub fn generate_without_file(specs: &[ScSpecEntry]) -> Result<TokenStream, GenerateError> {
104    generate_without_file_with_options(specs, &GenerateOptions::default())
105}
106
107pub fn generate_without_file_with_options(
108    specs: &[ScSpecEntry],
109    opts: &GenerateOptions,
110) -> Result<TokenStream, GenerateError> {
111    let specs = apply_error_udt_override(specs);
112    let specs: &[ScSpecEntry] = &specs;
113
114    let mut spec_fns = Vec::new();
115    let mut spec_structs = Vec::new();
116    let mut spec_unions = Vec::new();
117    let mut spec_enums = Vec::new();
118    let mut spec_error_enums = Vec::new();
119    let mut spec_events = Vec::new();
120    for s in specs {
121        match s {
122            ScSpecEntry::FunctionV0(f) => spec_fns.push(f),
123            ScSpecEntry::UdtStructV0(s) => spec_structs.push(s),
124            ScSpecEntry::UdtUnionV0(u) => spec_unions.push(u),
125            ScSpecEntry::UdtEnumV0(e) => spec_enums.push(e),
126            ScSpecEntry::UdtErrorEnumV0(e) => spec_error_enums.push(e),
127            ScSpecEntry::EventV0(e) => spec_events.push(e),
128        }
129    }
130
131    let trait_name = "Contract";
132
133    let trait_ = r#trait::generate_trait(trait_name, &spec_fns)?;
134    let structs = spec_structs
135        .iter()
136        .map(|s| generate_struct_with_options(s, opts))
137        .collect::<Result<Vec<_>, _>>()?;
138    let unions = spec_unions
139        .iter()
140        .map(|s| generate_union_with_options(s, opts))
141        .collect::<Result<Vec<_>, _>>()?;
142    let enums = spec_enums
143        .iter()
144        .map(|s| generate_enum_with_options(s, opts))
145        .collect::<Result<Vec<_>, _>>()?;
146    let error_enums = spec_error_enums
147        .iter()
148        .map(|s| generate_error_enum_with_options(s, opts))
149        .collect::<Result<Vec<_>, _>>()?;
150    let events = spec_events
151        .iter()
152        .map(|s| generate_event_with_options(s, opts))
153        .collect::<Result<Vec<_>, _>>()?;
154
155    Ok(quote! {
156        #[soroban_sdk::contractargs(name = "Args")]
157        #[soroban_sdk::contractclient(name = "Client")]
158        #trait_
159
160        #(#structs)*
161        #(#unions)*
162        #(#enums)*
163        #(#error_enums)*
164        #(#events)*
165    })
166}
167
168/// The `#[contractimpl]` macro emits any type named `Error` in a contract's
169/// function signatures as the built-in `ScSpecTypeDef::Error` in the spec,
170/// regardless of whether the contract defined its own error enum named `Error`
171/// or used `soroban_sdk::Error` directly. To let clients of contracts that
172/// define their own `Error` enum see the user-defined type instead of
173/// `soroban_sdk::Error`, this pass rewrites every `ScSpecTypeDef::Error`
174/// reference in the spec to `Udt { name: "Error" }` whenever the spec also
175/// contains a `UdtErrorEnumV0` named `Error`.
176///
177/// This keeps the on-the-wire spec format unchanged (so already-deployed
178/// contracts benefit without redeployment) and shifts the resolution to the
179/// client generator.
180///
181/// Returns a borrowed slice when no rewrite is needed, otherwise a
182/// freshly-owned `Vec` with the rewrite applied.
183fn apply_error_udt_override(specs: &[ScSpecEntry]) -> Cow<'_, [ScSpecEntry]> {
184    let has_error_udt = specs.iter().any(|e| {
185        matches!(
186            e,
187            ScSpecEntry::UdtErrorEnumV0(err) if err.name.to_utf8_string_lossy() == "Error"
188        )
189    });
190    if has_error_udt {
191        let mut v = specs.to_vec();
192        rewrite_error_to_udt(&mut v);
193        Cow::Owned(v)
194    } else {
195        Cow::Borrowed(specs)
196    }
197}
198
199/// Rewrites every `ScSpecTypeDef::Error` reference in the given entries to
200/// `ScSpecTypeDef::Udt { name: "Error" }`. Called only when the spec contains
201/// a user-defined error enum named `Error`, so the UDT reference resolves to
202/// that enum during code generation.
203fn rewrite_error_to_udt(entries: &mut [ScSpecEntry]) {
204    fn rewrite_ty(t: &mut ScSpecTypeDef) {
205        match t {
206            ScSpecTypeDef::Error => {
207                *t = ScSpecTypeDef::Udt(ScSpecTypeUdt {
208                    name: "Error".try_into().unwrap(),
209                });
210            }
211            ScSpecTypeDef::Option(o) => rewrite_ty(&mut o.value_type),
212            ScSpecTypeDef::Result(r) => {
213                rewrite_ty(&mut r.ok_type);
214                rewrite_ty(&mut r.error_type);
215            }
216            ScSpecTypeDef::Vec(v) => rewrite_ty(&mut v.element_type),
217            ScSpecTypeDef::Map(m) => {
218                rewrite_ty(&mut m.key_type);
219                rewrite_ty(&mut m.value_type);
220            }
221            ScSpecTypeDef::Tuple(tu) => {
222                for vt in tu.value_types.iter_mut() {
223                    rewrite_ty(vt);
224                }
225            }
226            _ => {}
227        }
228    }
229    for entry in entries.iter_mut() {
230        match entry {
231            ScSpecEntry::FunctionV0(f) => {
232                for input in f.inputs.iter_mut() {
233                    rewrite_ty(&mut input.type_);
234                }
235                for output in f.outputs.iter_mut() {
236                    rewrite_ty(output);
237                }
238            }
239            ScSpecEntry::UdtStructV0(s) => {
240                for field in s.fields.iter_mut() {
241                    rewrite_ty(&mut field.type_);
242                }
243            }
244            ScSpecEntry::UdtUnionV0(u) => {
245                for case in u.cases.iter_mut() {
246                    if let ScSpecUdtUnionCaseV0::TupleV0(t) = case {
247                        for ty in t.type_.iter_mut() {
248                            rewrite_ty(ty);
249                        }
250                    }
251                }
252            }
253            ScSpecEntry::UdtEnumV0(_) | ScSpecEntry::UdtErrorEnumV0(_) => {}
254            ScSpecEntry::EventV0(e) => {
255                for p in e.params.iter_mut() {
256                    rewrite_ty(&mut p.type_);
257                }
258            }
259        }
260    }
261}
262
263/// Implemented by types that can be converted into pretty formatted Strings of
264/// Rust code.
265pub trait ToFormattedString {
266    /// Converts the value to a String that is pretty formatted. If there is any
267    /// error parsing the token stream the raw String version of the code is
268    /// returned instead.
269    fn to_formatted_string(&self) -> Result<String, Error>;
270}
271
272impl ToFormattedString for TokenStream {
273    fn to_formatted_string(&self) -> Result<String, Error> {
274        let file = syn::parse2(self.clone())?;
275        Ok(prettyplease::unparse(&file))
276    }
277}
278
279#[cfg(test)]
280mod test {
281    use pretty_assertions::assert_eq;
282
283    use super::{generate, ToFormattedString};
284    use soroban_spec::read::from_wasm;
285
286    const EXAMPLE_WASM: &[u8] = include_bytes!("../../target/wasm32v1-none/release/test_udt.wasm");
287
288    #[test]
289    fn example() {
290        let entries = from_wasm(EXAMPLE_WASM).unwrap();
291        let rust = generate(&entries, "<file>", "<sha256>")
292            .unwrap()
293            .to_formatted_string()
294            .unwrap();
295        assert_eq!(
296            rust,
297            r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
298#[soroban_sdk::contractargs(name = "Args")]
299#[soroban_sdk::contractclient(name = "Client")]
300pub trait Contract {
301    fn add(env: soroban_sdk::Env, a: UdtEnum, b: UdtEnum) -> i64;
302    fn recursive(env: soroban_sdk::Env, a: UdtRecursive) -> Option<UdtRecursive>;
303    fn recursive_enum(
304        env: soroban_sdk::Env,
305        a: RecursiveEnum,
306        key: u32,
307    ) -> Result<Option<RecursiveEnum>, soroban_sdk::Error>;
308}
309#[soroban_sdk::contracttype(export = false)]
310#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
311pub struct UdtTuple(pub i64, pub soroban_sdk::Vec<i64>);
312#[soroban_sdk::contracttype(export = false)]
313#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
314pub struct UdtStruct {
315    pub a: i64,
316    pub b: i64,
317    pub c: soroban_sdk::Vec<i64>,
318}
319#[soroban_sdk::contracttype(export = false)]
320#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
321pub struct UdtRecursive {
322    pub a: soroban_sdk::Symbol,
323    pub b: soroban_sdk::Vec<UdtRecursive>,
324}
325#[soroban_sdk::contracttype(export = false)]
326#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
327pub struct RecursiveToEnum {
328    pub a: soroban_sdk::Symbol,
329    pub b: soroban_sdk::Map<u32, RecursiveEnum>,
330}
331#[soroban_sdk::contracttype(export = false)]
332#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
333pub enum UdtEnum {
334    UdtA,
335    UdtB(UdtStruct),
336    UdtC(UdtEnum2),
337    UdtD(UdtTuple),
338}
339#[soroban_sdk::contracttype(export = false)]
340#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
341pub enum RecursiveEnum {
342    NotRecursive,
343    Recursive(RecursiveToEnum),
344}
345#[soroban_sdk::contracttype(export = false)]
346#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
347pub enum UdtEnum2 {
348    A = 10,
349    B = 15,
350}
351"#,
352        );
353    }
354
355    const ADD_U64_WASM: &[u8] =
356        include_bytes!("../../target/wasm32v1-none/release/test_add_u64.wasm");
357
358    /// Test that Result types with user-defined error types are generated correctly.
359    /// This specifically tests that:
360    /// - An error enum named `Error` generates `Result<u64, Error>` (not `Result<u64, soroban_sdk::Error>`)
361    /// - An error enum named `MyError` generates `Result<u64, MyError>`
362    #[test]
363    fn test_add_u64_result_types() {
364        let entries = from_wasm(ADD_U64_WASM).unwrap();
365        let rust = generate(&entries, "<file>", "<sha256>")
366            .unwrap()
367            .to_formatted_string()
368            .unwrap();
369        assert_eq!(
370            rust,
371            r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
372#[soroban_sdk::contractargs(name = "Args")]
373#[soroban_sdk::contractclient(name = "Client")]
374pub trait Contract {
375    fn add(env: soroban_sdk::Env, a: u64, b: u64) -> u64;
376    fn safe_add(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, Error>;
377    fn safe_add_two(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, MyError>;
378}
379#[soroban_sdk::contracterror(export = false)]
380#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
381pub enum Error {
382    Overflow = 1,
383}
384#[soroban_sdk::contracterror(export = false)]
385#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
386pub enum MyError {
387    Overflow = 1,
388}
389"#,
390        );
391    }
392
393    /// Test that shows the raw spec entries from the wasm.
394    /// Verifies that the on-the-wire spec format is unchanged: a contract
395    /// error enum named `Error` is still emitted as the built-in
396    /// `ScSpecTypeDef::Error` in function signatures (the user-defined-vs-SDK
397    /// disambiguation happens at client generation time, not here). A
398    /// differently-named error enum (`MyError`) is emitted as a UDT reference.
399    #[test]
400    fn test_add_u64_spec_entries() {
401        use super::ScSpecEntry;
402        use stellar_xdr::curr::ScSpecTypeDef;
403
404        let entries = from_wasm(ADD_U64_WASM).unwrap();
405
406        // Find the safe_add function spec
407        let safe_add_fn = entries
408            .iter()
409            .find_map(|e| match e {
410                ScSpecEntry::FunctionV0(f) if f.name.to_utf8_string().unwrap() == "safe_add" => {
411                    Some(f)
412                }
413                _ => None,
414            })
415            .expect("safe_add function not found");
416
417        let output = safe_add_fn.outputs.to_option().expect("should have output");
418        let ScSpecTypeDef::Result(r) = output else {
419            panic!("output should be a Result type");
420        };
421        assert!(
422            matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
423            "ok_type should be U64"
424        );
425        assert!(
426            matches!(r.error_type.as_ref(), ScSpecTypeDef::Error),
427            "error_type should be the built-in Error in the wasm spec, got {:?}",
428            r.error_type
429        );
430
431        // Find the safe_add_two function spec
432        let safe_add_two_fn = entries
433            .iter()
434            .find_map(|e| match e {
435                ScSpecEntry::FunctionV0(f)
436                    if f.name.to_utf8_string().unwrap() == "safe_add_two" =>
437                {
438                    Some(f)
439                }
440                _ => None,
441            })
442            .expect("safe_add_two function not found");
443
444        let output = safe_add_two_fn
445            .outputs
446            .to_option()
447            .expect("should have output");
448        let ScSpecTypeDef::Result(r) = output else {
449            panic!("output should be a Result type");
450        };
451        assert!(
452            matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
453            "ok_type should be U64"
454        );
455        let ScSpecTypeDef::Udt(u) = r.error_type.as_ref() else {
456            panic!(
457                "error_type should be a UDT for MyError, got {:?}",
458                r.error_type
459            );
460        };
461        assert_eq!(
462            u.name.to_utf8_string().unwrap(),
463            "MyError",
464            "error_type should be MyError UDT"
465        );
466    }
467
468    /// When the spec references `ScSpecTypeDef::Error` and contains no error
469    /// enum named `Error`, the generator must leave it as `soroban_sdk::Error`.
470    /// This covers contracts that use `soroban_sdk::Error` directly as their
471    /// Result error type, including every contract compiled before the
472    /// error-enum override was introduced.
473    #[test]
474    fn test_missing_error_udt_falls_back_to_sdk_error() {
475        use super::ScSpecEntry;
476        use stellar_xdr::curr::{ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult};
477
478        let func = ScSpecFunctionV0 {
479            doc: "".try_into().unwrap(),
480            name: "safe_add".try_into().unwrap(),
481            inputs: [].try_into().unwrap(),
482            outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
483                ok_type: Box::new(ScSpecTypeDef::U64),
484                error_type: Box::new(ScSpecTypeDef::Error),
485            }))]
486            .try_into()
487            .unwrap(),
488        };
489        let entries = [ScSpecEntry::FunctionV0(func)];
490        let rust = generate(&entries, "<file>", "<sha256>")
491            .unwrap()
492            .to_formatted_string()
493            .unwrap();
494        assert_eq!(
495            rust,
496            r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
497#[soroban_sdk::contractargs(name = "Args")]
498#[soroban_sdk::contractclient(name = "Client")]
499pub trait Contract {
500    fn safe_add(env: soroban_sdk::Env) -> Result<u64, soroban_sdk::Error>;
501}
502"#,
503        );
504    }
505
506    /// When the spec contains a user-defined `Error` error enum, every
507    /// `ScSpecTypeDef::Error` reference in the spec must be rewritten to
508    /// reference that UDT instead of `soroban_sdk::Error`.
509    #[test]
510    fn test_error_udt_overrides_sdk_error() {
511        use super::ScSpecEntry;
512        use stellar_xdr::curr::{
513            ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult, ScSpecUdtErrorEnumCaseV0,
514            ScSpecUdtErrorEnumV0,
515        };
516
517        let func = ScSpecFunctionV0 {
518            doc: "".try_into().unwrap(),
519            name: "safe_add".try_into().unwrap(),
520            inputs: [].try_into().unwrap(),
521            outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
522                ok_type: Box::new(ScSpecTypeDef::U64),
523                error_type: Box::new(ScSpecTypeDef::Error),
524            }))]
525            .try_into()
526            .unwrap(),
527        };
528        let error_enum = ScSpecUdtErrorEnumV0 {
529            doc: "".try_into().unwrap(),
530            lib: "".try_into().unwrap(),
531            name: "Error".try_into().unwrap(),
532            cases: [ScSpecUdtErrorEnumCaseV0 {
533                doc: "".try_into().unwrap(),
534                name: "Overflow".try_into().unwrap(),
535                value: 1,
536            }]
537            .try_into()
538            .unwrap(),
539        };
540        let entries = [
541            ScSpecEntry::FunctionV0(func),
542            ScSpecEntry::UdtErrorEnumV0(error_enum),
543        ];
544        let rust = generate(&entries, "<file>", "<sha256>")
545            .unwrap()
546            .to_formatted_string()
547            .unwrap();
548        assert_eq!(
549            rust,
550            r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
551#[soroban_sdk::contractargs(name = "Args")]
552#[soroban_sdk::contractclient(name = "Client")]
553pub trait Contract {
554    fn safe_add(env: soroban_sdk::Env) -> Result<u64, Error>;
555}
556#[soroban_sdk::contracterror(export = false)]
557#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
558pub enum Error {
559    Overflow = 1,
560}
561"#,
562        );
563    }
564
565    /// When the `Error` override applies, nested `ScSpecTypeDef::Error`
566    /// references must be rewritten too.
567    #[test]
568    fn test_error_udt_override_rewrites_nested_vec() {
569        use super::ScSpecEntry;
570        use stellar_xdr::curr::{
571            ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeVec, ScSpecUdtErrorEnumCaseV0,
572            ScSpecUdtErrorEnumV0,
573        };
574
575        let func = ScSpecFunctionV0 {
576            doc: "".try_into().unwrap(),
577            name: "errors".try_into().unwrap(),
578            inputs: [].try_into().unwrap(),
579            outputs: [ScSpecTypeDef::Vec(Box::new(ScSpecTypeVec {
580                element_type: Box::new(ScSpecTypeDef::Error),
581            }))]
582            .try_into()
583            .unwrap(),
584        };
585        let error_enum = ScSpecUdtErrorEnumV0 {
586            doc: "".try_into().unwrap(),
587            lib: "".try_into().unwrap(),
588            name: "Error".try_into().unwrap(),
589            cases: [ScSpecUdtErrorEnumCaseV0 {
590                doc: "".try_into().unwrap(),
591                name: "Overflow".try_into().unwrap(),
592                value: 1,
593            }]
594            .try_into()
595            .unwrap(),
596        };
597        let entries = [
598            ScSpecEntry::FunctionV0(func),
599            ScSpecEntry::UdtErrorEnumV0(error_enum),
600        ];
601        let rust = generate(&entries, "<file>", "<sha256>")
602            .unwrap()
603            .to_formatted_string()
604            .unwrap();
605        assert_eq!(
606            rust,
607            r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
608#[soroban_sdk::contractargs(name = "Args")]
609#[soroban_sdk::contractclient(name = "Client")]
610pub trait Contract {
611    fn errors(env: soroban_sdk::Env) -> soroban_sdk::Vec<Error>;
612}
613#[soroban_sdk::contracterror(export = false)]
614#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
615pub enum Error {
616    Overflow = 1,
617}
618"#,
619        );
620    }
621}