1mod syn_ext;
2pub mod r#trait;
3pub mod types;
4
5use std::borrow::Cow;
6use std::{fs, io};
7
8use proc_macro2::TokenStream;
9use quote::quote;
10use sha2::{Digest, Sha256};
11use stellar_xdr::curr as stellar_xdr;
12use stellar_xdr::{ScSpecEntry, ScSpecTypeDef, ScSpecTypeUdt, ScSpecUdtUnionCaseV0};
13use syn::Error;
14
15use soroban_spec::read::{from_wasm, FromWasmError};
16
17use types::{
18 generate_enum_with_options, generate_error_enum_with_options, generate_event_with_options,
19 generate_struct_with_options, generate_union_with_options,
20};
21pub use types::{GenerateError, GenerateOptions};
22
23#[derive(thiserror::Error, Debug)]
29pub enum GenerateFromFileError {
30 #[error("reading file: {0}")]
31 Io(io::Error),
32 #[error("sha256 does not match, expected: {expected}")]
33 VerifySha256 { expected: String },
34 #[error("parsing contract spec: {0}")]
35 Parse(stellar_xdr::Error),
36 #[error("getting contract spec: {0}")]
37 GetSpec(FromWasmError),
38 #[error("generating code: {0}")]
39 Generate(GenerateError),
40}
41
42pub fn generate_from_file(
43 file: &str,
44 verify_sha256: Option<&str>,
45) -> Result<TokenStream, GenerateFromFileError> {
46 let wasm = fs::read(file).map_err(GenerateFromFileError::Io)?;
48
49 let code = generate_from_wasm(&wasm, file, verify_sha256)?;
51 Ok(code)
52}
53
54pub fn generate_from_wasm(
55 wasm: &[u8],
56 file: &str,
57 verify_sha256: Option<&str>,
58) -> Result<TokenStream, GenerateFromFileError> {
59 generate_from_wasm_with_options(wasm, file, verify_sha256, &GenerateOptions::default())
60}
61
62pub fn generate_from_wasm_with_options(
63 wasm: &[u8],
64 file: &str,
65 verify_sha256: Option<&str>,
66 opts: &GenerateOptions,
67) -> Result<TokenStream, GenerateFromFileError> {
68 let sha256 = Sha256::digest(wasm);
69 let sha256 = format!("{:x}", sha256);
70 if let Some(verify_sha256) = verify_sha256 {
71 if verify_sha256 != sha256 {
72 return Err(GenerateFromFileError::VerifySha256 { expected: sha256 });
73 }
74 }
75
76 let spec = from_wasm(wasm).map_err(GenerateFromFileError::GetSpec)?;
77 let code = generate_with_options(&spec, file, &sha256, opts)
78 .map_err(GenerateFromFileError::Generate)?;
79 Ok(code)
80}
81
82pub fn generate(
83 specs: &[ScSpecEntry],
84 file: &str,
85 sha256: &str,
86) -> Result<TokenStream, GenerateError> {
87 generate_with_options(specs, file, sha256, &GenerateOptions::default())
88}
89
90pub fn generate_with_options(
91 specs: &[ScSpecEntry],
92 file: &str,
93 sha256: &str,
94 opts: &GenerateOptions,
95) -> Result<TokenStream, GenerateError> {
96 let generated = generate_without_file_with_options(specs, opts)?;
97 Ok(quote! {
98 pub const WASM: &[u8] = soroban_sdk::contractfile!(file = #file, sha256 = #sha256);
99 #generated
100 })
101}
102
103pub fn generate_without_file(specs: &[ScSpecEntry]) -> Result<TokenStream, GenerateError> {
104 generate_without_file_with_options(specs, &GenerateOptions::default())
105}
106
107pub fn generate_without_file_with_options(
108 specs: &[ScSpecEntry],
109 opts: &GenerateOptions,
110) -> Result<TokenStream, GenerateError> {
111 let specs = apply_error_udt_override(specs);
112 let specs: &[ScSpecEntry] = &specs;
113
114 let mut spec_fns = Vec::new();
115 let mut spec_structs = Vec::new();
116 let mut spec_unions = Vec::new();
117 let mut spec_enums = Vec::new();
118 let mut spec_error_enums = Vec::new();
119 let mut spec_events = Vec::new();
120 for s in specs {
121 match s {
122 ScSpecEntry::FunctionV0(f) => spec_fns.push(f),
123 ScSpecEntry::UdtStructV0(s) => spec_structs.push(s),
124 ScSpecEntry::UdtUnionV0(u) => spec_unions.push(u),
125 ScSpecEntry::UdtEnumV0(e) => spec_enums.push(e),
126 ScSpecEntry::UdtErrorEnumV0(e) => spec_error_enums.push(e),
127 ScSpecEntry::EventV0(e) => spec_events.push(e),
128 }
129 }
130
131 let trait_name = "Contract";
132
133 let trait_ = r#trait::generate_trait(trait_name, &spec_fns)?;
134 let structs = spec_structs
135 .iter()
136 .map(|s| generate_struct_with_options(s, opts))
137 .collect::<Result<Vec<_>, _>>()?;
138 let unions = spec_unions
139 .iter()
140 .map(|s| generate_union_with_options(s, opts))
141 .collect::<Result<Vec<_>, _>>()?;
142 let enums = spec_enums
143 .iter()
144 .map(|s| generate_enum_with_options(s, opts))
145 .collect::<Result<Vec<_>, _>>()?;
146 let error_enums = spec_error_enums
147 .iter()
148 .map(|s| generate_error_enum_with_options(s, opts))
149 .collect::<Result<Vec<_>, _>>()?;
150 let events = spec_events
151 .iter()
152 .map(|s| generate_event_with_options(s, opts))
153 .collect::<Result<Vec<_>, _>>()?;
154
155 Ok(quote! {
156 #[soroban_sdk::contractargs(name = "Args")]
157 #[soroban_sdk::contractclient(name = "Client")]
158 #trait_
159
160 #(#structs)*
161 #(#unions)*
162 #(#enums)*
163 #(#error_enums)*
164 #(#events)*
165 })
166}
167
168fn apply_error_udt_override(specs: &[ScSpecEntry]) -> Cow<'_, [ScSpecEntry]> {
184 let has_error_udt = specs.iter().any(|e| {
185 matches!(
186 e,
187 ScSpecEntry::UdtErrorEnumV0(err) if err.name.to_utf8_string_lossy() == "Error"
188 )
189 });
190 if has_error_udt {
191 let mut v = specs.to_vec();
192 rewrite_error_to_udt(&mut v);
193 Cow::Owned(v)
194 } else {
195 Cow::Borrowed(specs)
196 }
197}
198
199fn rewrite_error_to_udt(entries: &mut [ScSpecEntry]) {
204 fn rewrite_ty(t: &mut ScSpecTypeDef) {
205 match t {
206 ScSpecTypeDef::Error => {
207 *t = ScSpecTypeDef::Udt(ScSpecTypeUdt {
208 name: "Error".try_into().unwrap(),
209 });
210 }
211 ScSpecTypeDef::Option(o) => rewrite_ty(&mut o.value_type),
212 ScSpecTypeDef::Result(r) => {
213 rewrite_ty(&mut r.ok_type);
214 rewrite_ty(&mut r.error_type);
215 }
216 ScSpecTypeDef::Vec(v) => rewrite_ty(&mut v.element_type),
217 ScSpecTypeDef::Map(m) => {
218 rewrite_ty(&mut m.key_type);
219 rewrite_ty(&mut m.value_type);
220 }
221 ScSpecTypeDef::Tuple(tu) => {
222 for vt in tu.value_types.iter_mut() {
223 rewrite_ty(vt);
224 }
225 }
226 _ => {}
227 }
228 }
229 for entry in entries.iter_mut() {
230 match entry {
231 ScSpecEntry::FunctionV0(f) => {
232 for input in f.inputs.iter_mut() {
233 rewrite_ty(&mut input.type_);
234 }
235 for output in f.outputs.iter_mut() {
236 rewrite_ty(output);
237 }
238 }
239 ScSpecEntry::UdtStructV0(s) => {
240 for field in s.fields.iter_mut() {
241 rewrite_ty(&mut field.type_);
242 }
243 }
244 ScSpecEntry::UdtUnionV0(u) => {
245 for case in u.cases.iter_mut() {
246 if let ScSpecUdtUnionCaseV0::TupleV0(t) = case {
247 for ty in t.type_.iter_mut() {
248 rewrite_ty(ty);
249 }
250 }
251 }
252 }
253 ScSpecEntry::UdtEnumV0(_) | ScSpecEntry::UdtErrorEnumV0(_) => {}
254 ScSpecEntry::EventV0(e) => {
255 for p in e.params.iter_mut() {
256 rewrite_ty(&mut p.type_);
257 }
258 }
259 }
260 }
261}
262
263pub trait ToFormattedString {
266 fn to_formatted_string(&self) -> Result<String, Error>;
270}
271
272impl ToFormattedString for TokenStream {
273 fn to_formatted_string(&self) -> Result<String, Error> {
274 let file = syn::parse2(self.clone())?;
275 Ok(prettyplease::unparse(&file))
276 }
277}
278
279#[cfg(test)]
280mod test {
281 use pretty_assertions::assert_eq;
282
283 use super::{generate, ToFormattedString};
284 use soroban_spec::read::from_wasm;
285
286 const EXAMPLE_WASM: &[u8] = include_bytes!("../../target/wasm32v1-none/release/test_udt.wasm");
287
288 #[test]
289 fn example() {
290 let entries = from_wasm(EXAMPLE_WASM).unwrap();
291 let rust = generate(&entries, "<file>", "<sha256>")
292 .unwrap()
293 .to_formatted_string()
294 .unwrap();
295 assert_eq!(
296 rust,
297 r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
298#[soroban_sdk::contractargs(name = "Args")]
299#[soroban_sdk::contractclient(name = "Client")]
300pub trait Contract {
301 fn add(env: soroban_sdk::Env, a: UdtEnum, b: UdtEnum) -> i64;
302 fn recursive(env: soroban_sdk::Env, a: UdtRecursive) -> Option<UdtRecursive>;
303 fn recursive_enum(
304 env: soroban_sdk::Env,
305 a: RecursiveEnum,
306 key: u32,
307 ) -> Result<Option<RecursiveEnum>, soroban_sdk::Error>;
308}
309#[soroban_sdk::contracttype(export = false)]
310#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
311pub struct UdtTuple(pub i64, pub soroban_sdk::Vec<i64>);
312#[soroban_sdk::contracttype(export = false)]
313#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
314pub struct UdtStruct {
315 pub a: i64,
316 pub b: i64,
317 pub c: soroban_sdk::Vec<i64>,
318}
319#[soroban_sdk::contracttype(export = false)]
320#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
321pub struct UdtRecursive {
322 pub a: soroban_sdk::Symbol,
323 pub b: soroban_sdk::Vec<UdtRecursive>,
324}
325#[soroban_sdk::contracttype(export = false)]
326#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
327pub struct RecursiveToEnum {
328 pub a: soroban_sdk::Symbol,
329 pub b: soroban_sdk::Map<u32, RecursiveEnum>,
330}
331#[soroban_sdk::contracttype(export = false)]
332#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
333pub enum UdtEnum {
334 UdtA,
335 UdtB(UdtStruct),
336 UdtC(UdtEnum2),
337 UdtD(UdtTuple),
338}
339#[soroban_sdk::contracttype(export = false)]
340#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
341pub enum RecursiveEnum {
342 NotRecursive,
343 Recursive(RecursiveToEnum),
344}
345#[soroban_sdk::contracttype(export = false)]
346#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
347pub enum UdtEnum2 {
348 A = 10,
349 B = 15,
350}
351"#,
352 );
353 }
354
355 const ADD_U64_WASM: &[u8] =
356 include_bytes!("../../target/wasm32v1-none/release/test_add_u64.wasm");
357
358 #[test]
363 fn test_add_u64_result_types() {
364 let entries = from_wasm(ADD_U64_WASM).unwrap();
365 let rust = generate(&entries, "<file>", "<sha256>")
366 .unwrap()
367 .to_formatted_string()
368 .unwrap();
369 assert_eq!(
370 rust,
371 r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
372#[soroban_sdk::contractargs(name = "Args")]
373#[soroban_sdk::contractclient(name = "Client")]
374pub trait Contract {
375 fn add(env: soroban_sdk::Env, a: u64, b: u64) -> u64;
376 fn safe_add(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, Error>;
377 fn safe_add_two(env: soroban_sdk::Env, a: u64, b: u64) -> Result<u64, MyError>;
378}
379#[soroban_sdk::contracterror(export = false)]
380#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
381pub enum Error {
382 Overflow = 1,
383}
384#[soroban_sdk::contracterror(export = false)]
385#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
386pub enum MyError {
387 Overflow = 1,
388}
389"#,
390 );
391 }
392
393 #[test]
400 fn test_add_u64_spec_entries() {
401 use super::ScSpecEntry;
402 use stellar_xdr::curr::ScSpecTypeDef;
403
404 let entries = from_wasm(ADD_U64_WASM).unwrap();
405
406 let safe_add_fn = entries
408 .iter()
409 .find_map(|e| match e {
410 ScSpecEntry::FunctionV0(f) if f.name.to_utf8_string().unwrap() == "safe_add" => {
411 Some(f)
412 }
413 _ => None,
414 })
415 .expect("safe_add function not found");
416
417 let output = safe_add_fn.outputs.to_option().expect("should have output");
418 let ScSpecTypeDef::Result(r) = output else {
419 panic!("output should be a Result type");
420 };
421 assert!(
422 matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
423 "ok_type should be U64"
424 );
425 assert!(
426 matches!(r.error_type.as_ref(), ScSpecTypeDef::Error),
427 "error_type should be the built-in Error in the wasm spec, got {:?}",
428 r.error_type
429 );
430
431 let safe_add_two_fn = entries
433 .iter()
434 .find_map(|e| match e {
435 ScSpecEntry::FunctionV0(f)
436 if f.name.to_utf8_string().unwrap() == "safe_add_two" =>
437 {
438 Some(f)
439 }
440 _ => None,
441 })
442 .expect("safe_add_two function not found");
443
444 let output = safe_add_two_fn
445 .outputs
446 .to_option()
447 .expect("should have output");
448 let ScSpecTypeDef::Result(r) = output else {
449 panic!("output should be a Result type");
450 };
451 assert!(
452 matches!(r.ok_type.as_ref(), ScSpecTypeDef::U64),
453 "ok_type should be U64"
454 );
455 let ScSpecTypeDef::Udt(u) = r.error_type.as_ref() else {
456 panic!(
457 "error_type should be a UDT for MyError, got {:?}",
458 r.error_type
459 );
460 };
461 assert_eq!(
462 u.name.to_utf8_string().unwrap(),
463 "MyError",
464 "error_type should be MyError UDT"
465 );
466 }
467
468 #[test]
474 fn test_missing_error_udt_falls_back_to_sdk_error() {
475 use super::ScSpecEntry;
476 use stellar_xdr::curr::{ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult};
477
478 let func = ScSpecFunctionV0 {
479 doc: "".try_into().unwrap(),
480 name: "safe_add".try_into().unwrap(),
481 inputs: [].try_into().unwrap(),
482 outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
483 ok_type: Box::new(ScSpecTypeDef::U64),
484 error_type: Box::new(ScSpecTypeDef::Error),
485 }))]
486 .try_into()
487 .unwrap(),
488 };
489 let entries = [ScSpecEntry::FunctionV0(func)];
490 let rust = generate(&entries, "<file>", "<sha256>")
491 .unwrap()
492 .to_formatted_string()
493 .unwrap();
494 assert_eq!(
495 rust,
496 r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
497#[soroban_sdk::contractargs(name = "Args")]
498#[soroban_sdk::contractclient(name = "Client")]
499pub trait Contract {
500 fn safe_add(env: soroban_sdk::Env) -> Result<u64, soroban_sdk::Error>;
501}
502"#,
503 );
504 }
505
506 #[test]
510 fn test_error_udt_overrides_sdk_error() {
511 use super::ScSpecEntry;
512 use stellar_xdr::curr::{
513 ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeResult, ScSpecUdtErrorEnumCaseV0,
514 ScSpecUdtErrorEnumV0,
515 };
516
517 let func = ScSpecFunctionV0 {
518 doc: "".try_into().unwrap(),
519 name: "safe_add".try_into().unwrap(),
520 inputs: [].try_into().unwrap(),
521 outputs: [ScSpecTypeDef::Result(Box::new(ScSpecTypeResult {
522 ok_type: Box::new(ScSpecTypeDef::U64),
523 error_type: Box::new(ScSpecTypeDef::Error),
524 }))]
525 .try_into()
526 .unwrap(),
527 };
528 let error_enum = ScSpecUdtErrorEnumV0 {
529 doc: "".try_into().unwrap(),
530 lib: "".try_into().unwrap(),
531 name: "Error".try_into().unwrap(),
532 cases: [ScSpecUdtErrorEnumCaseV0 {
533 doc: "".try_into().unwrap(),
534 name: "Overflow".try_into().unwrap(),
535 value: 1,
536 }]
537 .try_into()
538 .unwrap(),
539 };
540 let entries = [
541 ScSpecEntry::FunctionV0(func),
542 ScSpecEntry::UdtErrorEnumV0(error_enum),
543 ];
544 let rust = generate(&entries, "<file>", "<sha256>")
545 .unwrap()
546 .to_formatted_string()
547 .unwrap();
548 assert_eq!(
549 rust,
550 r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
551#[soroban_sdk::contractargs(name = "Args")]
552#[soroban_sdk::contractclient(name = "Client")]
553pub trait Contract {
554 fn safe_add(env: soroban_sdk::Env) -> Result<u64, Error>;
555}
556#[soroban_sdk::contracterror(export = false)]
557#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
558pub enum Error {
559 Overflow = 1,
560}
561"#,
562 );
563 }
564
565 #[test]
568 fn test_error_udt_override_rewrites_nested_vec() {
569 use super::ScSpecEntry;
570 use stellar_xdr::curr::{
571 ScSpecFunctionV0, ScSpecTypeDef, ScSpecTypeVec, ScSpecUdtErrorEnumCaseV0,
572 ScSpecUdtErrorEnumV0,
573 };
574
575 let func = ScSpecFunctionV0 {
576 doc: "".try_into().unwrap(),
577 name: "errors".try_into().unwrap(),
578 inputs: [].try_into().unwrap(),
579 outputs: [ScSpecTypeDef::Vec(Box::new(ScSpecTypeVec {
580 element_type: Box::new(ScSpecTypeDef::Error),
581 }))]
582 .try_into()
583 .unwrap(),
584 };
585 let error_enum = ScSpecUdtErrorEnumV0 {
586 doc: "".try_into().unwrap(),
587 lib: "".try_into().unwrap(),
588 name: "Error".try_into().unwrap(),
589 cases: [ScSpecUdtErrorEnumCaseV0 {
590 doc: "".try_into().unwrap(),
591 name: "Overflow".try_into().unwrap(),
592 value: 1,
593 }]
594 .try_into()
595 .unwrap(),
596 };
597 let entries = [
598 ScSpecEntry::FunctionV0(func),
599 ScSpecEntry::UdtErrorEnumV0(error_enum),
600 ];
601 let rust = generate(&entries, "<file>", "<sha256>")
602 .unwrap()
603 .to_formatted_string()
604 .unwrap();
605 assert_eq!(
606 rust,
607 r#"pub const WASM: &[u8] = soroban_sdk::contractfile!(file = "<file>", sha256 = "<sha256>");
608#[soroban_sdk::contractargs(name = "Args")]
609#[soroban_sdk::contractclient(name = "Client")]
610pub trait Contract {
611 fn errors(env: soroban_sdk::Env) -> soroban_sdk::Vec<Error>;
612}
613#[soroban_sdk::contracterror(export = false)]
614#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
615pub enum Error {
616 Overflow = 1,
617}
618"#,
619 );
620 }
621}