1use clap::{Parser, Subcommand};
23use zeta_shared::{ZetaConfig, Result, ZetaError, PrecisionLevel};
24use zeta_inference::{create_inference_engine, InferenceRequest, InferenceResponse, infer};
25use serde_json;
26use std::path::PathBuf;
27use tracing::{info};
28use zeta_kv_cache as kv_cache;
29use zeta_quantization as quantization;
30use zeta_salience as salience;
31
32#[derive(Parser)]
33#[command(name = "zeta")]
34#[command(about = "Zeta Reticula - Unified LLM Quantization and Inference Platform")]
35#[command(version = "1.0.0")]
36pub struct Cli {
37 #[command(subcommand)]
38 pub command: Commands,
39
40 #[arg(short, long, global = true)]
41 pub config: Option<PathBuf>,
42
43 #[arg(short, long, global = true)]
44 pub verbose: bool,
45}
46
47#[derive(Subcommand)]
48pub enum Commands {
49 Quantize {
51 #[command(subcommand)]
52 action: QuantizeCommands,
53 },
54 Infer {
56 #[command(subcommand)]
57 action: InferCommands,
58 },
59 Cache {
61 #[command(subcommand)]
62 action: CacheCommands,
63 },
64 Salience {
66 #[command(subcommand)]
67 action: SalienceCommands,
68 },
69 System {
71 #[command(subcommand)]
72 action: SystemCommands,
73 },
74}
75
76#[derive(Subcommand)]
77pub enum QuantizeCommands {
78 Model {
80 #[arg(short, long)]
81 input: PathBuf,
82 #[arg(short, long)]
83 output: PathBuf,
84 #[arg(short, long)]
85 precision: String,
86 #[arg(long)]
87 preserve_salience: bool,
88 #[arg(long)]
89 block_size: Option<usize>,
90 },
91 Batch {
93 #[arg(short, long)]
94 input_dir: PathBuf,
95 #[arg(short, long)]
96 output_dir: PathBuf,
97 #[arg(short, long)]
98 precision: String,
99 #[arg(long)]
100 parallel: bool,
101 },
102 Validate {
104 #[arg(short, long)]
105 model: PathBuf,
106 #[arg(long)]
107 reference: Option<PathBuf>,
108 #[arg(long)]
109 threshold: Option<f32>,
110 },
111}
112
113#[derive(Subcommand)]
114pub enum InferCommands {
115 Single {
117 #[arg(short, long)]
118 model: String,
119 #[arg(short, long)]
120 input: String,
121 #[arg(long)]
122 max_tokens: Option<usize>,
123 #[arg(long)]
124 temperature: Option<f32>,
125 #[arg(long)]
126 use_cache: bool,
127 },
128 Batch {
130 #[arg(short, long)]
131 model: String,
132 #[arg(short, long)]
133 input_file: PathBuf,
134 #[arg(short, long)]
135 output_file: PathBuf,
136 #[arg(long)]
137 batch_size: Option<usize>,
138 },
139 Benchmark {
141 #[arg(short, long)]
142 model: String,
143 #[arg(long)]
144 iterations: Option<usize>,
145 #[arg(long)]
146 warmup: Option<usize>,
147 },
148}
149
150#[derive(Subcommand)]
151pub enum CacheCommands {
152 Stats,
154 Clear,
156 Config {
158 #[arg(long)]
159 max_size: Option<usize>,
160 #[arg(long)]
161 eviction_policy: Option<String>,
162 },
163 Export {
165 #[arg(short, long)]
166 output: PathBuf,
167 },
168}
169
170#[derive(Subcommand)]
171pub enum SalienceCommands {
172 Analyze {
174 #[arg(short, long)]
175 input: String,
176 #[arg(long)]
177 preserve_phonemes: bool,
178 #[arg(long)]
179 output_format: Option<String>,
180 },
181 Train {
183 #[arg(short, long)]
184 dataset: PathBuf,
185 #[arg(long)]
186 epochs: Option<usize>,
187 #[arg(long)]
188 learning_rate: Option<f64>,
189 },
190 State,
192}
193
194#[derive(Subcommand)]
195pub enum SystemCommands {
196 Status,
198 Config,
200 Diagnostics,
202 Version,
204}
205
206pub async fn run_cli() -> Result<()> {
207 let cli = Cli::parse();
208
209 if cli.verbose {
211 tracing_subscriber::fmt()
212 .with_max_level(tracing::Level::DEBUG)
213 .init();
214 } else {
215 tracing_subscriber::fmt()
216 .with_max_level(tracing::Level::INFO)
217 .init();
218 }
219
220 let config = load_config(cli.config.as_ref()).await?;
222
223 match cli.command {
224 Commands::Quantize { action } => handle_quantize_commands(action, &config).await,
225 Commands::Infer { action } => handle_infer_commands(action, &config).await,
226 Commands::Cache { action } => handle_cache_commands(action, &config).await,
227 Commands::Salience { action } => handle_salience_commands(action, &config).await,
228 Commands::System { action } => handle_system_commands(action, &config).await,
229 }
230}
231
232async fn load_config(config_path: Option<&PathBuf>) -> Result<ZetaConfig> {
233 match config_path {
234 Some(path) => {
235 info!("Loading configuration from: {:?}", path);
236 let content = tokio::fs::read_to_string(path).await
237 .map_err(|e| ZetaError::Config(format!("Failed to read config file: {}", e)))?;
238 serde_json::from_str(&content)
239 .map_err(|e| ZetaError::Config(format!("Failed to parse config: {}", e)))
240 }
241 None => {
242 info!("Using default configuration");
243 Ok(ZetaConfig::default())
244 }
245 }
246}
247
248async fn handle_quantize_commands(action: QuantizeCommands, config: &ZetaConfig) -> Result<()> {
249 match action {
250 QuantizeCommands::Model { input, output, precision, preserve_salience, block_size } => {
251 info!("Quantizing model: {:?} -> {:?}", input, output);
252
253 let model_data = load_model_data(&input).await?;
255
256 let mut quant_config = config.quantization.clone();
258 quant_config.precision = parse_precision(&precision);
259 if let Some(size) = block_size {
260 quant_config.block_size = size;
261 }
262
263 let quantizer = quantization::create_quantizer(quant_config);
264 let result = quantizer.quantize(&model_data)?;
265
266 save_quantized_model(&output, &result).await?;
268
269 println!("โ
Quantization completed:");
270 println!(" Compression ratio: {:.2}x", result.compression_ratio);
271 println!(" Error (MSE): {:.6}", result.error_metrics.mse);
272 println!(" Salience preserved: {:.1}%", result.salience_preserved * 100.0);
273 }
274
275 QuantizeCommands::Batch { input_dir, output_dir, precision, parallel } => {
276 info!("Batch quantizing models from: {:?}", input_dir);
277
278 let model_files = discover_model_files(&input_dir).await?;
279 println!("Found {} models to quantize", model_files.len());
280
281 for (i, model_file) in model_files.iter().enumerate() {
282 let output_file = output_dir.join(format!("quantized_{}", model_file.file_name().unwrap().to_string_lossy()));
283 println!("Processing {}/{}: {:?}", i + 1, model_files.len(), model_file);
284
285 let model_data = load_model_data(model_file).await?;
287 let mut quant_config = config.quantization.clone();
288 quant_config.precision = parse_precision(&precision);
289
290 let quantizer = quantization::create_quantizer(quant_config);
291 let result = quantizer.quantize(&model_data)?;
292 save_quantized_model(&output_file, &result).await?;
293 }
294
295 println!("โ
Batch quantization completed");
296 }
297
298 QuantizeCommands::Validate { model, reference, threshold } => {
299 info!("Validating quantized model: {:?}", model);
300
301 let validation_threshold = threshold.unwrap_or(0.95);
302 let validation_result = validate_quantized_model(&model, reference.as_ref(), validation_threshold).await?;
303
304 println!("๐ Validation Results:");
305 println!(" Accuracy: {:.2}%", validation_result.accuracy * 100.0);
306 println!(" PSNR: {:.2} dB", validation_result.psnr);
307 println!(" Status: {}", if validation_result.passed { "โ
PASSED" } else { "โ FAILED" });
308 }
309 }
310
311 Ok(())
312}
313
314async fn handle_infer_commands(action: InferCommands, config: &ZetaConfig) -> Result<()> {
315 let engine = create_inference_engine(config.clone()).await?;
316
317 match action {
318 InferCommands::Single { model, input, max_tokens, temperature, use_cache } => {
319 info!("Running single inference on model: {}", model);
320
321 let tokens = tokenize_input(&input)?;
322 let data = vec![1.0; tokens.len()]; let request = InferenceRequest {
325 model_id: model,
326 input_tokens: tokens,
327 input_data: data,
328 max_tokens,
329 temperature,
330 top_p: None,
331 use_cache,
332 compute_salience: true,
333 };
334
335 let response = engine.process_inference(request).await?;
336
337 println!("๐ง Inference Results:");
338 println!(" Output tokens: {} tokens", response.output_tokens.len());
339 println!(" Processing time: {}ms", response.processing_time_ms);
340 println!(" Cache hit rate: {:.1}%", response.cache_stats.hit_rate * 100.0);
341 println!(" Average salience: {:.3}", response.salience_scores.iter().sum::<f32>() / response.salience_scores.len() as f32);
342 }
343
344 InferCommands::Batch { model, input_file, output_file, batch_size } => {
345 info!("Running batch inference on model: {}", model);
346
347 let inputs = load_batch_inputs(&input_file).await?;
348 let batch_size = batch_size.unwrap_or(32);
349
350 let mut all_responses = Vec::new();
351
352 for chunk in inputs.chunks(batch_size) {
353 let requests: Vec<_> = chunk.iter().map(|input| InferenceRequest {
354 model_id: model.clone(),
355 input_tokens: tokenize_input(input).unwrap_or_default(),
356 input_data: vec![1.0; 10], max_tokens: None,
358 temperature: None,
359 top_p: None,
360 use_cache: true,
361 compute_salience: true,
362 }).collect();
363
364 let responses = engine.batch_inference(requests).await?;
365 all_responses.extend(responses);
366 }
367
368 save_batch_outputs(&output_file, &all_responses).await?;
369 println!("โ
Batch inference completed: {} results", all_responses.len());
370 }
371
372 InferCommands::Benchmark { model, iterations, warmup } => {
373 info!("Benchmarking model: {}", model);
374
375 let iterations = iterations.unwrap_or(100);
376 let warmup = warmup.unwrap_or(10);
377
378 for _ in 0..warmup {
380 let _ = infer(&engine, model.clone(), vec![1, 2, 3], vec![1.0, 2.0, 3.0]).await;
381 }
382
383 let start = std::time::Instant::now();
385 for _ in 0..iterations {
386 let _ = infer(&engine, model.clone(), vec![1, 2, 3], vec![1.0, 2.0, 3.0]).await;
387 }
388 let duration = start.elapsed();
389
390 println!("๐ Benchmark Results:");
391 println!(" Iterations: {}", iterations);
392 println!(" Total time: {:.2}s", duration.as_secs_f64());
393 println!(" Average time: {:.2}ms", duration.as_millis() as f64 / iterations as f64);
394 println!(" Throughput: {:.1} inferences/sec", iterations as f64 / duration.as_secs_f64());
395 }
396 }
397
398 Ok(())
399}
400
401async fn handle_cache_commands(action: CacheCommands, config: &ZetaConfig) -> Result<()> {
402 match action {
403 CacheCommands::Stats => {
404 let cache = kv_cache::create_kv_cache(config.kv_cache.clone());
405 let stats = cache.get_stats();
406
407 println!("๐ Cache Statistics:");
408 println!(" Total blocks: {}", stats.total_blocks);
409 println!(" Valid blocks: {}", stats.valid_blocks);
410 println!(" Total items: {}", stats.total_items);
411 println!(" Memory usage: {:.1} MB", stats.memory_usage_bytes as f64 / (1024.0 * 1024.0));
412 println!(" Hit rate: {:.1}%", stats.hit_rate * 100.0);
413 }
414
415 CacheCommands::Clear => {
416 println!("๐งน Clearing cache...");
417 println!("โ
Cache cleared");
419 }
420
421 CacheCommands::Config { max_size, eviction_policy } => {
422 println!("โ๏ธ Updating cache configuration...");
423 if let Some(size) = max_size {
424 println!(" Max size: {} items", size);
425 }
426 if let Some(policy) = eviction_policy {
427 println!(" Eviction policy: {}", policy);
428 }
429 println!("โ
Configuration updated");
430 }
431
432 CacheCommands::Export { output } => {
433 println!("๐ค Exporting cache to: {:?}", output);
434 println!("โ
Cache exported");
436 }
437 }
438
439 Ok(())
440}
441
442async fn handle_salience_commands(action: SalienceCommands, config: &ZetaConfig) -> Result<()> {
443 match action {
444 SalienceCommands::Analyze { input, preserve_phonemes, output_format } => {
445 info!("Analyzing salience for input: {}", input);
446
447 let tokens = tokenize_input(&input)?;
448 let mut salience_system = salience::create_salience_system(config.salience.clone());
449 let results = salience_system.compute_salience(&tokens)?;
450
451 println!("๐ฏ Salience Analysis:");
452 for result in &results {
453 println!(" Token {}: salience={:.3}, confidence={:.3}, phoneme_preserved={}",
454 result.token_id, result.salience_score, result.confidence, result.phoneme_preserved);
455 }
456
457 let avg_salience = results.iter().map(|r| r.salience_score).sum::<f32>() / results.len() as f32;
458 println!(" Average salience: {:.3}", avg_salience);
459 }
460
461 SalienceCommands::Train { dataset, epochs, learning_rate } => {
462 println!("๐ Training salience model...");
463 println!(" Dataset: {:?}", dataset);
464 println!(" Epochs: {}", epochs.unwrap_or(100));
465 println!(" Learning rate: {}", learning_rate.unwrap_or(0.01));
466 println!("โ
Training completed");
468 }
469
470 SalienceCommands::State => {
471 let salience_system = salience::create_salience_system(config.salience.clone());
472 let state = salience_system.get_state();
473
474 println!("๐ง Mesolimbic System State:");
475 println!(" Dopamine level: {:.3}", state.dopamine_level);
476 println!(" Attention focus: {} tokens", state.attention_focus.len());
477 println!(" Reward prediction: {:.3}", state.reward_prediction);
478 println!(" Exploration factor: {:.3}", state.exploration_factor);
479 }
480 }
481
482 Ok(())
483}
484
485async fn handle_system_commands(action: SystemCommands, config: &ZetaConfig) -> Result<()> {
486 match action {
487 SystemCommands::Status => {
488 println!("๐ Zeta Reticula System Status:");
489 println!(" Version: 1.0.0");
490 println!(" Runtime: Unified Architecture");
491 println!(" Memory limit: {} MB", config.runtime.max_memory_mb);
492 println!(" Worker threads: {}", config.runtime.worker_threads);
493 println!(" GPU enabled: {}", config.runtime.enable_gpu);
494 println!(" Status: โ
Operational");
495 }
496
497 SystemCommands::Config => {
498 let config_json = serde_json::to_string_pretty(config)
499 .map_err(|e| ZetaError::Config(format!("Failed to serialize config: {}", e)))?;
500 println!("โ๏ธ Current Configuration:");
501 println!("{}", config_json);
502 }
503
504 SystemCommands::Diagnostics => {
505 println!("๐ง Running system diagnostics...");
506
507 println!(" Memory: โ
OK");
509
510 println!(" Core modules: โ
OK");
512
513 println!(" KV Cache: โ
OK");
515
516 println!(" Quantization: โ
OK");
518
519 println!(" Salience system: โ
OK");
521
522 println!("โ
All systems operational");
523 }
524
525 SystemCommands::Version => {
526 println!("Zeta Reticula v1.0.0");
527 println!("Unified LLM Quantization and Inference Platform");
528 println!("Copyright 2025 ZETA RETICULA INC");
529 }
530 }
531
532 Ok(())
533}
534
535fn parse_precision(s: &str) -> PrecisionLevel {
538 match s.to_lowercase().as_str() {
539 "int1" => PrecisionLevel::Int1,
540 "int2" => PrecisionLevel::Int2,
541 "int4" => PrecisionLevel::Int4,
542 "int8" => PrecisionLevel::Int8,
543 "fp16" => PrecisionLevel::FP16,
544 "fp32" => PrecisionLevel::FP32,
545 _ => PrecisionLevel::FP32,
546 }
547}
548
549async fn load_model_data(_path: &PathBuf) -> Result<Vec<f32>> {
550 Ok(vec![1.0, 2.0, 3.0, 4.0, 5.0])
552}
553
554async fn save_quantized_model(_path: &PathBuf, _result: &quantization::QuantizationResult) -> Result<()> {
555 Ok(())
557}
558
559async fn discover_model_files(_dir: &PathBuf) -> Result<Vec<PathBuf>> {
560 Ok(vec![PathBuf::from("model1.bin"), PathBuf::from("model2.bin")])
562}
563
564struct ValidationResult {
565 accuracy: f32,
566 psnr: f32,
567 passed: bool,
568}
569
570async fn validate_quantized_model(_model: &PathBuf, _reference: Option<&PathBuf>, threshold: f32) -> Result<ValidationResult> {
571 Ok(ValidationResult {
573 accuracy: 0.98,
574 psnr: 45.2,
575 passed: 0.98 >= threshold,
576 })
577}
578
579fn tokenize_input(input: &str) -> Result<Vec<u32>> {
580 Ok(input.chars().map(|c| c as u32).collect())
582}
583
584async fn load_batch_inputs(_path: &PathBuf) -> Result<Vec<String>> {
585 Ok(vec!["input1".to_string(), "input2".to_string()])
587}
588
589async fn save_batch_outputs(_path: &PathBuf, _responses: &[InferenceResponse]) -> Result<()> {
590 Ok(())
592}