1#![allow(clippy::excessive_nesting)]
72
73use std::string::ToString;
74use std::vec::Vec;
75
76pub mod layers;
77pub mod models;
78#[cfg(feature = "web-workers")]
79pub mod runtime;
80
81pub mod core;
83pub use core::{model, pipeline, tensor, tokenizer, utils};
84
85pub mod compute;
87
88#[cfg(feature = "web-workers")]
89pub use compute::web_workers;
90
91#[cfg(feature = "shared-memory")]
92pub use compute::threads;
93
94#[cfg(feature = "webgpu")]
95pub use compute::webgpu_simple;
96
97#[cfg(feature = "webgpu")]
98pub use compute::webgpu;
99
100#[cfg(feature = "webgpu")]
101pub use compute::gpu_tensor;
102
103#[cfg(feature = "indexeddb")]
105pub mod storage;
106
107#[cfg(feature = "memory64")]
108pub use storage::memory64;
109
110#[cfg(feature = "streaming-loader")]
111pub use storage::streaming_loader;
112
113#[cfg(feature = "model-splitting")]
114pub use storage::model_splitting;
115
116#[cfg(feature = "react-components")]
117pub mod react_components;
118
119#[cfg(feature = "vue-components")]
120pub mod vue_components;
121
122#[cfg(feature = "angular-components")]
123pub mod angular_components;
124
125#[cfg(feature = "web-components")]
126pub mod web_components;
127
128#[cfg(feature = "playground")]
129pub mod playground;
130
131#[cfg(feature = "streaming-generation")]
132pub mod streaming_generation;
133
134#[cfg(feature = "mobile-optimization")]
135pub mod mobile;
136
137#[cfg(feature = "mobile-optimization")]
138pub mod touch_gestures;
139
140#[cfg(feature = "mobile-optimization")]
141pub mod camera_integration;
142
143#[cfg(feature = "mobile-optimization")]
144pub mod device_capability;
145
146#[cfg(feature = "mobile-optimization")]
147pub mod device_capability_detection;
148
149pub mod debug;
150pub mod error;
151pub mod events;
152pub mod export;
153pub mod optimization;
155pub use optimization::{
156 batch_processing, memory_pool, quantization, simd_tensor_ops, weight_compression,
157};
158
159pub mod auto_docs;
160pub mod multi_model_manager;
161pub mod performance;
162pub mod performance_profiler;
163pub mod plugin_framework;
164pub mod plugins;
165
166use std::sync::Mutex;
167use wasm_bindgen::prelude::*;
168
169static GPU_MEMORY_TRACKER: Mutex<GpuMemoryTracker> = Mutex::new(GpuMemoryTracker::new());
171
172#[derive(Debug)]
174struct GpuMemoryTracker {
175 current_usage: usize,
176 peak_usage: usize,
177 total_allocated: usize,
178 total_deallocated: usize,
179 allocation_count: usize,
180 deallocation_count: usize,
181}
182
183impl GpuMemoryTracker {
184 const fn new() -> Self {
185 Self {
186 current_usage: 0,
187 peak_usage: 0,
188 total_allocated: 0,
189 total_deallocated: 0,
190 allocation_count: 0,
191 deallocation_count: 0,
192 }
193 }
194
195 fn allocate(&mut self, size: usize) {
196 self.current_usage += size;
197 self.total_allocated += size;
198 self.allocation_count += 1;
199
200 if self.current_usage > self.peak_usage {
201 self.peak_usage = self.current_usage;
202 }
203 }
204
205 fn deallocate(&mut self, size: usize) {
206 self.current_usage = self.current_usage.saturating_sub(size);
207 self.total_deallocated += size;
208 self.deallocation_count += 1;
209 }
210
211 fn get_current_usage(&self) -> usize {
212 self.current_usage
213 }
214
215 fn get_peak_usage(&self) -> usize {
216 self.peak_usage
217 }
218
219 fn reset_peak(&mut self) {
220 self.peak_usage = self.current_usage;
221 }
222}
223
224#[wasm_bindgen]
225pub struct TrustformersWasm {
226 initialized: bool,
227}
228
229impl Default for TrustformersWasm {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[wasm_bindgen]
236impl TrustformersWasm {
237 #[wasm_bindgen(constructor)]
238 pub fn new() -> Self {
239 #[cfg(feature = "console_panic")]
240 console_error_panic_hook::set_once();
241
242 TrustformersWasm { initialized: true }
243 }
244
245 #[wasm_bindgen(getter)]
246 pub fn version(&self) -> String {
247 "0.1.0".to_string()
248 }
249
250 #[wasm_bindgen(getter)]
251 pub fn initialized(&self) -> bool {
252 self.initialized
253 }
254}
255
256pub use tensor::WasmTensor;
258
259#[cfg(feature = "webgpu")]
260pub use webgpu::{
261 AsyncExecutor, DeviceCapabilities, DeviceSelector, DeviceType, ExecutionStatus, FusableOp,
262 KernelFusion, OperationType, Priority, ShaderManager, WorkgroupTuner,
263};
264
265#[cfg(feature = "web-workers")]
266pub use runtime::edge_runtime::{
267 EdgeCapabilities, EdgeInferenceConfig, EdgeRuntime, EdgeRuntimeDetector,
268};
269
270#[cfg(feature = "web-workers")]
271pub use runtime::geo_distribution::{
272 create_geo_distribution_manager, estimate_network_latency, get_distance_between_points,
273 EdgeLocation, GeoDistributionManager, GeoRegion, RoutingDecision, RoutingWeights, UserLocation,
274};
275
276#[cfg(feature = "web-workers")]
277pub use runtime::edge_caching::{
278 create_edge_cache_manager, create_edge_computing_cache_config,
279 create_memory_efficient_cache_config, create_performance_cache_config, estimate_cache_overhead,
280 CacheConfig, CacheEntry, CacheEntryType, CacheStatistics, ConsistencyLevel, EdgeCacheManager,
281 EvictionPolicy, ReplicationStrategy,
282};
283
284#[cfg(feature = "web-workers")]
285pub use web_workers::{
286 get_optimal_worker_count, is_web_workers_supported, WorkerCoordinator, WorkerPool,
287 WorkerPriority, WorkerTaskType,
288};
289
290#[cfg(feature = "shared-memory")]
291pub use threads::{
292 get_optimal_thread_count, is_cross_origin_isolated, is_threading_supported, AtomicOperations,
293 ThreadPool, ThreadSync, ThreadTaskType,
294};
295
296#[cfg(feature = "indexeddb")]
297pub use storage::{CompressionType, ModelMetadata, ModelStorage, StoredModel};
298
299#[cfg(feature = "memory64")]
300pub use memory64::{
301 can_load_model_size, get_memory64_capabilities, is_memory64_supported, AllocationStrategy,
302 Memory64Capabilities, Memory64Manager,
303};
304
305#[cfg(feature = "streaming-loader")]
306pub use streaming_loader::{
307 get_optimal_chunk_size_kb, is_cache_api_available, is_streaming_compilation_supported,
308 LoadingProgress, StreamingConfig, StreamingLoader,
309};
310
311#[cfg(feature = "model-splitting")]
312pub use model_splitting::{
313 get_recommended_chunk_size_mb, should_split_model, ChunkConfig, ChunkPriority, ChunkType,
314 LoadingStrategy, ModelLoadingSession, ModelSplitter,
315};
316
317#[cfg(feature = "react-components")]
318pub use react_components::{
319 generate_react_package, is_react_available, ComponentType, InferenceState, ModelLoadingState,
320 ReactComponentFactory, ReactConfig,
321};
322
323#[cfg(feature = "vue-components")]
324pub use vue_components::{
325 generate_vue_package, is_vue_available, VueComponentFactory, VueComponentType, VueConfig,
326 VueInferenceState, VueModelState,
327};
328
329#[cfg(feature = "angular-components")]
330pub use angular_components::{
331 generate_angular_package, is_angular_available, AngularComponentType, AngularConfig,
332 AngularInferenceState, AngularModelState, AngularServiceFactory, AngularServiceType,
333};
334
335#[cfg(feature = "web-components")]
336pub use web_components::{
337 create_web_component_html_template, generate_web_components_package,
338 is_web_components_supported, ComponentType as WebComponentType,
339 InferenceState as WebInferenceState, ModelState as WebModelState, WebComponentConfig,
340 WebComponentFactory,
341};
342
343#[cfg(feature = "playground")]
344pub use playground::{
345 create_playground_config, create_playground_example, generate_playground_package,
346 ExampleCategory, InteractivePlayground, PlaygroundConfig, PlaygroundExample,
347};
348
349#[cfg(feature = "streaming-generation")]
350pub use streaming_generation::{
351 get_optimal_streaming_config, is_streaming_supported, CompletionReason, GenerationProgress,
352 StreamingConfig as GenerationStreamingConfig, StreamingGenerator, StreamingStats,
353 StreamingToken,
354};
355
356#[cfg(feature = "mobile-optimization")]
357pub use mobile::{
358 create_mobile_optimizer, get_device_memory_gb, get_optimal_model_for_device, is_low_data_mode,
359 is_mobile_device, is_tablet_device, AdaptiveModelConfig, BatteryInfo, BatteryStatus,
360 DeviceClass, MobileCapabilities, MobileOptimizer, ModelSize, NetworkStatus, NetworkType,
361};
362
363pub use auto_docs::{
364 create_default_doc_generator,
365 create_html_doc_generator,
366 create_markdown_doc_generator,
367 get_version_info,
369 AutoDocGenerator,
370 DocConfig,
371 DocFormat,
372 DocTheme,
373 VersionInfo,
374};
375pub use batch_processing::{
376 BatchConfig, BatchProcessor, BatchResponse, BatchingStrategy, Priority as BatchPriority,
377};
378pub use debug::{DebugConfig, DebugLogger, LogLevel, PerformanceMetrics};
379pub use error::{
380 ErrorBuilder, ErrorCode, ErrorCollection, ErrorContext, ErrorHandler, ErrorSeverity,
381 TrustformersError, TrustformersResult,
382};
383pub use events::{EventData, EventEmittable, EventManager, EventPriority, EventType};
384pub use multi_model_manager::{
385 create_development_multi_model_manager, create_production_multi_model_manager,
386 DeploymentEnvironment, ModelPriority, ModelStatus, MultiModelConfig, MultiModelManager,
387};
388pub use performance::{
389 BottleneckType, OperationType as ProfilerOperationType, ProfilerConfig, ResourceType,
390};
391pub use performance_profiler::{
392 create_development_profiler, create_production_profiler, PerformanceProfiler,
393};
394pub use plugin_framework::{
395 create_default_plugin_config, create_plugin_context, ExecutionMetrics, ExecutionPriority,
396 ModelMetadata as PluginModelMetadata, PerformanceBudget, Plugin, PluginConfig, PluginContext,
397 PluginError, PluginErrorCode, PluginManager, PluginMetadata, PluginPermission, PluginRegistry,
398 PluginResult, PluginType, ResourceLimits,
399};
400pub use plugins::{ModelOptimizerPlugin, TextProcessorPlugin, VisualizationPlugin};
401pub use quantization::{
402 QuantizationConfig, QuantizationPrecision, QuantizationStrategy, QuantizedModelData,
403 WebQuantizer,
404};
405pub use weight_compression::{
406 CompressedModelData, CompressionConfig, CompressionLevel, CompressionStrategy, SparsityPattern,
407 WeightCompressor,
408};
409
410#[wasm_bindgen]
411pub fn init_panic_hook() {
412 #[cfg(feature = "console_panic")]
413 console_error_panic_hook::set_once();
414}
415
416#[wasm_bindgen]
418pub fn get_wasm_memory_usage() -> usize {
419 let memory = wasm_bindgen::memory();
422 let memory_obj: &js_sys::WebAssembly::Memory = memory.unchecked_ref();
423 let buffer = js_sys::WebAssembly::Memory::buffer(memory_obj);
424 let array_buffer: &js_sys::ArrayBuffer = buffer.unchecked_ref();
425 js_sys::ArrayBuffer::byte_length(array_buffer) as usize
426}
427
428#[wasm_bindgen]
430#[derive(Debug, Clone)]
431pub struct MemoryStats {
432 wasm_memory: usize,
433 gpu_memory: usize,
434 peak_gpu_memory: usize,
435}
436
437#[wasm_bindgen]
438impl MemoryStats {
439 #[wasm_bindgen(getter)]
440 pub fn wasm_memory(&self) -> usize {
441 self.wasm_memory
442 }
443
444 #[wasm_bindgen(getter)]
445 pub fn gpu_memory(&self) -> usize {
446 self.gpu_memory
447 }
448
449 #[wasm_bindgen(getter)]
450 pub fn peak_gpu_memory(&self) -> usize {
451 self.peak_gpu_memory
452 }
453
454 #[wasm_bindgen(getter)]
455 pub fn used_mb(&self) -> f32 {
456 self.wasm_memory as f32 / (1024.0 * 1024.0)
457 }
458
459 #[wasm_bindgen(getter)]
460 pub fn limit_mb(&self) -> f32 {
461 256.0
463 }
464}
465
466#[wasm_bindgen]
468pub fn track_gpu_allocation(size: usize) {
469 if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
470 tracker.allocate(size);
471 }
472}
473
474#[wasm_bindgen]
476pub fn track_gpu_deallocation(size: usize) {
477 if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
478 tracker.deallocate(size);
479 }
480}
481
482#[wasm_bindgen]
484pub fn get_gpu_memory_usage() -> usize {
485 GPU_MEMORY_TRACKER
486 .lock()
487 .map(|tracker| tracker.get_current_usage())
488 .unwrap_or(0)
489}
490
491#[wasm_bindgen]
493pub fn get_peak_gpu_memory_usage() -> usize {
494 GPU_MEMORY_TRACKER.lock().map(|tracker| tracker.get_peak_usage()).unwrap_or(0)
495}
496
497#[wasm_bindgen]
499pub fn reset_peak_gpu_memory() {
500 if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
501 tracker.reset_peak();
502 }
503}
504
505#[wasm_bindgen]
507pub fn get_memory_stats() -> MemoryStats {
508 MemoryStats {
509 wasm_memory: get_wasm_memory_usage(),
510 gpu_memory: get_gpu_memory_usage(),
511 peak_gpu_memory: get_peak_gpu_memory_usage(),
512 }
513}
514
515#[wasm_bindgen]
516pub fn enable_simd() -> bool {
517 cfg!(target_feature = "simd128")
519}
520
521#[wasm_bindgen]
523pub struct InferenceSession {
524 model_type: String,
525 #[cfg(feature = "webgpu")]
526 device_selector: Option<webgpu::DeviceSelector>,
527 #[cfg(feature = "webgpu")]
528 current_device: webgpu::DeviceType,
529 #[cfg(feature = "web-workers")]
530 edge_detector: runtime::edge_runtime::EdgeRuntimeDetector,
531 #[cfg(feature = "web-workers")]
532 edge_config: runtime::edge_runtime::EdgeInferenceConfig,
533 #[cfg(feature = "indexeddb")]
534 storage: Option<storage::ModelStorage>,
535 debug_logger: Option<debug::DebugLogger>,
536 quantizer: Option<quantization::WebQuantizer>,
537 batch_processor: Option<batch_processing::BatchProcessor>,
538 event_emitter: Option<events::EventEmitter>,
539}
540
541#[wasm_bindgen]
542impl InferenceSession {
543 #[wasm_bindgen(constructor)]
544 pub fn new(model_type: String) -> Result<InferenceSession, JsValue> {
545 #[cfg(feature = "web-workers")]
546 let edge_detector = runtime::edge_runtime::EdgeRuntimeDetector::new();
547 #[cfg(feature = "web-workers")]
548 let edge_config =
549 runtime::edge_runtime::EdgeInferenceConfig::for_runtime(&edge_detector.capabilities());
550
551 Ok(InferenceSession {
552 model_type,
553 #[cfg(feature = "webgpu")]
554 device_selector: None,
555 #[cfg(feature = "webgpu")]
556 current_device: webgpu::DeviceType::CPU,
557 #[cfg(feature = "web-workers")]
558 edge_detector,
559 #[cfg(feature = "web-workers")]
560 edge_config,
561 #[cfg(feature = "indexeddb")]
562 storage: None,
563 debug_logger: None,
564 quantizer: None,
565 batch_processor: None,
566 event_emitter: None,
567 })
568 }
569
570 pub async fn initialize_with_auto_device(&mut self) -> Result<(), JsValue> {
572 #[cfg(feature = "webgpu")]
573 {
574 let mut selector = webgpu::DeviceSelector::new().await?;
575 selector.initialize_device().await?;
576 self.current_device = selector.selected_device();
577 self.device_selector = Some(selector);
578 }
579 Ok(())
580 }
581
582 #[cfg(feature = "webgpu")]
584 #[wasm_bindgen(getter)]
585 pub fn current_device_type(&self) -> DeviceType {
586 self.current_device
587 }
588
589 #[cfg(feature = "webgpu")]
591 pub fn get_device_capabilities(&self) -> Option<DeviceCapabilities> {
592 self.device_selector.as_ref().map(|s| s.capabilities())
593 }
594
595 pub async fn load_model(&mut self, model_data: &[u8]) -> Result<(), JsValue> {
596 use crate::error::{ErrorBuilder, ErrorCode};
597
598 let model_size = model_data.len();
599 let model_size_mb = model_size as f64 / 1024.0 / 1024.0;
600
601 if let Some(ref mut emitter) = self.event_emitter {
603 let event = events::EventData::model_load_start(&self.model_type, model_size_mb);
604 emitter.emit(event);
605 }
606
607 if model_data.is_empty() {
609 let error = ErrorBuilder::new(ErrorCode::E1001, "Model data is empty")
610 .operation("load_model")
611 .component("inference_session")
612 .build();
613
614 if let Some(ref mut emitter) = self.event_emitter {
616 let event = events::EventData::error_occurred(&error.message, "load_model");
617 emitter.emit(event);
618 }
619
620 return Err(error.into());
621 }
622
623 if model_size > 2_147_483_648 {
625 let error = ErrorBuilder::new(ErrorCode::E1002, "Model exceeds maximum size limit")
626 .operation("load_model")
627 .component("inference_session")
628 .memory_usage_mb(model_size_mb)
629 .additional_info("Consider using quantization or model splitting")
630 .build();
631
632 if let Some(ref mut emitter) = self.event_emitter {
634 let event = events::EventData::error_occurred(&error.message, "load_model")
635 .with_data("size_mb", &format!("{model_size_mb:.2}"));
636 emitter.emit(event);
637 }
638
639 return Err(error.into());
640 }
641
642 let start_time = js_sys::Date::now();
643
644 if let Some(ref mut logger) = self.debug_logger {
646 logger.start_timer("model_loading");
647 logger.log_model_loading(&self.model_type, model_size, "memory");
648 logger.log_memory_usage("Before model loading");
649 }
650
651 let memory_stats = crate::get_memory_stats();
653 let available_memory = 2_147_483_648; if model_size > available_memory {
655 return Err(
656 ErrorBuilder::new(ErrorCode::E4002, "Insufficient memory for model")
657 .operation("load_model")
658 .component("inference_session")
659 .memory_usage_mb(memory_stats.wasm_memory as f64 / 1024.0 / 1024.0)
660 .additional_info("Try enabling quantization or using a smaller model")
661 .build()
662 .into(),
663 );
664 }
665
666 #[cfg(feature = "webgpu")]
667 if let Some(selector) = &self.device_selector {
668 let should_use_gpu = selector.should_use_gpu(model_size, 0.8); if should_use_gpu {
671 web_sys::console::log_1(
673 &format!("Loading model on GPU (size: {model_size} bytes)").into(),
674 );
675
676 let gpu_memory_required = model_size * 2; let capabilities = selector.capabilities();
679 if gpu_memory_required > capabilities.gpu_memory_limit as usize {
680 return Err(
681 ErrorBuilder::new(ErrorCode::E3003, "Insufficient GPU memory")
682 .operation("load_model")
683 .component("webgpu_device")
684 .device_type("GPU")
685 .memory_usage_mb(gpu_memory_required as f64 / 1024.0 / 1024.0)
686 .additional_info("Try using CPU mode or enabling quantization")
687 .build()
688 .into(),
689 );
690 }
691 } else {
692 web_sys::console::log_1(
694 &format!("Loading model on CPU (size: {model_size} bytes)").into(),
695 );
696 }
697 } else {
698 web_sys::console::log_1(
699 &format!("Loading model on CPU (size: {model_size} bytes)").into(),
700 );
701 }
702
703 #[cfg(not(feature = "webgpu"))]
704 web_sys::console::log_1(&format!("Loading model on CPU (size: {model_size} bytes)").into());
705
706 if model_size < 1024 {
708 if let Some(ref mut logger) = self.debug_logger {
709 logger.warn(
710 "Model size is very small, may not be a valid model",
711 "model_validation",
712 );
713 }
714 }
715
716 if let Some(ref mut logger) = self.debug_logger {
718 logger.log_memory_usage("After model loading");
719 logger.end_timer("model_loading");
720 logger.info(
721 &format!(
722 "Model loaded successfully (size: {} MB)",
723 model_size / 1024 / 1024
724 ),
725 "model_loading",
726 );
727 }
728
729 let duration_ms = js_sys::Date::now() - start_time;
731 if let Some(ref mut emitter) = self.event_emitter {
732 let event = events::EventData::model_load_complete(&self.model_type, duration_ms);
733 emitter.emit(event);
734 }
735
736 Ok(())
737 }
738
739 pub fn predict(&mut self, input: &tensor::WasmTensor) -> Result<tensor::WasmTensor, JsValue> {
740 let input_size = input.len();
741 let start_time = js_sys::Date::now();
742
743 if let Some(ref mut emitter) = self.event_emitter {
745 let event = events::EventData::inference_start(&input.shape());
746 emitter.emit(event);
747 }
748
749 if let Some(ref mut logger) = self.debug_logger {
751 logger.start_timer("inference");
752 logger.log_inference(&self.model_type, &input.shape(), "auto");
753 logger.log_memory_usage("Before inference");
754 }
755
756 #[cfg(feature = "webgpu")]
757 if let Some(selector) = &self.device_selector {
758 let should_use_gpu = selector.should_use_gpu(input_size, 0.6); if should_use_gpu {
761 web_sys::console::log_1(
762 &format!("Running inference on GPU (input size: {input_size})").into(),
763 );
764 } else {
766 web_sys::console::log_1(
767 &format!("Running inference on CPU (input size: {input_size})").into(),
768 );
769 }
771 } else {
772 web_sys::console::log_1(
773 &format!("Running inference on CPU (input size: {input_size})").into(),
774 );
775 }
776
777 #[cfg(not(feature = "webgpu"))]
778 web_sys::console::log_1(
779 &format!("Running inference on CPU (input size: {input_size})").into(),
780 );
781
782 let result = input.clone();
783
784 if let Some(ref mut logger) = self.debug_logger {
786 logger.log_memory_usage("After inference");
787 logger.end_timer("inference");
788 }
789
790 let duration_ms = js_sys::Date::now() - start_time;
792 if let Some(ref mut emitter) = self.event_emitter {
793 let event = events::EventData::inference_complete(duration_ms, result.len());
794 emitter.emit(event);
795 }
796
797 Ok(result)
798 }
799
800 #[cfg(feature = "webgpu")]
802 pub fn force_device_type(&mut self, device_type: DeviceType) {
803 self.current_device = device_type;
804 }
805
806 #[cfg(feature = "web-workers")]
808 #[wasm_bindgen(getter)]
809 pub fn edge_capabilities(&self) -> EdgeCapabilities {
810 self.edge_detector.capabilities()
811 }
812
813 #[cfg(feature = "web-workers")]
815 #[wasm_bindgen(getter)]
816 pub fn edge_config(&self) -> EdgeInferenceConfig {
817 self.edge_config.clone()
818 }
819
820 #[cfg(feature = "web-workers")]
822 pub fn is_edge_suitable(&self) -> bool {
823 self.edge_detector.is_ml_suitable()
824 }
825
826 #[cfg(feature = "web-workers")]
828 pub fn recommended_model_size_mb(&self) -> u32 {
829 self.edge_detector.recommended_model_size_mb()
830 }
831
832 #[cfg(feature = "web-workers")]
834 pub fn get_cold_start_optimizations(&self) -> js_sys::Array {
835 self.edge_detector.get_cold_start_optimizations()
836 }
837
838 #[cfg(feature = "indexeddb")]
840 pub async fn initialize_storage(&mut self, max_storage_mb: f64) -> Result<(), JsValue> {
841 let mut storage =
842 storage::ModelStorage::new("trustformers-models".to_string(), max_storage_mb);
843 storage.initialize().await?;
844 self.storage = Some(storage);
845 web_sys::console::log_1(
846 &format!("Initialized model storage (max: {max_storage_mb} MB)").into(),
847 );
848 Ok(())
849 }
850
851 #[cfg(feature = "indexeddb")]
853 pub async fn store_model(
854 &self,
855 model_id: &str,
856 model_name: &str,
857 architecture: &str,
858 version: &str,
859 data: &[u8],
860 ) -> Result<(), JsValue> {
861 let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
862 storage.store_model(model_id, model_name, architecture, version, data).await
863 }
864
865 #[cfg(feature = "indexeddb")]
867 pub async fn load_cached_model(&self, model_id: &str) -> Result<Option<Vec<u8>>, JsValue> {
868 let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
869 storage.get_model(model_id).await
870 }
871
872 #[cfg(feature = "indexeddb")]
874 pub async fn has_cached_model(&self, model_id: &str) -> Result<bool, JsValue> {
875 let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
876 storage.has_model(model_id).await
877 }
878
879 #[cfg(feature = "indexeddb")]
881 pub async fn load_model_with_cache(
882 &mut self,
883 model_id: &str,
884 model_url: &str,
885 model_name: &str,
886 architecture: &str,
887 version: &str,
888 ) -> Result<(), JsValue> {
889 if let Some(storage) = &self.storage {
890 if let Some(cached_data) = storage.get_model(model_id).await? {
892 web_sys::console::log_1(&format!("Loaded model '{model_name}' from cache").into());
893 return self.load_model(&cached_data).await;
894 }
895 }
896
897 web_sys::console::log_1(
899 &format!("Downloading model '{model_name}' from {model_url}").into(),
900 );
901
902 let model_data = self.fetch_model_from_url(model_url).await.map_err(|e| {
903 JsValue::from_str(&format!(
904 "Failed to fetch model from {}: {:?}",
905 model_url, e
906 ))
907 })?;
908
909 if let Some(storage) = &self.storage {
911 storage
912 .store_model(model_id, model_name, architecture, version, &model_data)
913 .await?;
914 }
915
916 self.load_model(&model_data).await
917 }
918
919 #[cfg(feature = "indexeddb")]
921 pub async fn get_storage_stats(&self) -> Result<Option<String>, JsValue> {
922 if let Some(storage) = &self.storage {
923 let usage = storage.get_storage_usage().await?;
924 let models_js = storage.list_models().await?;
925 let models_count = if let Ok(models) =
926 serde_wasm_bindgen::from_value::<Vec<ModelMetadata>>(models_js)
927 {
928 models.len()
929 } else {
930 0
931 };
932 Ok(Some(format!(
933 "Storage: {} bytes, {} models",
934 usage, models_count
935 )))
936 } else {
937 Ok(None)
938 }
939 }
940
941 #[cfg(feature = "indexeddb")]
943 pub async fn clear_model_cache(&self) -> Result<(), JsValue> {
944 let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
945 storage.clear_all().await
946 }
947
948 #[allow(dead_code)]
950 async fn fetch_model_from_url(&self, url: &str) -> Result<Vec<u8>, JsValue> {
951 use wasm_bindgen::JsCast;
952 use wasm_bindgen_futures::JsFuture;
953
954 let request = web_sys::Request::new_with_str(url)?;
956 request.headers().set("Accept", "application/octet-stream")?;
957
958 let window = web_sys::window().ok_or("No global window object")?;
960
961 let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
963 let resp: web_sys::Response = resp_value.dyn_into()?;
964
965 if !resp.ok() {
967 return Err(JsValue::from_str(&format!(
968 "HTTP error: {} {}",
969 resp.status(),
970 resp.status_text()
971 )));
972 }
973
974 let array_buffer = JsFuture::from(resp.array_buffer()?).await?;
976 let uint8_array = js_sys::Uint8Array::new(&array_buffer);
977
978 let mut data = vec![0u8; uint8_array.length() as usize];
980 uint8_array.copy_to(&mut data);
981
982 web_sys::console::log_1(
983 &format!("Successfully downloaded {len} bytes", len = data.len()).into(),
984 );
985
986 Ok(data)
987 }
988
989 pub fn enable_debug_logging(&mut self, config: debug::DebugConfig) {
991 self.debug_logger = Some(debug::DebugLogger::new(config));
992 if let Some(ref mut logger) = self.debug_logger {
993 logger.info(
994 "Debug logging enabled for inference session",
995 "initialization",
996 );
997 }
998 }
999
1000 pub fn disable_debug_logging(&mut self) {
1002 if let Some(ref mut logger) = self.debug_logger {
1003 logger.info("Debug logging disabled", "cleanup");
1004 }
1005 self.debug_logger = None;
1006 }
1007
1008 pub fn start_timer(&mut self, operation: &str) {
1010 if let Some(ref mut logger) = self.debug_logger {
1011 logger.start_timer(operation);
1012 }
1013 }
1014
1015 pub fn end_timer(&mut self, operation: &str) -> Option<f64> {
1017 if let Some(ref mut logger) = self.debug_logger {
1018 logger.end_timer(operation)
1019 } else {
1020 None
1021 }
1022 }
1023
1024 pub fn get_performance_summary(&self) -> Option<String> {
1026 self.debug_logger.as_ref().map(|logger| logger.get_performance_summary())
1027 }
1028
1029 pub fn export_debug_logs(&self) -> Option<String> {
1031 self.debug_logger.as_ref().map(|logger| logger.export_logs())
1032 }
1033
1034 pub fn clear_debug_logs(&mut self) {
1036 if let Some(ref mut logger) = self.debug_logger {
1037 logger.clear();
1038 }
1039 }
1040
1041 pub fn log_memory_usage(&mut self, context: &str) {
1043 if let Some(ref mut logger) = self.debug_logger {
1044 logger.log_memory_usage(context);
1045 }
1046 }
1047
1048 pub fn enable_quantization(&mut self, config: quantization::QuantizationConfig) {
1050 self.quantizer = Some(quantization::WebQuantizer::new(config));
1051 if let Some(ref mut logger) = self.debug_logger {
1052 logger.info("Quantization enabled for inference session", "quantization");
1053 }
1054 }
1055
1056 pub fn disable_quantization(&mut self) {
1058 if let Some(ref mut logger) = self.debug_logger {
1059 logger.info("Quantization disabled", "quantization");
1060 }
1061 self.quantizer = None;
1062 }
1063
1064 pub async fn load_model_with_quantization(&mut self, model_data: &[u8]) -> Result<(), JsValue> {
1066 let mut final_data = model_data.to_vec();
1067 let mut quantization_summary = "No quantization applied".to_string();
1068
1069 if let Some(ref quantizer) = self.quantizer {
1071 if quantizer.should_quantize(model_data.len()) {
1072 if let Some(ref mut logger) = self.debug_logger {
1073 logger.start_timer("quantization");
1074 logger.info(
1075 &format!(
1076 "Starting quantization for {len} bytes",
1077 len = model_data.len()
1078 ),
1079 "quantization",
1080 );
1081 }
1082
1083 match quantizer.quantize_model(model_data) {
1084 Ok(quantized) => {
1085 final_data = quantized.data();
1086 quantization_summary = quantized.summary();
1087 self.log_quantization_success(&quantization_summary);
1088 },
1089 Err(e) => {
1090 let error_msg = format!("Quantization failed: {e:?}, using original model");
1091 self.log_quantization_failure(&error_msg);
1092 },
1094 }
1095 } else {
1096 self.log_quantization_skipped();
1097 }
1098 }
1099
1100 self.load_model(&final_data).await?;
1102
1103 if let Some(ref mut logger) = self.debug_logger {
1104 logger.info(
1105 &format!("Model loaded: {quantization_summary}"),
1106 "model_loading",
1107 );
1108 }
1109
1110 Ok(())
1111 }
1112
1113 pub fn get_quantization_recommendations(
1115 &self,
1116 model_size_bytes: usize,
1117 ) -> Option<quantization::QuantizationConfig> {
1118 self.quantizer.as_ref().map(|q| q.get_recommended_settings(model_size_bytes))
1119 }
1120
1121 pub fn should_quantize_model(&self, model_size_bytes: usize) -> bool {
1123 self.quantizer.as_ref().is_some_and(|q| q.should_quantize(model_size_bytes))
1124 }
1125
1126 pub fn enable_batch_processing(&mut self, config: batch_processing::BatchConfig) {
1128 self.batch_processor = Some(batch_processing::BatchProcessor::new(config));
1129 if let Some(ref mut logger) = self.debug_logger {
1130 logger.info(
1131 "Batch processing enabled for inference session",
1132 "batch_processing",
1133 );
1134 }
1135 }
1136
1137 pub fn disable_batch_processing(&mut self) {
1139 if let Some(ref mut logger) = self.debug_logger {
1140 logger.info("Batch processing disabled", "batch_processing");
1141 }
1142 self.batch_processor = None;
1143 }
1144
1145 pub fn add_batch_request(
1147 &mut self,
1148 input: &tensor::WasmTensor,
1149 priority: batch_processing::Priority,
1150 timeout_ms: Option<u32>,
1151 ) -> Option<String> {
1152 if let Some(ref mut processor) = self.batch_processor {
1153 let request_id = processor.add_request(input.clone(), priority, timeout_ms);
1154
1155 if let Some(ref mut logger) = self.debug_logger {
1156 logger.debug(
1157 &format!("Added batch request {request_id} with priority {priority:?}"),
1158 "batch_processing",
1159 );
1160 }
1161
1162 Some(request_id)
1163 } else {
1164 if let Some(ref mut logger) = self.debug_logger {
1165 logger.warn("Batch processing not enabled", "batch_processing");
1166 }
1167 None
1168 }
1169 }
1170
1171 pub async fn process_batch(&mut self) -> Result<Vec<batch_processing::BatchResponse>, JsValue> {
1173 if let Some(ref mut processor) = self.batch_processor {
1174 if let Some(ref mut logger) = self.debug_logger {
1175 logger.start_timer("batch_processing");
1176 logger.debug(
1177 &format!(
1178 "Processing batch with {} pending requests",
1179 processor.queue_length()
1180 ),
1181 "batch_processing",
1182 );
1183 }
1184
1185 let responses = processor.process_batch().await?;
1186
1187 if let Some(ref mut logger) = self.debug_logger {
1188 logger.info(
1189 &format!("Processed batch: {len} responses", len = responses.len()),
1190 "batch_processing",
1191 );
1192 logger.end_timer("batch_processing");
1193 }
1194
1195 Ok(responses)
1196 } else {
1197 Err("Batch processing not enabled".into())
1198 }
1199 }
1200
1201 pub fn is_batch_ready(&self) -> bool {
1203 self.batch_processor.as_ref().is_some_and(|p| p.is_batch_ready())
1204 }
1205
1206 pub fn get_batch_queue_length(&self) -> usize {
1208 self.batch_processor.as_ref().map_or(0, |p| p.queue_length())
1209 }
1210
1211 pub fn get_batch_stats(&self) -> Option<String> {
1213 self.batch_processor.as_ref().map(|p| p.get_stats())
1214 }
1215
1216 pub fn clear_batch_queue(&mut self) {
1218 if let Some(ref mut processor) = self.batch_processor {
1219 processor.clear_queue();
1220 if let Some(ref mut logger) = self.debug_logger {
1221 logger.info("Batch queue cleared", "batch_processing");
1222 }
1223 }
1224 }
1225
1226 pub fn enable_events(&mut self) {
1228 self.event_emitter = Some(events::EventEmitter::new());
1229 if let Some(ref mut logger) = self.debug_logger {
1230 logger.info("Event system enabled for inference session", "events");
1231 }
1232 }
1233
1234 pub fn disable_events(&mut self) {
1236 if let Some(ref mut logger) = self.debug_logger {
1237 logger.info("Event system disabled", "events");
1238 }
1239 self.event_emitter = None;
1240 }
1241
1242 pub fn get_event_history(&self) -> Option<String> {
1244 self.event_emitter
1245 .as_ref()
1246 .and_then(|emitter| serde_json::to_string(emitter.get_history()).ok())
1247 }
1248
1249 pub fn clear_event_history(&mut self) {
1251 if let Some(ref mut emitter) = self.event_emitter {
1252 emitter.clear_history();
1253 }
1254 }
1255
1256 pub fn emit_custom_event(&mut self, event_type: u32, source: &str, data: Option<String>) {
1258 if let Some(ref mut emitter) = self.event_emitter {
1259 if let Ok(event_type) = Self::event_type_from_u32(event_type) {
1260 let mut event = events::EventData::new(event_type, source);
1261 if let Some(data_str) = data {
1262 event = event.with_data("custom_data", &data_str);
1263 }
1264 emitter.emit(event);
1265 }
1266 }
1267 }
1268
1269 fn event_type_from_u32(event_type: u32) -> Result<events::EventType, ()> {
1270 match event_type {
1271 1000 => Ok(events::EventType::ModelLoadStart),
1272 1001 => Ok(events::EventType::ModelLoadProgress),
1273 1002 => Ok(events::EventType::ModelLoadComplete),
1274 1003 => Ok(events::EventType::ModelLoadError),
1275 2000 => Ok(events::EventType::InferenceStart),
1276 2001 => Ok(events::EventType::InferenceProgress),
1277 2002 => Ok(events::EventType::InferenceComplete),
1278 2003 => Ok(events::EventType::InferenceError),
1279 6000 => Ok(events::EventType::ErrorOccurred),
1280 _ => Err(()),
1281 }
1282 }
1283
1284 pub async fn predict_with_batching(
1286 &mut self,
1287 input: &tensor::WasmTensor,
1288 priority: batch_processing::Priority,
1289 ) -> Result<tensor::WasmTensor, JsValue> {
1290 if let Some(ref mut processor) = self.batch_processor {
1291 let request_id = processor.add_request(input.clone(), priority, None);
1293
1294 if processor.is_batch_ready() || priority >= batch_processing::Priority::High {
1296 let responses = processor.process_batch().await?;
1297
1298 return self.find_batch_response(&responses, &request_id);
1300 }
1301
1302 return self.predict(input);
1304 }
1305
1306 self.predict(input)
1308 }
1309
1310 pub async fn flush_batches(&mut self) -> Result<Vec<batch_processing::BatchResponse>, JsValue> {
1312 if let Some(ref mut processor) = self.batch_processor {
1313 let mut all_responses = Vec::new();
1314
1315 while processor.queue_length() > 0 {
1316 let responses = processor.process_batch().await?;
1317 all_responses.extend(responses);
1318 }
1319
1320 Ok(all_responses)
1321 } else {
1322 Ok(Vec::new())
1323 }
1324 }
1325
1326 fn log_quantization_success(&mut self, summary: &str) {
1328 if let Some(ref mut logger) = self.debug_logger {
1329 logger.info(summary, "quantization");
1330 logger.end_timer("quantization");
1331 }
1332 }
1333
1334 fn log_quantization_failure(&mut self, error: &str) {
1335 if let Some(ref mut logger) = self.debug_logger {
1336 logger.warn(error, "quantization");
1337 logger.end_timer("quantization");
1338 }
1339 }
1340
1341 fn log_quantization_skipped(&mut self) {
1342 if let Some(ref mut logger) = self.debug_logger {
1343 logger.info(
1344 "Model size is optimal, skipping quantization",
1345 "quantization",
1346 );
1347 }
1348 }
1349
1350 fn find_batch_response(
1351 &self,
1352 responses: &[batch_processing::BatchResponse],
1353 request_id: &str,
1354 ) -> Result<tensor::WasmTensor, JsValue> {
1355 for response in responses.iter() {
1356 if response.request_id() == request_id {
1357 if let Some(_result) = response.result() {
1358 return tensor::WasmTensor::new(vec![1.0], vec![1]);
1361 } else if let Some(error) = response.error() {
1362 return Err(error.into());
1363 }
1364 }
1365 }
1366 Err("Response not found for request ID".into())
1367 }
1368}
1369
1370#[cfg(test)]
1371mod tests {
1372 use super::*;
1373
1374 #[test]
1375 fn test_initialization() {
1376 let tf = TrustformersWasm::new();
1377 assert!(tf.initialized());
1378 assert_eq!(tf.version(), "0.1.0");
1379 }
1380}