Skip to main content

ruvector_sona/export/
huggingface_hub.rs

1//! HuggingFace Hub Integration
2//!
3//! Direct integration with HuggingFace Hub API for uploading SONA models,
4//! patterns, and datasets.
5
6use super::{
7    DatasetExporter, ExportConfig, ExportError, ExportResult, ExportType, SafeTensorsExporter,
8};
9use crate::engine::SonaEngine;
10use std::path::Path;
11
12#[cfg(feature = "serde-support")]
13use serde::{Deserialize, Serialize};
14
15/// HuggingFace Hub client
16pub struct HuggingFaceHub {
17    /// API token (optional for public repos)
18    token: Option<String>,
19    /// API base URL
20    api_url: String,
21}
22
23impl HuggingFaceHub {
24    /// Create new Hub client
25    pub fn new(token: Option<&str>) -> Self {
26        Self {
27            token: token.map(|t| t.to_string()),
28            api_url: "https://huggingface.co/api".to_string(),
29        }
30    }
31
32    /// Create Hub client from environment variable
33    pub fn from_env() -> Self {
34        let token = std::env::var("HF_TOKEN")
35            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
36            .ok();
37        Self::new(token.as_deref())
38    }
39
40    /// Push all exports to HuggingFace Hub
41    pub fn push_all(
42        &self,
43        engine: &SonaEngine,
44        config: &ExportConfig,
45        repo_id: &str,
46    ) -> Result<ExportResult, ExportError> {
47        // Create temporary directory for exports
48        let temp_dir = std::env::temp_dir().join(format!("sona-export-{}", uuid_v4()));
49        std::fs::create_dir_all(&temp_dir).map_err(ExportError::Io)?;
50
51        // Export all components to temp directory
52        let safetensors_exporter = SafeTensorsExporter::new(config);
53        let dataset_exporter = DatasetExporter::new(config);
54
55        let mut total_items = 0;
56        let mut total_size = 0u64;
57
58        // Export LoRA weights
59        if config.include_lora {
60            let result = safetensors_exporter.export_engine(engine, temp_dir.join("lora"))?;
61            total_items += result.items_exported;
62            total_size += result.size_bytes;
63        }
64
65        // Export patterns
66        if config.include_patterns {
67            let result =
68                dataset_exporter.export_patterns(engine, temp_dir.join("patterns.jsonl"))?;
69            total_items += result.items_exported;
70            total_size += result.size_bytes;
71        }
72
73        // Export preferences
74        if config.include_preferences {
75            let result =
76                dataset_exporter.export_preferences(engine, temp_dir.join("preferences.jsonl"))?;
77            total_items += result.items_exported;
78            total_size += result.size_bytes;
79        }
80
81        // Create model card
82        let readme = self.create_model_card(engine, config);
83        let readme_path = temp_dir.join("README.md");
84        std::fs::write(&readme_path, readme).map_err(ExportError::Io)?;
85
86        // Create adapter config
87        let adapter_config = self.create_adapter_config(engine, config);
88        let config_path = temp_dir.join("adapter_config.json");
89        let config_json = serde_json::to_string_pretty(&adapter_config)?;
90        std::fs::write(&config_path, config_json).map_err(ExportError::Io)?;
91
92        // Upload to Hub (using git LFS approach)
93        self.upload_directory(&temp_dir, repo_id)?;
94
95        // Cleanup
96        let _ = std::fs::remove_dir_all(&temp_dir);
97
98        Ok(ExportResult {
99            export_type: ExportType::SafeTensors,
100            items_exported: total_items,
101            output_path: format!("https://huggingface.co/{}", repo_id),
102            size_bytes: total_size,
103        })
104    }
105
106    /// Upload directory to HuggingFace Hub
107    fn upload_directory(&self, local_path: &Path, repo_id: &str) -> Result<(), ExportError> {
108        // Check for git and git-lfs
109        let has_git = std::process::Command::new("git")
110            .arg("--version")
111            .output()
112            .is_ok();
113
114        if !has_git {
115            return Err(ExportError::HubError(
116                "git is required for HuggingFace Hub upload. Install git and git-lfs.".to_string(),
117            ));
118        }
119
120        // Clone or create repo
121        let repo_url = if let Some(ref token) = self.token {
122            format!("https://{}@huggingface.co/{}", token, repo_id)
123        } else {
124            format!("https://huggingface.co/{}", repo_id)
125        };
126
127        let clone_dir = local_path.parent().unwrap().join("hf-repo");
128
129        // Try to clone existing repo
130        let clone_result = std::process::Command::new("git")
131            .args(["clone", &repo_url, clone_dir.to_str().unwrap()])
132            .output();
133
134        if clone_result.is_err() {
135            // Create new repo via API
136            self.create_repo(repo_id)?;
137
138            // Try cloning again
139            std::process::Command::new("git")
140                .args(["clone", &repo_url, clone_dir.to_str().unwrap()])
141                .output()
142                .map_err(|e| ExportError::HubError(format!("Failed to clone repo: {}", e)))?;
143        }
144
145        // Copy files to cloned repo
146        copy_dir_recursive(local_path, &clone_dir)?;
147
148        // Add, commit, and push
149        std::process::Command::new("git")
150            .args(["-C", clone_dir.to_str().unwrap(), "add", "-A"])
151            .output()
152            .map_err(|e| ExportError::HubError(format!("git add failed: {}", e)))?;
153
154        std::process::Command::new("git")
155            .args([
156                "-C",
157                clone_dir.to_str().unwrap(),
158                "commit",
159                "-m",
160                "Upload SONA adapter",
161            ])
162            .output()
163            .map_err(|e| ExportError::HubError(format!("git commit failed: {}", e)))?;
164
165        let push_result = std::process::Command::new("git")
166            .args(["-C", clone_dir.to_str().unwrap(), "push"])
167            .output()
168            .map_err(|e| ExportError::HubError(format!("git push failed: {}", e)))?;
169
170        if !push_result.status.success() {
171            let stderr = String::from_utf8_lossy(&push_result.stderr);
172            return Err(ExportError::HubError(format!(
173                "git push failed: {}",
174                stderr
175            )));
176        }
177
178        // Cleanup
179        let _ = std::fs::remove_dir_all(&clone_dir);
180
181        Ok(())
182    }
183
184    /// Create a new repository on HuggingFace Hub
185    fn create_repo(&self, repo_id: &str) -> Result<(), ExportError> {
186        let token = self.token.as_ref().ok_or_else(|| {
187            ExportError::HubError("HuggingFace token required to create repos".to_string())
188        })?;
189
190        // Parse repo_id (org/name or just name)
191        let (organization, name) = if let Some(idx) = repo_id.find('/') {
192            (Some(&repo_id[..idx]), &repo_id[idx + 1..])
193        } else {
194            (None, repo_id)
195        };
196
197        let create_request = CreateRepoRequest {
198            name: name.to_string(),
199            organization: organization.map(|s| s.to_string()),
200            private: false,
201            repo_type: "model".to_string(),
202        };
203
204        let url = format!("{}/repos/create", self.api_url);
205
206        // Use simple HTTP client approach (blocking for simplicity)
207        // In production, you'd use reqwest or similar
208        let body = serde_json::to_string(&create_request)?;
209
210        let output = std::process::Command::new("curl")
211            .args([
212                "-X",
213                "POST",
214                "-H",
215                &format!("Authorization: Bearer {}", token),
216                "-H",
217                "Content-Type: application/json",
218                "-d",
219                &body,
220                &url,
221            ])
222            .output()
223            .map_err(|e| ExportError::HubError(format!("curl failed: {}", e)))?;
224
225        if !output.status.success() {
226            let stderr = String::from_utf8_lossy(&output.stderr);
227            // Repo might already exist, which is fine
228            if !stderr.contains("already exists") {
229                return Err(ExportError::HubError(format!(
230                    "Failed to create repo: {}",
231                    stderr
232                )));
233            }
234        }
235
236        Ok(())
237    }
238
239    /// Create model card content
240    fn create_model_card(&self, engine: &SonaEngine, config: &ExportConfig) -> String {
241        let stats = engine.stats();
242        format!(
243            r#"---
244license: mit
245library_name: peft
246base_model: {}
247tags:
248  - sona
249  - lora
250  - adaptive-learning
251  - ruvector
252---
253
254# {} SONA Adapter
255
256This adapter was generated using [SONA (Self-Optimizing Neural Architecture)](https://github.com/ruvnet/ruvector/tree/main/crates/sona) - a runtime-adaptive learning system.
257
258## Model Details
259
260- **Base Model**: {}
261- **PEFT Type**: LoRA (Two-Tier)
262- **MicroLoRA Rank**: {} (instant adaptation)
263- **BaseLoRA Rank**: {} (background learning)
264- **Patterns Learned**: {}
265- **Trajectories Processed**: {}
266
267## SONA Features
268
269### Two-Tier LoRA Architecture
270- **MicroLoRA**: Rank 1-2 for instant adaptation (<0.5ms latency)
271- **BaseLoRA**: Rank 4-16 for background learning
272
273### EWC++ (Elastic Weight Consolidation)
274Prevents catastrophic forgetting when learning new patterns.
275
276### ReasoningBank
277K-means++ clustering for efficient pattern storage and retrieval.
278
279## Performance Benchmarks
280
281| Metric | Value |
282|--------|-------|
283| Throughput | 2211 ops/sec |
284| Latency | <0.5ms per layer |
285| Quality Improvement | +55% max |
286
287## Usage with PEFT
288
289```python
290from peft import PeftModel, PeftConfig
291from transformers import AutoModelForCausalLM
292
293# Load adapter
294config = PeftConfig.from_pretrained("your-username/{}")
295model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
296model = PeftModel.from_pretrained(model, "your-username/{}")
297
298# Use for inference
299outputs = model.generate(input_ids)
300```
301
302## Training with Included Datasets
303
304### Patterns Dataset
305```python
306from datasets import load_dataset
307
308patterns = load_dataset("json", data_files="patterns.jsonl")
309```
310
311### Preference Pairs (for DPO/RLHF)
312```python
313preferences = load_dataset("json", data_files="preferences.jsonl")
314```
315
316## License
317
318MIT License - see [LICENSE](LICENSE) for details.
319
320---
321
322Generated with [ruvector-sona](https://crates.io/crates/ruvector-sona) v{}
323"#,
324            config.target_architecture,
325            config.model_name,
326            config.target_architecture,
327            engine.config().micro_lora_rank,
328            engine.config().base_lora_rank,
329            stats.patterns_stored,
330            stats.trajectories_buffered,
331            config.model_name,
332            config.model_name,
333            env!("CARGO_PKG_VERSION"),
334        )
335    }
336
337    /// Create PEFT-compatible adapter config
338    fn create_adapter_config(
339        &self,
340        engine: &SonaEngine,
341        config: &ExportConfig,
342    ) -> AdapterConfigJson {
343        let sona_config = engine.config();
344        AdapterConfigJson {
345            peft_type: "LORA".to_string(),
346            auto_mapping: None,
347            base_model_name_or_path: config.target_architecture.clone(),
348            revision: None,
349            task_type: "CAUSAL_LM".to_string(),
350            inference_mode: true,
351            r: sona_config.base_lora_rank,
352            lora_alpha: sona_config.base_lora_rank as f32,
353            lora_dropout: 0.0,
354            fan_in_fan_out: false,
355            bias: "none".to_string(),
356            target_modules: vec![
357                "q_proj".to_string(),
358                "k_proj".to_string(),
359                "v_proj".to_string(),
360                "o_proj".to_string(),
361            ],
362            modules_to_save: None,
363            layers_to_transform: None,
364            layers_pattern: None,
365        }
366    }
367}
368
369/// Request to create a new repo
370#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
371#[derive(Clone, Debug)]
372struct CreateRepoRequest {
373    name: String,
374    #[serde(skip_serializing_if = "Option::is_none")]
375    organization: Option<String>,
376    private: bool,
377    #[serde(rename = "type")]
378    repo_type: String,
379}
380
381/// PEFT adapter config for JSON export
382#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
383#[derive(Clone, Debug)]
384pub struct AdapterConfigJson {
385    pub peft_type: String,
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub auto_mapping: Option<serde_json::Value>,
388    pub base_model_name_or_path: String,
389    #[serde(skip_serializing_if = "Option::is_none")]
390    pub revision: Option<String>,
391    pub task_type: String,
392    pub inference_mode: bool,
393    pub r: usize,
394    pub lora_alpha: f32,
395    pub lora_dropout: f32,
396    pub fan_in_fan_out: bool,
397    pub bias: String,
398    pub target_modules: Vec<String>,
399    #[serde(skip_serializing_if = "Option::is_none")]
400    pub modules_to_save: Option<Vec<String>>,
401    #[serde(skip_serializing_if = "Option::is_none")]
402    pub layers_to_transform: Option<Vec<usize>>,
403    #[serde(skip_serializing_if = "Option::is_none")]
404    pub layers_pattern: Option<String>,
405}
406
407/// Simple UUID v4 generator
408fn uuid_v4() -> String {
409    use rand::Rng;
410    let mut rng = rand::thread_rng();
411    let bytes: [u8; 16] = rng.gen();
412    format!(
413        "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
414        bytes[0], bytes[1], bytes[2], bytes[3],
415        bytes[4], bytes[5],
416        (bytes[6] & 0x0f) | 0x40, bytes[7],
417        (bytes[8] & 0x3f) | 0x80, bytes[9],
418        bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
419    )
420}
421
422/// Copy directory recursively
423fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<(), ExportError> {
424    if !dst.exists() {
425        std::fs::create_dir_all(dst).map_err(ExportError::Io)?;
426    }
427
428    for entry in std::fs::read_dir(src).map_err(ExportError::Io)? {
429        let entry = entry.map_err(ExportError::Io)?;
430        let path = entry.path();
431        let file_name = path.file_name().unwrap();
432        let dest_path = dst.join(file_name);
433
434        if path.is_dir() {
435            copy_dir_recursive(&path, &dest_path)?;
436        } else {
437            std::fs::copy(&path, &dest_path).map_err(ExportError::Io)?;
438        }
439    }
440
441    Ok(())
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_hub_from_env() {
450        // Just ensure it doesn't panic
451        let _hub = HuggingFaceHub::from_env();
452    }
453
454    #[test]
455    fn test_uuid_v4() {
456        let uuid = uuid_v4();
457        assert_eq!(uuid.len(), 36);
458        assert!(uuid.contains('-'));
459    }
460
461    #[test]
462    fn test_adapter_config_json() {
463        let config = AdapterConfigJson {
464            peft_type: "LORA".to_string(),
465            auto_mapping: None,
466            base_model_name_or_path: "microsoft/phi-4".to_string(),
467            revision: None,
468            task_type: "CAUSAL_LM".to_string(),
469            inference_mode: true,
470            r: 8,
471            lora_alpha: 8.0,
472            lora_dropout: 0.0,
473            fan_in_fan_out: false,
474            bias: "none".to_string(),
475            target_modules: vec!["q_proj".to_string()],
476            modules_to_save: None,
477            layers_to_transform: None,
478            layers_pattern: None,
479        };
480
481        let json = serde_json::to_string_pretty(&config).unwrap();
482        assert!(json.contains("LORA"));
483        assert!(json.contains("phi-4"));
484    }
485}