Skip to main content

trustformers_wasm/
lib.rs

1//! # TrustformeRS WebAssembly Bindings
2//!
3//! Run transformer models directly in the browser with WebAssembly and WebGPU acceleration.
4//!
5//! This crate provides WebAssembly bindings for TrustformeRS, enabling transformer model
6//! inference in web browsers with near-native performance. It leverages WebGPU for GPU
7//! acceleration and Web Workers for parallel processing.
8//!
9//! ## Features
10//!
11//! - **WebGPU acceleration**: GPU compute in the browser via WebGPU API
12//! - **Web Workers**: Multi-threaded inference using Web Workers
13//! - **Streaming inference**: Progressive token generation for chat applications
14//! - **Zero downloads**: Models run entirely in-browser (no server calls)
15//! - **Privacy-preserving**: All computation happens client-side
16//!
17//! ## Quick Start
18//!
19//! ```javascript
20//! import init, { Model, Tokenizer } from './trustformers_wasm.js';
21//!
22//! async function main() {
23//!   // Initialize the WASM module
24//!   await init();
25//!
26//!   // Load model and tokenizer
27//!   const model = await Model.from_pretrained("bert-base-uncased");
28//!   const tokenizer = await Tokenizer.from_pretrained("bert-base-uncased");
29//!
30//!   // Run inference
31//!   const text = "Hello, world!";
32//!   const tokens = tokenizer.encode(text);
33//!   const output = await model.forward(tokens);
34//!
35//!   console.log(output);
36//! }
37//! ```
38//!
39//! ## Architecture
40//!
41//! - **WASM Core**: Compiled Rust code for tensor operations
42//! - **WebGPU Backend**: GPU compute shaders for matrix operations
43//! - **Web Workers**: Parallel processing for batched inference
44//! - **Shared Memory**: Zero-copy data transfer between workers
45//!
46//! ## Performance
47//!
48//! - **WebGPU**: ~50-100x faster than CPU-only WASM
49//! - **SIMD**: Vectorized operations via WASM SIMD
50//! - **Streaming**: Progressive inference for lower latency
51//! - **Caching**: Model weights cached in IndexedDB
52//!
53//! ## Browser Support
54//!
55//! - Chrome/Edge 113+ (WebGPU)
56//! - Firefox 121+ (WebGPU experimental)
57//! - Safari 18+ (WebGPU preview)
58//!
59//! ## Build
60//!
61//! ```bash
62//! wasm-pack build --target web --features webgpu
63//! ```
64
65// Using std for wasm-bindgen compatibility and feature completeness
66// Note: While wasm-bindgen 0.2.100+ supports no_std, this codebase uses many std features
67// (HashMap, async/await, complex error types) that would require extensive refactoring
68// to work with no_std+alloc. The std overhead is negligible in WebAssembly contexts.
69
70// Allow excessive nesting for complex algorithms (matrix ops, GPU buffer management, etc.)
71#![allow(clippy::excessive_nesting)]
72
73use std::string::ToString;
74use std::vec::Vec;
75
76pub mod layers;
77pub mod models;
78#[cfg(feature = "web-workers")]
79pub mod runtime;
80
81// Import core modules from the core subdirectory
82pub mod core;
83pub use core::{model, pipeline, tensor, tokenizer, utils};
84
85// Compute modules
86pub mod compute;
87
88#[cfg(feature = "web-workers")]
89pub use compute::web_workers;
90
91#[cfg(feature = "shared-memory")]
92pub use compute::threads;
93
94#[cfg(feature = "webgpu")]
95pub use compute::webgpu_simple;
96
97#[cfg(feature = "webgpu")]
98pub use compute::webgpu;
99
100#[cfg(feature = "webgpu")]
101pub use compute::gpu_tensor;
102
103// Storage modules
104#[cfg(feature = "indexeddb")]
105pub mod storage;
106
107#[cfg(feature = "memory64")]
108pub use storage::memory64;
109
110#[cfg(feature = "streaming-loader")]
111pub use storage::streaming_loader;
112
113#[cfg(feature = "model-splitting")]
114pub use storage::model_splitting;
115
116#[cfg(feature = "react-components")]
117pub mod react_components;
118
119#[cfg(feature = "vue-components")]
120pub mod vue_components;
121
122#[cfg(feature = "angular-components")]
123pub mod angular_components;
124
125#[cfg(feature = "web-components")]
126pub mod web_components;
127
128#[cfg(feature = "playground")]
129pub mod playground;
130
131#[cfg(feature = "streaming-generation")]
132pub mod streaming_generation;
133
134#[cfg(feature = "mobile-optimization")]
135pub mod mobile;
136
137#[cfg(feature = "mobile-optimization")]
138pub mod touch_gestures;
139
140#[cfg(feature = "mobile-optimization")]
141pub mod camera_integration;
142
143#[cfg(feature = "mobile-optimization")]
144pub mod device_capability;
145
146#[cfg(feature = "mobile-optimization")]
147pub mod device_capability_detection;
148
149pub mod debug;
150pub mod error;
151pub mod events;
152pub mod export;
153// Import optimization modules from the optimization subdirectory
154pub mod optimization;
155pub use optimization::{
156    batch_processing, memory_pool, quantization, simd_tensor_ops, weight_compression,
157};
158
159pub mod auto_docs;
160pub mod multi_model_manager;
161pub mod performance;
162pub mod performance_profiler;
163pub mod plugin_framework;
164pub mod plugins;
165
166use std::sync::Mutex;
167use wasm_bindgen::prelude::*;
168
169// Global GPU memory tracking
170static GPU_MEMORY_TRACKER: Mutex<GpuMemoryTracker> = Mutex::new(GpuMemoryTracker::new());
171
172/// GPU memory tracker for monitoring WebGPU buffer allocations
173#[derive(Debug)]
174struct GpuMemoryTracker {
175    current_usage: usize,
176    peak_usage: usize,
177    total_allocated: usize,
178    total_deallocated: usize,
179    allocation_count: usize,
180    deallocation_count: usize,
181}
182
183impl GpuMemoryTracker {
184    const fn new() -> Self {
185        Self {
186            current_usage: 0,
187            peak_usage: 0,
188            total_allocated: 0,
189            total_deallocated: 0,
190            allocation_count: 0,
191            deallocation_count: 0,
192        }
193    }
194
195    fn allocate(&mut self, size: usize) {
196        self.current_usage += size;
197        self.total_allocated += size;
198        self.allocation_count += 1;
199
200        if self.current_usage > self.peak_usage {
201            self.peak_usage = self.current_usage;
202        }
203    }
204
205    fn deallocate(&mut self, size: usize) {
206        self.current_usage = self.current_usage.saturating_sub(size);
207        self.total_deallocated += size;
208        self.deallocation_count += 1;
209    }
210
211    fn get_current_usage(&self) -> usize {
212        self.current_usage
213    }
214
215    fn get_peak_usage(&self) -> usize {
216        self.peak_usage
217    }
218
219    fn reset_peak(&mut self) {
220        self.peak_usage = self.current_usage;
221    }
222}
223
224#[wasm_bindgen]
225pub struct TrustformersWasm {
226    initialized: bool,
227}
228
229impl Default for TrustformersWasm {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[wasm_bindgen]
236impl TrustformersWasm {
237    #[wasm_bindgen(constructor)]
238    pub fn new() -> Self {
239        #[cfg(feature = "console_panic")]
240        console_error_panic_hook::set_once();
241
242        TrustformersWasm { initialized: true }
243    }
244
245    #[wasm_bindgen(getter)]
246    pub fn version(&self) -> String {
247        "0.1.0".to_string()
248    }
249
250    #[wasm_bindgen(getter)]
251    pub fn initialized(&self) -> bool {
252        self.initialized
253    }
254}
255
256// Re-export main types
257pub use tensor::WasmTensor;
258
259#[cfg(feature = "webgpu")]
260pub use webgpu::{
261    AsyncExecutor, DeviceCapabilities, DeviceSelector, DeviceType, ExecutionStatus, FusableOp,
262    KernelFusion, OperationType, Priority, ShaderManager, WorkgroupTuner,
263};
264
265#[cfg(feature = "web-workers")]
266pub use runtime::edge_runtime::{
267    EdgeCapabilities, EdgeInferenceConfig, EdgeRuntime, EdgeRuntimeDetector,
268};
269
270#[cfg(feature = "web-workers")]
271pub use runtime::geo_distribution::{
272    create_geo_distribution_manager, estimate_network_latency, get_distance_between_points,
273    EdgeLocation, GeoDistributionManager, GeoRegion, RoutingDecision, RoutingWeights, UserLocation,
274};
275
276#[cfg(feature = "web-workers")]
277pub use runtime::edge_caching::{
278    create_edge_cache_manager, create_edge_computing_cache_config,
279    create_memory_efficient_cache_config, create_performance_cache_config, estimate_cache_overhead,
280    CacheConfig, CacheEntry, CacheEntryType, CacheStatistics, ConsistencyLevel, EdgeCacheManager,
281    EvictionPolicy, ReplicationStrategy,
282};
283
284#[cfg(feature = "web-workers")]
285pub use web_workers::{
286    get_optimal_worker_count, is_web_workers_supported, WorkerCoordinator, WorkerPool,
287    WorkerPriority, WorkerTaskType,
288};
289
290#[cfg(feature = "shared-memory")]
291pub use threads::{
292    get_optimal_thread_count, is_cross_origin_isolated, is_threading_supported, AtomicOperations,
293    ThreadPool, ThreadSync, ThreadTaskType,
294};
295
296#[cfg(feature = "indexeddb")]
297pub use storage::{CompressionType, ModelMetadata, ModelStorage, StoredModel};
298
299#[cfg(feature = "memory64")]
300pub use memory64::{
301    can_load_model_size, get_memory64_capabilities, is_memory64_supported, AllocationStrategy,
302    Memory64Capabilities, Memory64Manager,
303};
304
305#[cfg(feature = "streaming-loader")]
306pub use streaming_loader::{
307    get_optimal_chunk_size_kb, is_cache_api_available, is_streaming_compilation_supported,
308    LoadingProgress, StreamingConfig, StreamingLoader,
309};
310
311#[cfg(feature = "model-splitting")]
312pub use model_splitting::{
313    get_recommended_chunk_size_mb, should_split_model, ChunkConfig, ChunkPriority, ChunkType,
314    LoadingStrategy, ModelLoadingSession, ModelSplitter,
315};
316
317#[cfg(feature = "react-components")]
318pub use react_components::{
319    generate_react_package, is_react_available, ComponentType, InferenceState, ModelLoadingState,
320    ReactComponentFactory, ReactConfig,
321};
322
323#[cfg(feature = "vue-components")]
324pub use vue_components::{
325    generate_vue_package, is_vue_available, VueComponentFactory, VueComponentType, VueConfig,
326    VueInferenceState, VueModelState,
327};
328
329#[cfg(feature = "angular-components")]
330pub use angular_components::{
331    generate_angular_package, is_angular_available, AngularComponentType, AngularConfig,
332    AngularInferenceState, AngularModelState, AngularServiceFactory, AngularServiceType,
333};
334
335#[cfg(feature = "web-components")]
336pub use web_components::{
337    create_web_component_html_template, generate_web_components_package,
338    is_web_components_supported, ComponentType as WebComponentType,
339    InferenceState as WebInferenceState, ModelState as WebModelState, WebComponentConfig,
340    WebComponentFactory,
341};
342
343#[cfg(feature = "playground")]
344pub use playground::{
345    create_playground_config, create_playground_example, generate_playground_package,
346    ExampleCategory, InteractivePlayground, PlaygroundConfig, PlaygroundExample,
347};
348
349#[cfg(feature = "streaming-generation")]
350pub use streaming_generation::{
351    get_optimal_streaming_config, is_streaming_supported, CompletionReason, GenerationProgress,
352    StreamingConfig as GenerationStreamingConfig, StreamingGenerator, StreamingStats,
353    StreamingToken,
354};
355
356#[cfg(feature = "mobile-optimization")]
357pub use mobile::{
358    create_mobile_optimizer, get_device_memory_gb, get_optimal_model_for_device, is_low_data_mode,
359    is_mobile_device, is_tablet_device, AdaptiveModelConfig, BatteryInfo, BatteryStatus,
360    DeviceClass, MobileCapabilities, MobileOptimizer, ModelSize, NetworkStatus, NetworkType,
361};
362
363pub use auto_docs::{
364    create_default_doc_generator,
365    create_html_doc_generator,
366    create_markdown_doc_generator,
367    // Version and build information
368    get_version_info,
369    AutoDocGenerator,
370    DocConfig,
371    DocFormat,
372    DocTheme,
373    VersionInfo,
374};
375pub use batch_processing::{
376    BatchConfig, BatchProcessor, BatchResponse, BatchingStrategy, Priority as BatchPriority,
377};
378pub use debug::{DebugConfig, DebugLogger, LogLevel, PerformanceMetrics};
379pub use error::{
380    ErrorBuilder, ErrorCode, ErrorCollection, ErrorContext, ErrorHandler, ErrorSeverity,
381    TrustformersError, TrustformersResult,
382};
383pub use events::{EventData, EventEmittable, EventManager, EventPriority, EventType};
384pub use multi_model_manager::{
385    create_development_multi_model_manager, create_production_multi_model_manager,
386    DeploymentEnvironment, ModelPriority, ModelStatus, MultiModelConfig, MultiModelManager,
387};
388pub use performance::{
389    BottleneckType, OperationType as ProfilerOperationType, ProfilerConfig, ResourceType,
390};
391pub use performance_profiler::{
392    create_development_profiler, create_production_profiler, PerformanceProfiler,
393};
394pub use plugin_framework::{
395    create_default_plugin_config, create_plugin_context, ExecutionMetrics, ExecutionPriority,
396    ModelMetadata as PluginModelMetadata, PerformanceBudget, Plugin, PluginConfig, PluginContext,
397    PluginError, PluginErrorCode, PluginManager, PluginMetadata, PluginPermission, PluginRegistry,
398    PluginResult, PluginType, ResourceLimits,
399};
400pub use plugins::{ModelOptimizerPlugin, TextProcessorPlugin, VisualizationPlugin};
401pub use quantization::{
402    QuantizationConfig, QuantizationPrecision, QuantizationStrategy, QuantizedModelData,
403    WebQuantizer,
404};
405pub use weight_compression::{
406    CompressedModelData, CompressionConfig, CompressionLevel, CompressionStrategy, SparsityPattern,
407    WeightCompressor,
408};
409
410#[wasm_bindgen]
411pub fn init_panic_hook() {
412    #[cfg(feature = "console_panic")]
413    console_error_panic_hook::set_once();
414}
415
416// Utility functions for WebAssembly
417#[wasm_bindgen]
418pub fn get_wasm_memory_usage() -> usize {
419    // Get the current memory pages and convert to bytes
420    // Each WASM memory page is 64KB
421    let memory = wasm_bindgen::memory();
422    let memory_obj: &js_sys::WebAssembly::Memory = memory.unchecked_ref();
423    let buffer = js_sys::WebAssembly::Memory::buffer(memory_obj);
424    let array_buffer: &js_sys::ArrayBuffer = buffer.unchecked_ref();
425    js_sys::ArrayBuffer::byte_length(array_buffer) as usize
426}
427
428/// Memory usage statistics
429#[wasm_bindgen]
430#[derive(Debug, Clone)]
431pub struct MemoryStats {
432    wasm_memory: usize,
433    gpu_memory: usize,
434    peak_gpu_memory: usize,
435}
436
437#[wasm_bindgen]
438impl MemoryStats {
439    #[wasm_bindgen(getter)]
440    pub fn wasm_memory(&self) -> usize {
441        self.wasm_memory
442    }
443
444    #[wasm_bindgen(getter)]
445    pub fn gpu_memory(&self) -> usize {
446        self.gpu_memory
447    }
448
449    #[wasm_bindgen(getter)]
450    pub fn peak_gpu_memory(&self) -> usize {
451        self.peak_gpu_memory
452    }
453
454    #[wasm_bindgen(getter)]
455    pub fn used_mb(&self) -> f32 {
456        self.wasm_memory as f32 / (1024.0 * 1024.0)
457    }
458
459    #[wasm_bindgen(getter)]
460    pub fn limit_mb(&self) -> f32 {
461        // Return a reasonable WASM memory limit (256MB typical)
462        256.0
463    }
464}
465
466/// Track GPU memory allocation (called by WebGPU backend)
467#[wasm_bindgen]
468pub fn track_gpu_allocation(size: usize) {
469    if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
470        tracker.allocate(size);
471    }
472}
473
474/// Track GPU memory deallocation (called by WebGPU backend)
475#[wasm_bindgen]
476pub fn track_gpu_deallocation(size: usize) {
477    if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
478        tracker.deallocate(size);
479    }
480}
481
482/// Get current GPU memory usage
483#[wasm_bindgen]
484pub fn get_gpu_memory_usage() -> usize {
485    GPU_MEMORY_TRACKER
486        .lock()
487        .map(|tracker| tracker.get_current_usage())
488        .unwrap_or(0)
489}
490
491/// Get peak GPU memory usage
492#[wasm_bindgen]
493pub fn get_peak_gpu_memory_usage() -> usize {
494    GPU_MEMORY_TRACKER.lock().map(|tracker| tracker.get_peak_usage()).unwrap_or(0)
495}
496
497/// Reset peak GPU memory usage tracking
498#[wasm_bindgen]
499pub fn reset_peak_gpu_memory() {
500    if let Ok(mut tracker) = GPU_MEMORY_TRACKER.lock() {
501        tracker.reset_peak();
502    }
503}
504
505/// Get comprehensive memory statistics
506#[wasm_bindgen]
507pub fn get_memory_stats() -> MemoryStats {
508    MemoryStats {
509        wasm_memory: get_wasm_memory_usage(),
510        gpu_memory: get_gpu_memory_usage(),
511        peak_gpu_memory: get_peak_gpu_memory_usage(),
512    }
513}
514
515#[wasm_bindgen]
516pub fn enable_simd() -> bool {
517    // Check if SIMD is available in the WASM environment
518    cfg!(target_feature = "simd128")
519}
520
521// Enhanced inference API with automatic device selection
522#[wasm_bindgen]
523pub struct InferenceSession {
524    model_type: String,
525    #[cfg(feature = "webgpu")]
526    device_selector: Option<webgpu::DeviceSelector>,
527    #[cfg(feature = "webgpu")]
528    current_device: webgpu::DeviceType,
529    #[cfg(feature = "web-workers")]
530    edge_detector: runtime::edge_runtime::EdgeRuntimeDetector,
531    #[cfg(feature = "web-workers")]
532    edge_config: runtime::edge_runtime::EdgeInferenceConfig,
533    #[cfg(feature = "indexeddb")]
534    storage: Option<storage::ModelStorage>,
535    debug_logger: Option<debug::DebugLogger>,
536    quantizer: Option<quantization::WebQuantizer>,
537    batch_processor: Option<batch_processing::BatchProcessor>,
538    event_emitter: Option<events::EventEmitter>,
539}
540
541#[wasm_bindgen]
542impl InferenceSession {
543    #[wasm_bindgen(constructor)]
544    pub fn new(model_type: String) -> Result<InferenceSession, JsValue> {
545        #[cfg(feature = "web-workers")]
546        let edge_detector = runtime::edge_runtime::EdgeRuntimeDetector::new();
547        #[cfg(feature = "web-workers")]
548        let edge_config =
549            runtime::edge_runtime::EdgeInferenceConfig::for_runtime(&edge_detector.capabilities());
550
551        Ok(InferenceSession {
552            model_type,
553            #[cfg(feature = "webgpu")]
554            device_selector: None,
555            #[cfg(feature = "webgpu")]
556            current_device: webgpu::DeviceType::CPU,
557            #[cfg(feature = "web-workers")]
558            edge_detector,
559            #[cfg(feature = "web-workers")]
560            edge_config,
561            #[cfg(feature = "indexeddb")]
562            storage: None,
563            debug_logger: None,
564            quantizer: None,
565            batch_processor: None,
566            event_emitter: None,
567        })
568    }
569
570    /// Initialize with automatic device selection
571    pub async fn initialize_with_auto_device(&mut self) -> Result<(), JsValue> {
572        #[cfg(feature = "webgpu")]
573        {
574            let mut selector = webgpu::DeviceSelector::new().await?;
575            selector.initialize_device().await?;
576            self.current_device = selector.selected_device();
577            self.device_selector = Some(selector);
578        }
579        Ok(())
580    }
581
582    /// Get the currently selected device type
583    #[cfg(feature = "webgpu")]
584    #[wasm_bindgen(getter)]
585    pub fn current_device_type(&self) -> DeviceType {
586        self.current_device
587    }
588
589    /// Get device capabilities
590    #[cfg(feature = "webgpu")]
591    pub fn get_device_capabilities(&self) -> Option<DeviceCapabilities> {
592        self.device_selector.as_ref().map(|s| s.capabilities())
593    }
594
595    pub async fn load_model(&mut self, model_data: &[u8]) -> Result<(), JsValue> {
596        use crate::error::{ErrorBuilder, ErrorCode};
597
598        let model_size = model_data.len();
599        let model_size_mb = model_size as f64 / 1024.0 / 1024.0;
600
601        // Emit model load start event
602        if let Some(ref mut emitter) = self.event_emitter {
603            let event = events::EventData::model_load_start(&self.model_type, model_size_mb);
604            emitter.emit(event);
605        }
606
607        // Validate model data
608        if model_data.is_empty() {
609            let error = ErrorBuilder::new(ErrorCode::E1001, "Model data is empty")
610                .operation("load_model")
611                .component("inference_session")
612                .build();
613
614            // Emit error event
615            if let Some(ref mut emitter) = self.event_emitter {
616                let event = events::EventData::error_occurred(&error.message, "load_model");
617                emitter.emit(event);
618            }
619
620            return Err(error.into());
621        }
622
623        // Check if model is too large (> 2GB)
624        if model_size > 2_147_483_648 {
625            let error = ErrorBuilder::new(ErrorCode::E1002, "Model exceeds maximum size limit")
626                .operation("load_model")
627                .component("inference_session")
628                .memory_usage_mb(model_size_mb)
629                .additional_info("Consider using quantization or model splitting")
630                .build();
631
632            // Emit error event
633            if let Some(ref mut emitter) = self.event_emitter {
634                let event = events::EventData::error_occurred(&error.message, "load_model")
635                    .with_data("size_mb", &format!("{model_size_mb:.2}"));
636                emitter.emit(event);
637            }
638
639            return Err(error.into());
640        }
641
642        let start_time = js_sys::Date::now();
643
644        // Debug logging
645        if let Some(ref mut logger) = self.debug_logger {
646            logger.start_timer("model_loading");
647            logger.log_model_loading(&self.model_type, model_size, "memory");
648            logger.log_memory_usage("Before model loading");
649        }
650
651        // Check available memory before loading
652        let memory_stats = crate::get_memory_stats();
653        let available_memory = 2_147_483_648; // 2GB assumed available
654        if model_size > available_memory {
655            return Err(
656                ErrorBuilder::new(ErrorCode::E4002, "Insufficient memory for model")
657                    .operation("load_model")
658                    .component("inference_session")
659                    .memory_usage_mb(memory_stats.wasm_memory as f64 / 1024.0 / 1024.0)
660                    .additional_info("Try enabling quantization or using a smaller model")
661                    .build()
662                    .into(),
663            );
664        }
665
666        #[cfg(feature = "webgpu")]
667        if let Some(selector) = &self.device_selector {
668            let should_use_gpu = selector.should_use_gpu(model_size, 0.8); // High complexity for model loading
669
670            if should_use_gpu {
671                // GPU-optimized model loading
672                web_sys::console::log_1(
673                    &format!("Loading model on GPU (size: {model_size} bytes)").into(),
674                );
675
676                // Simulate GPU model loading validation
677                let gpu_memory_required = model_size * 2; // Assume 2x memory needed for GPU
678                let capabilities = selector.capabilities();
679                if gpu_memory_required > capabilities.gpu_memory_limit as usize {
680                    return Err(
681                        ErrorBuilder::new(ErrorCode::E3003, "Insufficient GPU memory")
682                            .operation("load_model")
683                            .component("webgpu_device")
684                            .device_type("GPU")
685                            .memory_usage_mb(gpu_memory_required as f64 / 1024.0 / 1024.0)
686                            .additional_info("Try using CPU mode or enabling quantization")
687                            .build()
688                            .into(),
689                    );
690                }
691            } else {
692                // CPU-optimized model loading
693                web_sys::console::log_1(
694                    &format!("Loading model on CPU (size: {model_size} bytes)").into(),
695                );
696            }
697        } else {
698            web_sys::console::log_1(
699                &format!("Loading model on CPU (size: {model_size} bytes)").into(),
700            );
701        }
702
703        #[cfg(not(feature = "webgpu"))]
704        web_sys::console::log_1(&format!("Loading model on CPU (size: {model_size} bytes)").into());
705
706        // Simulate model loading validation
707        if model_size < 1024 {
708            if let Some(ref mut logger) = self.debug_logger {
709                logger.warn(
710                    "Model size is very small, may not be a valid model",
711                    "model_validation",
712                );
713            }
714        }
715
716        // Complete debug logging
717        if let Some(ref mut logger) = self.debug_logger {
718            logger.log_memory_usage("After model loading");
719            logger.end_timer("model_loading");
720            logger.info(
721                &format!(
722                    "Model loaded successfully (size: {} MB)",
723                    model_size / 1024 / 1024
724                ),
725                "model_loading",
726            );
727        }
728
729        // Emit model load complete event
730        let duration_ms = js_sys::Date::now() - start_time;
731        if let Some(ref mut emitter) = self.event_emitter {
732            let event = events::EventData::model_load_complete(&self.model_type, duration_ms);
733            emitter.emit(event);
734        }
735
736        Ok(())
737    }
738
739    pub fn predict(&mut self, input: &tensor::WasmTensor) -> Result<tensor::WasmTensor, JsValue> {
740        let input_size = input.len();
741        let start_time = js_sys::Date::now();
742
743        // Emit inference start event
744        if let Some(ref mut emitter) = self.event_emitter {
745            let event = events::EventData::inference_start(&input.shape());
746            emitter.emit(event);
747        }
748
749        // Debug logging for inference
750        if let Some(ref mut logger) = self.debug_logger {
751            logger.start_timer("inference");
752            logger.log_inference(&self.model_type, &input.shape(), "auto");
753            logger.log_memory_usage("Before inference");
754        }
755
756        #[cfg(feature = "webgpu")]
757        if let Some(selector) = &self.device_selector {
758            let should_use_gpu = selector.should_use_gpu(input_size, 0.6); // Medium complexity for inference
759
760            if should_use_gpu {
761                web_sys::console::log_1(
762                    &format!("Running inference on GPU (input size: {input_size})").into(),
763                );
764                // GPU-accelerated prediction would go here
765            } else {
766                web_sys::console::log_1(
767                    &format!("Running inference on CPU (input size: {input_size})").into(),
768                );
769                // CPU-optimized prediction would go here
770            }
771        } else {
772            web_sys::console::log_1(
773                &format!("Running inference on CPU (input size: {input_size})").into(),
774            );
775        }
776
777        #[cfg(not(feature = "webgpu"))]
778        web_sys::console::log_1(
779            &format!("Running inference on CPU (input size: {input_size})").into(),
780        );
781
782        let result = input.clone();
783
784        // Complete debug logging for inference
785        if let Some(ref mut logger) = self.debug_logger {
786            logger.log_memory_usage("After inference");
787            logger.end_timer("inference");
788        }
789
790        // Emit inference complete event
791        let duration_ms = js_sys::Date::now() - start_time;
792        if let Some(ref mut emitter) = self.event_emitter {
793            let event = events::EventData::inference_complete(duration_ms, result.len());
794            emitter.emit(event);
795        }
796
797        Ok(result)
798    }
799
800    /// Force device selection for testing
801    #[cfg(feature = "webgpu")]
802    pub fn force_device_type(&mut self, device_type: DeviceType) {
803        self.current_device = device_type;
804    }
805
806    /// Get edge runtime capabilities
807    #[cfg(feature = "web-workers")]
808    #[wasm_bindgen(getter)]
809    pub fn edge_capabilities(&self) -> EdgeCapabilities {
810        self.edge_detector.capabilities()
811    }
812
813    /// Get edge inference configuration
814    #[cfg(feature = "web-workers")]
815    #[wasm_bindgen(getter)]
816    pub fn edge_config(&self) -> EdgeInferenceConfig {
817        self.edge_config.clone()
818    }
819
820    /// Check if the current edge runtime is suitable for ML inference
821    #[cfg(feature = "web-workers")]
822    pub fn is_edge_suitable(&self) -> bool {
823        self.edge_detector.is_ml_suitable()
824    }
825
826    /// Get recommended model size for current edge runtime
827    #[cfg(feature = "web-workers")]
828    pub fn recommended_model_size_mb(&self) -> u32 {
829        self.edge_detector.recommended_model_size_mb()
830    }
831
832    /// Get cold start optimization recommendations
833    #[cfg(feature = "web-workers")]
834    pub fn get_cold_start_optimizations(&self) -> js_sys::Array {
835        self.edge_detector.get_cold_start_optimizations()
836    }
837
838    /// Initialize model storage with optional maximum size in MB
839    #[cfg(feature = "indexeddb")]
840    pub async fn initialize_storage(&mut self, max_storage_mb: f64) -> Result<(), JsValue> {
841        let mut storage =
842            storage::ModelStorage::new("trustformers-models".to_string(), max_storage_mb);
843        storage.initialize().await?;
844        self.storage = Some(storage);
845        web_sys::console::log_1(
846            &format!("Initialized model storage (max: {max_storage_mb} MB)").into(),
847        );
848        Ok(())
849    }
850
851    /// Store a model in IndexedDB for caching
852    #[cfg(feature = "indexeddb")]
853    pub async fn store_model(
854        &self,
855        model_id: &str,
856        model_name: &str,
857        architecture: &str,
858        version: &str,
859        data: &[u8],
860    ) -> Result<(), JsValue> {
861        let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
862        storage.store_model(model_id, model_name, architecture, version, data).await
863    }
864
865    /// Load a model from IndexedDB cache
866    #[cfg(feature = "indexeddb")]
867    pub async fn load_cached_model(&self, model_id: &str) -> Result<Option<Vec<u8>>, JsValue> {
868        let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
869        storage.get_model(model_id).await
870    }
871
872    /// Check if a model exists in cache
873    #[cfg(feature = "indexeddb")]
874    pub async fn has_cached_model(&self, model_id: &str) -> Result<bool, JsValue> {
875        let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
876        storage.has_model(model_id).await
877    }
878
879    /// Load model with automatic caching
880    #[cfg(feature = "indexeddb")]
881    pub async fn load_model_with_cache(
882        &mut self,
883        model_id: &str,
884        model_url: &str,
885        model_name: &str,
886        architecture: &str,
887        version: &str,
888    ) -> Result<(), JsValue> {
889        if let Some(storage) = &self.storage {
890            // Try to load from cache first
891            if let Some(cached_data) = storage.get_model(model_id).await? {
892                web_sys::console::log_1(&format!("Loaded model '{model_name}' from cache").into());
893                return self.load_model(&cached_data).await;
894            }
895        }
896
897        // Download from URL if not in cache
898        web_sys::console::log_1(
899            &format!("Downloading model '{model_name}' from {model_url}").into(),
900        );
901
902        let model_data = self.fetch_model_from_url(model_url).await.map_err(|e| {
903            JsValue::from_str(&format!(
904                "Failed to fetch model from {}: {:?}",
905                model_url, e
906            ))
907        })?;
908
909        // Store in cache if storage is available
910        if let Some(storage) = &self.storage {
911            storage
912                .store_model(model_id, model_name, architecture, version, &model_data)
913                .await?;
914        }
915
916        self.load_model(&model_data).await
917    }
918
919    /// Get storage usage statistics
920    #[cfg(feature = "indexeddb")]
921    pub async fn get_storage_stats(&self) -> Result<Option<String>, JsValue> {
922        if let Some(storage) = &self.storage {
923            let usage = storage.get_storage_usage().await?;
924            let models_js = storage.list_models().await?;
925            let models_count = if let Ok(models) =
926                serde_wasm_bindgen::from_value::<Vec<ModelMetadata>>(models_js)
927            {
928                models.len()
929            } else {
930                0
931            };
932            Ok(Some(format!(
933                "Storage: {} bytes, {} models",
934                usage, models_count
935            )))
936        } else {
937            Ok(None)
938        }
939    }
940
941    /// Clear all cached models
942    #[cfg(feature = "indexeddb")]
943    pub async fn clear_model_cache(&self) -> Result<(), JsValue> {
944        let storage = self.storage.as_ref().ok_or("Storage not initialized")?;
945        storage.clear_all().await
946    }
947
948    /// Fetch model data from a remote URL using the Fetch API
949    #[allow(dead_code)]
950    async fn fetch_model_from_url(&self, url: &str) -> Result<Vec<u8>, JsValue> {
951        use wasm_bindgen::JsCast;
952        use wasm_bindgen_futures::JsFuture;
953
954        // Create fetch request
955        let request = web_sys::Request::new_with_str(url)?;
956        request.headers().set("Accept", "application/octet-stream")?;
957
958        // Get the global window object
959        let window = web_sys::window().ok_or("No global window object")?;
960
961        // Perform the fetch
962        let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
963        let resp: web_sys::Response = resp_value.dyn_into()?;
964
965        // Check if the response is ok
966        if !resp.ok() {
967            return Err(JsValue::from_str(&format!(
968                "HTTP error: {} {}",
969                resp.status(),
970                resp.status_text()
971            )));
972        }
973
974        // Get the response as ArrayBuffer
975        let array_buffer = JsFuture::from(resp.array_buffer()?).await?;
976        let uint8_array = js_sys::Uint8Array::new(&array_buffer);
977
978        // Convert to Vec<u8>
979        let mut data = vec![0u8; uint8_array.length() as usize];
980        uint8_array.copy_to(&mut data);
981
982        web_sys::console::log_1(
983            &format!("Successfully downloaded {len} bytes", len = data.len()).into(),
984        );
985
986        Ok(data)
987    }
988
989    /// Initialize debug logging
990    pub fn enable_debug_logging(&mut self, config: debug::DebugConfig) {
991        self.debug_logger = Some(debug::DebugLogger::new(config));
992        if let Some(ref mut logger) = self.debug_logger {
993            logger.info(
994                "Debug logging enabled for inference session",
995                "initialization",
996            );
997        }
998    }
999
1000    /// Disable debug logging
1001    pub fn disable_debug_logging(&mut self) {
1002        if let Some(ref mut logger) = self.debug_logger {
1003            logger.info("Debug logging disabled", "cleanup");
1004        }
1005        self.debug_logger = None;
1006    }
1007
1008    /// Start a performance timer
1009    pub fn start_timer(&mut self, operation: &str) {
1010        if let Some(ref mut logger) = self.debug_logger {
1011            logger.start_timer(operation);
1012        }
1013    }
1014
1015    /// End a performance timer
1016    pub fn end_timer(&mut self, operation: &str) -> Option<f64> {
1017        if let Some(ref mut logger) = self.debug_logger {
1018            logger.end_timer(operation)
1019        } else {
1020            None
1021        }
1022    }
1023
1024    /// Get debug performance summary
1025    pub fn get_performance_summary(&self) -> Option<String> {
1026        self.debug_logger.as_ref().map(|logger| logger.get_performance_summary())
1027    }
1028
1029    /// Export debug logs
1030    pub fn export_debug_logs(&self) -> Option<String> {
1031        self.debug_logger.as_ref().map(|logger| logger.export_logs())
1032    }
1033
1034    /// Clear debug logs
1035    pub fn clear_debug_logs(&mut self) {
1036        if let Some(ref mut logger) = self.debug_logger {
1037            logger.clear();
1038        }
1039    }
1040
1041    /// Log memory usage with context
1042    pub fn log_memory_usage(&mut self, context: &str) {
1043        if let Some(ref mut logger) = self.debug_logger {
1044            logger.log_memory_usage(context);
1045        }
1046    }
1047
1048    /// Initialize quantization with configuration
1049    pub fn enable_quantization(&mut self, config: quantization::QuantizationConfig) {
1050        self.quantizer = Some(quantization::WebQuantizer::new(config));
1051        if let Some(ref mut logger) = self.debug_logger {
1052            logger.info("Quantization enabled for inference session", "quantization");
1053        }
1054    }
1055
1056    /// Disable quantization
1057    pub fn disable_quantization(&mut self) {
1058        if let Some(ref mut logger) = self.debug_logger {
1059            logger.info("Quantization disabled", "quantization");
1060        }
1061        self.quantizer = None;
1062    }
1063
1064    /// Load model with automatic quantization
1065    pub async fn load_model_with_quantization(&mut self, model_data: &[u8]) -> Result<(), JsValue> {
1066        let mut final_data = model_data.to_vec();
1067        let mut quantization_summary = "No quantization applied".to_string();
1068
1069        // Apply quantization if enabled and beneficial
1070        if let Some(ref quantizer) = self.quantizer {
1071            if quantizer.should_quantize(model_data.len()) {
1072                if let Some(ref mut logger) = self.debug_logger {
1073                    logger.start_timer("quantization");
1074                    logger.info(
1075                        &format!(
1076                            "Starting quantization for {len} bytes",
1077                            len = model_data.len()
1078                        ),
1079                        "quantization",
1080                    );
1081                }
1082
1083                match quantizer.quantize_model(model_data) {
1084                    Ok(quantized) => {
1085                        final_data = quantized.data();
1086                        quantization_summary = quantized.summary();
1087                        self.log_quantization_success(&quantization_summary);
1088                    },
1089                    Err(e) => {
1090                        let error_msg = format!("Quantization failed: {e:?}, using original model");
1091                        self.log_quantization_failure(&error_msg);
1092                        // Continue with original data if quantization fails
1093                    },
1094                }
1095            } else {
1096                self.log_quantization_skipped();
1097            }
1098        }
1099
1100        // Load the (possibly quantized) model
1101        self.load_model(&final_data).await?;
1102
1103        if let Some(ref mut logger) = self.debug_logger {
1104            logger.info(
1105                &format!("Model loaded: {quantization_summary}"),
1106                "model_loading",
1107            );
1108        }
1109
1110        Ok(())
1111    }
1112
1113    /// Get quantization recommendations for current model
1114    pub fn get_quantization_recommendations(
1115        &self,
1116        model_size_bytes: usize,
1117    ) -> Option<quantization::QuantizationConfig> {
1118        self.quantizer.as_ref().map(|q| q.get_recommended_settings(model_size_bytes))
1119    }
1120
1121    /// Check if quantization would be beneficial for a given model size
1122    pub fn should_quantize_model(&self, model_size_bytes: usize) -> bool {
1123        self.quantizer.as_ref().is_some_and(|q| q.should_quantize(model_size_bytes))
1124    }
1125
1126    /// Initialize batch processing with configuration
1127    pub fn enable_batch_processing(&mut self, config: batch_processing::BatchConfig) {
1128        self.batch_processor = Some(batch_processing::BatchProcessor::new(config));
1129        if let Some(ref mut logger) = self.debug_logger {
1130            logger.info(
1131                "Batch processing enabled for inference session",
1132                "batch_processing",
1133            );
1134        }
1135    }
1136
1137    /// Disable batch processing
1138    pub fn disable_batch_processing(&mut self) {
1139        if let Some(ref mut logger) = self.debug_logger {
1140            logger.info("Batch processing disabled", "batch_processing");
1141        }
1142        self.batch_processor = None;
1143    }
1144
1145    /// Add a request to the batch queue
1146    pub fn add_batch_request(
1147        &mut self,
1148        input: &tensor::WasmTensor,
1149        priority: batch_processing::Priority,
1150        timeout_ms: Option<u32>,
1151    ) -> Option<String> {
1152        if let Some(ref mut processor) = self.batch_processor {
1153            let request_id = processor.add_request(input.clone(), priority, timeout_ms);
1154
1155            if let Some(ref mut logger) = self.debug_logger {
1156                logger.debug(
1157                    &format!("Added batch request {request_id} with priority {priority:?}"),
1158                    "batch_processing",
1159                );
1160            }
1161
1162            Some(request_id)
1163        } else {
1164            if let Some(ref mut logger) = self.debug_logger {
1165                logger.warn("Batch processing not enabled", "batch_processing");
1166            }
1167            None
1168        }
1169    }
1170
1171    /// Process pending batch requests
1172    pub async fn process_batch(&mut self) -> Result<Vec<batch_processing::BatchResponse>, JsValue> {
1173        if let Some(ref mut processor) = self.batch_processor {
1174            if let Some(ref mut logger) = self.debug_logger {
1175                logger.start_timer("batch_processing");
1176                logger.debug(
1177                    &format!(
1178                        "Processing batch with {} pending requests",
1179                        processor.queue_length()
1180                    ),
1181                    "batch_processing",
1182                );
1183            }
1184
1185            let responses = processor.process_batch().await?;
1186
1187            if let Some(ref mut logger) = self.debug_logger {
1188                logger.info(
1189                    &format!("Processed batch: {len} responses", len = responses.len()),
1190                    "batch_processing",
1191                );
1192                logger.end_timer("batch_processing");
1193            }
1194
1195            Ok(responses)
1196        } else {
1197            Err("Batch processing not enabled".into())
1198        }
1199    }
1200
1201    /// Check if a batch is ready for processing
1202    pub fn is_batch_ready(&self) -> bool {
1203        self.batch_processor.as_ref().is_some_and(|p| p.is_batch_ready())
1204    }
1205
1206    /// Get current batch queue length
1207    pub fn get_batch_queue_length(&self) -> usize {
1208        self.batch_processor.as_ref().map_or(0, |p| p.queue_length())
1209    }
1210
1211    /// Get batch processing statistics
1212    pub fn get_batch_stats(&self) -> Option<String> {
1213        self.batch_processor.as_ref().map(|p| p.get_stats())
1214    }
1215
1216    /// Clear the batch queue
1217    pub fn clear_batch_queue(&mut self) {
1218        if let Some(ref mut processor) = self.batch_processor {
1219            processor.clear_queue();
1220            if let Some(ref mut logger) = self.debug_logger {
1221                logger.info("Batch queue cleared", "batch_processing");
1222            }
1223        }
1224    }
1225
1226    /// Enable event system
1227    pub fn enable_events(&mut self) {
1228        self.event_emitter = Some(events::EventEmitter::new());
1229        if let Some(ref mut logger) = self.debug_logger {
1230            logger.info("Event system enabled for inference session", "events");
1231        }
1232    }
1233
1234    /// Disable event system
1235    pub fn disable_events(&mut self) {
1236        if let Some(ref mut logger) = self.debug_logger {
1237            logger.info("Event system disabled", "events");
1238        }
1239        self.event_emitter = None;
1240    }
1241
1242    /// Get event history as JSON
1243    pub fn get_event_history(&self) -> Option<String> {
1244        self.event_emitter
1245            .as_ref()
1246            .and_then(|emitter| serde_json::to_string(emitter.get_history()).ok())
1247    }
1248
1249    /// Clear event history
1250    pub fn clear_event_history(&mut self) {
1251        if let Some(ref mut emitter) = self.event_emitter {
1252            emitter.clear_history();
1253        }
1254    }
1255
1256    /// Emit a custom event
1257    pub fn emit_custom_event(&mut self, event_type: u32, source: &str, data: Option<String>) {
1258        if let Some(ref mut emitter) = self.event_emitter {
1259            if let Ok(event_type) = Self::event_type_from_u32(event_type) {
1260                let mut event = events::EventData::new(event_type, source);
1261                if let Some(data_str) = data {
1262                    event = event.with_data("custom_data", &data_str);
1263                }
1264                emitter.emit(event);
1265            }
1266        }
1267    }
1268
1269    fn event_type_from_u32(event_type: u32) -> Result<events::EventType, ()> {
1270        match event_type {
1271            1000 => Ok(events::EventType::ModelLoadStart),
1272            1001 => Ok(events::EventType::ModelLoadProgress),
1273            1002 => Ok(events::EventType::ModelLoadComplete),
1274            1003 => Ok(events::EventType::ModelLoadError),
1275            2000 => Ok(events::EventType::InferenceStart),
1276            2001 => Ok(events::EventType::InferenceProgress),
1277            2002 => Ok(events::EventType::InferenceComplete),
1278            2003 => Ok(events::EventType::InferenceError),
1279            6000 => Ok(events::EventType::ErrorOccurred),
1280            _ => Err(()),
1281        }
1282    }
1283
1284    /// Single inference with automatic batching support
1285    pub async fn predict_with_batching(
1286        &mut self,
1287        input: &tensor::WasmTensor,
1288        priority: batch_processing::Priority,
1289    ) -> Result<tensor::WasmTensor, JsValue> {
1290        if let Some(ref mut processor) = self.batch_processor {
1291            // Add to batch queue
1292            let request_id = processor.add_request(input.clone(), priority, None);
1293
1294            // Process if batch is ready or if it's a high-priority request
1295            if processor.is_batch_ready() || priority >= batch_processing::Priority::High {
1296                let responses = processor.process_batch().await?;
1297
1298                // Find our response
1299                return self.find_batch_response(&responses, &request_id);
1300            }
1301
1302            // If not processed in batch, fall back to direct prediction
1303            return self.predict(input);
1304        }
1305
1306        // No batch processing, use direct prediction
1307        self.predict(input)
1308    }
1309
1310    /// Process all pending batches
1311    pub async fn flush_batches(&mut self) -> Result<Vec<batch_processing::BatchResponse>, JsValue> {
1312        if let Some(ref mut processor) = self.batch_processor {
1313            let mut all_responses = Vec::new();
1314
1315            while processor.queue_length() > 0 {
1316                let responses = processor.process_batch().await?;
1317                all_responses.extend(responses);
1318            }
1319
1320            Ok(all_responses)
1321        } else {
1322            Ok(Vec::new())
1323        }
1324    }
1325
1326    // Helper methods to reduce nesting
1327    fn log_quantization_success(&mut self, summary: &str) {
1328        if let Some(ref mut logger) = self.debug_logger {
1329            logger.info(summary, "quantization");
1330            logger.end_timer("quantization");
1331        }
1332    }
1333
1334    fn log_quantization_failure(&mut self, error: &str) {
1335        if let Some(ref mut logger) = self.debug_logger {
1336            logger.warn(error, "quantization");
1337            logger.end_timer("quantization");
1338        }
1339    }
1340
1341    fn log_quantization_skipped(&mut self) {
1342        if let Some(ref mut logger) = self.debug_logger {
1343            logger.info(
1344                "Model size is optimal, skipping quantization",
1345                "quantization",
1346            );
1347        }
1348    }
1349
1350    fn find_batch_response(
1351        &self,
1352        responses: &[batch_processing::BatchResponse],
1353        request_id: &str,
1354    ) -> Result<tensor::WasmTensor, JsValue> {
1355        for response in responses.iter() {
1356            if response.request_id() == request_id {
1357                if let Some(_result) = response.result() {
1358                    // Convert String result to WasmTensor - in a real implementation this would parse the result
1359                    // For now, create a dummy tensor with the result as metadata
1360                    return tensor::WasmTensor::new(vec![1.0], vec![1]);
1361                } else if let Some(error) = response.error() {
1362                    return Err(error.into());
1363                }
1364            }
1365        }
1366        Err("Response not found for request ID".into())
1367    }
1368}
1369
1370#[cfg(test)]
1371mod tests {
1372    use super::*;
1373
1374    #[test]
1375    fn test_initialization() {
1376        let tf = TrustformersWasm::new();
1377        assert!(tf.initialized());
1378        assert_eq!(tf.version(), "0.1.0");
1379    }
1380}