Skip to main content

torsh_cli/commands/
init.rs

1//! Project initialization commands
2
3use anyhow::Result;
4use clap::Args;
5use std::path::{Path, PathBuf};
6
7use crate::config::Config;
8use crate::utils::output;
9
10#[derive(Debug, Args)]
11pub struct InitCommand {
12    /// Project name
13    #[arg(short, long)]
14    pub name: Option<String>,
15
16    /// Project directory
17    #[arg(short, long)]
18    pub directory: Option<PathBuf>,
19
20    /// Project template (basic, vision, nlp, custom)
21    #[arg(short, long, default_value = "basic")]
22    pub template: String,
23
24    /// Enable Git repository initialization
25    #[arg(long)]
26    pub git: bool,
27
28    /// Use interactive mode
29    #[arg(short, long)]
30    pub interactive: bool,
31}
32
33pub async fn execute(args: InitCommand, _config: &Config, _output_format: &str) -> Result<()> {
34    output::print_info("Initializing new ToRSh project...");
35
36    let project_name = args.name.unwrap_or_else(|| "torsh-project".to_string());
37    let project_dir = args
38        .directory
39        .unwrap_or_else(|| PathBuf::from(&project_name));
40
41    // Create project directory
42    tokio::fs::create_dir_all(&project_dir).await?;
43
44    // Create basic project structure
45    create_project_structure(&project_dir, &args.template).await?;
46
47    output::print_success(&format!(
48        "Project '{}' initialized successfully!",
49        project_name
50    ));
51    output::print_info(&format!("Location: {}", project_dir.display()));
52
53    Ok(())
54}
55
56async fn create_project_structure(dir: &Path, template: &str) -> Result<()> {
57    // Create basic directories
58    let src_dir = dir.join("src");
59    tokio::fs::create_dir_all(&src_dir).await?;
60
61    // Create main.rs with appropriate template
62    let main_content = match template {
63        "vision" => {
64            r#"//! Vision project template
65//!
66//! This template provides a starting point for computer vision projects using ToRSh.
67//! Uncomment and modify the code below to train your vision model.
68
69use anyhow::Result;
70// Uncomment these imports when you're ready to use them:
71// use torsh::prelude::*;
72// use torsh_models::vision::*;
73// use torsh_optim::Adam;
74
75fn main() -> Result<()> {
76    println!("ToRSh Vision Project");
77
78    // Step 1: Create or load a vision model
79    // Example - ResNet for image classification:
80    // let config = ResNetConfig {
81    //     variant: ResNetVariant::ResNet18,
82    //     num_classes: 10,  // e.g., CIFAR-10
83    //     ..Default::default()
84    // };
85    // let mut model = ResNet::new(config)?;
86
87    // Step 2: Prepare your dataset
88    // let train_loader = create_image_dataloader("path/to/train", batch_size: 32)?;
89    // let val_loader = create_image_dataloader("path/to/val", batch_size: 32)?;
90
91    // Step 3: Setup optimizer and loss function
92    // let mut optimizer = Adam::new(model.parameters(), 0.001)?;
93    // let loss_fn = CrossEntropyLoss::new();
94
95    // Step 4: Training loop
96    // for epoch in 1..=10 {
97    //     model.train();
98    //     for (images, labels) in train_loader.iter() {
99    //         let predictions = model.forward(&images)?;
100    //         let loss = loss_fn.forward(&predictions, &labels)?;
101    //         loss.backward()?;
102    //         optimizer.step()?;
103    //         optimizer.zero_grad()?;
104    //     }
105    //
106    //     // Validation
107    //     model.eval();
108    //     let accuracy = evaluate(&model, &val_loader)?;
109    //     println!("Epoch {}: Validation Accuracy = {:.2}%", epoch, accuracy * 100.0);
110    // }
111
112    println!("Tip: Check torsh-models documentation for available vision models!");
113    println!("Available models: ResNet, VisionTransformer, EfficientNet (v0.2.0+)");
114
115    Ok(())
116}
117"#
118        }
119        "nlp" => {
120            r#"//! NLP project template
121//!
122//! This template provides a starting point for natural language processing projects using ToRSh.
123//! Uncomment and modify the code below to train your NLP model.
124
125use anyhow::Result;
126// Uncomment these imports when you're ready to use them:
127// use torsh::prelude::*;
128// use torsh_models::nlp::*;
129// use torsh_optim::AdamW;
130
131fn main() -> Result<()> {
132    println!("ToRSh NLP Project");
133
134    // Step 1: Create or load an NLP model
135    // Example - RoBERTa for text classification:
136    // let config = RobertaConfig {
137    //     vocab_size: 50265,
138    //     hidden_size: 768,
139    //     num_hidden_layers: 12,
140    //     num_attention_heads: 12,
141    //     ..Default::default()
142    // };
143    // let mut model = RobertaForSequenceClassification::new(config, num_labels: 2)?;
144
145    // Step 2: Prepare your text dataset
146    // let tokenizer = load_tokenizer("roberta-base")?;
147    // let train_loader = create_text_dataloader("path/to/train.csv", tokenizer, batch_size: 16)?;
148    // let val_loader = create_text_dataloader("path/to/val.csv", tokenizer, batch_size: 16)?;
149
150    // Step 3: Setup optimizer and loss function
151    // let mut optimizer = AdamW::new(model.parameters(), learning_rate: 2e-5)?;
152    // let loss_fn = CrossEntropyLoss::new();
153
154    // Step 4: Training loop
155    // for epoch in 1..=5 {
156    //     model.train();
157    //     for (input_ids, attention_mask, labels) in train_loader.iter() {
158    //         let logits = model.forward(&input_ids)?;
159    //         let loss = loss_fn.forward(&logits, &labels)?;
160    //         loss.backward()?;
161    //         optimizer.step()?;
162    //         optimizer.zero_grad()?;
163    //     }
164    //
165    //     // Validation
166    //     model.eval();
167    //     let accuracy = evaluate(&model, &val_loader)?;
168    //     println!("Epoch {}: Validation Accuracy = {:.2}%", epoch, accuracy * 100.0);
169    // }
170
171    println!("Tip: Check torsh-models documentation for available NLP models!");
172    println!("Available models: RoBERTa, BERT (v0.2.0+), GPT-2 (v0.2.0+)");
173
174    Ok(())
175}
176"#
177        }
178        _ => {
179            r#"//! Basic ToRSh project template
180//!
181//! This template provides a minimal starting point for machine learning projects using ToRSh.
182//! Uncomment and modify the code below to build your model.
183
184use anyhow::Result;
185// Uncomment these imports when you're ready to use them:
186// use torsh::prelude::*;
187// use torsh_nn::{Linear, Module, Sequential};
188// use torsh_optim::SGD;
189
190fn main() -> Result<()> {
191    println!("ToRSh Basic Project - Getting Started");
192
193    // Step 1: Create a simple neural network
194    // Example - Basic feedforward network:
195    // let model = Sequential::new(vec![
196    //     Linear::new(input_dim: 784, output_dim: 128, bias: true)?,
197    //     ReLU::new(),
198    //     Linear::new(128, 10, true)?,
199    // ]);
200
201    // Step 2: Prepare your dataset
202    // let train_data = load_dataset("path/to/train.csv")?;
203    // let val_data = load_dataset("path/to/val.csv")?;
204
205    // Step 3: Setup optimizer and loss function
206    // let mut optimizer = SGD::new(model.parameters(), learning_rate: 0.01)?;
207    // let loss_fn = MSELoss::new();
208
209    // Step 4: Training loop
210    // let epochs = 10;
211    // for epoch in 1..=epochs {
212    //     model.train();
213    //     let predictions = model.forward(&train_data.inputs)?;
214    //     let loss = loss_fn.forward(&predictions, &train_data.targets)?;
215    //
216    //     loss.backward()?;
217    //     optimizer.step()?;
218    //     optimizer.zero_grad()?;
219    //
220    //     println!("Epoch {}/{}: Loss = {:.4}", epoch, epochs, loss.item());
221    // }
222
223    println!("\nNext steps:");
224    println!("1. Add torsh dependencies to Cargo.toml");
225    println!("2. Import necessary modules (torsh::prelude::*, torsh_nn::*, torsh_optim::*)");
226    println!("3. Define your model architecture");
227    println!("4. Load your dataset");
228    println!("5. Train and evaluate your model");
229    println!("\nFor more examples, see: https://github.com/cool-japan/torsh/tree/main/examples");
230
231    Ok(())
232}
233"#
234        }
235    };
236
237    // Write the template content to main.rs
238    let main_rs = src_dir.join("main.rs");
239    tokio::fs::write(&main_rs, main_content).await?;
240
241    // Create Cargo.toml with correct version
242    let cargo_toml = dir.join("Cargo.toml");
243    let cargo_content = r#"[package]
244name = "torsh-project"
245version = "0.1.0"
246edition = "2021"
247
248[dependencies]
249torsh = "0.1.0"
250anyhow = "1.0"
251
252# Uncomment these dependencies as needed:
253# torsh-models = "0.1.0"
254# torsh-optim = "0.1.0"
255# torsh-data = "0.1.0"
256"#;
257    tokio::fs::write(&cargo_toml, cargo_content).await?;
258
259    Ok(())
260}