Skip to main content

torsh_core/
examples.rs

1//! Comprehensive examples for torsh-core modules
2//!
3//! This module provides real-world examples and tutorials for using
4//! the core functionality of the ToRSh tensor library.
5
6use crate::{
7    backend_detection::BackendFeatureDetector,
8    device::CpuDevice,
9    memory_monitor::{MemoryPressure, SystemMemoryMonitor},
10    ConversionUtils, DType, Device, InteropDocs, NumpyArrayInfo, Result, Shape,
11};
12
13/// Examples for device operations
14pub struct DeviceExamples;
15
16impl DeviceExamples {
17    /// Basic device creation and usage
18    pub fn basic_device_usage() -> Result<()> {
19        // Create a CPU device
20        let cpu_device = CpuDevice::new();
21        println!("Created CPU device: {:?}", cpu_device.device_type());
22
23        // Check device availability
24        if cpu_device.is_available().unwrap_or(false) {
25            println!("Device is available for computation");
26        }
27
28        // Synchronize device operations
29        cpu_device.synchronize()?;
30        println!("Device synchronized successfully");
31
32        Ok(())
33    }
34
35    /// Device capability detection
36    pub fn device_capabilities() -> Result<()> {
37        let detector = BackendFeatureDetector::new()?;
38
39        // Access system capabilities
40        let runtime_features = &detector.runtime_features;
41        println!("Runtime features: {runtime_features:#?}");
42
43        // Access available devices
44        let available_devices = &detector.available_devices;
45        println!("Available devices: {available_devices:#?}");
46
47        // Check specific capabilities
48        if runtime_features.cpu_features.simd.avx2 {
49            println!("AVX2 is available for accelerated operations");
50        }
51
52        if runtime_features.cpu_features.simd.neon {
53            println!("ARM NEON is available for vectorized operations");
54        }
55
56        Ok(())
57    }
58
59    /// Device synchronization patterns
60    pub fn synchronization_patterns() -> Result<()> {
61        let device = CpuDevice::new();
62
63        // Basic synchronization
64        device.synchronize()?;
65
66        // Synchronization with timeout
67        device.synchronize()?; // Timeout version not available
68
69        // Wait for device to become idle
70        device.synchronize()?; // Wait for idle version not available
71
72        println!("All synchronization patterns completed successfully");
73        Ok(())
74    }
75}
76
77/// Examples for shape operations
78pub struct ShapeExamples;
79
80impl ShapeExamples {
81    /// Basic shape creation and manipulation
82    pub fn basic_shape_operations() -> Result<()> {
83        // Create shapes
84        let shape1 = Shape::new(vec![2, 3, 4]);
85        let shape2 = Shape::new(vec![24]);
86
87        println!("Shape 1: {:?}, elements: {}", shape1.dims(), shape1.numel());
88        println!("Shape 2: {:?}, elements: {}", shape2.dims(), shape2.numel());
89
90        // Check shape properties
91        if shape1.is_contiguous() {
92            println!("Shape 1 is contiguous");
93        }
94
95        if shape2.is_scalar() {
96            println!("Shape 2 is scalar");
97        } else {
98            println!("Shape 2 is not scalar");
99        }
100
101        Ok(())
102    }
103
104    /// Shape broadcasting examples
105    pub fn broadcasting_examples() -> Result<()> {
106        let shape1 = Shape::new(vec![3, 1, 4]);
107        let shape2 = Shape::new(vec![1, 2, 1]);
108
109        // Check if shapes are broadcastable
110        match shape1.broadcast_with(&shape2) {
111            Ok(result_shape) => {
112                println!(
113                    "Broadcasting {:?} with {:?} gives {:?}",
114                    shape1.dims(),
115                    shape2.dims(),
116                    result_shape.dims()
117                );
118            }
119            Err(e) => {
120                println!("Broadcasting failed: {e}");
121            }
122        }
123
124        Ok(())
125    }
126
127    /// Advanced shape operations
128    pub fn advanced_shape_operations() -> Result<()> {
129        let shape = Shape::new(vec![2, 3, 4, 5]);
130
131        // Get shape information
132        println!("Shape: {:?}", shape.dims());
133        println!("Number of dimensions: {}", shape.ndim());
134        println!("Total elements: {}", shape.numel());
135        println!("Is contiguous: {}", shape.is_contiguous());
136
137        // Get strides
138        let strides = shape.strides();
139        println!("Strides: {strides:?}");
140
141        // Create reshaped views (using new shape with compatible numel)
142        let reshaped = Shape::new(vec![6, 20]);
143        if shape.numel() == reshaped.numel() {
144            println!("Reshaped to: {:?}", reshaped.dims());
145        } else {
146            println!(
147                "Cannot reshape {} to {} - incompatible element count",
148                shape.numel(),
149                reshaped.numel()
150            );
151        }
152
153        Ok(())
154    }
155}
156
157/// Examples for data type operations
158pub struct DTypeExamples;
159
160impl DTypeExamples {
161    /// Basic data type usage
162    pub fn basic_dtype_operations() {
163        // Different data types
164        let dtypes = vec![
165            DType::F32,
166            DType::F64,
167            DType::I32,
168            DType::I64,
169            DType::U8,
170            DType::Bool,
171            DType::C64,
172            DType::C128,
173        ];
174
175        for dtype in dtypes {
176            let name = dtype.name();
177            let size = dtype.size_bytes();
178            let is_float = dtype.is_float();
179            let is_int = dtype.is_int();
180            let is_complex = dtype.is_complex();
181            println!("Type: {name}, Size: {size} bytes, Float: {is_float}, Int: {is_int}, Complex: {is_complex}");
182        }
183    }
184
185    /// Type promotion examples
186    pub fn type_promotion_examples() {
187        use crate::dtype::TypePromotion;
188
189        // Promote types for mixed operations
190        let common_type = <DType as TypePromotion>::common_type(&[DType::I32, DType::F32]);
191        println!("Common type of I32 and F32: {common_type:?}");
192
193        let common_type = <DType as TypePromotion>::common_type(&[DType::F32, DType::F64]);
194        println!("Common type of F32 and F64: {common_type:?}");
195
196        // Complex type promotion
197        let common_type = <DType as TypePromotion>::common_type(&[DType::F32, DType::C64]);
198        println!("Common type of F32 and C64: {common_type:?}");
199    }
200
201    /// Quantized type examples
202    pub fn quantized_types() {
203        use crate::dtype::{QInt8, QUInt8};
204
205        // Create quantized types
206        let qint8 = QInt8 {
207            value: -100,
208            scale: 0.5,
209            zero_point: 0,
210        };
211        let quint8 = QUInt8 {
212            value: 50,
213            scale: 0.25,
214            zero_point: 128,
215        };
216
217        println!("QInt8: value={}, scale={}", qint8.value, qint8.scale);
218        println!("QUInt8: value={}, scale={}", quint8.value, quint8.scale);
219
220        // Convert to/from float
221        let float_val = qint8.value as f32 * qint8.scale;
222        let back_to_qint8 = QInt8 {
223            value: (float_val / 0.5) as i8,
224            scale: 0.5,
225            zero_point: 0,
226        };
227
228        println!("QInt8 -> f32: {float_val}");
229        println!("f32 -> QInt8: value={}", back_to_qint8.value);
230    }
231}
232
233/// Examples for memory management
234pub struct MemoryExamples;
235
236impl MemoryExamples {
237    /// Memory pool usage
238    pub fn memory_pool_usage() -> Result<()> {
239        use crate::storage::{allocate_pooled, deallocate_pooled};
240
241        // Allocate memory from pool
242        let size = 1024;
243        let _alignment = 64;
244        let ptr: Vec<f32> = allocate_pooled(size);
245
246        println!("Allocated {size} bytes from memory pool");
247
248        // Check pool statistics
249        println!("Pool allocation successful");
250
251        // Deallocate memory
252        deallocate_pooled(ptr);
253        println!("Memory deallocated");
254
255        Ok(())
256    }
257
258    /// System memory monitoring
259    pub fn memory_monitoring() -> Result<()> {
260        let monitor = SystemMemoryMonitor::new()?;
261
262        // Get system memory stats
263        let stats = monitor.current_stats();
264        println!(
265            "System memory: {} MB total, {} MB available",
266            stats.total_physical / 1024 / 1024,
267            stats.available_physical / 1024 / 1024
268        );
269
270        // Check memory pressure
271        let pressure = stats.pressure;
272        match pressure {
273            crate::memory_monitor::MemoryPressure::Normal => println!("Memory pressure: Normal"),
274            crate::memory_monitor::MemoryPressure::Moderate => {
275                println!("Memory pressure: Moderate")
276            }
277            crate::memory_monitor::MemoryPressure::High => println!("Memory pressure: High"),
278            crate::memory_monitor::MemoryPressure::Critical => {
279                println!("Memory pressure: Critical")
280            }
281        }
282
283        Ok(())
284    }
285}
286
287/// Examples for interoperability features
288pub struct InteropExamples;
289
290impl InteropExamples {
291    /// NumPy compatibility examples
292    pub fn numpy_compatibility() {
293        // Create NumPy-compatible array info
294        let numpy_info = NumpyArrayInfo::new(vec![10, 20, 30], DType::F32);
295
296        println!("NumPy array info:");
297        println!("  Shape: {:?}", numpy_info.shape);
298        println!("  Strides: {:?}", numpy_info.strides);
299        println!("  C-contiguous: {}", numpy_info.c_contiguous);
300        println!("  F-contiguous: {}", numpy_info.f_contiguous);
301        println!("  Size in bytes: {}", numpy_info.nbytes);
302
303        // Check layout efficiency
304        let efficiency = ConversionUtils::layout_efficiency_score(
305            &numpy_info.shape,
306            &numpy_info.strides,
307            numpy_info.dtype.size(),
308        );
309        println!("  Layout efficiency: {efficiency:.2}");
310    }
311
312    /// ONNX type conversion examples
313    pub fn onnx_conversion() {
314        use crate::interop::{OnnxDataType, OnnxTensorInfo};
315
316        // Convert ToRSh types to ONNX
317        let torsh_types = vec![DType::F32, DType::I64, DType::Bool, DType::C64];
318
319        for dtype in torsh_types {
320            let onnx_type = OnnxDataType::from(dtype);
321            println!("ToRSh {dtype:?} -> ONNX {onnx_type:?}");
322
323            // Convert back - should always succeed for standard types
324            let back_to_torsh = DType::try_from(onnx_type)
325                .expect("Standard DType to ONNX conversion should be bi-directional");
326            assert_eq!(dtype, back_to_torsh);
327        }
328
329        // Create ONNX tensor info
330        let tensor_info = OnnxTensorInfo {
331            elem_type: OnnxDataType::Float,
332            shape: vec![Some(10), None, Some(20)], // None for dynamic dimensions
333            name: Some("example_tensor".to_string()),
334        };
335
336        println!("ONNX tensor info: {tensor_info:#?}");
337    }
338
339    /// Arrow integration examples
340    pub fn arrow_integration() {
341        use crate::interop::{ArrowDataType, ArrowTypeInfo};
342        use std::collections::HashMap;
343
344        // Convert ToRSh types to Arrow
345        let dtype = DType::F64;
346        let arrow_type = ArrowDataType::from(dtype);
347        println!("ToRSh {dtype:?} -> Arrow {arrow_type:?}");
348
349        // Complex type conversion
350        let complex_dtype = DType::C128;
351        let arrow_complex = ArrowDataType::from(complex_dtype);
352        println!("ToRSh {complex_dtype:?} -> Arrow {arrow_complex:?}");
353
354        // Create Arrow type info with metadata
355        let mut metadata = HashMap::new();
356        metadata.insert("origin".to_string(), "torsh".to_string());
357        metadata.insert("version".to_string(), "0.1.0".to_string());
358
359        let arrow_info = ArrowTypeInfo {
360            data_type: arrow_type,
361            metadata,
362        };
363
364        println!("Arrow type info: {arrow_info:#?}");
365    }
366}
367
368/// Complete workflow examples
369pub struct WorkflowExamples;
370
371impl WorkflowExamples {
372    /// Basic tensor creation workflow
373    pub fn basic_tensor_workflow() -> Result<()> {
374        println!("=== Basic Tensor Workflow ===");
375
376        // 1. Create device
377        let device = CpuDevice::new();
378        println!("1. Created device: {:?}", device.device_type());
379
380        // 2. Define shape and dtype
381        let shape = Shape::new(vec![3, 4, 5]);
382        let dtype = DType::F32;
383        println!("2. Defined shape: {:?}, dtype: {:?}", shape.dims(), dtype);
384
385        // 3. Check memory requirements
386        let bytes_needed = shape.numel() * dtype.size_bytes();
387        println!("3. Memory needed: {bytes_needed} bytes");
388
389        // 4. Create NumPy-compatible info
390        let numpy_info = NumpyArrayInfo::new(shape.dims().to_vec(), dtype);
391        println!(
392            "4. NumPy compatible: C-order={}, size={} bytes",
393            numpy_info.c_contiguous, numpy_info.nbytes
394        );
395
396        // 5. Synchronize device
397        device.synchronize()?;
398        println!("5. Device synchronized");
399
400        Ok(())
401    }
402
403    /// Memory-aware tensor processing
404    pub fn memory_aware_processing() -> Result<()> {
405        println!("=== Memory-Aware Processing ===");
406
407        // 1. Check system memory
408        let monitor = SystemMemoryMonitor::new()?;
409        let stats = monitor.current_stats();
410
411        println!(
412            "1. System memory: {:.1} GB available",
413            stats.available_physical as f64 / 1024.0 / 1024.0 / 1024.0
414        );
415        println!("   Memory pressure: {:?}", stats.pressure);
416
417        // 2. Decide on tensor size based on available memory
418        let max_elements = match stats.pressure {
419            MemoryPressure::Normal => 1_000_000,
420            MemoryPressure::Moderate => 500_000,
421            MemoryPressure::High => 100_000,
422            MemoryPressure::Critical => 10_000,
423        };
424
425        // 3. Create appropriately sized tensor
426        let shape = if max_elements >= 1_000_000 {
427            Shape::new(vec![100, 100, 100])
428        } else if max_elements >= 100_000 {
429            Shape::new(vec![50, 50, 40])
430        } else {
431            Shape::new(vec![20, 20, 25])
432        };
433
434        println!(
435            "2. Selected shape: {:?} ({} elements)",
436            shape.dims(),
437            shape.numel()
438        );
439
440        // 4. Allocate using memory pool
441        let size = shape.numel();
442        let data: Vec<f32> = vec![0.0; size];
443        println!("3. Allocated {size} elements");
444
445        // 5. Process (simulated)
446        println!("4. Processing tensor...");
447
448        // 6. Clean up (automatic when Vec<f32> goes out of scope)
449        drop(data);
450        println!("5. Memory deallocated");
451
452        Ok(())
453    }
454
455    /// Cross-platform compatibility workflow
456    pub fn cross_platform_workflow() -> Result<()> {
457        println!("=== Cross-Platform Compatibility ===");
458
459        // 1. Detect platform capabilities
460        let detector = BackendFeatureDetector::new()?;
461        let features = &detector.runtime_features;
462
463        println!("1. Platform detection:");
464        println!("   Architecture: {:?}", features.cpu_features.architecture);
465        println!(
466            "   SIMD: AVX2={}, NEON={}",
467            features.cpu_features.simd.avx2, features.cpu_features.simd.neon
468        );
469
470        // 2. Choose optimal data types
471        let dtype = if features.cpu_features.simd.avx512f {
472            println!("2. Using F32 (AVX-512 available)");
473            DType::F32
474        } else if features.cpu_features.simd.avx2 {
475            println!("2. Using F32 (AVX2 available)");
476            DType::F32
477        } else {
478            println!("2. Using F64 (fallback for precision)");
479            DType::F64
480        };
481
482        // 3. Create tensors with optimal layout
483        let shape = Shape::new(vec![32, 32, 32]); // Power of 2 for better SIMD
484        let numpy_info = NumpyArrayInfo::new(shape.dims().to_vec(), dtype);
485
486        println!("3. Created tensor:");
487        println!("   Shape: {:?}", numpy_info.shape);
488        println!(
489            "   Layout efficiency: {:.2}",
490            ConversionUtils::layout_efficiency_score(
491                &numpy_info.shape,
492                &numpy_info.strides,
493                dtype.size()
494            )
495        );
496
497        // 4. Show interoperability info
498        println!("4. Interoperability:");
499        let onnx_type = crate::interop::OnnxDataType::from(dtype);
500        let arrow_type = crate::interop::ArrowDataType::from(dtype);
501        println!("   ONNX type: {onnx_type:?}");
502        println!("   Arrow type: {arrow_type:?}");
503
504        Ok(())
505    }
506}
507
508/// Performance optimization examples
509pub struct PerformanceExamples;
510
511impl PerformanceExamples {
512    /// Memory layout optimization
513    pub fn memory_layout_optimization() {
514        println!("=== Memory Layout Optimization ===");
515
516        let shapes_and_layouts = vec![
517            ("C-contiguous", vec![1000, 1000], vec![4000, 4]),
518            ("F-contiguous", vec![1000, 1000], vec![4, 4000]),
519            ("Strided", vec![1000, 1000], vec![8000, 8]),
520        ];
521
522        for (name, shape, strides) in shapes_and_layouts {
523            let efficiency = ConversionUtils::layout_efficiency_score(&shape, &strides, 4);
524            println!("{name}: efficiency = {efficiency:.3}");
525
526            if efficiency > 0.9 {
527                println!("  ✓ Excellent layout for performance");
528            } else if efficiency > 0.7 {
529                println!("  ⚠ Good layout, some optimization possible");
530            } else {
531                println!("  ⚡ Consider layout optimization");
532            }
533        }
534    }
535
536    /// SIMD optimization guidance
537    pub fn simd_optimization_guidance() -> Result<()> {
538        println!("=== SIMD Optimization Guidance ===");
539
540        let detector = BackendFeatureDetector::new()?;
541        let features = &detector.runtime_features;
542
543        // Vector widths for different SIMD instruction sets
544        let vector_widths = if features.cpu_features.simd.avx512f {
545            println!("AVX-512 detected: 16 floats per vector");
546            16
547        } else if features.cpu_features.simd.avx2 {
548            println!("AVX2 detected: 8 floats per vector");
549            8
550        } else if features.cpu_features.simd.neon {
551            println!("NEON detected: 4 floats per vector");
552            4
553        } else {
554            println!("No SIMD detected: scalar operations");
555            1
556        };
557
558        // Recommend tensor sizes
559        let recommended_sizes = vec![vector_widths * 32, vector_widths * 64, vector_widths * 128];
560
561        println!("Recommended tensor sizes for optimal SIMD usage:");
562        for size in recommended_sizes {
563            println!(
564                "  {} elements ({}x vector width)",
565                size,
566                size / vector_widths
567            );
568        }
569
570        Ok(())
571    }
572}
573
574/// Documentation and help utilities
575pub struct DocumentationExamples;
576
577impl DocumentationExamples {
578    /// Print comprehensive help
579    pub fn print_help() {
580        println!("{}", InteropDocs::supported_conversions());
581        println!("{}", InteropDocs::conversion_examples());
582    }
583
584    /// Print API overview
585    pub fn api_overview() {
586        println!(
587            r#"
588ToRSh Core API Overview
589======================
590
591Core Modules:
592• device    - Device abstraction and management
593• dtype     - Data type definitions and operations
594• shape     - Shape and stride utilities
595• storage   - Memory management and allocation
596• interop   - Interoperability with other libraries
597• error     - Error handling and reporting
598
599Key Types:
600• Device    - Hardware device abstraction
601• DType     - Tensor data types (F32, I64, C64, etc.)
602• Shape     - Tensor dimensions and layout
603• TorshError - Comprehensive error handling
604
605Getting Started:
6061. Create a device: let device = CpuDevice::new();
6072. Define shape: let shape = Shape::new(vec![2, 3, 4]);
6083. Choose dtype: let dtype = DType::F32;
6094. Check interop: let numpy_info = NumpyArrayInfo::new(...);
610
611For detailed examples, see the examples module.
612"#
613        );
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_device_examples() {
623        assert!(DeviceExamples::basic_device_usage().is_ok());
624        assert!(DeviceExamples::device_capabilities().is_ok());
625        assert!(DeviceExamples::synchronization_patterns().is_ok());
626    }
627
628    #[test]
629    fn test_shape_examples() {
630        assert!(ShapeExamples::basic_shape_operations().is_ok());
631        assert!(ShapeExamples::broadcasting_examples().is_ok());
632        assert!(ShapeExamples::advanced_shape_operations().is_ok());
633    }
634
635    #[test]
636    fn test_dtype_examples() {
637        DTypeExamples::basic_dtype_operations();
638        DTypeExamples::type_promotion_examples();
639        DTypeExamples::quantized_types();
640    }
641
642    #[test]
643    fn test_memory_examples() {
644        assert!(MemoryExamples::memory_pool_usage().is_ok());
645        assert!(MemoryExamples::memory_monitoring().is_ok());
646    }
647
648    #[test]
649    fn test_interop_examples() {
650        InteropExamples::numpy_compatibility();
651        InteropExamples::onnx_conversion();
652        InteropExamples::arrow_integration();
653    }
654
655    #[test]
656    fn test_workflow_examples() {
657        assert!(WorkflowExamples::basic_tensor_workflow().is_ok());
658        assert!(WorkflowExamples::memory_aware_processing().is_ok());
659        assert!(WorkflowExamples::cross_platform_workflow().is_ok());
660    }
661
662    #[test]
663    fn test_performance_examples() {
664        PerformanceExamples::memory_layout_optimization();
665        assert!(PerformanceExamples::simd_optimization_guidance().is_ok());
666    }
667
668    #[test]
669    fn test_documentation_examples() {
670        DocumentationExamples::print_help();
671        DocumentationExamples::api_overview();
672    }
673}