stellar_scaffold_cli/commands/generate/contract/
mod.rs

1use clap::Parser;
2use flate2::read::GzDecoder;
3use reqwest;
4use serde::Deserialize;
5use std::{fs, path::Path};
6use stellar_cli::commands::global;
7use stellar_cli::print::Print;
8use tar::Archive;
9
10#[derive(Deserialize)]
11struct Release {
12    tag_name: String,
13}
14
15#[derive(Parser, Debug)]
16pub struct Cmd {
17    /// Clone contract from `OpenZeppelin` examples
18    #[arg(long, conflicts_with_all = ["ls", "from_wizard"])]
19    pub from: Option<String>,
20
21    /// List available contract examples
22    #[arg(long, conflicts_with_all = ["from", "from_wizard"])]
23    pub ls: bool,
24
25    /// Open contract generation wizard in browser
26    #[arg(long, conflicts_with_all = ["from", "ls"])]
27    pub from_wizard: bool,
28
29    /// Output directory for the generated contract (defaults to contracts/<example-name>)
30    #[arg(short, long)]
31    pub output: Option<String>,
32}
33
34#[derive(thiserror::Error, Debug)]
35pub enum Error {
36    #[error(transparent)]
37    Io(#[from] std::io::Error),
38    #[error(transparent)]
39    Reqwest(#[from] reqwest::Error),
40    #[error(transparent)]
41    CargoToml(#[from] cargo_toml::Error),
42    #[error(transparent)]
43    TomlDeserialize(#[from] toml::de::Error),
44    #[error(transparent)]
45    TomlSerialize(#[from] toml::ser::Error),
46    #[error("Git command failed: {0}")]
47    GitCloneFailed(String),
48    #[error("Example '{0}' not found in OpenZeppelin stellar-contracts")]
49    ExampleNotFound(String),
50    #[error("Failed to open browser: {0}")]
51    BrowserFailed(String),
52    #[error("No action specified. Use --from, --ls, or --from-wizard")]
53    NoActionSpecified,
54}
55
56impl Cmd {
57    pub async fn run(&self, global_args: &global::Args) -> Result<(), Error> {
58        match (&self.from, self.ls, self.from_wizard) {
59            (Some(example_name), _, _) => self.clone_example(example_name, global_args).await,
60            (_, true, _) => self.list_examples(global_args).await,
61            (_, _, true) => open_wizard(global_args),
62            _ => Err(Error::NoActionSpecified),
63        }
64    }
65
66    async fn clone_example(
67        &self,
68        example_name: &str,
69        global_args: &global::Args,
70    ) -> Result<(), Error> {
71        let printer = Print::new(global_args.quiet);
72
73        printer.infoln(format!("Downloading example '{example_name}'..."));
74
75        let dest_path = self
76            .output
77            .clone()
78            .unwrap_or_else(|| format!("contracts/{example_name}"));
79
80        let repo_cache_path = self.ensure_cache_updated().await?;
81
82        // Check if the example exists
83        let example_source_path = repo_cache_path.join(format!("examples/{example_name}"));
84        if !example_source_path.exists() {
85            return Err(Error::ExampleNotFound(example_name.to_string()));
86        }
87
88        // Create destination and copy example contents
89        fs::create_dir_all(&dest_path)?;
90        Self::copy_directory_contents(&example_source_path, Path::new(&dest_path))?;
91
92        // Get the latest release tag we're using
93        let Release { tag_name } = Self::fetch_latest_release().await?;
94
95        // Read and update workspace Cargo.toml
96        let workspace_cargo_path = Path::new("Cargo.toml");
97        if workspace_cargo_path.exists() {
98            Self::update_workspace_dependencies(
99                workspace_cargo_path,
100                &example_source_path,
101                &tag_name,
102                global_args,
103            )?;
104        } else {
105            printer.warnln("Warning: No workspace Cargo.toml found in current directory.");
106            printer
107                .println("   You'll need to manually add required dependencies to your workspace.");
108        }
109
110        printer.checkln(format!(
111            "Successfully downloaded example '{example_name}' to {dest_path}"
112        ));
113        printer
114            .infoln("You may need to modify your environments.toml to add constructor arguments!");
115        Ok(())
116    }
117
118    fn update_workspace_dependencies(
119        workspace_path: &Path,
120        example_path: &Path,
121        tag: &str,
122        global_args: &global::Args,
123    ) -> Result<(), Error> {
124        let printer = Print::new(global_args.quiet);
125
126        let example_cargo_content = fs::read_to_string(example_path.join("Cargo.toml"))?;
127        let deps = Self::extract_stellar_dependencies(&example_cargo_content)?;
128        if deps.is_empty() {
129            return Ok(());
130        }
131
132        // Parse the workspace Cargo.toml
133        let mut manifest = cargo_toml::Manifest::from_path(workspace_path)?;
134
135        // Ensure workspace.dependencies exists
136        if manifest.workspace.is_none() {
137            // Create a minimal workspace with just what we need
138            let workspace_toml = r"
139[workspace]
140members = []
141
142[workspace.dependencies]
143";
144            let workspace: cargo_toml::Workspace<toml::Value> = toml::from_str(workspace_toml)?;
145            manifest.workspace = Some(workspace);
146        }
147        let workspace = manifest.workspace.as_mut().unwrap();
148
149        let mut workspace_deps = workspace.dependencies.clone();
150
151        let mut added_deps = Vec::new();
152        let mut updated_deps = Vec::new();
153
154        for dep in deps {
155            let git_dep = cargo_toml::DependencyDetail {
156                git: Some("https://github.com/OpenZeppelin/stellar-contracts".to_string()),
157                tag: Some(tag.to_string()),
158                ..Default::default()
159            };
160
161            if let Some(existing_dep) = workspace_deps.clone().get(&dep) {
162                // Check if we need to update the tag
163                if let cargo_toml::Dependency::Detailed(detail) = existing_dep {
164                    if let Some(existing_tag) = &detail.tag {
165                        if existing_tag != tag {
166                            workspace_deps.insert(
167                                dep.clone(),
168                                cargo_toml::Dependency::Detailed(Box::new(git_dep)),
169                            );
170                            updated_deps.push((dep, existing_tag.clone()));
171                        }
172                    }
173                }
174            } else {
175                workspace_deps.insert(
176                    dep.clone(),
177                    cargo_toml::Dependency::Detailed(Box::new(git_dep)),
178                );
179                added_deps.push(dep);
180            }
181        }
182
183        if !added_deps.is_empty() || !updated_deps.is_empty() {
184            workspace.dependencies = workspace_deps;
185            // Write the updated manifest back to file
186            let toml_string = toml::to_string_pretty(&manifest)?;
187            fs::write(workspace_path, toml_string)?;
188
189            if !added_deps.is_empty() {
190                printer.infoln("Added the following dependencies to workspace:");
191                for dep in added_deps {
192                    printer.println(format!("   • {dep}"));
193                }
194            }
195
196            if !updated_deps.is_empty() {
197                printer.infoln("Updated the following dependencies:");
198                for (dep, old_tag) in updated_deps {
199                    printer.println(format!("   • {dep}: {old_tag} -> {tag}"));
200                }
201            }
202        }
203
204        Ok(())
205    }
206
207    fn extract_stellar_dependencies(cargo_toml_content: &str) -> Result<Vec<String>, Error> {
208        let manifest: cargo_toml::Manifest = toml::from_str(cargo_toml_content)?;
209
210        Ok(manifest
211            .dependencies
212            .iter()
213            .filter(|(dep_name, _)| dep_name.starts_with("stellar-"))
214            .filter_map(|(dep_name, dep_detail)| match dep_detail {
215                cargo_toml::Dependency::Detailed(detail)
216                    if !(detail.inherited || detail.git.is_some()) =>
217                {
218                    None
219                }
220                _ => Some(dep_name.clone()),
221            })
222            .collect())
223    }
224
225    async fn list_examples(&self, global_args: &global::Args) -> Result<(), Error> {
226        let printer = Print::new(global_args.quiet);
227
228        printer.infoln("Fetching available contract examples...");
229
230        let repo_cache_path = self.ensure_cache_updated().await?;
231        let examples_path = repo_cache_path.join("examples");
232
233        let mut examples: Vec<String> = if examples_path.exists() {
234            fs::read_dir(examples_path)?
235                .filter_map(std::result::Result::ok)
236                .filter(|entry| entry.path().is_dir())
237                .filter_map(|entry| {
238                    entry
239                        .file_name()
240                        .to_str()
241                        .map(std::string::ToString::to_string)
242                })
243                .collect()
244        } else {
245            Vec::new()
246        };
247
248        examples.sort();
249
250        printer.println("\nAvailable contract examples:");
251        printer.println("────────────────────────────────");
252
253        for example in &examples {
254            printer.println(format!("  📁 {example}"));
255        }
256
257        printer.println("\nUsage:");
258        printer.println("   stellar-scaffold contract generate --from <example-name>");
259        printer.println("   Example: stellar-scaffold contract generate --from nft-royalties");
260
261        Ok(())
262    }
263
264    async fn fetch_latest_release() -> Result<Release, Error> {
265        Self::fetch_latest_release_from_url(
266            "https://api.github.com/repos/OpenZeppelin/stellar-contracts/releases/latest",
267        )
268        .await
269    }
270
271    async fn fetch_latest_release_from_url(url: &str) -> Result<Release, Error> {
272        let client = reqwest::Client::new();
273        let response = client
274            .get(url)
275            .header("User-Agent", "stellar-scaffold-cli")
276            .send()
277            .await?;
278
279        if !response.status().is_success() {
280            return Err(Error::Reqwest(response.error_for_status().unwrap_err()));
281        }
282
283        let release: Release = response.json().await?;
284        Ok(release)
285    }
286
287    async fn cache_repository(repo_cache_path: &Path, tag_name: &str) -> Result<(), Error> {
288        fs::create_dir_all(repo_cache_path)?;
289
290        // Download and extract the specific tag directly
291        Self::download_and_extract_tag(repo_cache_path, tag_name).await?;
292
293        if repo_cache_path.read_dir()?.next().is_none() {
294            return Err(Error::GitCloneFailed(format!(
295                "Failed to download repository release {tag_name} to cache"
296            )));
297        }
298
299        Ok(())
300    }
301
302    async fn download_and_extract_tag(dest_path: &Path, tag_name: &str) -> Result<(), Error> {
303        let url =
304            format!("https://github.com/OpenZeppelin/stellar-contracts/archive/{tag_name}.tar.gz",);
305
306        // Download the tar.gz file
307        let client = reqwest::Client::new();
308        let response = client
309            .get(&url)
310            .header("User-Agent", "stellar-scaffold-cli")
311            .send()
312            .await?;
313
314        if !response.status().is_success() {
315            return Err(Error::GitCloneFailed(format!(
316                "Failed to download release {tag_name}: HTTP {}",
317                response.status()
318            )));
319        }
320
321        // Get the response bytes
322        let bytes = response.bytes().await?;
323
324        // Extract the tar.gz in a blocking task to avoid blocking the async runtime
325        let dest_path = dest_path.to_path_buf();
326        tokio::task::spawn_blocking(move || {
327            let tar = GzDecoder::new(std::io::Cursor::new(bytes));
328            let mut archive = Archive::new(tar);
329
330            for entry in archive.entries()? {
331                let mut entry = entry?;
332                let path = entry.path()?;
333
334                // Strip the root directory (stellar-contracts-{tag}/)
335                let stripped_path = path.components().skip(1).collect::<std::path::PathBuf>();
336
337                if stripped_path.as_os_str().is_empty() {
338                    continue;
339                }
340
341                let dest_file_path = dest_path.join(&stripped_path);
342
343                if entry.header().entry_type().is_dir() {
344                    std::fs::create_dir_all(&dest_file_path)?;
345                } else {
346                    if let Some(parent) = dest_file_path.parent() {
347                        std::fs::create_dir_all(parent)?;
348                    }
349                    entry.unpack(&dest_file_path)?;
350                }
351            }
352
353            Ok::<(), std::io::Error>(())
354        })
355        .await
356        .map_err(|e| {
357            Error::Io(std::io::Error::new(
358                std::io::ErrorKind::Other,
359                e.to_string(),
360            ))
361        })?
362        .map_err(Error::Io)?;
363
364        Ok(())
365    }
366
367    async fn ensure_cache_updated(&self) -> Result<std::path::PathBuf, Error> {
368        let cache_dir = dirs::cache_dir().ok_or_else(|| {
369            Error::Io(std::io::Error::new(
370                std::io::ErrorKind::NotFound,
371                "Cache directory not found",
372            ))
373        })?;
374
375        let base_cache_path = cache_dir.join("stellar-scaffold-cli/openzeppelin-stellar-contracts");
376
377        // Get the latest release tag
378        let Release { tag_name } = Self::fetch_latest_release().await?;
379        let repo_cache_path = base_cache_path.join(&tag_name);
380        if !repo_cache_path.exists() {
381            Self::cache_repository(&repo_cache_path, &tag_name).await?;
382        }
383
384        Ok(repo_cache_path)
385    }
386
387    fn copy_directory_contents(source: &Path, dest: &Path) -> Result<(), Error> {
388        let copy_options = fs_extra::dir::CopyOptions::new()
389            .overwrite(true)
390            .content_only(true);
391
392        fs_extra::dir::copy(source, dest, &copy_options)
393            .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
394
395        Ok(())
396    }
397}
398
399fn open_wizard(global_args: &global::Args) -> Result<(), Error> {
400    let printer = Print::new(global_args.quiet);
401
402    printer.infoln("Opening OpenZeppelin Contract Wizard...");
403
404    let url = "https://wizard.openzeppelin.com/stellar";
405
406    webbrowser::open(url)
407        .map_err(|e| Error::BrowserFailed(format!("Failed to open browser: {e}")))?;
408
409    printer.checkln("Opened Contract Wizard in your default browser");
410    printer.println("\nInstructions:");
411    printer.println("   1. Configure your contract in the wizard");
412    printer.println("   2. Click 'Download' to get your contract files");
413    printer.println("   3. Extract the downloaded ZIP file");
414    printer.println("   4. Move the contract folder to your contracts/ directory");
415    printer.println("   5. Add the contract to your workspace Cargo.toml if needed");
416    printer.println(
417        "   6. You may need to modify your environments.toml file to add constructor arguments",
418    );
419    printer.infoln(
420        "The wizard will generate a complete Soroban contract with your selected features!",
421    );
422
423    Ok(())
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use mockito::{mock, server_url};
430
431    fn create_test_cmd(from: Option<String>, ls: bool, from_wizard: bool) -> Cmd {
432        Cmd {
433            from,
434            ls,
435            from_wizard,
436            output: None,
437        }
438    }
439
440    #[tokio::test]
441    #[ignore]
442    async fn test_ls_command() {
443        let cmd = create_test_cmd(None, true, false);
444        let global_args = global::Args::default();
445
446        let _m = mock(
447            "GET",
448            "/repos/OpenZeppelin/stellar-contracts/contents/examples",
449        )
450        .with_status(200)
451        .with_header("content-type", "application/json")
452        .with_body(r#"[{"name": "example1", "type": "dir"}, {"name": "example2", "type": "dir"}]"#)
453        .create();
454
455        let result = cmd.run(&global_args).await;
456        assert!(result.is_ok());
457    }
458
459    #[tokio::test]
460    async fn test_fetch_latest_release() {
461        let _m = mock(
462            "GET",
463            "/repos/OpenZeppelin/stellar-contracts/releases/latest",
464        )
465        .with_status(200)
466        .with_header("content-type", "application/json")
467        .with_body(
468            r#"{
469                "tag_name": "v1.2.3",
470                "name": "Release v1.2.3",
471                "published_at": "2024-01-15T10:30:00Z"
472            }"#,
473        )
474        .create();
475
476        let mock_url = format!(
477            "{}/repos/OpenZeppelin/stellar-contracts/releases/latest",
478            server_url()
479        );
480        let result = Cmd::fetch_latest_release_from_url(&mock_url).await;
481
482        assert!(result.is_ok());
483        let release = result.unwrap();
484        assert_eq!(release.tag_name, "v1.2.3");
485    }
486
487    #[tokio::test]
488    async fn test_fetch_latest_release_error() {
489        let _m = mock(
490            "GET",
491            "/repos/OpenZeppelin/stellar-contracts/releases/latest",
492        )
493        .with_status(404)
494        .with_header("content-type", "application/json")
495        .with_body(r#"{"message": "Not Found"}"#)
496        .create();
497
498        let mock_url = format!(
499            "{}/repos/OpenZeppelin/stellar-contracts/releases/latest",
500            server_url()
501        );
502        let result = Cmd::fetch_latest_release_from_url(&mock_url).await;
503
504        assert!(result.is_err());
505    }
506
507    #[tokio::test]
508    async fn test_no_action_specified() {
509        let cmd = create_test_cmd(None, false, false);
510        let global_args = global::Args::default();
511        let result = cmd.run(&global_args).await;
512        assert!(matches!(result, Err(Error::NoActionSpecified)));
513    }
514}