Skip to main content

torsh_cli/commands/model/
args.rs

1//! Command-line argument structures for model operations
2
3use clap::Args;
4use std::path::PathBuf;
5
6/// Arguments for model conversion
7#[derive(Debug, Args)]
8pub struct ConvertArgs {
9    /// Input model file path
10    #[arg(short, long)]
11    pub input: PathBuf,
12
13    /// Output model file path
14    #[arg(short, long)]
15    pub output: PathBuf,
16
17    /// Target format for conversion
18    #[arg(short, long)]
19    pub format: String,
20
21    /// Optimization level (0-3)
22    #[arg(long, default_value = "1")]
23    pub optimization_level: u8,
24
25    /// Preserve model metadata during conversion
26    #[arg(long)]
27    pub preserve_metadata: bool,
28
29    /// Target device for optimization
30    #[arg(long, default_value = "cpu")]
31    pub target_device: String,
32
33    /// Enable verbose output
34    #[arg(long)]
35    pub verbose: bool,
36}
37
38/// Arguments for model optimization
39#[derive(Debug, Args)]
40pub struct OptimizeArgs {
41    /// Input model file path
42    #[arg(short, long)]
43    pub input: PathBuf,
44
45    /// Output optimized model file path
46    #[arg(short, long)]
47    pub output: PathBuf,
48
49    /// Optimization level (0-3)
50    #[arg(long, default_value = "2")]
51    pub level: u8,
52
53    /// Target device for optimization
54    #[arg(long, default_value = "cpu")]
55    pub target: String,
56
57    /// Enable operator fusion
58    #[arg(long)]
59    pub fusion: bool,
60
61    /// Enable constant folding
62    #[arg(long)]
63    pub constant_folding: bool,
64
65    /// Enable dead code elimination
66    #[arg(long)]
67    pub dead_code_elimination: bool,
68
69    /// Memory optimization passes
70    #[arg(long)]
71    pub memory_optimization: bool,
72}
73
74/// Arguments for model quantization
75#[derive(Debug, Args)]
76pub struct QuantizeArgs {
77    /// Input model file path
78    #[arg(short, long)]
79    pub input: PathBuf,
80
81    /// Output quantized model file path
82    #[arg(short, long)]
83    pub output: PathBuf,
84
85    /// Quantization method (dynamic, static, qat)
86    #[arg(short, long, default_value = "dynamic")]
87    pub method: String,
88
89    /// Target precision (int8, int4, fp16)
90    #[arg(long, default_value = "int8")]
91    pub precision: String,
92
93    /// Calibration dataset path for static quantization
94    #[arg(long)]
95    pub calibration_data: Option<PathBuf>,
96
97    /// Number of calibration samples
98    #[arg(long, default_value = "100")]
99    pub calibration_samples: usize,
100
101    /// Accuracy threshold for validation
102    #[arg(long, default_value = "0.95")]
103    pub accuracy_threshold: f64,
104}
105
106/// Arguments for model pruning
107#[derive(Debug, Args)]
108pub struct PruneArgs {
109    /// Input model file path
110    #[arg(short, long)]
111    pub input: PathBuf,
112
113    /// Output pruned model file path
114    #[arg(short, long)]
115    pub output: PathBuf,
116
117    /// Pruning method (magnitude, gradient, fisher)
118    #[arg(short, long, default_value = "magnitude")]
119    pub method: String,
120
121    /// Sparsity ratio (0.0-1.0)
122    #[arg(short, long, default_value = "0.5")]
123    pub sparsity: f64,
124
125    /// Structured pruning (channels, filters)
126    #[arg(long)]
127    pub structured: bool,
128
129    /// Fine-tuning epochs after pruning
130    #[arg(long, default_value = "10")]
131    pub finetune_epochs: usize,
132
133    /// Validation dataset path
134    #[arg(long)]
135    pub validation_data: Option<PathBuf>,
136}
137
138/// Arguments for model inspection
139#[derive(Debug, Args)]
140pub struct InspectArgs {
141    /// Input model file path
142    #[arg(short, long)]
143    pub input: PathBuf,
144
145    /// Show detailed information
146    #[arg(long)]
147    pub detailed: bool,
148
149    /// Show model statistics
150    #[arg(long)]
151    pub stats: bool,
152
153    /// Show memory analysis
154    #[arg(long)]
155    pub memory: bool,
156
157    /// Show computational complexity
158    #[arg(long)]
159    pub complexity: bool,
160
161    /// Export information to file
162    #[arg(long)]
163    pub export: Option<PathBuf>,
164}
165
166/// Arguments for model validation
167#[derive(Debug, Args)]
168pub struct ValidateArgs {
169    /// Input model file path
170    #[arg(short, long)]
171    pub input: PathBuf,
172
173    /// Validation dataset directory
174    #[arg(short, long)]
175    pub dataset: PathBuf,
176
177    /// Number of samples to validate
178    #[arg(short, long, default_value = "1000")]
179    pub samples: usize,
180
181    /// Target device for validation
182    #[arg(long, default_value = "cpu")]
183    pub device: String,
184
185    /// Batch size for validation
186    #[arg(long, default_value = "32")]
187    pub batch_size: usize,
188
189    /// Accuracy threshold
190    #[arg(long, default_value = "0.9")]
191    pub accuracy_threshold: f64,
192}
193
194/// Arguments for model benchmarking
195#[derive(Debug, Args)]
196pub struct BenchmarkArgs {
197    /// Input model file path
198    #[arg(short, long)]
199    pub input: PathBuf,
200
201    /// Target device for benchmarking
202    #[arg(long, default_value = "cpu")]
203    pub device: String,
204
205    /// Input shape for benchmarking
206    #[arg(long, value_delimiter = ',')]
207    pub input_shape: Vec<usize>,
208
209    /// Batch sizes to test
210    #[arg(long, value_delimiter = ',', default_values = ["1", "4", "8", "16"])]
211    pub batch_sizes: Vec<usize>,
212
213    /// Number of warmup iterations
214    #[arg(long, default_value = "10")]
215    pub warmup: usize,
216
217    /// Number of benchmark iterations
218    #[arg(long, default_value = "100")]
219    pub iterations: usize,
220
221    /// Profile memory usage
222    #[arg(long)]
223    pub profile_memory: bool,
224
225    /// Export results to file
226    #[arg(long)]
227    pub export: Option<PathBuf>,
228}
229
230/// Arguments for model compression
231#[derive(Debug, Args)]
232pub struct CompressArgs {
233    /// Input model file path
234    #[arg(short, long)]
235    pub input: PathBuf,
236
237    /// Output compressed model file path
238    #[arg(short, long)]
239    pub output: PathBuf,
240
241    /// Compression algorithm
242    #[arg(short, long, default_value = "gzip")]
243    pub algorithm: String,
244
245    /// Compression level (1-9)
246    #[arg(long, default_value = "6")]
247    pub level: u8,
248}
249
250/// Arguments for model component extraction
251#[derive(Debug, Args)]
252pub struct ExtractArgs {
253    /// Input model file path
254    #[arg(short, long)]
255    pub input: PathBuf,
256
257    /// Component to extract (weights, architecture, metadata)
258    #[arg(short = 'x', long)]
259    pub component: String,
260
261    /// Output file path
262    #[arg(short, long)]
263    pub output: PathBuf,
264}
265
266/// Arguments for model merging
267#[derive(Debug, Args)]
268pub struct MergeArgs {
269    /// Input model file paths
270    #[arg(short, long)]
271    pub inputs: Vec<PathBuf>,
272
273    /// Output merged model file path
274    #[arg(short, long)]
275    pub output: PathBuf,
276
277    /// Merge strategy (average, concatenate, ensemble)
278    #[arg(short, long, default_value = "average")]
279    pub strategy: String,
280
281    /// Weights for merging (if using weighted average)
282    #[arg(long, value_delimiter = ',')]
283    pub weights: Vec<f64>,
284}