soroban_cli/commands/contract/
init.rs

1use std::borrow::Cow;
2use std::{
3    fs::{create_dir_all, metadata, write, Metadata},
4    io,
5    path::{Path, PathBuf},
6    str,
7};
8
9use clap::Parser;
10use rust_embed::RustEmbed;
11
12use crate::{commands::global, print};
13
14#[derive(Parser, Debug, Clone)]
15#[group(skip)]
16pub struct Cmd {
17    pub project_path: String,
18
19    #[arg(
20        long,
21        default_value = "hello-world",
22        long_help = "An optional flag to specify a new contract's name."
23    )]
24    pub name: String,
25
26    #[arg(long, long_help = "Overwrite all existing files.")]
27    pub overwrite: bool,
28}
29
30#[derive(thiserror::Error, Debug)]
31pub enum Error {
32    #[error("{0}: {1}")]
33    Io(String, io::Error),
34
35    #[error(transparent)]
36    Std(#[from] std::io::Error),
37
38    #[error("failed to convert bytes to string: {0}")]
39    ConvertBytesToString(#[from] str::Utf8Error),
40
41    #[error("contract package already exists: {0}")]
42    AlreadyExists(String),
43
44    #[error("provided project path exists and is not a directory")]
45    PathExistsNotDir,
46
47    #[error("provided project path exists and is not a cargo workspace root directory. Hint: run init on an empty or non-existing directory"
48    )]
49    PathExistsNotCargoProject,
50}
51
52impl Cmd {
53    #[allow(clippy::unused_self)]
54    pub fn run(&self, global_args: &global::Args) -> Result<(), Error> {
55        let runner = Runner {
56            args: self.clone(),
57            print: print::Print::new(global_args.quiet),
58        };
59
60        runner.run()
61    }
62}
63
64#[derive(RustEmbed)]
65#[folder = "src/utils/contract-workspace-template"]
66struct WorkspaceTemplateFiles;
67
68#[derive(RustEmbed)]
69#[folder = "src/utils/contract-template"]
70struct ContractTemplateFiles;
71
72struct Runner {
73    args: Cmd,
74    print: print::Print,
75}
76
77impl Runner {
78    fn run(&self) -> Result<(), Error> {
79        let project_path = PathBuf::from(&self.args.project_path);
80        self.print
81            .infoln(format!("Initializing workspace at {project_path:?}"));
82
83        // create a project dir, and copy the contents of the base template (contract-init-template) into it
84        Self::create_dir_all(&project_path)?;
85        self.copy_template_files(
86            project_path.as_path(),
87            &mut WorkspaceTemplateFiles::iter(),
88            WorkspaceTemplateFiles::get,
89        )?;
90
91        let contract_path = project_path.join("contracts").join(&self.args.name);
92        self.print
93            .infoln(format!("Initializing contract at {contract_path:?}"));
94
95        Self::create_dir_all(contract_path.as_path())?;
96        self.copy_template_files(
97            contract_path.as_path(),
98            &mut ContractTemplateFiles::iter(),
99            ContractTemplateFiles::get,
100        )?;
101
102        Ok(())
103    }
104
105    fn copy_template_files(
106        &self,
107        root_path: &Path,
108        files: &mut dyn Iterator<Item = Cow<str>>,
109        getter: fn(&str) -> Option<rust_embed::EmbeddedFile>,
110    ) -> Result<(), Error> {
111        for item in &mut *files {
112            let mut to = root_path.join(item.as_ref());
113            // We need to include the Cargo.toml file as Cargo.toml.removeextension in the template
114            // so that it will be included the package. This is making sure that the Cargo file is
115            // written as Cargo.toml in the new project. This is a workaround for this issue:
116            // https://github.com/rust-lang/cargo/issues/8597.
117            let item_path = Path::new(item.as_ref());
118            let is_toml = item_path.file_name().unwrap() == "Cargo.toml.removeextension";
119            if is_toml {
120                let item_parent_path = item_path.parent().unwrap();
121                to = root_path.join(item_parent_path).join("Cargo.toml");
122            }
123
124            let exists = Self::file_exists(&to);
125            if exists && !self.args.overwrite {
126                self.print
127                    .infoln(format!("Skipped creating {to:?} as it already exists"));
128                continue;
129            }
130
131            Self::create_dir_all(to.parent().unwrap())?;
132
133            let Some(file) = getter(item.as_ref()) else {
134                self.print
135                    .warnln(format!("Failed to read file: {}", item.as_ref()));
136                continue;
137            };
138
139            let mut file_contents = str::from_utf8(file.data.as_ref())
140                .map_err(Error::ConvertBytesToString)?
141                .to_string();
142
143            if is_toml {
144                let new_content = file_contents.replace("%contract-template%", &self.args.name);
145                file_contents = new_content;
146            }
147
148            if exists {
149                self.print
150                    .plusln(format!("Writing {to:?} (overwriting existing file)"));
151            } else {
152                self.print.plusln(format!("Writing {to:?}"));
153            }
154            Self::write(&to, &file_contents)?;
155        }
156
157        Ok(())
158    }
159
160    fn file_exists(file_path: &Path) -> bool {
161        metadata(file_path)
162            .as_ref()
163            .map(Metadata::is_file)
164            .unwrap_or(false)
165    }
166
167    fn create_dir_all(path: &Path) -> Result<(), Error> {
168        create_dir_all(path).map_err(|e| Error::Io(format!("creating directory: {path:?}"), e))
169    }
170
171    fn write(path: &Path, contents: &str) -> Result<(), Error> {
172        write(path, contents).map_err(|e| Error::Io(format!("writing file: {path:?}"), e))
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::fs;
179    use std::fs::read_to_string;
180
181    use itertools::Itertools;
182
183    use super::*;
184
185    const TEST_PROJECT_NAME: &str = "test-project";
186
187    #[test]
188    fn test_init() {
189        let temp_dir = tempfile::tempdir().unwrap();
190        let project_dir = temp_dir.path().join(TEST_PROJECT_NAME);
191        let runner = Runner {
192            args: Cmd {
193                project_path: project_dir.to_string_lossy().to_string(),
194                name: "hello_world".to_string(),
195                overwrite: false,
196            },
197            print: print::Print::new(false),
198        };
199        runner.run().unwrap();
200
201        assert_base_template_files_exist(&project_dir);
202
203        assert_contract_files_exist(&project_dir, "hello_world");
204        assert_excluded_paths_do_not_exist(&project_dir);
205
206        assert_contract_cargo_file_is_well_formed(&project_dir, "hello_world");
207        assert_excluded_paths_do_not_exist(&project_dir);
208
209        let runner = Runner {
210            args: Cmd {
211                project_path: project_dir.to_string_lossy().to_string(),
212                name: "contract2".to_string(),
213                overwrite: false,
214            },
215            print: print::Print::new(false),
216        };
217        runner.run().unwrap();
218
219        assert_contract_files_exist(&project_dir, "contract2");
220        assert_excluded_paths_do_not_exist(&project_dir);
221
222        assert_contract_cargo_file_is_well_formed(&project_dir, "contract2");
223        assert_excluded_paths_do_not_exist(&project_dir);
224
225        temp_dir.close().unwrap();
226    }
227
228    // test helpers
229    fn assert_base_template_files_exist(project_dir: &Path) {
230        let expected_paths = ["contracts", "Cargo.toml", "README.md"];
231        for path in &expected_paths {
232            assert!(project_dir.join(path).exists());
233        }
234    }
235
236    fn assert_contract_files_exist(project_dir: &Path, contract_name: &str) {
237        let contract_dir = project_dir.join("contracts").join(contract_name);
238
239        assert!(contract_dir.exists());
240        assert!(contract_dir.as_path().join("Cargo.toml").exists());
241        assert!(contract_dir.as_path().join("src").join("lib.rs").exists());
242        assert!(contract_dir.as_path().join("src").join("test.rs").exists());
243    }
244
245    fn assert_contract_cargo_file_is_well_formed(project_dir: &Path, contract_name: &str) {
246        let contract_dir = project_dir.join("contracts").join(contract_name);
247        let cargo_toml_path = contract_dir.as_path().join("Cargo.toml");
248        let cargo_toml_str = read_to_string(cargo_toml_path.clone()).unwrap();
249        let doc: toml_edit::DocumentMut = cargo_toml_str.parse().unwrap();
250        assert!(
251            doc.get("dependencies")
252                .unwrap()
253                .get("soroban-sdk")
254                .unwrap()
255                .get("workspace")
256                .unwrap()
257                .as_bool()
258                .unwrap(),
259            "expected [dependencies.soroban-sdk] to be a workspace dependency"
260        );
261        assert!(
262            doc.get("dev-dependencies")
263                .unwrap()
264                .get("soroban-sdk")
265                .unwrap()
266                .get("workspace")
267                .unwrap()
268                .as_bool()
269                .unwrap(),
270            "expected [dev-dependencies.soroban-sdk] to be a workspace dependency"
271        );
272        assert_ne!(
273            0,
274            doc.get("dev-dependencies")
275                .unwrap()
276                .get("soroban-sdk")
277                .unwrap()
278                .get("features")
279                .unwrap()
280                .as_array()
281                .unwrap()
282                .len(),
283            "expected [dev-dependencies.soroban-sdk] to have a features list"
284        );
285        assert!(
286            doc.get("dev_dependencies").is_none(),
287            "erroneous 'dev_dependencies' section"
288        );
289        assert_eq!(
290            doc.get("lib")
291                .unwrap()
292                .get("crate-type")
293                .unwrap()
294                .as_array()
295                .unwrap()
296                .iter()
297                .map(|v| v.as_str().unwrap())
298                .collect::<Vec<_>>(),
299            ["lib", "cdylib"],
300            "expected [lib.crate-type] to be lib,cdylib"
301        );
302    }
303
304    fn assert_excluded_paths_do_not_exist(project_dir: &Path) {
305        let base_excluded_paths = [".git", ".github", "Makefile", ".vscode", "target"];
306        for path in &base_excluded_paths {
307            let filepath = project_dir.join(path);
308            assert!(!filepath.exists(), "{filepath:?} should not exist");
309        }
310        let contract_excluded_paths = ["target", "Cargo.lock"];
311        let contract_dirs = fs::read_dir(project_dir.join("contracts"))
312            .unwrap()
313            .map(|entry| entry.unwrap().path());
314        contract_dirs
315            .cartesian_product(contract_excluded_paths.iter())
316            .for_each(|(contract_dir, excluded_path)| {
317                let filepath = contract_dir.join(excluded_path);
318                assert!(!filepath.exists(), "{filepath:?} should not exist");
319            });
320    }
321}