1#![cfg_attr(not(feature = "tensorflow"), allow(unexpected_cfgs))]
7use serde::Deserialize;
86use std::fs::File;
87use std::io::Read;
88use std::path::{Path, PathBuf};
89use torsh_core::error::{Result, TorshError};
90
91pub mod access_control;
92pub mod analytics;
93pub mod bandwidth;
94pub mod cache;
95pub mod cli;
96pub mod community;
97pub mod debugging;
98pub mod download;
99pub mod enterprise;
100pub mod export;
101pub mod fine_tuning;
102pub mod huggingface;
103pub mod metadata;
104pub mod model_info;
105pub mod model_ops;
106pub mod models;
107pub mod onnx;
108pub mod profiling;
109pub mod quantization;
110pub mod registry;
111pub mod retry;
112pub mod security;
113#[cfg(feature = "tensorflow")]
114pub mod tensorflow;
115pub mod upload;
116pub mod utils;
117pub mod visualization;
118
119pub use access_control::{
121 AccessToken, PermissionChecker, RateLimit, TokenManager, TokenScope, TokenStats,
122};
123pub use analytics::{
124 ABTestingFramework, AnalyticsManager, AnalyticsReport, ExportFormat, ModelUsageStats,
125 PerformanceProfiler, RealTimeMetrics, RecommendationEngine, UserAnalytics,
126};
127pub use bandwidth::{
128 format_bytes, format_duration, AdaptiveBandwidthLimiter, BandwidthLimiter, BandwidthMonitor,
129 BandwidthStats,
130};
131pub use cache::{
132 compress_file, decompress_file, CacheCleanupResult, CacheManager, CacheStats,
133 CacheValidationResult, CompressionResult, FileCompressionStats,
134};
135pub use cli::{run_cli, Cli, CliApp, CliConfig, Commands};
136pub use community::{
137 Badge, Challenge, ChallengeId, ChallengeParticipant, ChallengeStatus, ChallengeSubmission,
138 ChallengeType, Comment, CommunityManager, Contribution, ContributionStatus, ContributionType,
139 Discussion, DiscussionCategory, DiscussionId, DiscussionStatus, EvaluationCriteria, MetricType,
140 ModelId, ModelRating, ModelRatingStats, RatingCategory, UserId, UserProfile,
141};
142pub use debugging::{
143 ActivationAnalyzer, ActivationPattern, Anomaly, AnomalyType, DebugAction, DebugCommand,
144 DebugConfig, DebugHook, DebugReport, DebugSession, GradientDebugger, GradientInfo, HookType,
145 InteractiveDebugState, ModelDebugger, Severity, TensorInspector, TensorSnapshot,
146 TensorStatistics, TriggerCondition,
147};
148pub use download::{
149 create_regional_cdn_config, create_regional_mirror_config, download_file_parallel,
150 download_file_streaming,
151 download_files_parallel, download_with_default_cdn, download_with_default_mirrors, validate_url, validate_urls,
153 CdnConfig, CdnEndpoint, CdnManager, CdnStatistics, FailoverStrategy,
154 MirrorAttempt, MirrorBenchmarkResult, MirrorCapacity,
155 MirrorConfig, MirrorDownloadResult, MirrorLocation, MirrorManager, MirrorSelectionStrategy,
156 MirrorServer, MirrorStatistics, MirrorWeights, ParallelDownloadConfig,
157};
158pub use enterprise::{
159 Action, AuditAction, AuditLogEntry, ComplianceLabel, ComplianceReport, DataClassification,
160 EnterpriseManager, OrganizationId, Permission, PermissionId, PermissionScope,
161 PrivateRepository, RepositoryAccessControl, RepositoryVisibility, ResourceType, Role, RoleId,
162 ServiceLevelAgreement, ServiceTier, SlaPerformanceReport, UserRoleAssignment,
163};
164pub use fine_tuning::{
165 CheckpointManager, EarlyStoppingConfig, FineTuner, FineTuningConfig, FineTuningFactory,
166 FineTuningStrategy, TrainingHistory, TrainingMetrics,
167};
168pub use huggingface::{
169 HfModelConfig, HfModelInfo, HfSearchParams, HfToTorshConverter, HuggingFaceHub,
170};
171pub use metadata::{
172 ExtendedMetadata, FileMetadata, MetadataManager, MetadataSearchCriteria, PerformanceMetrics,
173 QualityScores, UsageStatistics,
174};
175pub use model_info::{
176 ModelCard, ModelCardBuilder, ModelCardManager, ModelCardRenderer, ModelInfo, Version,
177 VersionHistory,
178};
179pub use model_ops::{
180 compare_models, create_model_ensemble, load_model_auto, ComparisonOptions, ConversionMetadata,
181 EnsembleConfig, ModelDiff, QuantizationStats, ShapeDifference, ValueDifference, VotingStrategy,
182};
183pub use models::{
184 audio, multimodal, nlp, nlp_pretrained, rl, vision, vision_pretrained, ActorCritic, BasicBlock,
185 BertEmbeddings, BertEncoder, EfficientNet, GPTDecoder, GPTEmbeddings, MultiHeadAttention,
186 PPOAgent, ResNet, TransformerBlock, VisionTransformer, DQN,
187};
188pub use onnx::{
189 InputShape, OnnxConfig, OnnxLoader, OnnxModel, OnnxModelMetadata, OnnxToTorshWrapper,
190 OutputShape,
191};
192pub use profiling::{
193 ExecutionContext, ExecutionMode, LayerProfile, MemoryAnalysis, MemorySnapshot, ModelProfiler,
194 OperationAnalysis, OperationTrace, OptimizationRecommendation, PerformanceBottleneck,
195 PerformanceCounters, PerformanceSummary, ProfilerConfig, ProfilingResult, ProfilingSession,
196 ResourceUtilizationSummary, TensorInfo,
197};
198pub use registry::{
199 HardwareFilter, ModelCategory, ModelRegistry, ModelStatus, RegistryAPI, RegistryEntry,
200 SearchQuery,
201};
202pub use retry::{
203 retry_with_backoff, retry_with_backoff_async, retry_with_policy, CircuitBreaker, CircuitState,
204 DefaultRetryPolicy, RetryConfig, RetryPolicy, RetryStats,
205};
206pub use security::{
207 calculate_file_hash, sandbox_model, scan_model_vulnerabilities, validate_model_source,
208 validate_signature_age, verify_file_integrity, KeyPair, ModelSandbox, ModelSignature,
209 ResourceUsage, RiskLevel, SandboxConfig, SandboxedModel, ScanMetadata, SecurityConfig,
210 SecurityManager, Severity as SecuritySeverity, SignatureAlgorithm, Vulnerability,
211 VulnerabilityScanResult, VulnerabilityScanner, VulnerabilityType,
212};
213#[cfg(feature = "tensorflow")]
214pub use tensorflow::{
215 TfConfig, TfLoader, TfModel, TfModelMetadata, TfModelType, TfTensorInfo, TfToTorshWrapper,
216};
217pub use upload::{
218 batch_publish_models, upload_model, upload_model_with_versioning, validate_version_change,
219 PublishResult, PublishStrategy, UploadConfig, VersionChangeInfo, VersionValidationRules,
220};
221pub use utils::{
222 cleanup_old_cache, compare_versions, estimate_parameters_from_size, extract_extension,
223 format_parameter_count, format_size, get_model_cache_dir, get_temp_dir, is_safe_path,
224 is_supported_model_format, parse_repo_string, sanitize_model_name, validate_semver,
225};
226pub use visualization::{
227 ChartData, ChartType, DashboardTemplate, PerformanceVisualization, TrainingVisualization,
228 UsageVisualization, VisualizationConfig, VisualizationEngine,
229};
230
231use torsh_nn::prelude::*;
233
234#[derive(Debug, Clone)]
236pub struct HubConfig {
237 pub cache_dir: PathBuf,
238 pub hub_url: String,
239 pub force_reload: bool,
240 pub verbose: bool,
241 pub skip_validation: bool,
242 pub auth_token: Option<String>,
243 pub timeout_seconds: u64,
244 pub max_retries: u32,
245 pub user_agent: String,
246}
247
248impl Default for HubConfig {
249 fn default() -> Self {
250 let cache_dir = dirs::cache_dir()
251 .unwrap_or_else(|| PathBuf::from("."))
252 .join("torsh")
253 .join("hub");
254
255 Self {
256 cache_dir,
257 hub_url: "https://github.com".to_string(),
258 force_reload: false,
259 verbose: true,
260 skip_validation: false,
261 auth_token: load_auth_token_from_env_or_file(),
262 timeout_seconds: 300,
263 max_retries: 3,
264 user_agent: format!(
265 "torsh-hub/{}",
266 option_env!("CARGO_PKG_VERSION").unwrap_or("0.1.0-alpha.2")
267 ),
268 }
269 }
270}
271
272fn load_auth_token_from_env_or_file() -> Option<String> {
274 if let Ok(token) = std::env::var("TORSH_HUB_TOKEN") {
276 return Some(token);
277 }
278
279 if let Some(config_dir) = dirs::config_dir() {
281 let token_file = config_dir.join("torsh").join("hub_token");
282 if token_file.exists() {
283 if let Ok(token) = std::fs::read_to_string(&token_file) {
284 return Some(token.trim().to_string());
285 }
286 }
287 }
288
289 None
290}
291
292pub fn set_auth_token(token: &str) -> Result<()> {
294 std::env::set_var("TORSH_HUB_TOKEN", token);
295
296 if let Some(config_dir) = dirs::config_dir() {
298 let torsh_config_dir = config_dir.join("torsh");
299 std::fs::create_dir_all(&torsh_config_dir)?;
300
301 let token_file = torsh_config_dir.join("hub_token");
302 std::fs::write(&token_file, token)?;
303
304 #[cfg(unix)]
306 {
307 use std::os::unix::fs::PermissionsExt;
308 let mut perms = std::fs::metadata(&token_file)?.permissions();
309 perms.set_mode(0o600); std::fs::set_permissions(&token_file, perms)?;
311 }
312 }
313
314 Ok(())
315}
316
317pub fn remove_auth_token() -> Result<()> {
319 std::env::remove_var("TORSH_HUB_TOKEN");
320
321 if let Some(config_dir) = dirs::config_dir() {
323 let token_file = config_dir.join("torsh").join("hub_token");
324 if token_file.exists() {
325 std::fs::remove_file(&token_file)?;
326 }
327 }
328
329 Ok(())
330}
331
332pub fn is_authenticated() -> bool {
334 load_auth_token_from_env_or_file().is_some()
335}
336
337pub fn auth_status() -> String {
339 if let Some(token) = load_auth_token_from_env_or_file() {
340 let visible_part = if token.len() > 8 {
342 format!("{}***", &token[..4])
343 } else {
344 "***".to_string()
345 };
346 format!("Authenticated with token: {}", visible_part)
347 } else {
348 "Not authenticated".to_string()
349 }
350}
351
352pub fn load_onnx_model<P: AsRef<Path>>(
368 path: P,
369 config: Option<crate::onnx::OnnxConfig>,
370) -> Result<Box<dyn torsh_nn::Module>> {
371 use crate::onnx::{OnnxModel, OnnxToTorshWrapper};
372
373 let onnx_model = OnnxModel::from_file(path, config)?;
374 let wrapper = OnnxToTorshWrapper::new(onnx_model);
375 Ok(Box::new(wrapper))
376}
377
378pub fn load_onnx_model_from_bytes(
395 model_bytes: &[u8],
396 config: Option<crate::onnx::OnnxConfig>,
397) -> Result<Box<dyn torsh_nn::Module>> {
398 use crate::onnx::{OnnxModel, OnnxToTorshWrapper};
399
400 let onnx_model = OnnxModel::from_bytes(model_bytes, config)?;
401 let wrapper = OnnxToTorshWrapper::new(onnx_model);
402 Ok(Box::new(wrapper))
403}
404
405pub async fn load_onnx_model_from_url(
422 url: &str,
423 config: Option<crate::onnx::OnnxConfig>,
424) -> Result<Box<dyn torsh_nn::Module>> {
425 use crate::onnx::{OnnxLoader, OnnxToTorshWrapper};
426
427 let onnx_model = OnnxLoader::from_url(url, config).await?;
428 let wrapper = OnnxToTorshWrapper::new(onnx_model);
429 Ok(Box::new(wrapper))
430}
431
432pub fn load_onnx_model_from_hub(
449 repo: &str,
450 model_name: &str,
451 config: Option<crate::onnx::OnnxConfig>,
452) -> Result<Box<dyn torsh_nn::Module>> {
453 use crate::onnx::{OnnxLoader, OnnxToTorshWrapper};
454
455 let onnx_model = OnnxLoader::from_hub(repo, model_name, config)?;
456 let wrapper = OnnxToTorshWrapper::new(onnx_model);
457 Ok(Box::new(wrapper))
458}
459
460pub fn validate_auth_token(token: &str) -> Result<bool> {
462 if token.is_empty() {
463 return Ok(false);
464 }
465
466 if token.len() < 8 {
468 return Err(TorshError::InvalidArgument(
469 "Authentication token is too short".to_string(),
470 ));
471 }
472
473 if !token
476 .chars()
477 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
478 {
479 return Err(TorshError::InvalidArgument(
480 "Authentication token contains invalid characters".to_string(),
481 ));
482 }
483
484 Ok(true)
485}
486
487pub fn load(
506 repo: &str,
507 model: &str,
508 pretrained: bool,
509 config: Option<HubConfig>,
510) -> Result<Box<dyn torsh_nn::Module>> {
511 let config = config.unwrap_or_default();
512
513 let (owner, repo_name, branch) = parse_repo_info(repo)?;
515
516 let repo_dir = download_repo(&owner, &repo_name, &branch, &config)?;
518
519 let model_fn = load_model_fn(&repo_dir, model)?;
521
522 let model = model_fn(pretrained)?;
524
525 Ok(model)
526}
527
528pub fn list(repo: &str, config: Option<HubConfig>) -> Result<Vec<String>> {
530 let config = config.unwrap_or_default();
531
532 let (owner, repo_name, branch) = parse_repo_info(repo)?;
534
535 let repo_dir = download_repo(&owner, &repo_name, &branch, &config)?;
537
538 let models = list_available_models(&repo_dir)?;
540
541 Ok(models)
542}
543
544pub fn help(repo: &str, model: &str, config: Option<HubConfig>) -> Result<String> {
546 let config = config.unwrap_or_default();
547
548 let (owner, repo_name, branch) = parse_repo_info(repo)?;
550
551 let repo_dir = download_repo(&owner, &repo_name, &branch, &config)?;
553
554 let doc = get_model_doc(&repo_dir, model)?;
556
557 Ok(doc)
558}
559
560pub fn set_dir(path: impl AsRef<Path>) -> Result<()> {
562 let path = path.as_ref();
563 std::fs::create_dir_all(path)?;
564
565 std::env::set_var("TORSH_HUB_DIR", path);
567
568 Ok(())
569}
570
571pub fn get_dir() -> PathBuf {
573 std::env::var("TORSH_HUB_DIR")
574 .map(PathBuf::from)
575 .unwrap_or_else(|_| HubConfig::default().cache_dir)
576}
577
578pub fn load_state_dict_from_url(
580 url: &str,
581 model_dir: Option<&Path>,
582 map_location: Option<torsh_core::DeviceType>,
583 progress: bool,
584) -> Result<StateDict> {
585 let default_dir = get_dir();
586 let model_dir = model_dir.unwrap_or(&default_dir);
587
588 let file_path = download_url_to_file(url, model_dir, progress)?;
590
591 let state_dict = load_state_dict(&file_path, map_location)?;
593
594 Ok(state_dict)
595}
596
597pub type StateDict = std::collections::HashMap<String, torsh_tensor::Tensor<f32>>;
599
600pub fn parse_repo_info(repo: &str) -> Result<(String, String, String)> {
602 if repo.starts_with("https://") || repo.starts_with("http://") {
603 parse_github_url(repo)
605 } else if repo.contains('/') {
606 let parts: Vec<&str> = repo.split('/').collect();
608 if parts.len() != 2 {
609 return Err(TorshError::InvalidArgument(
610 "Repository should be in format 'owner/repo'".to_string(),
611 ));
612 }
613 Ok((
614 parts[0].to_string(),
615 parts[1].to_string(),
616 "main".to_string(),
617 ))
618 } else {
619 Err(TorshError::InvalidArgument(
620 "Invalid repository format".to_string(),
621 ))
622 }
623}
624
625fn parse_github_url(url: &str) -> Result<(String, String, String)> {
627 let url = url.trim_end_matches('/');
629 let parts: Vec<&str> = url.split('/').collect();
630
631 if parts.len() < 5 || parts[2] != "github.com" {
632 return Err(TorshError::InvalidArgument(
633 "Invalid GitHub URL".to_string(),
634 ));
635 }
636
637 let owner = parts[3].to_string();
638 let repo = parts[4].to_string();
639 let branch = if parts.len() >= 7 && parts[5] == "tree" {
640 parts[6].to_string()
641 } else {
642 "main".to_string()
643 };
644
645 Ok((owner, repo, branch))
646}
647
648pub fn download_repo(owner: &str, repo: &str, branch: &str, config: &HubConfig) -> Result<PathBuf> {
650 let cache_manager = CacheManager::new(&config.cache_dir)?;
651 let repo_dir = cache_manager.get_repo_dir(owner, repo, branch);
652
653 if repo_dir.exists() && !config.force_reload {
654 if config.verbose {
655 println!("Using cached repository at: {:?}", repo_dir);
656 }
657 return Ok(repo_dir);
658 }
659
660 download::download_github_repo(owner, repo, branch, &repo_dir, config.verbose)?;
662
663 Ok(repo_dir)
664}
665
666type ModelFactoryFn = Box<dyn Fn(bool) -> Result<Box<dyn torsh_nn::Module>>>;
668
669fn load_model_fn(repo_dir: &Path, model: &str) -> Result<ModelFactoryFn> {
671 let models_toml = repo_dir.join("models.toml");
673 let hubconf_toml = repo_dir.join("hubconf.toml");
674
675 let config_path = if models_toml.exists() {
676 models_toml
677 } else if hubconf_toml.exists() {
678 hubconf_toml
679 } else {
680 return Err(TorshError::IoError(
681 "No model configuration file found (models.toml or hubconf.toml)".to_string(),
682 ));
683 };
684
685 let config_content = std::fs::read_to_string(config_path)?;
687 let config: ModelConfig = toml::from_str(&config_content)
688 .map_err(|e| TorshError::ConfigError(format!("Failed to parse model config: {}", e)))?;
689
690 let model_def = config
692 .models
693 .iter()
694 .find(|m| m.name == model)
695 .ok_or_else(|| {
696 TorshError::InvalidArgument(format!("Model '{}' not found in repository", model))
697 })?;
698
699 let model_def = model_def.clone();
701 let repo_dir = repo_dir.to_path_buf();
702
703 Ok(Box::new(move |pretrained: bool| {
705 create_model_from_config(&model_def, &repo_dir, pretrained)
706 }))
707}
708
709#[derive(Debug, Deserialize, Clone)]
711struct ModelConfig {
712 models: Vec<ModelDefinition>,
713}
714
715#[derive(Debug, Deserialize, Clone)]
717struct ModelDefinition {
718 name: String,
719 architecture: String,
720 description: Option<String>,
721 parameters: std::collections::HashMap<String, toml::Value>,
722 weights_url: Option<String>,
723 local_weights: Option<String>,
724}
725
726fn create_model_from_config(
728 model_def: &ModelDefinition,
729 repo_dir: &Path,
730 pretrained: bool,
731) -> Result<Box<dyn torsh_nn::Module>> {
732 match model_def.architecture.as_str() {
733 "linear" => create_linear_model(model_def, repo_dir, pretrained),
734 "conv2d" => create_conv2d_model(model_def, repo_dir, pretrained),
735 "mlp" => create_mlp_model(model_def, repo_dir, pretrained),
736 "resnet" => create_resnet_model(model_def, repo_dir, pretrained),
737 "custom" => create_custom_model(model_def, repo_dir, pretrained),
738 "onnx" => create_onnx_model(model_def, repo_dir, pretrained),
739 #[cfg(feature = "tensorflow")]
740 "tensorflow" | "tf" => create_tensorflow_model(model_def, repo_dir, pretrained),
741 #[cfg(not(feature = "tensorflow"))]
742 "tensorflow" | "tf" => Err(TorshError::Other(
743 "TensorFlow support is disabled. Enable the 'tensorflow' feature to use TensorFlow models".to_string(),
744 )),
745 _ => Err(TorshError::InvalidArgument(format!(
746 "Unsupported model architecture: {}",
747 model_def.architecture
748 ))),
749 }
750}
751
752fn create_linear_model(
754 model_def: &ModelDefinition,
755 repo_dir: &Path,
756 pretrained: bool,
757) -> Result<Box<dyn torsh_nn::Module>> {
758 let in_features = extract_param_i64(&model_def.parameters, "in_features")? as usize;
759 let out_features = extract_param_i64(&model_def.parameters, "out_features")? as usize;
760 let bias = extract_param_bool(&model_def.parameters, "bias").unwrap_or(true);
761
762 let mut model = Linear::new(in_features, out_features, bias);
763
764 if pretrained {
765 load_pretrained_weights(&mut model as &mut dyn torsh_nn::Module, model_def, repo_dir)?;
766 }
767
768 Ok(Box::new(model))
769}
770
771fn create_conv2d_model(
773 model_def: &ModelDefinition,
774 repo_dir: &Path,
775 pretrained: bool,
776) -> Result<Box<dyn torsh_nn::Module>> {
777 let in_channels = extract_param_i64(&model_def.parameters, "in_channels")? as usize;
778 let out_channels = extract_param_i64(&model_def.parameters, "out_channels")? as usize;
779 let kernel_size = extract_param_i64(&model_def.parameters, "kernel_size")? as usize;
780 let stride = extract_param_i64(&model_def.parameters, "stride").unwrap_or(1) as usize;
781 let padding = extract_param_i64(&model_def.parameters, "padding").unwrap_or(0) as usize;
782 let bias = extract_param_bool(&model_def.parameters, "bias").unwrap_or(true);
783
784 let mut model = Conv2d::new(
785 in_channels,
786 out_channels,
787 (kernel_size, kernel_size),
788 (stride, stride),
789 (padding, padding),
790 (1, 1), bias,
792 1, );
794
795 if pretrained {
796 load_pretrained_weights(&mut model as &mut dyn torsh_nn::Module, model_def, repo_dir)?;
797 }
798
799 Ok(Box::new(model))
800}
801
802fn create_mlp_model(
804 model_def: &ModelDefinition,
805 repo_dir: &Path,
806 pretrained: bool,
807) -> Result<Box<dyn torsh_nn::Module>> {
808 let layers = extract_param_array(&model_def.parameters, "layers").ok_or_else(|| {
809 TorshError::InvalidArgument("Missing 'layers' parameter for MLP".to_string())
810 })?;
811 let activation = extract_param_string(&model_def.parameters, "activation")
812 .unwrap_or_else(|| "relu".to_string());
813 let dropout = extract_param_f64(&model_def.parameters, "dropout").unwrap_or(0.0);
814
815 let mut sequential = Sequential::new();
816
817 for i in 0..layers.len() - 1 {
818 let in_features = layers[i];
819 let out_features = layers[i + 1];
820
821 sequential = sequential.add(Linear::new(in_features, out_features, true));
822
823 if i < layers.len() - 2 {
824 match activation.as_str() {
826 "relu" => sequential = sequential.add(ReLU::new()),
827 "tanh" => sequential = sequential.add(Tanh::new()),
828 "sigmoid" => sequential = sequential.add(Sigmoid::new()),
829 _ => {
830 return Err(TorshError::InvalidArgument(format!(
831 "Unsupported activation: {}",
832 activation
833 )))
834 }
835 }
836
837 if dropout > 0.0 {
839 sequential = sequential.add(Dropout::new(dropout as f32));
840 }
841 }
842 }
843
844 let mut model = sequential;
845
846 if pretrained {
847 load_pretrained_weights(&mut model as &mut dyn torsh_nn::Module, model_def, repo_dir)?;
848 }
849
850 Ok(Box::new(model))
851}
852
853fn create_resnet_model(
855 model_def: &ModelDefinition,
856 repo_dir: &Path,
857 pretrained: bool,
858) -> Result<Box<dyn torsh_nn::Module>> {
859 let num_classes =
860 extract_param_i64(&model_def.parameters, "num_classes").unwrap_or(1000) as usize;
861 let layers =
862 extract_param_array(&model_def.parameters, "layers").unwrap_or_else(|| vec![2, 2, 2, 2]);
863
864 let mut model = Sequential::new();
867
868 model = model.add(Conv2d::new(3, 64, (7, 7), (2, 2), (3, 3), (1, 1), false, 1));
870 model = model.add(BatchNorm2d::new(64).expect("Failed to create BatchNorm2d"));
871 model = model.add(ReLU::new());
872 model = model.add(MaxPool2d::new((3, 3), Some((2, 2)), (1, 1), (1, 1), false));
873
874 let mut in_channels = 64;
876 let mut out_channels = 64;
877
878 for (layer_idx, &num_blocks) in layers.iter().enumerate() {
879 if layer_idx > 0 {
880 out_channels *= 2;
881 }
882
883 for _block_idx in 0..num_blocks {
884 let stride = if layer_idx > 0 { 2 } else { 1 };
885
886 model = model.add(Conv2d::new(
888 in_channels,
889 out_channels,
890 (3, 3),
891 (stride, stride),
892 (1, 1),
893 (1, 1),
894 false,
895 1,
896 ));
897 model =
898 model.add(BatchNorm2d::new(out_channels).expect("Failed to create BatchNorm2d"));
899 model = model.add(ReLU::new());
900
901 in_channels = out_channels;
902 }
903 }
904
905 model = model.add(AdaptiveAvgPool2d::with_output_size(1));
907 model = model.add(Flatten::new());
908 model = model.add(Linear::new(out_channels, num_classes, true));
909
910 let mut model = model;
911
912 if pretrained {
913 load_pretrained_weights(&mut model as &mut dyn torsh_nn::Module, model_def, repo_dir)?;
914 }
915
916 Ok(Box::new(model))
917}
918
919fn create_custom_model(
921 _model_def: &ModelDefinition,
922 _repo_dir: &Path,
923 _pretrained: bool,
924) -> Result<Box<dyn torsh_nn::Module>> {
925 Err(TorshError::Other(
926 "Custom model loading requires compilation of Rust code. Use predefined architectures instead.".to_string(),
927 ))
928}
929
930fn create_onnx_model(
932 model_def: &ModelDefinition,
933 repo_dir: &Path,
934 _pretrained: bool,
935) -> Result<Box<dyn torsh_nn::Module>> {
936 use crate::onnx::{OnnxModel, OnnxToTorshWrapper};
937
938 let model_file =
940 if let Some(local_path) = extract_param_string(&model_def.parameters, "model_file") {
941 repo_dir.join(local_path)
942 } else {
943 repo_dir.join(format!("{}.onnx", model_def.name))
945 };
946
947 if !model_file.exists() {
948 return Err(TorshError::IoError(format!(
949 "ONNX model file not found: {:?}",
950 model_file
951 )));
952 }
953
954 let config = create_onnx_config_from_params(&model_def.parameters);
956
957 let onnx_model = OnnxModel::from_file(&model_file, Some(config))?;
959
960 let wrapper = OnnxToTorshWrapper::new(onnx_model);
962
963 Ok(Box::new(wrapper))
964}
965
966fn create_onnx_config_from_params(
968 params: &std::collections::HashMap<String, toml::Value>,
969) -> crate::onnx::OnnxConfig {
970 use crate::onnx::OnnxConfig;
971 use ort::session::builder::GraphOptimizationLevel;
972
973 let mut config = OnnxConfig::default();
974
975 if let Some(providers) = extract_param_array_strings(params, "execution_providers") {
977 println!("Execution providers configured: {:?}", providers);
980 }
981
982 if let Some(opt_level) = extract_param_string(params, "optimization_level") {
984 config.graph_optimization_level = match opt_level.as_str() {
985 "disable" => GraphOptimizationLevel::Disable,
986 "basic" => GraphOptimizationLevel::Level1,
987 "extended" => GraphOptimizationLevel::Level2,
988 "all" => GraphOptimizationLevel::Level3,
989 _ => GraphOptimizationLevel::Level3,
990 };
991 }
992
993 if let Ok(inter_threads) = extract_param_i64(params, "inter_op_threads") {
995 config.inter_op_num_threads = Some(inter_threads as usize);
996 }
997
998 if let Ok(intra_threads) = extract_param_i64(params, "intra_op_threads") {
999 config.intra_op_num_threads = Some(intra_threads as usize);
1000 }
1001
1002 if let Some(enable_profiling) = extract_param_bool(params, "enable_profiling") {
1004 config.enable_profiling = enable_profiling;
1005 }
1006
1007 if let Some(enable_mem_pattern) = extract_param_bool(params, "enable_mem_pattern") {
1008 config.enable_mem_pattern = enable_mem_pattern;
1009 }
1010
1011 if let Some(enable_cpu_mem_arena) = extract_param_bool(params, "enable_cpu_mem_arena") {
1012 config.enable_cpu_mem_arena = enable_cpu_mem_arena;
1013 }
1014
1015 config
1016}
1017
1018#[cfg(feature = "tensorflow")]
1020fn create_tensorflow_model(
1021 model_def: &ModelDefinition,
1022 repo_dir: &Path,
1023 _pretrained: bool,
1024) -> Result<Box<dyn torsh_nn::Module>> {
1025 use crate::tensorflow::{TfConfig, TfModel, TfToTorshWrapper};
1026
1027 let model_dir =
1029 if let Some(local_path) = extract_param_string(&model_def.parameters, "model_dir") {
1030 repo_dir.join(local_path)
1031 } else {
1032 repo_dir.join(&model_def.name)
1034 };
1035
1036 if !model_dir.exists() {
1037 return Err(TorshError::IoError(format!(
1038 "TensorFlow model directory not found: {:?}",
1039 model_dir
1040 )));
1041 }
1042
1043 let config = create_tf_config_from_params(&model_def.parameters);
1045
1046 let tags = if let Some(tags_array) = extract_param_array_strings(&model_def.parameters, "tags")
1048 {
1049 tags_array.into_iter().collect::<Vec<_>>()
1050 } else {
1051 vec!["serve".to_string()]
1052 };
1053
1054 let tf_model = TfModel::from_saved_model(
1056 &model_dir,
1057 &tags.iter().map(|s| s.as_str()).collect::<Vec<_>>(),
1058 Some(config),
1059 )?;
1060
1061 let wrapper = TfToTorshWrapper::new(tf_model);
1063
1064 Ok(Box::new(wrapper))
1065}
1066
1067#[cfg(feature = "tensorflow")]
1069fn create_tf_config_from_params(
1070 params: &std::collections::HashMap<String, toml::Value>,
1071) -> crate::tensorflow::TfConfig {
1072 use crate::tensorflow::TfConfig;
1073
1074 let mut config = TfConfig::default();
1075
1076 if let Some(use_gpu) = extract_param_bool(params, "use_gpu") {
1078 config.use_gpu = use_gpu;
1079 }
1080
1081 if let Some(allow_growth) = extract_param_bool(params, "allow_growth") {
1083 config.allow_growth = allow_growth;
1084 }
1085
1086 if let Some(memory_limit) = extract_param_i64(params, "memory_limit").ok() {
1088 config.memory_limit = Some(memory_limit as usize);
1089 }
1090
1091 if let Some(gpu_memory_fraction) = extract_param_f64(params, "gpu_memory_fraction") {
1093 config.gpu_memory_fraction = Some(gpu_memory_fraction);
1094 }
1095
1096 if let Some(inter_threads) = extract_param_i64(params, "inter_op_threads").ok() {
1098 config.inter_op_parallelism_threads = Some(inter_threads as i32);
1099 }
1100
1101 if let Some(intra_threads) = extract_param_i64(params, "intra_op_threads").ok() {
1102 config.intra_op_parallelism_threads = Some(intra_threads as i32);
1103 }
1104
1105 if let Some(device_placement) = extract_param_bool(params, "device_placement") {
1107 config.device_placement = device_placement;
1108 }
1109
1110 config
1111}
1112
1113fn load_pretrained_weights(
1115 _model: &mut dyn torsh_nn::Module,
1116 model_def: &ModelDefinition,
1117 repo_dir: &Path,
1118) -> Result<()> {
1119 let weights_path = if let Some(ref local_weights) = model_def.local_weights {
1120 repo_dir.join(local_weights)
1121 } else if let Some(ref weights_url) = model_def.weights_url {
1122 let cache_dir = repo_dir.join(".weights_cache");
1124 std::fs::create_dir_all(&cache_dir)?;
1125
1126 let weights_filename = weights_url
1127 .split('/')
1128 .next_back()
1129 .unwrap_or("weights.torsh");
1130 let weights_path = cache_dir.join(weights_filename);
1131
1132 if !weights_path.exists() {
1133 download::download_file(weights_url, &weights_path, true)?;
1134 }
1135 weights_path
1136 } else {
1137 return Err(TorshError::InvalidArgument(
1138 "No weights specified for pretrained model".to_string(),
1139 ));
1140 };
1141
1142 let state_dict = load_state_dict(&weights_path, None)?;
1144 _model.load_state_dict(&state_dict, true)?;
1145
1146 Ok(())
1147}
1148
1149fn extract_param_i64(
1151 params: &std::collections::HashMap<String, toml::Value>,
1152 key: &str,
1153) -> Result<i64> {
1154 params.get(key).and_then(|v| v.as_integer()).ok_or_else(|| {
1155 TorshError::InvalidArgument(format!("Missing or invalid parameter: {}", key))
1156 })
1157}
1158
1159fn extract_param_bool(
1161 params: &std::collections::HashMap<String, toml::Value>,
1162 key: &str,
1163) -> Option<bool> {
1164 params.get(key).and_then(|v| v.as_bool())
1165}
1166
1167fn extract_param_string(
1169 params: &std::collections::HashMap<String, toml::Value>,
1170 key: &str,
1171) -> Option<String> {
1172 params
1173 .get(key)
1174 .and_then(|v| v.as_str())
1175 .map(|s| s.to_string())
1176}
1177
1178fn extract_param_f64(
1180 params: &std::collections::HashMap<String, toml::Value>,
1181 key: &str,
1182) -> Option<f64> {
1183 params.get(key).and_then(|v| v.as_float())
1184}
1185
1186fn extract_param_array(
1188 params: &std::collections::HashMap<String, toml::Value>,
1189 key: &str,
1190) -> Option<Vec<usize>> {
1191 params.get(key).and_then(|v| {
1192 v.as_array().map(|arr| {
1193 arr.iter()
1194 .filter_map(|v| v.as_integer())
1195 .map(|i| i as usize)
1196 .collect()
1197 })
1198 })
1199}
1200
1201fn extract_param_array_strings(
1203 params: &std::collections::HashMap<String, toml::Value>,
1204 key: &str,
1205) -> Option<Vec<String>> {
1206 params.get(key).and_then(|v| {
1207 v.as_array().map(|arr| {
1208 arr.iter()
1209 .filter_map(|v| v.as_str())
1210 .map(|s| s.to_string())
1211 .collect()
1212 })
1213 })
1214}
1215
1216fn list_available_models(repo_dir: &Path) -> Result<Vec<String>> {
1218 let models_toml = repo_dir.join("models.toml");
1220 let hubconf_toml = repo_dir.join("hubconf.toml");
1221
1222 let config_path = if models_toml.exists() {
1223 models_toml
1224 } else if hubconf_toml.exists() {
1225 hubconf_toml
1226 } else {
1227 let legacy_models_file = repo_dir.join("models.toml");
1229 if legacy_models_file.exists() {
1230 let content = std::fs::read_to_string(legacy_models_file)?;
1231 let models: ModelList =
1232 toml::from_str(&content).map_err(|e| TorshError::ConfigError(e.to_string()))?;
1233 return Ok(models.models.into_iter().map(|m| m.name).collect());
1234 }
1235 return Ok(vec![]);
1236 };
1237
1238 let config_content = std::fs::read_to_string(config_path)?;
1240 let config: ModelConfig = toml::from_str(&config_content)
1241 .map_err(|e| TorshError::ConfigError(format!("Failed to parse model config: {}", e)))?;
1242
1243 Ok(config.models.into_iter().map(|m| m.name).collect())
1244}
1245
1246fn get_model_doc(repo_dir: &Path, model: &str) -> Result<String> {
1248 let doc_file = repo_dir.join("docs").join(format!("{}.md", model));
1250 if doc_file.exists() {
1251 return std::fs::read_to_string(doc_file).map_err(Into::into);
1252 }
1253
1254 let models_toml = repo_dir.join("models.toml");
1256 let hubconf_toml = repo_dir.join("hubconf.toml");
1257
1258 let config_path = if models_toml.exists() {
1259 models_toml
1260 } else if hubconf_toml.exists() {
1261 hubconf_toml
1262 } else {
1263 return Ok(format!("No documentation available for model '{}'", model));
1264 };
1265
1266 let config_content = std::fs::read_to_string(config_path)?;
1267 let config: ModelConfig = toml::from_str(&config_content)
1268 .map_err(|e| TorshError::ConfigError(format!("Failed to parse model config: {}", e)))?;
1269
1270 if let Some(model_def) = config.models.iter().find(|m| m.name == model) {
1272 let mut doc = format!("# {}\n\n", model_def.name);
1273
1274 if let Some(ref description) = model_def.description {
1275 doc.push_str(&format!("**Description:** {}\n\n", description));
1276 }
1277
1278 doc.push_str(&format!("**Architecture:** {}\n\n", model_def.architecture));
1279
1280 if !model_def.parameters.is_empty() {
1281 doc.push_str("**Parameters:**\n");
1282 for (key, value) in &model_def.parameters {
1283 doc.push_str(&format!("- {}: {:?}\n", key, value));
1284 }
1285 doc.push('\n');
1286 }
1287
1288 if model_def.weights_url.is_some() || model_def.local_weights.is_some() {
1289 doc.push_str("**Pretrained weights available:** Yes\n\n");
1290 }
1291
1292 Ok(doc)
1293 } else {
1294 Ok(format!("Model '{}' not found in repository", model))
1295 }
1296}
1297
1298fn download_url_to_file(url: &str, dst_dir: &Path, progress: bool) -> Result<PathBuf> {
1300 let filename = url
1301 .split('/')
1302 .next_back()
1303 .ok_or_else(|| TorshError::InvalidArgument("Invalid URL".to_string()))?;
1304
1305 let dst_path = dst_dir.join(filename);
1306
1307 if dst_path.exists() {
1308 if progress {
1309 println!("File already exists: {:?}", dst_path);
1310 }
1311 return Ok(dst_path);
1312 }
1313
1314 download::download_file(url, &dst_path, progress)?;
1315
1316 Ok(dst_path)
1317}
1318
1319fn load_state_dict(path: &Path, map_location: Option<torsh_core::DeviceType>) -> Result<StateDict> {
1321 use std::io::BufReader;
1322
1323 if !path.exists() {
1324 return Err(TorshError::IoError(format!(
1325 "State dict file not found: {:?}",
1326 path
1327 )));
1328 }
1329
1330 let file = File::open(path)?;
1331 let mut reader = BufReader::new(file);
1332
1333 let extension = path.extension().and_then(|s| s.to_str()).unwrap_or("");
1335
1336 match extension {
1337 "json" => {
1338 load_json_state_dict(&mut reader, map_location)
1340 }
1341 "torsh" => {
1342 load_torsh_state_dict(&mut reader, map_location)
1344 }
1345 "pt" | "pth" => {
1346 load_pytorch_compatible_state_dict(&mut reader, map_location)
1348 }
1349 _ => Err(TorshError::InvalidArgument(format!(
1350 "Unsupported state dict format: {}",
1351 extension
1352 ))),
1353 }
1354}
1355
1356fn load_torsh_state_dict(
1358 reader: &mut impl Read,
1359 map_location: Option<torsh_core::DeviceType>,
1360) -> Result<StateDict> {
1361 let mut magic = [0u8; 8];
1363 reader.read_exact(&mut magic)?;
1364
1365 if &magic != b"TORSH\x01\x00\x00" {
1366 return Err(TorshError::SerializationError(
1367 "Invalid torsh file format".to_string(),
1368 ));
1369 }
1370
1371 let mut version = [0u8; 4];
1373 reader.read_exact(&mut version)?;
1374 let version = u32::from_le_bytes(version);
1375
1376 if version > 1 {
1377 return Err(TorshError::SerializationError(format!(
1378 "Unsupported torsh file version: {}",
1379 version
1380 )));
1381 }
1382
1383 let mut num_tensors = [0u8; 8];
1385 reader.read_exact(&mut num_tensors)?;
1386 let num_tensors = u64::from_le_bytes(num_tensors);
1387
1388 let mut state_dict = StateDict::new();
1389
1390 for _ in 0..num_tensors {
1391 let mut name_len = [0u8; 4];
1393 reader.read_exact(&mut name_len)?;
1394 let name_len = u32::from_le_bytes(name_len) as usize;
1395
1396 let mut name_bytes = vec![0u8; name_len];
1398 reader.read_exact(&mut name_bytes)?;
1399 let name = String::from_utf8(name_bytes)
1400 .map_err(|e| TorshError::SerializationError(format!("Invalid tensor name: {}", e)))?;
1401
1402 let tensor = create_placeholder_tensor(&name, map_location)?;
1405 state_dict.insert(name, tensor);
1406 }
1407
1408 Ok(state_dict)
1409}
1410
1411fn load_pytorch_compatible_state_dict(
1413 _reader: &mut impl Read,
1414 _map_location: Option<torsh_core::DeviceType>,
1415) -> Result<StateDict> {
1416 Err(TorshError::Other(
1419 "PyTorch (.pt/.pth) format loading not yet implemented. Please convert to .json or .torsh format".to_string(),
1420 ))
1421}
1422
1423fn load_json_state_dict(
1425 reader: &mut impl Read,
1426 map_location: Option<torsh_core::DeviceType>,
1427) -> Result<StateDict> {
1428 use serde_json::Value;
1429 use torsh_tensor::creation::*;
1430
1431 let mut content = String::new();
1433 reader.read_to_string(&mut content)?;
1434
1435 let json: Value = serde_json::from_str(&content)
1437 .map_err(|e| TorshError::SerializationError(format!("Invalid JSON: {}", e)))?;
1438
1439 let mut state_dict = StateDict::new();
1440
1441 if let Value::Object(obj) = json {
1442 for (name, value) in obj {
1443 let tensor = match value {
1444 Value::Object(tensor_obj) => {
1445 let shape_value = tensor_obj.get("shape").ok_or_else(|| {
1447 TorshError::SerializationError("Missing 'shape' field".to_string())
1448 })?;
1449 let data_value = tensor_obj.get("data").ok_or_else(|| {
1450 TorshError::SerializationError("Missing 'data' field".to_string())
1451 })?;
1452
1453 let shape: Vec<usize> = shape_value
1455 .as_array()
1456 .ok_or_else(|| {
1457 TorshError::SerializationError("Shape must be an array".to_string())
1458 })?
1459 .iter()
1460 .map(|v| {
1461 v.as_u64()
1462 .ok_or_else(|| {
1463 TorshError::SerializationError(
1464 "Shape dimensions must be integers".to_string(),
1465 )
1466 })
1467 .map(|u| u as usize)
1468 })
1469 .collect::<Result<Vec<_>>>()?;
1470
1471 let data: Vec<f32> = data_value
1473 .as_array()
1474 .ok_or_else(|| {
1475 TorshError::SerializationError("Data must be an array".to_string())
1476 })?
1477 .iter()
1478 .map(|v| {
1479 v.as_f64()
1480 .ok_or_else(|| {
1481 TorshError::SerializationError(
1482 "Data elements must be numbers".to_string(),
1483 )
1484 })
1485 .map(|f| f as f32)
1486 })
1487 .collect::<Result<Vec<_>>>()?;
1488
1489 let expected_len: usize = shape.iter().product();
1491 if data.len() != expected_len {
1492 return Err(TorshError::SerializationError(format!(
1493 "Data length {} doesn't match shape {:?} (expected {})",
1494 data.len(),
1495 shape,
1496 expected_len
1497 )));
1498 }
1499
1500 let device = map_location.unwrap_or(torsh_core::DeviceType::Cpu);
1502 from_vec(data, &shape, device)?
1503 }
1504 Value::Array(arr) => {
1505 let data: Vec<f32> = arr
1507 .iter()
1508 .map(|v| {
1509 v.as_f64()
1510 .ok_or_else(|| {
1511 TorshError::SerializationError(
1512 "Array elements must be numbers".to_string(),
1513 )
1514 })
1515 .map(|f| f as f32)
1516 })
1517 .collect::<Result<Vec<_>>>()?;
1518
1519 let shape = vec![data.len()];
1520 let device = map_location.unwrap_or(torsh_core::DeviceType::Cpu);
1521 from_vec(data, &shape, device)?
1522 }
1523 _ => {
1524 return Err(TorshError::SerializationError(format!(
1525 "Unsupported tensor format for parameter '{}'",
1526 name
1527 )));
1528 }
1529 };
1530
1531 state_dict.insert(name, tensor);
1532 }
1533 } else {
1534 return Err(TorshError::SerializationError(
1535 "JSON state dict must be an object".to_string(),
1536 ));
1537 }
1538
1539 Ok(state_dict)
1540}
1541
1542pub fn apply_device_mapping(
1544 mut state_dict: StateDict,
1545 target_device: torsh_core::DeviceType,
1546) -> Result<StateDict> {
1547 for (_name, tensor) in state_dict.iter_mut() {
1548 *tensor = tensor.clone().to(target_device)?;
1550 }
1551 Ok(state_dict)
1552}
1553
1554fn create_placeholder_tensor(
1556 name: &str,
1557 _device: Option<torsh_core::DeviceType>,
1558) -> Result<torsh_tensor::Tensor<f32>> {
1559 use torsh_tensor::creation::*;
1560
1561 if name.contains("weight") {
1563 Ok(randn(&[64, 32])?)
1564 } else if name.contains("bias") {
1565 Ok(zeros(&[64])?)
1566 } else {
1567 Ok(zeros(&[1])?)
1568 }
1569}
1570
1571#[derive(Deserialize)]
1572struct ModelList {
1573 models: Vec<ModelEntry>,
1574}
1575
1576#[derive(Deserialize)]
1577struct ModelEntry {
1578 name: String,
1579 #[allow(dead_code)]
1580 description: String,
1581}
1582
1583pub mod sources {
1585 use super::*;
1586
1587 pub const TORSH_VISION: &str = "torsh/vision";
1589 pub const TORSH_TEXT: &str = "torsh/text";
1590 pub const TORSH_AUDIO: &str = "torsh/audio";
1591
1592 pub fn resnet18(pretrained: bool) -> Result<Box<dyn torsh_nn::Module>> {
1594 load(TORSH_VISION, "resnet18", pretrained, None)
1595 }
1596
1597 pub fn resnet50(pretrained: bool) -> Result<Box<dyn torsh_nn::Module>> {
1598 load(TORSH_VISION, "resnet50", pretrained, None)
1599 }
1600
1601 pub fn efficientnet_b0(pretrained: bool) -> Result<Box<dyn torsh_nn::Module>> {
1603 load(TORSH_VISION, "efficientnet_b0", pretrained, None)
1604 }
1605}
1606
1607#[cfg(test)]
1608mod tests {
1609 use super::*;
1610
1611 #[test]
1612 fn test_parse_repo_info() {
1613 let (owner, repo, branch) = parse_repo_info("pytorch/vision").unwrap();
1614 assert_eq!(owner, "pytorch");
1615 assert_eq!(repo, "vision");
1616 assert_eq!(branch, "main");
1617
1618 let (owner, repo, branch) = parse_repo_info("https://github.com/pytorch/vision").unwrap();
1619 assert_eq!(owner, "pytorch");
1620 assert_eq!(repo, "vision");
1621 assert_eq!(branch, "main");
1622
1623 let (owner, repo, branch) =
1624 parse_repo_info("https://github.com/pytorch/vision/tree/v0.11.0").unwrap();
1625 assert_eq!(owner, "pytorch");
1626 assert_eq!(repo, "vision");
1627 assert_eq!(branch, "v0.11.0");
1628 }
1629
1630 #[test]
1631 fn test_hub_dir() {
1632 let original = get_dir();
1633
1634 let temp_dir = tempfile::tempdir().unwrap();
1635 set_dir(temp_dir.path()).unwrap();
1636
1637 assert_eq!(get_dir(), temp_dir.path());
1638
1639 std::env::remove_var("TORSH_HUB_DIR");
1641 assert_eq!(get_dir(), original);
1642 }
1643}
1644
1645pub const VERSION: &str = env!("CARGO_PKG_VERSION");
1647pub const VERSION_MAJOR: u32 = 0;
1648pub const VERSION_MINOR: u32 = 1;
1649pub const VERSION_PATCH: u32 = 0;
1650
1651#[allow(ambiguous_glob_reexports)]
1653pub mod prelude {
1654 pub use crate::{
1655 analytics::*, cache::*, community::*, debugging::*, download::*, enterprise::*,
1656 fine_tuning::*, huggingface::*, onnx::*, profiling::*, registry::*, security::*, utils::*,
1657 visualization::*,
1658 };
1659}