1use std::path::Path;
10
11#[cfg(feature = "serde-support")]
12use serde::{Deserialize, Serialize};
13
14use crate::engine::SonaEngine;
15use super::{ExportConfig, ExportResult, ExportError, HuggingFaceExporter};
16
17#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
19#[derive(Clone, Debug)]
20pub struct PretrainConfig {
21 pub base_model: String,
23
24 pub lora: LoraPretrainConfig,
26
27 pub training: TrainingConfig,
29
30 pub dataset: DatasetConfig,
32
33 pub hardware: HardwareConfig,
35
36 pub sona: SonaOptimizations,
38}
39
40#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
42#[derive(Clone, Debug)]
43pub struct LoraPretrainConfig {
44 pub rank: usize,
46 pub alpha: f32,
48 pub dropout: f32,
50 pub target_modules: Vec<String>,
52 pub use_rslora: bool,
54}
55
56#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
58#[derive(Clone, Debug)]
59pub struct TrainingConfig {
60 pub learning_rate: f64,
62 pub batch_size: usize,
64 pub gradient_accumulation_steps: usize,
66 pub num_epochs: usize,
68 pub warmup_ratio: f32,
70 pub weight_decay: f32,
72 pub max_grad_norm: f32,
74 pub lr_scheduler_type: String,
76 pub save_steps: usize,
78 pub eval_steps: usize,
80 pub logging_steps: usize,
82}
83
84#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
86#[derive(Clone, Debug)]
87pub struct DatasetConfig {
88 pub patterns_path: Option<String>,
90 pub preferences_path: Option<String>,
92 pub distillation_path: Option<String>,
94 pub max_seq_length: usize,
96 pub validation_split: f32,
98}
99
100#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
102#[derive(Clone, Debug)]
103pub struct HardwareConfig {
104 pub mixed_precision: String,
106 pub num_gpus: usize,
108 pub gradient_checkpointing: bool,
110 pub deepspeed: Option<String>,
112 pub fsdp: bool,
114}
115
116#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
118#[derive(Clone, Debug)]
119pub struct SonaOptimizations {
120 pub two_tier_lora: bool,
122 pub micro_lora_rank: usize,
124 pub ewc_enabled: bool,
126 pub ewc_lambda: f32,
128 pub pattern_clusters: usize,
130 pub enable_simd: bool,
132}
133
134impl Default for PretrainConfig {
135 fn default() -> Self {
136 Self {
137 base_model: "microsoft/phi-4".to_string(),
138 lora: LoraPretrainConfig::default(),
139 training: TrainingConfig::default(),
140 dataset: DatasetConfig::default(),
141 hardware: HardwareConfig::default(),
142 sona: SonaOptimizations::default(),
143 }
144 }
145}
146
147impl Default for LoraPretrainConfig {
148 fn default() -> Self {
149 Self {
150 rank: 2,
152 alpha: 2.0,
153 dropout: 0.0,
154 target_modules: vec![
155 "q_proj".to_string(),
156 "k_proj".to_string(),
157 "v_proj".to_string(),
158 "o_proj".to_string(),
159 ],
160 use_rslora: false,
161 }
162 }
163}
164
165impl Default for TrainingConfig {
166 fn default() -> Self {
167 Self {
168 learning_rate: 0.002,
170 batch_size: 32,
172 gradient_accumulation_steps: 4,
173 num_epochs: 3,
174 warmup_ratio: 0.1,
175 weight_decay: 0.01,
176 max_grad_norm: 1.0,
177 lr_scheduler_type: "cosine".to_string(),
178 save_steps: 500,
179 eval_steps: 100,
180 logging_steps: 10,
181 }
182 }
183}
184
185impl Default for DatasetConfig {
186 fn default() -> Self {
187 Self {
188 patterns_path: None,
189 preferences_path: None,
190 distillation_path: None,
191 max_seq_length: 2048,
192 validation_split: 0.1,
193 }
194 }
195}
196
197impl Default for HardwareConfig {
198 fn default() -> Self {
199 Self {
200 mixed_precision: "bf16".to_string(),
201 num_gpus: 1,
202 gradient_checkpointing: true,
203 deepspeed: None,
204 fsdp: false,
205 }
206 }
207}
208
209impl Default for SonaOptimizations {
210 fn default() -> Self {
211 Self {
212 two_tier_lora: true,
213 micro_lora_rank: 1,
214 ewc_enabled: true,
215 ewc_lambda: 1000.0,
217 pattern_clusters: 100,
219 enable_simd: true,
220 }
221 }
222}
223
224pub struct PretrainPipeline<'a> {
226 engine: &'a SonaEngine,
228 config: PretrainConfig,
230}
231
232impl<'a> PretrainPipeline<'a> {
233 pub fn new(engine: &'a SonaEngine) -> Self {
235 Self {
236 engine,
237 config: PretrainConfig::default(),
238 }
239 }
240
241 pub fn with_config(engine: &'a SonaEngine, config: PretrainConfig) -> Self {
243 Self { engine, config }
244 }
245
246 pub fn from_engine_stats(engine: &'a SonaEngine) -> Self {
248 let sona_config = engine.config();
249
250 let config = PretrainConfig {
251 lora: LoraPretrainConfig {
252 rank: sona_config.base_lora_rank,
253 alpha: sona_config.base_lora_rank as f32,
254 ..Default::default()
255 },
256 sona: SonaOptimizations {
257 micro_lora_rank: sona_config.micro_lora_rank,
258 ewc_lambda: sona_config.ewc_lambda,
259 pattern_clusters: sona_config.pattern_clusters,
260 ..Default::default()
261 },
262 ..Default::default()
263 };
264
265 Self { engine, config }
266 }
267
268 pub fn export_package<P: AsRef<Path>>(&self, output_dir: P) -> Result<PretrainPackage, ExportError> {
270 let output_dir = output_dir.as_ref();
271 std::fs::create_dir_all(output_dir).map_err(ExportError::Io)?;
272
273 let export_config = ExportConfig {
275 model_name: self.config.base_model.replace('/', "-"),
276 target_architecture: self.config.base_model.clone(),
277 include_patterns: true,
278 include_lora: true,
279 include_preferences: true,
280 min_quality_threshold: 0.5,
281 ..Default::default()
282 };
283
284 let exporter = HuggingFaceExporter::with_config(self.engine, export_config);
285 let export_results = exporter.export_all(output_dir)?;
286
287 let script_path = output_dir.join("train.py");
289 let script = self.generate_training_script();
290 std::fs::write(&script_path, script).map_err(ExportError::Io)?;
291
292 let config_path = output_dir.join("pretrain_config.json");
294 let config_json = serde_json::to_string_pretty(&self.config)?;
295 std::fs::write(&config_path, config_json).map_err(ExportError::Io)?;
296
297 let requirements_path = output_dir.join("requirements.txt");
299 let requirements = self.generate_requirements();
300 std::fs::write(&requirements_path, requirements).map_err(ExportError::Io)?;
301
302 let accelerate_path = output_dir.join("accelerate_config.yaml");
304 let accelerate_config = self.generate_accelerate_config();
305 std::fs::write(&accelerate_path, accelerate_config).map_err(ExportError::Io)?;
306
307 Ok(PretrainPackage {
308 output_dir: output_dir.to_string_lossy().to_string(),
309 export_results,
310 script_path: script_path.to_string_lossy().to_string(),
311 config_path: config_path.to_string_lossy().to_string(),
312 })
313 }
314
315 fn generate_training_script(&self) -> String {
317 format!(r#"#!/usr/bin/env python3
318"""
319SONA-Optimized Pretraining Script
320
321Based on SONA benchmark results:
322- Throughput: 2211 ops/sec
323- Latency: <0.5ms per layer
324- Quality improvement: +55%
325
326Configuration optimized for:
327- LoRA Rank: {}
328- Learning Rate: {}
329- Batch Size: {}
330- EWC Lambda: {}
331- Pattern Clusters: {}
332"""
333
334import os
335import json
336import torch
337from datasets import load_dataset
338from transformers import (
339 AutoModelForCausalLM,
340 AutoTokenizer,
341 TrainingArguments,
342 Trainer,
343 DataCollatorForLanguageModeling,
344)
345from peft import (
346 LoraConfig,
347 get_peft_model,
348 prepare_model_for_kbit_training,
349 TaskType,
350)
351
352# Load SONA config
353with open("pretrain_config.json", "r") as f:
354 CONFIG = json.load(f)
355
356def main():
357 # Load base model
358 print(f"Loading base model: {{CONFIG['base_model']}}")
359 model = AutoModelForCausalLM.from_pretrained(
360 CONFIG["base_model"],
361 torch_dtype=torch.bfloat16 if CONFIG["hardware"]["mixed_precision"] == "bf16" else torch.float16,
362 device_map="auto",
363 trust_remote_code=True,
364 )
365
366 tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_model"])
367 if tokenizer.pad_token is None:
368 tokenizer.pad_token = tokenizer.eos_token
369
370 # Configure LoRA with SONA-optimal settings
371 lora_config = LoraConfig(
372 r=CONFIG["lora"]["rank"],
373 lora_alpha=CONFIG["lora"]["alpha"],
374 lora_dropout=CONFIG["lora"]["dropout"],
375 target_modules=CONFIG["lora"]["target_modules"],
376 task_type=TaskType.CAUSAL_LM,
377 bias="none",
378 )
379
380 # Prepare model
381 if CONFIG["hardware"]["gradient_checkpointing"]:
382 model.gradient_checkpointing_enable()
383
384 model = get_peft_model(model, lora_config)
385 model.print_trainable_parameters()
386
387 # Load SONA datasets
388 datasets = {{}}
389
390 if CONFIG["dataset"]["patterns_path"] and os.path.exists(CONFIG["dataset"]["patterns_path"]):
391 print("Loading patterns dataset...")
392 datasets["patterns"] = load_dataset("json", data_files=CONFIG["dataset"]["patterns_path"])
393
394 if CONFIG["dataset"]["preferences_path"] and os.path.exists(CONFIG["dataset"]["preferences_path"]):
395 print("Loading preferences dataset...")
396 datasets["preferences"] = load_dataset("json", data_files=CONFIG["dataset"]["preferences_path"])
397
398 # Use patterns dataset for pretraining if available
399 if "patterns" in datasets:
400 train_dataset = datasets["patterns"]["train"]
401 else:
402 # Fall back to sample data
403 print("Warning: No patterns dataset found, using sample data")
404 train_dataset = None
405
406 # Training arguments with SONA-optimal settings
407 training_args = TrainingArguments(
408 output_dir="./sona-output",
409 num_train_epochs=CONFIG["training"]["num_epochs"],
410 per_device_train_batch_size=CONFIG["training"]["batch_size"],
411 gradient_accumulation_steps=CONFIG["training"]["gradient_accumulation_steps"],
412 learning_rate=CONFIG["training"]["learning_rate"],
413 warmup_ratio=CONFIG["training"]["warmup_ratio"],
414 weight_decay=CONFIG["training"]["weight_decay"],
415 max_grad_norm=CONFIG["training"]["max_grad_norm"],
416 lr_scheduler_type=CONFIG["training"]["lr_scheduler_type"],
417 save_steps=CONFIG["training"]["save_steps"],
418 eval_steps=CONFIG["training"]["eval_steps"],
419 logging_steps=CONFIG["training"]["logging_steps"],
420 bf16=CONFIG["hardware"]["mixed_precision"] == "bf16",
421 fp16=CONFIG["hardware"]["mixed_precision"] == "fp16",
422 gradient_checkpointing=CONFIG["hardware"]["gradient_checkpointing"],
423 report_to="tensorboard",
424 save_total_limit=3,
425 push_to_hub=False,
426 )
427
428 # Data collator
429 data_collator = DataCollatorForLanguageModeling(
430 tokenizer=tokenizer,
431 mlm=False,
432 )
433
434 if train_dataset:
435 # Initialize trainer
436 trainer = Trainer(
437 model=model,
438 args=training_args,
439 train_dataset=train_dataset,
440 data_collator=data_collator,
441 )
442
443 # Train
444 print("Starting SONA-optimized training...")
445 trainer.train()
446
447 # Save
448 print("Saving model...")
449 trainer.save_model("./sona-output/final")
450 tokenizer.save_pretrained("./sona-output/final")
451 else:
452 print("No training data available. Please provide patterns.jsonl or preferences.jsonl")
453
454 print("Done!")
455
456if __name__ == "__main__":
457 main()
458"#,
459 self.config.lora.rank,
460 self.config.training.learning_rate,
461 self.config.training.batch_size,
462 self.config.sona.ewc_lambda,
463 self.config.sona.pattern_clusters,
464 )
465 }
466
467 fn generate_requirements(&self) -> String {
469 r#"# SONA Pretraining Requirements
470torch>=2.0.0
471transformers>=4.35.0
472datasets>=2.14.0
473peft>=0.6.0
474accelerate>=0.24.0
475bitsandbytes>=0.41.0
476safetensors>=0.4.0
477tensorboard>=2.14.0
478scipy>=1.11.0
479scikit-learn>=1.3.0
480tqdm>=4.66.0
481"#.to_string()
482 }
483
484 fn generate_accelerate_config(&self) -> String {
486 format!(r#"compute_environment: LOCAL_MACHINE
487debug: false
488distributed_type: {}
489downcast_bf16: 'no'
490gpu_ids: all
491machine_rank: 0
492main_training_function: main
493mixed_precision: {}
494num_machines: 1
495num_processes: {}
496rdzv_backend: static
497same_network: true
498tpu_env: []
499tpu_use_cluster: false
500tpu_use_sudo: false
501use_cpu: false
502"#,
503 if self.config.hardware.num_gpus > 1 { "MULTI_GPU" } else { "NO" },
504 self.config.hardware.mixed_precision,
505 self.config.hardware.num_gpus,
506 )
507 }
508
509 pub fn generate_dpo_script(&self) -> String {
511 format!(r#"#!/usr/bin/env python3
512"""
513SONA DPO (Direct Preference Optimization) Training Script
514
515Uses preference pairs exported from SONA ReasoningBank for RLHF-style training
516without requiring a reward model.
517"""
518
519import json
520import torch
521from datasets import load_dataset
522from transformers import AutoModelForCausalLM, AutoTokenizer
523from trl import DPOTrainer, DPOConfig
524from peft import LoraConfig, get_peft_model
525
526# Load config
527with open("pretrain_config.json", "r") as f:
528 CONFIG = json.load(f)
529
530def main():
531 # Load model
532 model = AutoModelForCausalLM.from_pretrained(
533 CONFIG["base_model"],
534 torch_dtype=torch.bfloat16,
535 device_map="auto",
536 )
537
538 tokenizer = AutoTokenizer.from_pretrained(CONFIG["base_model"])
539 if tokenizer.pad_token is None:
540 tokenizer.pad_token = tokenizer.eos_token
541
542 # Configure LoRA
543 lora_config = LoraConfig(
544 r=CONFIG["lora"]["rank"],
545 lora_alpha=CONFIG["lora"]["alpha"],
546 lora_dropout=CONFIG["lora"]["dropout"],
547 target_modules=CONFIG["lora"]["target_modules"],
548 bias="none",
549 )
550
551 model = get_peft_model(model, lora_config)
552
553 # Load preference dataset
554 if CONFIG["dataset"]["preferences_path"]:
555 dataset = load_dataset("json", data_files=CONFIG["dataset"]["preferences_path"])
556 else:
557 raise ValueError("Preferences dataset required for DPO training")
558
559 # DPO config
560 dpo_config = DPOConfig(
561 output_dir="./sona-dpo-output",
562 num_train_epochs=CONFIG["training"]["num_epochs"],
563 per_device_train_batch_size=CONFIG["training"]["batch_size"] // 2,
564 gradient_accumulation_steps=CONFIG["training"]["gradient_accumulation_steps"],
565 learning_rate=CONFIG["training"]["learning_rate"] / 10, # Lower LR for DPO
566 warmup_ratio=CONFIG["training"]["warmup_ratio"],
567 bf16=True,
568 logging_steps=CONFIG["training"]["logging_steps"],
569 save_steps=CONFIG["training"]["save_steps"],
570 beta=0.1, # DPO temperature
571 )
572
573 # Initialize DPO trainer
574 trainer = DPOTrainer(
575 model=model,
576 args=dpo_config,
577 train_dataset=dataset["train"],
578 tokenizer=tokenizer,
579 )
580
581 # Train
582 print("Starting SONA DPO training...")
583 trainer.train()
584
585 # Save
586 trainer.save_model("./sona-dpo-output/final")
587 print("Done!")
588
589if __name__ == "__main__":
590 main()
591"#)
592 }
593}
594
595#[derive(Clone, Debug)]
597pub struct PretrainPackage {
598 pub output_dir: String,
600 pub export_results: Vec<ExportResult>,
602 pub script_path: String,
604 pub config_path: String,
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_pretrain_config_default() {
614 let config = PretrainConfig::default();
615
616 assert_eq!(config.lora.rank, 2);
618 assert_eq!(config.training.learning_rate, 0.002);
619 assert_eq!(config.training.batch_size, 32);
620 assert_eq!(config.sona.ewc_lambda, 1000.0);
621 assert_eq!(config.sona.pattern_clusters, 100);
622 }
623
624 #[test]
625 fn test_config_serialization() {
626 let config = PretrainConfig::default();
627 let json = serde_json::to_string_pretty(&config).unwrap();
628
629 assert!(json.contains("\"rank\": 2"));
630 assert!(json.contains("\"learning_rate\": 0.002"));
631 assert!(json.contains("\"batch_size\": 32"));
632 }
633
634 #[test]
635 fn test_lora_config_default() {
636 let config = LoraPretrainConfig::default();
637
638 assert_eq!(config.rank, 2);
639 assert_eq!(config.alpha, 2.0);
640 assert_eq!(config.dropout, 0.0);
641 assert!(config.target_modules.contains(&"q_proj".to_string()));
642 }
643
644 #[test]
645 fn test_sona_optimizations_default() {
646 let config = SonaOptimizations::default();
647
648 assert!(config.two_tier_lora);
649 assert_eq!(config.micro_lora_rank, 1);
650 assert!(config.ewc_enabled);
651 assert_eq!(config.ewc_lambda, 1000.0);
652 assert_eq!(config.pattern_clusters, 100);
653 assert!(config.enable_simd);
654 }
655}