stellar_scaffold_cli/commands/generate/contract/
mod.rs

1use clap::Parser;
2use flate2::read::GzDecoder;
3use reqwest;
4use serde::Deserialize;
5use std::{fs, path::Path};
6use stellar_cli::commands::global;
7use stellar_cli::print::Print;
8use tar::Archive;
9
10#[derive(Deserialize)]
11struct Release {
12    tag_name: String,
13}
14
15#[derive(Parser, Debug)]
16pub struct Cmd {
17    /// Clone contract from `OpenZeppelin` examples
18    #[arg(long, conflicts_with_all = ["ls", "from_wizard"])]
19    pub from: Option<String>,
20
21    /// List available contract examples
22    #[arg(long, conflicts_with_all = ["from", "from_wizard"])]
23    pub ls: bool,
24
25    /// Open contract generation wizard in browser
26    #[arg(long, conflicts_with_all = ["from", "ls"])]
27    pub from_wizard: bool,
28
29    /// Output directory for the generated contract (defaults to contracts/<example-name>)
30    #[arg(short, long)]
31    pub output: Option<String>,
32}
33
34#[derive(thiserror::Error, Debug)]
35pub enum Error {
36    #[error(transparent)]
37    Io(#[from] std::io::Error),
38    #[error(transparent)]
39    Reqwest(#[from] reqwest::Error),
40    #[error(transparent)]
41    CargoToml(#[from] cargo_toml::Error),
42    #[error(transparent)]
43    TomlDeserialize(#[from] toml::de::Error),
44    #[error(transparent)]
45    TomlSerialize(#[from] toml::ser::Error),
46    #[error("Git command failed: {0}")]
47    GitCloneFailed(String),
48    #[error("Example '{0}' not found in OpenZeppelin stellar-contracts")]
49    ExampleNotFound(String),
50    #[error("Failed to open browser: {0}")]
51    BrowserFailed(String),
52    #[error("No action specified. Use --from, --ls, or --from-wizard")]
53    NoActionSpecified,
54}
55
56impl Cmd {
57    pub async fn run(&self, global_args: &global::Args) -> Result<(), Error> {
58        match (&self.from, self.ls, self.from_wizard) {
59            (Some(example_name), _, _) => self.clone_example(example_name, global_args).await,
60            (_, true, _) => self.list_examples(global_args).await,
61            (_, _, true) => open_wizard(global_args),
62            _ => Err(Error::NoActionSpecified),
63        }
64    }
65
66    async fn clone_example(
67        &self,
68        example_name: &str,
69        global_args: &global::Args,
70    ) -> Result<(), Error> {
71        let printer = Print::new(global_args.quiet);
72
73        printer.infoln(format!("Downloading example '{example_name}'..."));
74
75        let dest_path = self
76            .output
77            .clone()
78            .unwrap_or_else(|| format!("contracts/{example_name}"));
79
80        let repo_cache_path = self.ensure_cache_updated(global_args).await?;
81
82        // Check if the example exists
83        let example_source_path = repo_cache_path.join(format!("examples/{example_name}"));
84        if !example_source_path.exists() {
85            return Err(Error::ExampleNotFound(example_name.to_string()));
86        }
87
88        // Create destination and copy example contents
89        fs::create_dir_all(&dest_path)?;
90        Self::copy_directory_contents(&example_source_path, Path::new(&dest_path))?;
91
92        // Get the latest release tag we're using
93        let Release { tag_name } = Self::fetch_latest_release().await?;
94
95        // Read and update workspace Cargo.toml
96        let workspace_cargo_path = Path::new("Cargo.toml");
97        if workspace_cargo_path.exists() {
98            Self::update_workspace_dependencies(
99                workspace_cargo_path,
100                &example_source_path,
101                &tag_name,
102                global_args,
103            )?;
104        } else {
105            printer.warnln("Warning: No workspace Cargo.toml found in current directory.");
106            printer
107                .println("   You'll need to manually add required dependencies to your workspace.");
108        }
109
110        printer.checkln(format!(
111            "Successfully downloaded example '{example_name}' to {dest_path}"
112        ));
113        printer
114            .infoln("You may need to modify your environments.toml to add constructor arguments!");
115        Ok(())
116    }
117
118    fn update_workspace_dependencies(
119        workspace_path: &Path,
120        example_path: &Path,
121        tag: &str,
122        global_args: &global::Args,
123    ) -> Result<(), Error> {
124        let printer = Print::new(global_args.quiet);
125
126        let example_cargo_content = fs::read_to_string(example_path.join("Cargo.toml"))?;
127        let deps = Self::extract_stellar_dependencies(&example_cargo_content)?;
128        if deps.is_empty() {
129            return Ok(());
130        }
131
132        // Parse the workspace Cargo.toml
133        let mut manifest = cargo_toml::Manifest::from_path(workspace_path)?;
134
135        // Ensure workspace.dependencies exists
136        if manifest.workspace.is_none() {
137            // Create a minimal workspace with just what we need
138            let workspace_toml = r"
139[workspace]
140members = []
141
142[workspace.dependencies]
143";
144            let workspace: cargo_toml::Workspace<toml::Value> = toml::from_str(workspace_toml)?;
145            manifest.workspace = Some(workspace);
146        }
147        let workspace = manifest.workspace.as_mut().unwrap();
148
149        let mut workspace_deps = workspace.dependencies.clone();
150
151        let mut added_deps = Vec::new();
152        let mut updated_deps = Vec::new();
153
154        for dep in deps {
155            let git_dep = cargo_toml::DependencyDetail {
156                git: Some("https://github.com/OpenZeppelin/stellar-contracts".to_string()),
157                tag: Some(tag.to_string()),
158                ..Default::default()
159            };
160
161            if let Some(existing_dep) = workspace_deps.clone().get(&dep) {
162                // Check if we need to update the tag
163                if let cargo_toml::Dependency::Detailed(detail) = existing_dep {
164                    if let Some(existing_tag) = &detail.tag {
165                        if existing_tag != tag {
166                            workspace_deps.insert(
167                                dep.clone(),
168                                cargo_toml::Dependency::Detailed(Box::new(git_dep)),
169                            );
170                            updated_deps.push((dep, existing_tag.clone()));
171                        }
172                    }
173                }
174            } else {
175                workspace_deps.insert(
176                    dep.clone(),
177                    cargo_toml::Dependency::Detailed(Box::new(git_dep)),
178                );
179                added_deps.push(dep);
180            }
181        }
182
183        if !added_deps.is_empty() || !updated_deps.is_empty() {
184            workspace.dependencies = workspace_deps;
185            // Write the updated manifest back to file
186            let toml_string = toml::to_string_pretty(&manifest)?;
187            fs::write(workspace_path, toml_string)?;
188
189            if !added_deps.is_empty() {
190                printer.infoln("Added the following dependencies to workspace:");
191                for dep in added_deps {
192                    printer.println(format!("   • {dep}"));
193                }
194            }
195
196            if !updated_deps.is_empty() {
197                printer.infoln("Updated the following dependencies:");
198                for (dep, old_tag) in updated_deps {
199                    printer.println(format!("   • {dep}: {old_tag} -> {tag}"));
200                }
201            }
202        }
203
204        Ok(())
205    }
206
207    fn extract_stellar_dependencies(cargo_toml_content: &str) -> Result<Vec<String>, Error> {
208        let manifest: cargo_toml::Manifest = toml::from_str(cargo_toml_content)?;
209
210        Ok(manifest
211            .dependencies
212            .iter()
213            .filter(|(dep_name, _)| dep_name.starts_with("stellar-"))
214            .filter_map(|(dep_name, dep_detail)| match dep_detail {
215                cargo_toml::Dependency::Detailed(detail)
216                    if !(detail.inherited || detail.git.is_some()) =>
217                {
218                    None
219                }
220                _ => Some(dep_name.clone()),
221            })
222            .collect())
223    }
224
225    async fn list_examples(&self, global_args: &global::Args) -> Result<(), Error> {
226        let printer = Print::new(global_args.quiet);
227
228        printer.infoln("Fetching available contract examples...");
229
230        let repo_cache_path = self.ensure_cache_updated(global_args).await?;
231        let examples_path = repo_cache_path.join("examples");
232
233        let mut examples: Vec<String> = if examples_path.exists() {
234            fs::read_dir(examples_path)?
235                .filter_map(std::result::Result::ok)
236                .filter(|entry| entry.path().is_dir())
237                .filter_map(|entry| {
238                    entry
239                        .file_name()
240                        .to_str()
241                        .map(std::string::ToString::to_string)
242                })
243                .collect()
244        } else {
245            Vec::new()
246        };
247
248        examples.sort();
249
250        printer.println("\nAvailable contract examples:");
251        printer.println("────────────────────────────────");
252
253        for example in &examples {
254            printer.println(format!("  📁 {example}"));
255        }
256
257        printer.println("\nUsage:");
258        printer.println("   stellar-scaffold contract generate --from <example-name>");
259        printer.println("   Example: stellar-scaffold contract generate --from nft-royalties");
260
261        Ok(())
262    }
263
264    async fn fetch_latest_release() -> Result<Release, Error> {
265        Self::fetch_latest_release_from_url(
266            "https://api.github.com/repos/OpenZeppelin/stellar-contracts/releases/latest",
267        )
268        .await
269    }
270
271    async fn fetch_latest_release_from_url(url: &str) -> Result<Release, Error> {
272        let client = reqwest::Client::new();
273        let response = client
274            .get(url)
275            .header("User-Agent", "stellar-scaffold-cli")
276            .send()
277            .await?;
278
279        if !response.status().is_success() {
280            return Err(Error::Reqwest(response.error_for_status().unwrap_err()));
281        }
282
283        let release: Release = response.json().await?;
284        Ok(release)
285    }
286
287    async fn cache_repository(
288        repo_cache_path: &Path,
289        cache_ref_file: &Path,
290        tag_name: &str,
291    ) -> Result<(), Error> {
292        fs::create_dir_all(repo_cache_path)?;
293
294        // Download and extract the specific tag directly
295        Self::download_and_extract_tag(repo_cache_path, tag_name).await?;
296
297        if repo_cache_path.read_dir()?.next().is_none() {
298            return Err(Error::GitCloneFailed(format!(
299                "Failed to download repository release {tag_name} to cache"
300            )));
301        }
302
303        fs::write(cache_ref_file, tag_name)?;
304        Ok(())
305    }
306
307    async fn download_and_extract_tag(dest_path: &Path, tag_name: &str) -> Result<(), Error> {
308        let url =
309            format!("https://github.com/OpenZeppelin/stellar-contracts/archive/{tag_name}.tar.gz",);
310
311        // Download the tar.gz file
312        let client = reqwest::Client::new();
313        let response = client
314            .get(&url)
315            .header("User-Agent", "stellar-scaffold-cli")
316            .send()
317            .await?;
318
319        if !response.status().is_success() {
320            return Err(Error::GitCloneFailed(format!(
321                "Failed to download release {tag_name}: HTTP {}",
322                response.status()
323            )));
324        }
325
326        // Get the response bytes
327        let bytes = response.bytes().await?;
328
329        // Extract the tar.gz in a blocking task to avoid blocking the async runtime
330        let dest_path = dest_path.to_path_buf();
331        tokio::task::spawn_blocking(move || {
332            let tar = GzDecoder::new(std::io::Cursor::new(bytes));
333            let mut archive = Archive::new(tar);
334
335            for entry in archive.entries()? {
336                let mut entry = entry?;
337                let path = entry.path()?;
338
339                // Strip the root directory (stellar-contracts-{tag}/)
340                let stripped_path = path.components().skip(1).collect::<std::path::PathBuf>();
341
342                if stripped_path.as_os_str().is_empty() {
343                    continue;
344                }
345
346                let dest_file_path = dest_path.join(&stripped_path);
347
348                if entry.header().entry_type().is_dir() {
349                    std::fs::create_dir_all(&dest_file_path)?;
350                } else {
351                    if let Some(parent) = dest_file_path.parent() {
352                        std::fs::create_dir_all(parent)?;
353                    }
354                    entry.unpack(&dest_file_path)?;
355                }
356            }
357
358            Ok::<(), std::io::Error>(())
359        })
360        .await
361        .map_err(|e| {
362            Error::Io(std::io::Error::new(
363                std::io::ErrorKind::Other,
364                e.to_string(),
365            ))
366        })?
367        .map_err(Error::Io)?;
368
369        Ok(())
370    }
371
372    async fn ensure_cache_updated(
373        &self,
374        global_args: &global::Args,
375    ) -> Result<std::path::PathBuf, Error> {
376        let printer = Print::new(global_args.quiet);
377
378        let cache_dir = dirs::cache_dir().ok_or_else(|| {
379            Error::Io(std::io::Error::new(
380                std::io::ErrorKind::NotFound,
381                "Cache directory not found",
382            ))
383        })?;
384
385        let base_cache_path = cache_dir.join("stellar-scaffold-cli/openzeppelin-stellar-contracts");
386
387        // Get the latest release tag
388        let Release { tag_name } = Self::fetch_latest_release().await?;
389        let repo_cache_path = base_cache_path.join(&tag_name);
390        let cache_ref_file = repo_cache_path.join(".release_ref");
391
392        let should_update_cache = if repo_cache_path.exists() {
393            if let Ok(cached_tag) = fs::read_to_string(&cache_ref_file) {
394                if cached_tag.trim() == tag_name {
395                    printer.infoln(format!("Using cached repository (release {tag_name})..."));
396                    false
397                } else {
398                    printer.infoln(format!(
399                        "New release available ({tag_name}). Updating cache..."
400                    ));
401                    true
402                }
403            } else {
404                printer.infoln("Cache metadata missing. Updating...");
405                true
406            }
407        } else {
408            printer.infoln(format!(
409                "Cache not found. Downloading release {tag_name}..."
410            ));
411            true
412        };
413
414        if should_update_cache {
415            if repo_cache_path.exists() {
416                fs::remove_dir_all(&repo_cache_path)?;
417            }
418            Self::cache_repository(&repo_cache_path, &cache_ref_file, &tag_name).await?;
419        }
420
421        Ok(repo_cache_path)
422    }
423
424    fn copy_directory_contents(source: &Path, dest: &Path) -> Result<(), Error> {
425        let copy_options = fs_extra::dir::CopyOptions::new()
426            .overwrite(true)
427            .content_only(true);
428
429        fs_extra::dir::copy(source, dest, &copy_options)
430            .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
431
432        Ok(())
433    }
434}
435
436fn open_wizard(global_args: &global::Args) -> Result<(), Error> {
437    let printer = Print::new(global_args.quiet);
438
439    printer.infoln("Opening OpenZeppelin Contract Wizard...");
440
441    let url = "https://wizard.openzeppelin.com/stellar";
442
443    webbrowser::open(url)
444        .map_err(|e| Error::BrowserFailed(format!("Failed to open browser: {e}")))?;
445
446    printer.checkln("Opened Contract Wizard in your default browser");
447    printer.println("\nInstructions:");
448    printer.println("   1. Configure your contract in the wizard");
449    printer.println("   2. Click 'Download' to get your contract files");
450    printer.println("   3. Extract the downloaded ZIP file");
451    printer.println("   4. Move the contract folder to your contracts/ directory");
452    printer.println("   5. Add the contract to your workspace Cargo.toml if needed");
453    printer.println(
454        "   6. You may need to modify your environments.toml file to add constructor arguments",
455    );
456    printer.infoln(
457        "The wizard will generate a complete Soroban contract with your selected features!",
458    );
459
460    Ok(())
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use mockito::{mock, server_url};
467
468    fn create_test_cmd(from: Option<String>, ls: bool, from_wizard: bool) -> Cmd {
469        Cmd {
470            from,
471            ls,
472            from_wizard,
473            output: None,
474        }
475    }
476
477    #[tokio::test]
478    async fn test_ls_command() {
479        let cmd = create_test_cmd(None, true, false);
480        let global_args = global::Args::default();
481
482        let _m = mock(
483            "GET",
484            "/repos/OpenZeppelin/stellar-contracts/contents/examples",
485        )
486        .with_status(200)
487        .with_header("content-type", "application/json")
488        .with_body(r#"[{"name": "example1", "type": "dir"}, {"name": "example2", "type": "dir"}]"#)
489        .create();
490
491        let result = cmd.run(&global_args).await;
492        assert!(result.is_ok());
493    }
494
495    #[tokio::test]
496    async fn test_fetch_latest_release() {
497        let _m = mock(
498            "GET",
499            "/repos/OpenZeppelin/stellar-contracts/releases/latest",
500        )
501        .with_status(200)
502        .with_header("content-type", "application/json")
503        .with_body(
504            r#"{
505                "tag_name": "v1.2.3",
506                "name": "Release v1.2.3",
507                "published_at": "2024-01-15T10:30:00Z"
508            }"#,
509        )
510        .create();
511
512        let mock_url = format!(
513            "{}/repos/OpenZeppelin/stellar-contracts/releases/latest",
514            server_url()
515        );
516        let result = Cmd::fetch_latest_release_from_url(&mock_url).await;
517
518        assert!(result.is_ok());
519        let release = result.unwrap();
520        assert_eq!(release.tag_name, "v1.2.3");
521    }
522
523    #[tokio::test]
524    async fn test_fetch_latest_release_error() {
525        let _m = mock(
526            "GET",
527            "/repos/OpenZeppelin/stellar-contracts/releases/latest",
528        )
529        .with_status(404)
530        .with_header("content-type", "application/json")
531        .with_body(r#"{"message": "Not Found"}"#)
532        .create();
533
534        let mock_url = format!(
535            "{}/repos/OpenZeppelin/stellar-contracts/releases/latest",
536            server_url()
537        );
538        let result = Cmd::fetch_latest_release_from_url(&mock_url).await;
539
540        assert!(result.is_err());
541    }
542
543    #[tokio::test]
544    async fn test_no_action_specified() {
545        let cmd = create_test_cmd(None, false, false);
546        let global_args = global::Args::default();
547        let result = cmd.run(&global_args).await;
548        assert!(matches!(result, Err(Error::NoActionSpecified)));
549    }
550}