scirs2_text/huggingface_compat/
hub.rs1use crate::error::{Result, TextError};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10#[cfg(feature = "serde-support")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug)]
15pub struct HfHub {
16 cache_dir: PathBuf,
18 token: Option<String>,
20 model_cache: HashMap<String, HfModelInfo>,
22}
23
24impl HfHub {
25 pub fn new() -> Self {
27 let cache_dir = std::env::var("HF_HOME")
28 .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
29 .map(PathBuf::from)
30 .unwrap_or_else(|_| {
31 let mut home = std::env::var("HOME")
32 .map(PathBuf::from)
33 .unwrap_or_else(|_| PathBuf::from("."));
34 home.push(".cache");
35 home.push("huggingface");
36 home.push("hub");
37 home
38 });
39
40 Self {
41 cache_dir,
42 token: None,
43 model_cache: HashMap::new(),
44 }
45 }
46
47 pub fn with_token(mut self, token: String) -> Self {
49 self.token = Some(token);
50 self
51 }
52
53 pub fn with_cache_dir<P: AsRef<Path>>(mut self, cache_dir: P) -> Self {
55 self.cache_dir = cache_dir.as_ref().to_path_buf();
56 self
57 }
58
59 pub fn list_models(&self, filter: Option<&str>) -> Result<Vec<String>> {
61 let models = vec![
63 "bert-base-uncased",
64 "bert-large-uncased",
65 "distilbert-base-uncased",
66 "roberta-base",
67 "roberta-large",
68 "gpt2",
69 "gpt2-medium",
70 "gpt2-large",
71 "t5-small",
72 "t5-base",
73 "t5-large",
74 "facebook/bart-base",
75 "facebook/bart-large",
76 "microsoft/DialoGPT-medium",
77 "microsoft/DialoGPT-large",
78 ];
79
80 let filtered_models: Vec<String> = models
81 .into_iter()
82 .filter(|model| filter.is_none_or(|f| model.to_lowercase().contains(&f.to_lowercase())))
83 .map(|s| s.to_string())
84 .collect();
85
86 Ok(filtered_models)
87 }
88
89 pub fn model_info(&mut self, model_id: &str) -> Result<HfModelInfo> {
91 if let Some(info) = self.model_cache.get(model_id) {
92 return Ok(info.clone());
93 }
94
95 let info = HfModelInfo {
97 model_id: model_id.to_string(),
98 tags: vec!["pytorch".to_string(), "transformers".to_string()],
99 pipeline_tag: Some(self.infer_pipeline_tag(model_id)),
100 downloads: 1000000,
101 likes: 500,
102 library_name: Some("transformers".to_string()),
103 };
104
105 self.model_cache.insert(model_id.to_string(), info.clone());
106 Ok(info)
107 }
108
109 pub fn download_model<P: AsRef<Path>>(
111 &self,
112 model_id: &str,
113 cache_dir: Option<P>,
114 ) -> Result<PathBuf> {
115 let download_dir = cache_dir
116 .map(|p| p.as_ref().to_path_buf())
117 .unwrap_or_else(|| self.cache_dir.join(model_id));
118
119 std::fs::create_dir_all(&download_dir)
121 .map_err(|e| TextError::IoError(format!("Failed to create download directory: {e}")))?;
122
123 let files = [
126 "config.json",
127 "pytorch_model.bin",
128 "tokenizer.json",
129 "vocab.txt",
130 ];
131
132 for file in &files {
133 let file_path = download_dir.join(file);
134 if !file_path.exists() {
135 let content = if file == &"config.json" {
136 r#"{
138 "architectures": ["BertModel"],
139 "model_type": "bert",
140 "num_attention_heads": 12,
141 "hidden_size": 768,
142 "intermediate_size": 3072,
143 "num_hidden_layers": 12,
144 "vocab_size": 30522,
145 "max_position_embeddings": 512,
146 "extraconfig": {}
147}"#
148 .to_string()
149 } else {
150 format!("# Placeholder {file} for {model_id}")
151 };
152 std::fs::write(&file_path, content)
153 .map_err(|e| TextError::IoError(format!("Failed to create {file}: {e}")))?;
154 }
155 }
156
157 Ok(download_dir)
158 }
159
160 pub fn upload_model<P: AsRef<Path>>(
162 &self,
163 model_path: P,
164 repo_id: &str,
165 commit_message: Option<&str>,
166 ) -> Result<()> {
167 let model_path = model_path.as_ref();
168
169 if !model_path.exists() {
170 return Err(TextError::InvalidInput(
171 "Model path does not exist".to_string(),
172 ));
173 }
174
175 let required_files = ["config.json"];
177 for file in &required_files {
178 if !model_path.join(file).exists() {
179 return Err(TextError::InvalidInput(format!(
180 "Required file {file} not found"
181 )));
182 }
183 }
184
185 println!(
186 "Would upload model from {} to {} with message: {}",
187 model_path.display(),
188 repo_id,
189 commit_message.unwrap_or("Upload model")
190 );
191
192 Ok(())
193 }
194
195 pub fn create_repo(&self, repo_id: &str, private: bool) -> Result<()> {
197 if self.token.is_none() {
198 return Err(TextError::InvalidInput(
199 "Authentication token required".to_string(),
200 ));
201 }
202
203 println!("Would create repository {} (private: {})", repo_id, private);
204
205 Ok(())
206 }
207
208 pub fn get_cached_model_path(&self, model_id: &str) -> PathBuf {
210 self.cache_dir.join(model_id)
211 }
212
213 fn infer_pipeline_tag(&self, model_id: &str) -> String {
214 if model_id.contains("bert") || model_id.contains("roberta") {
215 "text-classification".to_string()
216 } else if model_id.contains("gpt") || model_id.contains("t5") {
217 "text-generation".to_string()
218 } else if model_id.contains("bart") {
219 "summarization".to_string()
220 } else {
221 "feature-extraction".to_string()
222 }
223 }
224}
225
226impl Default for HfHub {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[derive(Debug, Clone)]
234#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
235pub struct HfModelInfo {
236 pub model_id: String,
238 pub tags: Vec<String>,
240 pub pipeline_tag: Option<String>,
242 pub downloads: u64,
244 pub likes: u64,
246 pub library_name: Option<String>,
248}