Skip to main content

torsh_utils/
cpp_extension.rs

1//! C++ Extension utilities for ToRSh
2//!
3//! This module provides utilities for building C++ extensions that integrate
4//! with the ToRSh framework, similar to PyTorch's cpp_extension module.
5
6use std::collections::HashMap;
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12/// JIT compilation configuration
13#[derive(Debug, Clone, Default)]
14pub struct JitCompilationConfig {
15    /// Enable just-in-time compilation
16    pub enabled: bool,
17    /// Cache compiled kernels
18    pub cache_enabled: bool,
19    /// Cache directory
20    pub cache_dir: Option<PathBuf>,
21    /// Optimization level for JIT (0-3)
22    pub optimization_level: u8,
23    /// Enable CUDA JIT compilation
24    pub cuda_jit: bool,
25    /// CUDA JIT cache size (in MB)
26    pub cuda_cache_size: usize,
27    /// Maximum number of registers for CUDA kernels
28    pub cuda_max_registers: Option<u32>,
29}
30
31/// Custom operation definition
32#[derive(Debug, Clone)]
33pub struct CustomOpDefinition {
34    /// Operation name
35    pub name: String,
36    /// Operation type (forward, backward, both)
37    pub op_type: CustomOpType,
38    /// Input tensor shapes (None means dynamic)
39    pub input_shapes: Vec<Option<Vec<usize>>>,
40    /// Output tensor shapes (None means dynamic)
41    pub output_shapes: Vec<Option<Vec<usize>>>,
42    /// CPU implementation source
43    pub cpu_source: Option<String>,
44    /// CUDA implementation source
45    pub cuda_source: Option<String>,
46    /// Custom compile flags for this operation
47    pub compile_flags: Vec<String>,
48    /// Operation schema for validation
49    pub schema: OpSchema,
50}
51
52/// Custom operation type
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum CustomOpType {
55    Forward,
56    Backward,
57    ForwardBackward,
58}
59
60/// Operation schema for validation and optimization
61#[derive(Debug, Clone, Default)]
62pub struct OpSchema {
63    /// Input tensor types
64    pub input_types: Vec<TensorType>,
65    /// Output tensor types
66    pub output_types: Vec<TensorType>,
67    /// Whether the operation is elementwise
68    pub is_elementwise: bool,
69    /// Whether the operation is deterministic
70    pub is_deterministic: bool,
71    /// Memory requirement estimation
72    pub memory_requirement: MemoryRequirement,
73}
74
75/// Tensor type information
76#[derive(Debug, Clone)]
77pub struct TensorType {
78    /// Data type (f32, f64, i32, etc.)
79    pub dtype: String,
80    /// Minimum number of dimensions
81    pub min_dims: usize,
82    /// Maximum number of dimensions (None means unlimited)
83    pub max_dims: Option<usize>,
84    /// Whether the tensor can be sparse
85    pub supports_sparse: bool,
86}
87
88/// Memory requirement estimation
89#[derive(Debug, Clone, Default)]
90pub enum MemoryRequirement {
91    #[default]
92    Unknown,
93    /// O(1) memory
94    Constant,
95    /// O(n) memory where n is input size
96    Linear,
97    /// O(n²) memory
98    Quadratic,
99    /// Custom memory formula
100    Custom(String),
101}
102
103/// Cross-platform build configuration
104#[derive(Debug, Clone, Default)]
105pub struct CrossPlatformConfig {
106    /// Target platforms to build for
107    pub target_platforms: Vec<TargetPlatform>,
108    /// Windows-specific settings
109    pub windows: WindowsConfig,
110    /// macOS-specific settings
111    pub macos: MacOsConfig,
112    /// Linux-specific settings
113    pub linux: LinuxConfig,
114    /// Enable cross-compilation
115    pub cross_compile: bool,
116    /// Docker-based building
117    pub use_docker: bool,
118}
119
120/// Target platform specification
121#[derive(Debug, Clone, PartialEq, Eq, Hash)]
122pub enum TargetPlatform {
123    WindowsX64,
124    WindowsX86,
125    MacOsX64,
126    MacOsArm64,
127    LinuxX64,
128    LinuxArm64,
129    LinuxAarch64,
130}
131
132/// Windows-specific build configuration
133#[derive(Debug, Clone, Default)]
134pub struct WindowsConfig {
135    /// Visual Studio version to use
136    pub vs_version: Option<String>,
137    /// Windows SDK version
138    pub sdk_version: Option<String>,
139    /// Use clang instead of MSVC
140    pub use_clang: bool,
141    /// Enable Windows-specific optimizations
142    pub enable_simd: bool,
143}
144
145/// macOS-specific build configuration
146#[derive(Debug, Clone, Default)]
147pub struct MacOsConfig {
148    /// Minimum macOS version
149    pub min_version: Option<String>,
150    /// Xcode version to use
151    pub xcode_version: Option<String>,
152    /// Enable Metal Performance Shaders
153    pub enable_mps: bool,
154    /// Universal binary (x64 + ARM64)
155    pub universal_binary: bool,
156}
157
158/// Linux-specific build configuration
159#[derive(Debug, Clone, Default)]
160pub struct LinuxConfig {
161    /// GCC/Clang version preference
162    pub compiler_preference: CompilerPreference,
163    /// Enable Intel MKL
164    pub enable_mkl: bool,
165    /// Enable OpenMP
166    pub enable_openmp: bool,
167    /// Distribution-specific packages
168    pub distro_packages: Vec<String>,
169}
170
171/// Compiler preference on Linux
172#[derive(Debug, Clone, Default, PartialEq, Eq)]
173pub enum CompilerPreference {
174    #[default]
175    Auto,
176    Gcc,
177    Clang,
178    Intel,
179}
180
181/// Configuration for building a C++ extension
182#[derive(Debug, Clone)]
183pub struct CppExtensionConfig {
184    /// Name of the extension module
185    pub name: String,
186    /// Source files to compile
187    pub sources: Vec<PathBuf>,
188    /// Include directories
189    pub include_dirs: Vec<PathBuf>,
190    /// Library directories
191    pub library_dirs: Vec<PathBuf>,
192    /// Libraries to link
193    pub libraries: Vec<String>,
194    /// Extra compiler flags
195    pub extra_compile_args: Vec<String>,
196    /// Extra linker flags
197    pub extra_link_args: Vec<String>,
198    /// Whether to build with CUDA support
199    pub with_cuda: bool,
200    /// CUDA architectures to target
201    pub cuda_archs: Vec<String>,
202    /// Whether to enable debug symbols
203    pub debug: bool,
204    /// Output directory
205    pub build_dir: PathBuf,
206    /// JIT compilation settings
207    pub jit_config: JitCompilationConfig,
208    /// Custom operation definitions
209    pub custom_ops: Vec<CustomOpDefinition>,
210    /// Cross-platform build settings
211    pub cross_platform: CrossPlatformConfig,
212}
213
214impl CppExtensionConfig {
215    /// Create a new C++ extension configuration
216    pub fn new(name: impl Into<String>, sources: Vec<PathBuf>) -> Self {
217        let name = name.into();
218        let build_dir = env::temp_dir().join("torsh_cpp_extensions").join(&name);
219
220        Self {
221            name,
222            sources,
223            include_dirs: vec![],
224            library_dirs: vec![],
225            libraries: vec![],
226            extra_compile_args: vec![],
227            extra_link_args: vec![],
228            with_cuda: false,
229            cuda_archs: vec![
230                "sm_70".to_string(),
231                "sm_75".to_string(),
232                "sm_80".to_string(),
233                "sm_86".to_string(),
234                "sm_89".to_string(),
235            ],
236            debug: false,
237            build_dir,
238            jit_config: JitCompilationConfig::default(),
239            custom_ops: vec![],
240            cross_platform: CrossPlatformConfig::default(),
241        }
242    }
243
244    /// Add include directory
245    pub fn include_dir(mut self, dir: impl AsRef<Path>) -> Self {
246        self.include_dirs.push(dir.as_ref().to_path_buf());
247        self
248    }
249
250    /// Add library directory
251    pub fn library_dir(mut self, dir: impl AsRef<Path>) -> Self {
252        self.library_dirs.push(dir.as_ref().to_path_buf());
253        self
254    }
255
256    /// Add library to link
257    pub fn library(mut self, lib: impl Into<String>) -> Self {
258        self.libraries.push(lib.into());
259        self
260    }
261
262    /// Add extra compile arguments
263    pub fn extra_compile_arg(mut self, arg: impl Into<String>) -> Self {
264        self.extra_compile_args.push(arg.into());
265        self
266    }
267
268    /// Add extra link arguments
269    pub fn extra_link_arg(mut self, arg: impl Into<String>) -> Self {
270        self.extra_link_args.push(arg.into());
271        self
272    }
273
274    /// Enable CUDA support
275    pub fn cuda(mut self, cuda_archs: Vec<String>) -> Self {
276        self.with_cuda = true;
277        self.cuda_archs = cuda_archs;
278        self
279    }
280
281    /// Enable debug symbols
282    pub fn debug(mut self) -> Self {
283        self.debug = true;
284        self
285    }
286
287    /// Set build directory
288    pub fn build_dir(mut self, dir: impl AsRef<Path>) -> Self {
289        self.build_dir = dir.as_ref().to_path_buf();
290        self
291    }
292
293    /// Enable JIT compilation
294    pub fn jit(mut self, config: JitCompilationConfig) -> Self {
295        self.jit_config = config;
296        self
297    }
298
299    /// Add custom operation
300    pub fn custom_op(mut self, op: CustomOpDefinition) -> Self {
301        self.custom_ops.push(op);
302        self
303    }
304
305    /// Set cross-platform build configuration
306    pub fn cross_platform(mut self, config: CrossPlatformConfig) -> Self {
307        self.cross_platform = config;
308        self
309    }
310
311    /// Enable JIT compilation with default settings
312    pub fn enable_jit(mut self) -> Self {
313        self.jit_config.enabled = true;
314        self.jit_config.cache_enabled = true;
315        self.jit_config.optimization_level = 2;
316        self
317    }
318
319    /// Enable CUDA JIT compilation
320    pub fn enable_cuda_jit(mut self) -> Self {
321        self.jit_config.cuda_jit = true;
322        self.jit_config.cuda_cache_size = 256; // 256 MB default
323        self
324    }
325}
326
327/// Build result containing the path to the compiled extension
328#[derive(Debug)]
329pub struct BuildResult {
330    /// Path to the compiled shared library
331    pub library_path: PathBuf,
332    /// Include directories for using the extension
333    pub include_dirs: Vec<PathBuf>,
334    /// JIT compilation results
335    pub jit_info: Option<JitBuildInfo>,
336    /// Custom operations that were compiled
337    pub compiled_ops: Vec<String>,
338    /// Cross-platform build artifacts
339    pub platform_artifacts: HashMap<TargetPlatform, PathBuf>,
340}
341
342/// JIT compilation build information
343#[derive(Debug)]
344pub struct JitBuildInfo {
345    /// JIT cache directory
346    pub cache_dir: PathBuf,
347    /// Number of kernels compiled
348    pub kernel_count: usize,
349    /// CUDA JIT compilation info
350    pub cuda_info: Option<CudaJitInfo>,
351}
352
353/// CUDA JIT compilation information
354#[derive(Debug)]
355pub struct CudaJitInfo {
356    /// PTX cache size in bytes
357    pub ptx_cache_size: usize,
358    /// Number of CUDA kernels
359    pub kernel_count: usize,
360    /// GPU compute capability used
361    pub compute_capability: Vec<String>,
362    /// Runtime compilation cache hits
363    pub cache_hits: usize,
364    /// Runtime compilation cache misses
365    pub cache_misses: usize,
366    /// JIT compilation time in milliseconds
367    pub compilation_time_ms: f64,
368}
369
370/// CUDA device information
371#[derive(Debug, Clone)]
372pub struct CudaDeviceInfo {
373    /// Device index
374    pub device_id: u32,
375    /// Device name
376    pub name: String,
377    /// Compute capability (e.g., "8.0")
378    pub compute_capability: String,
379    /// Total global memory in bytes
380    pub total_memory: usize,
381    /// Maximum threads per block
382    pub max_threads_per_block: u32,
383    /// Maximum grid dimensions
384    pub max_grid_size: [u32; 3],
385    /// Maximum block dimensions
386    pub max_block_size: [u32; 3],
387    /// Warp size
388    pub warp_size: u32,
389    /// Number of multiprocessors
390    pub multiprocessor_count: u32,
391    /// Maximum shared memory per block
392    pub shared_memory_per_block: usize,
393}
394
395/// Advanced CUDA kernel compilation options
396#[derive(Debug, Clone)]
397pub struct CudaKernelCompilationOptions {
398    /// Optimization level (0-3)
399    pub optimization_level: u8,
400    /// Enable fast math operations
401    pub fast_math: bool,
402    /// Maximum register count per thread
403    pub max_registers: Option<u32>,
404    /// Use cache for global memory loads
405    pub use_cache: bool,
406    /// Generate debug information
407    pub debug_info: bool,
408    /// Compile for specific GPU architecture
409    pub target_arch: Option<String>,
410    /// Custom compiler flags
411    pub custom_flags: Vec<String>,
412}
413
414impl Default for CudaKernelCompilationOptions {
415    fn default() -> Self {
416        Self {
417            optimization_level: 2,
418            fast_math: false,
419            max_registers: None,
420            use_cache: true,
421            debug_info: false,
422            target_arch: None,
423            custom_flags: vec![],
424        }
425    }
426}
427
428/// Runtime CUDA kernel management
429#[derive(Debug)]
430pub struct RuntimeCudaKernel {
431    /// Kernel name
432    pub name: String,
433    /// PTX source code
434    pub ptx_source: String,
435    /// Compiled module handle (would be CUmodule in real implementation)
436    pub module_handle: Option<usize>,
437    /// Kernel function handle (would be CUfunction in real implementation)
438    pub function_handle: Option<usize>,
439    /// Compilation options used
440    pub compilation_options: CudaKernelCompilationOptions,
441    /// Grid and block configuration
442    pub launch_config: CudaLaunchConfig,
443}
444
445/// CUDA kernel launch configuration
446#[derive(Debug, Clone)]
447pub struct CudaLaunchConfig {
448    /// Grid dimensions
449    pub grid_size: [u32; 3],
450    /// Block dimensions
451    pub block_size: [u32; 3],
452    /// Shared memory size in bytes
453    pub shared_memory_size: usize,
454    /// CUDA stream handle
455    pub stream: Option<usize>,
456}
457
458/// Build a C++ extension
459pub fn build_cpp_extension(config: &CppExtensionConfig) -> Result<BuildResult, String> {
460    // Create build directory
461    fs::create_dir_all(&config.build_dir)
462        .map_err(|e| format!("Failed to create build directory: {}", e))?;
463
464    // Setup JIT compilation if enabled
465    let jit_info = if config.jit_config.enabled {
466        Some(setup_jit_compilation(config)?)
467    } else {
468        None
469    };
470
471    // Generate custom operation sources
472    let mut generated_sources = vec![];
473    let mut compiled_ops = vec![];
474
475    for custom_op in &config.custom_ops {
476        let generated_source = generate_custom_op_source(custom_op)?;
477        generated_sources.push(generated_source);
478        compiled_ops.push(custom_op.name.clone());
479    }
480
481    // Build for each target platform
482    let mut platform_artifacts = HashMap::new();
483
484    if config.cross_platform.target_platforms.is_empty() {
485        // Build for current platform
486        let artifact = build_for_platform(config, None, &generated_sources, &jit_info)?;
487        platform_artifacts.insert(detect_current_platform(), artifact);
488    } else {
489        // Build for specified platforms
490        for platform in &config.cross_platform.target_platforms {
491            let artifact =
492                build_for_platform(config, Some(platform), &generated_sources, &jit_info)?;
493            platform_artifacts.insert(platform.clone(), artifact);
494        }
495    }
496
497    // Get the main artifact (current platform or first specified)
498    let main_artifact = platform_artifacts
499        .get(&detect_current_platform())
500        .or_else(|| platform_artifacts.values().next())
501        .ok_or("No artifacts built")?
502        .clone();
503
504    Ok(BuildResult {
505        library_path: main_artifact,
506        include_dirs: config.include_dirs.clone(),
507        jit_info,
508        compiled_ops,
509        platform_artifacts,
510    })
511}
512
513/// Setup JIT compilation
514fn setup_jit_compilation(config: &CppExtensionConfig) -> Result<JitBuildInfo, String> {
515    let cache_dir = config
516        .jit_config
517        .cache_dir
518        .clone()
519        .unwrap_or_else(|| config.build_dir.join("jit_cache"));
520
521    fs::create_dir_all(&cache_dir)
522        .map_err(|e| format!("Failed to create JIT cache directory: {}", e))?;
523
524    let cuda_info = if config.jit_config.cuda_jit && config.with_cuda {
525        Some(setup_cuda_jit(config, &cache_dir)?)
526    } else {
527        None
528    };
529
530    Ok(JitBuildInfo {
531        cache_dir,
532        kernel_count: config.custom_ops.len(),
533        cuda_info,
534    })
535}
536
537/// Setup CUDA JIT compilation
538fn setup_cuda_jit(config: &CppExtensionConfig, cache_dir: &Path) -> Result<CudaJitInfo, String> {
539    let cuda_cache_dir = cache_dir.join("cuda");
540    fs::create_dir_all(&cuda_cache_dir)
541        .map_err(|e| format!("Failed to create CUDA cache directory: {}", e))?;
542
543    // Initialize CUDA runtime and query device capabilities
544    let device_info = query_cuda_devices()?;
545    let available_archs = device_info
546        .iter()
547        .map(|dev| format!("sm_{}", dev.compute_capability.replace(".", "")))
548        .collect::<Vec<_>>();
549
550    // Setup PTX cache structure
551    let ptx_cache_dir = cuda_cache_dir.join("ptx");
552    let cubin_cache_dir = cuda_cache_dir.join("cubin");
553    fs::create_dir_all(&ptx_cache_dir)
554        .map_err(|e| format!("Failed to create PTX cache directory: {}", e))?;
555    fs::create_dir_all(&cubin_cache_dir)
556        .map_err(|e| format!("Failed to create CUBIN cache directory: {}", e))?;
557
558    // Configure JIT compilation options
559    configure_cuda_jit_options(config)?;
560
561    // Validate CUDA kernel sources for syntax
562    for op in &config.custom_ops {
563        if let Some(cuda_source) = &op.cuda_source {
564            validate_cuda_kernel_syntax(cuda_source, &op.name)?;
565        }
566    }
567
568    Ok(CudaJitInfo {
569        ptx_cache_size: config.jit_config.cuda_cache_size * 1024 * 1024, // Convert MB to bytes
570        kernel_count: config
571            .custom_ops
572            .iter()
573            .filter(|op| op.cuda_source.is_some())
574            .count(),
575        compute_capability: available_archs,
576        cache_hits: 0,
577        cache_misses: 0,
578        compilation_time_ms: 0.0,
579    })
580}
581
582/// Generate custom operation source code
583fn generate_custom_op_source(op: &CustomOpDefinition) -> Result<PathBuf, String> {
584    // Generate C++ source code for the custom operation
585    let source_content = match &op.op_type {
586        CustomOpType::Forward => generate_forward_op(&op.name, &op.cpu_source, &op.cuda_source)?,
587        CustomOpType::Backward => generate_backward_op(&op.name, &op.cpu_source, &op.cuda_source)?,
588        CustomOpType::ForwardBackward => {
589            generate_forward_backward_op(&op.name, &op.cpu_source, &op.cuda_source)?
590        }
591    };
592
593    // Write to temporary file
594    let temp_file = env::temp_dir().join(format!("{}_custom_op.cpp", op.name));
595    fs::write(&temp_file, source_content)
596        .map_err(|e| format!("Failed to write custom op source: {}", e))?;
597
598    Ok(temp_file)
599}
600
601/// Generate forward operation source
602fn generate_forward_op(
603    name: &str,
604    cpu_source: &Option<String>,
605    cuda_source: &Option<String>,
606) -> Result<String, String> {
607    let mut source = format!(
608        r#"// Generated custom operation: {}
609#include <torsh/tensor.h>
610#include <torsh/autograd.h>
611
612namespace torsh {{
613namespace ops {{
614
615"#,
616        name
617    );
618
619    // Add CPU implementation
620    if let Some(cpu_impl) = cpu_source {
621        source.push_str(&format!(
622            r#"
623// CPU implementation
624Tensor {}_cpu_forward(const std::vector<Tensor>& inputs) {{
625    {}
626}}
627"#,
628            name, cpu_impl
629        ));
630    }
631
632    // Add CUDA implementation
633    if let Some(cuda_impl) = cuda_source {
634        source.push_str(&format!(
635            r#"
636#ifdef TORSH_USE_CUDA
637// CUDA implementation
638Tensor {}_cuda_forward(const std::vector<Tensor>& inputs) {{
639    {}
640}}
641#endif
642"#,
643            name, cuda_impl
644        ));
645    }
646
647    // Add dispatcher
648    source.push_str(&format!(
649        r#"
650// Operation dispatcher
651Tensor {}_forward(const std::vector<Tensor>& inputs) {{
652#ifdef TORSH_USE_CUDA
653    if (inputs[0].is_cuda()) {{
654        return {}_cuda_forward(inputs);
655    }}
656#endif
657    return {}_cpu_forward(inputs);
658}}
659
660// Register operation
661TORSH_REGISTER_OP("{}", {}_forward);
662
663}} // namespace ops
664}} // namespace torsh
665"#,
666        name, name, name, name, name
667    ));
668
669    Ok(source)
670}
671
672/// Generate backward operation source
673fn generate_backward_op(
674    name: &str,
675    cpu_source: &Option<String>,
676    cuda_source: &Option<String>,
677) -> Result<String, String> {
678    // Similar structure to forward op but for backward pass
679    let mut source = format!(
680        r#"// Generated custom backward operation: {}
681#include <torsh/tensor.h>
682#include <torsh/autograd.h>
683
684namespace torsh {{
685namespace ops {{
686"#,
687        name
688    );
689
690    if let Some(cpu_impl) = cpu_source {
691        source.push_str(&format!(
692            r#"
693std::vector<Tensor> {}_cpu_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
694    {}
695}}
696"#,
697            name, cpu_impl
698        ));
699    }
700
701    if let Some(cuda_impl) = cuda_source {
702        source.push_str(&format!(
703            r#"
704#ifdef TORSH_USE_CUDA
705std::vector<Tensor> {}_cuda_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
706    {}
707}}
708#endif
709"#,
710            name, cuda_impl
711        ));
712    }
713
714    source.push_str(&format!(
715        r#"
716std::vector<Tensor> {}_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
717#ifdef TORSH_USE_CUDA
718    if (inputs[0].is_cuda()) {{
719        return {}_cuda_backward(grad_outputs, inputs);
720    }}
721#endif
722    return {}_cpu_backward(grad_outputs, inputs);
723}}
724
725TORSH_REGISTER_BACKWARD_OP("{}", {}_backward);
726
727}} // namespace ops
728}} // namespace torsh
729"#,
730        name, name, name, name, name
731    ));
732
733    Ok(source)
734}
735
736/// Generate forward and backward operation source
737fn generate_forward_backward_op(
738    name: &str,
739    cpu_source: &Option<String>,
740    cuda_source: &Option<String>,
741) -> Result<String, String> {
742    // Combine forward and backward generation
743    let forward_source = generate_forward_op(name, cpu_source, cuda_source)?;
744    let backward_source =
745        generate_backward_op(&format!("{}_backward", name), cpu_source, cuda_source)?;
746
747    Ok(format!("{}\n\n{}", forward_source, backward_source))
748}
749
750/// Build for a specific platform
751fn build_for_platform(
752    config: &CppExtensionConfig,
753    target_platform: Option<&TargetPlatform>,
754    generated_sources: &[PathBuf],
755    _jit_info: &Option<JitBuildInfo>,
756) -> Result<PathBuf, String> {
757    // Determine compiler based on platform
758    let (compiler, extra_flags) =
759        match target_platform {
760            Some(TargetPlatform::WindowsX64) | Some(TargetPlatform::WindowsX86) => {
761                if config.cross_platform.windows.use_clang {
762                    (
763                        "clang++".to_string(),
764                        vec![
765                            "-target".to_string(),
766                            get_windows_target(target_platform.expect(
767                                "target_platform should be Some for Windows platform branch",
768                            )),
769                        ],
770                    )
771                } else {
772                    ("cl.exe".to_string(), vec!["/std:c++17".to_string()])
773                }
774            }
775            Some(TargetPlatform::MacOsX64) | Some(TargetPlatform::MacOsArm64) => {
776                let target = match target_platform
777                    .expect("target_platform should be Some for macOS platform branch")
778                {
779                    TargetPlatform::MacOsX64 => "x86_64-apple-darwin",
780                    TargetPlatform::MacOsArm64 => "arm64-apple-darwin",
781                    _ => unreachable!(),
782                };
783                (
784                    "clang++".to_string(),
785                    vec!["-target".to_string(), target.to_string()],
786                )
787            }
788            Some(TargetPlatform::LinuxX64)
789            | Some(TargetPlatform::LinuxArm64)
790            | Some(TargetPlatform::LinuxAarch64) => {
791                match config.cross_platform.linux.compiler_preference {
792                    CompilerPreference::Clang => ("clang++".to_string(), vec![]),
793                    CompilerPreference::Gcc => ("g++".to_string(), vec![]),
794                    CompilerPreference::Intel => ("icpc".to_string(), vec![]),
795                    CompilerPreference::Auto => (
796                        env::var("CXX").unwrap_or_else(|_| "g++".to_string()),
797                        vec![],
798                    ),
799                }
800            }
801            None => {
802                // Current platform
803                if config.with_cuda {
804                    ("nvcc".to_string(), vec![])
805                } else {
806                    (
807                        env::var("CXX").unwrap_or_else(|_| "c++".to_string()),
808                        vec![],
809                    )
810                }
811            }
812        };
813
814    // Build compile command
815    let mut cmd = Command::new(&compiler);
816
817    // Add platform-specific flags
818    cmd.args(&extra_flags);
819
820    // Add include directories
821    for include_dir in &config.include_dirs {
822        cmd.arg(format!("-I{}", include_dir.display()));
823    }
824
825    // Add ToRSh include directory
826    if let Ok(torsh_include) = env::var("TORSH_INCLUDE_DIR") {
827        cmd.arg(format!("-I{}", torsh_include));
828    }
829
830    // Add standard flags (platform-specific)
831    if compiler.contains("cl.exe") {
832        // MSVC flags
833        cmd.arg("/std:c++17");
834        if !config.debug {
835            cmd.arg("/O2");
836            cmd.arg("/DNDEBUG");
837        } else {
838            cmd.arg("/Od");
839            cmd.arg("/Zi");
840        }
841    } else {
842        // GCC/Clang flags
843        cmd.arg("-std=c++17");
844        cmd.arg("-fPIC");
845        if !config.debug {
846            cmd.arg("-O3");
847            cmd.arg("-DNDEBUG");
848        } else {
849            cmd.arg("-g");
850            cmd.arg("-O0");
851        }
852    }
853
854    // Add CUDA specific flags
855    if config.with_cuda && compiler.contains("nvcc") {
856        for arch in &config.cuda_archs {
857            cmd.arg(format!(
858                "-gencode=arch=compute_{},code={}",
859                &arch[3..],
860                arch
861            ));
862        }
863        cmd.arg("-x").arg("cu");
864    }
865
866    // Add extra compile args
867    for arg in &config.extra_compile_args {
868        cmd.arg(arg);
869    }
870
871    // Add source files (original + generated)
872    for source in &config.sources {
873        cmd.arg(source);
874    }
875    for source in generated_sources {
876        cmd.arg(source);
877    }
878
879    // Output file
880    let platform_suffix = target_platform
881        .map(|p| format!("_{:?}", p))
882        .unwrap_or_default();
883    let output_file = config
884        .build_dir
885        .join(format!("lib{}{}.so", config.name, platform_suffix));
886
887    if compiler.contains("cl.exe") {
888        cmd.arg("/Fe:").arg(&output_file);
889        cmd.arg("/LD"); // Create DLL
890    } else {
891        cmd.arg("-shared");
892        cmd.arg("-o").arg(&output_file);
893    }
894
895    // Add library directories
896    for lib_dir in &config.library_dirs {
897        if compiler.contains("cl.exe") {
898            cmd.arg(format!("/LIBPATH:{}", lib_dir.display()));
899        } else {
900            cmd.arg(format!("-L{}", lib_dir.display()));
901        }
902    }
903
904    // Add libraries
905    for lib in &config.libraries {
906        if compiler.contains("cl.exe") {
907            cmd.arg(format!("{}.lib", lib));
908        } else {
909            cmd.arg(format!("-l{}", lib));
910        }
911    }
912
913    // Add extra link args
914    for arg in &config.extra_link_args {
915        cmd.arg(arg);
916    }
917
918    // Execute build
919    let output = cmd
920        .output()
921        .map_err(|e| format!("Failed to execute compiler {}: {}", compiler, e))?;
922
923    if !output.status.success() {
924        let stderr = String::from_utf8_lossy(&output.stderr);
925        return Err(format!(
926            "Compilation failed for platform {:?}:\n{}",
927            target_platform, stderr
928        ));
929    }
930
931    Ok(output_file)
932}
933
934/// Detect current platform
935fn detect_current_platform() -> TargetPlatform {
936    match env::consts::OS {
937        "windows" => match env::consts::ARCH {
938            "x86_64" => TargetPlatform::WindowsX64,
939            "x86" => TargetPlatform::WindowsX86,
940            _ => TargetPlatform::WindowsX64, // Default
941        },
942        "macos" => match env::consts::ARCH {
943            "aarch64" => TargetPlatform::MacOsArm64,
944            _ => TargetPlatform::MacOsX64,
945        },
946        "linux" => match env::consts::ARCH {
947            "aarch64" => TargetPlatform::LinuxAarch64,
948            "arm64" => TargetPlatform::LinuxArm64,
949            _ => TargetPlatform::LinuxX64,
950        },
951        _ => TargetPlatform::LinuxX64, // Default fallback
952    }
953}
954
955/// Get Windows target string
956fn get_windows_target(platform: &TargetPlatform) -> String {
957    match platform {
958        TargetPlatform::WindowsX64 => "x86_64-pc-windows-msvc".to_string(),
959        TargetPlatform::WindowsX86 => "i686-pc-windows-msvc".to_string(),
960        _ => "x86_64-pc-windows-msvc".to_string(), // Default
961    }
962}
963
964/// Load a C++ extension from a shared library
965pub fn load_cpp_extension(library_path: &Path) -> Result<(), String> {
966    // This would typically use libloading or similar to dynamically load the library
967    // For now, we just verify the file exists
968    if !library_path.exists() {
969        return Err(format!("Library not found: {}", library_path.display()));
970    }
971
972    // In a real implementation, we would:
973    // 1. Load the shared library
974    // 2. Register any custom operators
975    // 3. Initialize any global state
976
977    Ok(())
978}
979
980/// Generate a simple C++ extension template
981pub fn generate_extension_template(name: &str, output_dir: &Path) -> Result<(), String> {
982    fs::create_dir_all(output_dir)
983        .map_err(|e| format!("Failed to create output directory: {}", e))?;
984
985    // Generate header file
986    let header_content = format!(
987        r#"#pragma once
988
989#include <torsh/tensor.h>
990#include <torsh/module.h>
991
992namespace torsh {{
993namespace ops {{
994
995// Example custom operation
996Tensor {}_forward(const Tensor& input);
997
998}} // namespace ops
999}} // namespace torsh
1000"#,
1001        name
1002    );
1003
1004    let header_path = output_dir.join(format!("{}.h", name));
1005    fs::write(&header_path, header_content)
1006        .map_err(|e| format!("Failed to write header file: {}", e))?;
1007
1008    // Generate source file
1009    let source_content = format!(
1010        r#"#include "{}.h"
1011#include <torsh/autograd.h>
1012#include <iostream>
1013
1014namespace torsh {{
1015namespace ops {{
1016
1017Tensor {}_forward(const Tensor& input) {{
1018    // Example implementation
1019    auto output = input.clone();
1020    
1021    // Perform custom operation
1022    // This is where you would implement your custom logic
1023    
1024    return output;
1025}}
1026
1027// Register the operation
1028TORSH_LIBRARY(TORCH_EXTENSION_NAME, m) {{
1029    m.def("{}_forward", &{}_forward);
1030}}
1031
1032}} // namespace ops
1033}} // namespace torsh
1034"#,
1035        name, name, name, name
1036    );
1037
1038    let source_path = output_dir.join(format!("{}.cpp", name));
1039    fs::write(&source_path, source_content)
1040        .map_err(|e| format!("Failed to write source file: {}", e))?;
1041
1042    // Generate setup script
1043    let setup_content = format!(
1044        r#"use torsh_utils::cpp_extension::{{CppExtensionConfig, build_cpp_extension}};
1045use std::path::PathBuf;
1046
1047fn main() {{
1048    let config = CppExtensionConfig::new("{}", vec![
1049        PathBuf::from("{}.cpp"),
1050    ])
1051    .include_dir(".")
1052    .extra_compile_arg("-Wall")
1053    .extra_compile_arg("-Wextra");
1054
1055    match build_cpp_extension(&config) {{
1056        Ok(result) => {{
1057            println!("Extension built successfully!");
1058            println!("Library: {{:?}}", result.library_path);
1059        }}
1060        Err(e) => {{
1061            eprintln!("Build failed: {{}}", e);
1062            std::process::exit(1);
1063        }}
1064    }}
1065}}
1066"#,
1067        name, name
1068    );
1069
1070    let setup_path = output_dir.join("build.rs");
1071    fs::write(&setup_path, setup_content)
1072        .map_err(|e| format!("Failed to write setup script: {}", e))?;
1073
1074    Ok(())
1075}
1076
1077/// Check if CUDA is available for building extensions
1078pub fn cuda_is_available() -> bool {
1079    Command::new("nvcc")
1080        .arg("--version")
1081        .output()
1082        .map(|output| output.status.success())
1083        .unwrap_or(false)
1084}
1085
1086/// Get CUDA architectures available on the system
1087pub fn get_cuda_arch_list() -> Vec<String> {
1088    // This would ideally query the actual GPUs on the system
1089    // For now, return common architectures
1090    vec![
1091        "sm_70".to_string(), // V100
1092        "sm_75".to_string(), // T4, RTX 20xx
1093        "sm_80".to_string(), // A100
1094        "sm_86".to_string(), // RTX 30xx
1095        "sm_89".to_string(), // RTX 40xx
1096    ]
1097}
1098
1099/// Query CUDA devices on the system
1100fn query_cuda_devices() -> Result<Vec<CudaDeviceInfo>, String> {
1101    // In a real implementation, this would use CUDA Driver API to query devices
1102    // For now, return mock data based on CUDA availability
1103    if !cuda_is_available() {
1104        return Err("CUDA is not available on this system".to_string());
1105    }
1106
1107    // Mock device information (in real implementation, would query actual devices)
1108    let mock_device = CudaDeviceInfo {
1109        device_id: 0,
1110        name: "NVIDIA GPU".to_string(),
1111        compute_capability: "8.0".to_string(),
1112        total_memory: 8 * 1024 * 1024 * 1024, // 8GB
1113        max_threads_per_block: 1024,
1114        max_grid_size: [65535, 65535, 65535],
1115        max_block_size: [1024, 1024, 64],
1116        warp_size: 32,
1117        multiprocessor_count: 80,
1118        shared_memory_per_block: 48 * 1024, // 48KB
1119    };
1120
1121    Ok(vec![mock_device])
1122}
1123
1124/// Configure CUDA JIT compilation options
1125fn configure_cuda_jit_options(config: &CppExtensionConfig) -> Result<(), String> {
1126    // In a real implementation, this would configure CUDA driver JIT options
1127    // Such as:
1128    // - cuLinkCreate with JIT options
1129    // - Setting optimization level
1130    // - Configuring cache behavior
1131    // - Setting debug/profiling options
1132
1133    if config.jit_config.cuda_jit {
1134        // Validate JIT configuration
1135        if config.jit_config.cuda_cache_size == 0 {
1136            return Err("CUDA JIT cache size must be greater than 0".to_string());
1137        }
1138
1139        if config.jit_config.optimization_level > 3 {
1140            return Err("CUDA JIT optimization level must be 0-3".to_string());
1141        }
1142
1143        // Configure JIT options based on config
1144        // This is where we would set:
1145        // - CU_JIT_OPTIMIZATION_LEVEL
1146        // - CU_JIT_CACHE_MODE
1147        // - CU_JIT_MAX_REGISTERS
1148        // - CU_JIT_THREADS_PER_BLOCK
1149    }
1150
1151    Ok(())
1152}
1153
1154/// Validate CUDA kernel syntax
1155fn validate_cuda_kernel_syntax(cuda_source: &str, op_name: &str) -> Result<(), String> {
1156    // Basic syntax validation for CUDA kernel source
1157    let required_patterns = [
1158        "__global__", // Kernel function marker
1159        "__device__", // Or device function marker
1160        "__host__",   // Or host function marker
1161    ];
1162
1163    // Check if at least one CUDA pattern is present
1164    let has_cuda_pattern = required_patterns
1165        .iter()
1166        .any(|pattern| cuda_source.contains(pattern));
1167
1168    if !has_cuda_pattern {
1169        return Err(format!(
1170            "CUDA source for operation '{}' does not contain valid CUDA kernel markers (__global__, __device__, or __host__)",
1171            op_name
1172        ));
1173    }
1174
1175    // Check for common syntax errors
1176    let brackets_open = cuda_source.chars().filter(|&c| c == '{').count();
1177    let brackets_close = cuda_source.chars().filter(|&c| c == '}').count();
1178
1179    if brackets_open != brackets_close {
1180        return Err(format!(
1181            "CUDA source for operation '{}' has mismatched braces ({{ and }})",
1182            op_name
1183        ));
1184    }
1185
1186    // Check for semicolon at end of statements (basic check)
1187    let lines: Vec<&str> = cuda_source.lines().collect();
1188    for (i, line) in lines.iter().enumerate() {
1189        let trimmed = line.trim();
1190        if !trimmed.is_empty()
1191            && !trimmed.starts_with("//")
1192            && !trimmed.starts_with("/*")
1193            && !trimmed.ends_with('{')
1194            && !trimmed.ends_with('}')
1195            && !trimmed.ends_with(';')
1196            && !trimmed.starts_with('#')
1197        {
1198            return Err(format!(
1199                "CUDA source for operation '{}' line {} may be missing semicolon: '{}'",
1200                op_name,
1201                i + 1,
1202                trimmed
1203            ));
1204        }
1205    }
1206
1207    Ok(())
1208}
1209
1210/// Compile CUDA kernel at runtime
1211pub fn compile_cuda_kernel_runtime(
1212    kernel_source: &str,
1213    kernel_name: &str,
1214    options: &CudaKernelCompilationOptions,
1215) -> Result<RuntimeCudaKernel, String> {
1216    // Validate CUDA availability
1217    if !cuda_is_available() {
1218        return Err("CUDA is not available for runtime compilation".to_string());
1219    }
1220
1221    // Validate kernel source
1222    validate_cuda_kernel_syntax(kernel_source, kernel_name)?;
1223
1224    // In a real implementation, this would:
1225    // 1. Use CUDA Driver API to compile PTX from source
1226    // 2. Load the compiled module
1227    // 3. Get kernel function handle
1228    // 4. Configure launch parameters
1229
1230    // Generate PTX source (mock)
1231    let ptx_source = format!(
1232        r#"
1233.version 8.0
1234.target sm_80
1235.address_size 64
1236
1237.visible .entry {}(
1238    .param .u64 param_0
1239)
1240{{
1241    // Generated PTX code would go here
1242    ret;
1243}}
1244"#,
1245        kernel_name
1246    );
1247
1248    // Mock launch configuration
1249    let launch_config = CudaLaunchConfig {
1250        grid_size: [1, 1, 1],
1251        block_size: [256, 1, 1],
1252        shared_memory_size: 0,
1253        stream: None,
1254    };
1255
1256    Ok(RuntimeCudaKernel {
1257        name: kernel_name.to_string(),
1258        ptx_source,
1259        module_handle: Some(1),   // Mock handle
1260        function_handle: Some(1), // Mock handle
1261        compilation_options: options.clone(),
1262        launch_config,
1263    })
1264}
1265
1266/// Launch a runtime-compiled CUDA kernel
1267pub fn launch_cuda_kernel(
1268    kernel: &RuntimeCudaKernel,
1269    args: &[*mut std::ffi::c_void],
1270) -> Result<(), String> {
1271    // In a real implementation, this would:
1272    // 1. Validate kernel is loaded
1273    // 2. Set kernel parameters
1274    // 3. Launch kernel with configured grid/block dimensions
1275    // 4. Handle synchronization if needed
1276
1277    if kernel.module_handle.is_none() || kernel.function_handle.is_none() {
1278        return Err(format!("Kernel '{}' is not properly loaded", kernel.name));
1279    }
1280
1281    // Validate launch configuration
1282    if kernel.launch_config.grid_size[0] == 0 || kernel.launch_config.block_size[0] == 0 {
1283        return Err(format!(
1284            "Invalid launch configuration for kernel '{}'",
1285            kernel.name
1286        ));
1287    }
1288
1289    // Mock kernel launch validation
1290    println!(
1291        "Launching CUDA kernel '{}' with grid {:?} and block {:?}",
1292        kernel.name, kernel.launch_config.grid_size, kernel.launch_config.block_size
1293    );
1294
1295    // Validate argument count (basic check)
1296    if args.is_empty() {
1297        return Err(format!(
1298            "No arguments provided for kernel '{}'",
1299            kernel.name
1300        ));
1301    }
1302
1303    Ok(())
1304}
1305
1306/// Auto-tune CUDA kernel launch parameters
1307pub fn auto_tune_cuda_kernel(
1308    kernel: &mut RuntimeCudaKernel,
1309    input_sizes: &[usize],
1310) -> Result<CudaLaunchConfig, String> {
1311    // Query device properties for optimal configuration
1312    let devices = query_cuda_devices()?;
1313    let device = devices
1314        .first()
1315        .ok_or("No CUDA devices available for auto-tuning")?;
1316
1317    // Calculate optimal block size based on kernel complexity and device properties
1318    let optimal_block_size = if input_sizes.iter().any(|&size| size > 10000) {
1319        // Large inputs: use larger blocks for better memory coalescing
1320        device.max_threads_per_block.min(512)
1321    } else {
1322        // Small inputs: use smaller blocks to avoid warp underutilization
1323        device.max_threads_per_block.min(256)
1324    };
1325
1326    // Calculate grid size based on input size and block size
1327    let total_elements = input_sizes.iter().max().copied().unwrap_or(1);
1328    let optimal_grid_size =
1329        (total_elements + optimal_block_size as usize - 1) / optimal_block_size as usize;
1330
1331    // Limit grid size to device maximum
1332    let clamped_grid_size = (optimal_grid_size as u32).min(device.max_grid_size[0]);
1333
1334    let optimized_config = CudaLaunchConfig {
1335        grid_size: [clamped_grid_size, 1, 1],
1336        block_size: [optimal_block_size, 1, 1],
1337        shared_memory_size: 0, // Auto-tune shared memory based on kernel requirements
1338        stream: kernel.launch_config.stream,
1339    };
1340
1341    // Update kernel configuration
1342    kernel.launch_config = optimized_config.clone();
1343
1344    Ok(optimized_config)
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350    use std::env;
1351
1352    #[test]
1353    fn test_cpp_extension_config() {
1354        let config = CppExtensionConfig::new("test_ext", vec![PathBuf::from("test.cpp")])
1355            .include_dir("/usr/include")
1356            .library("torsh")
1357            .extra_compile_arg("-std=c++17");
1358
1359        assert_eq!(config.name, "test_ext");
1360        assert_eq!(config.sources.len(), 1);
1361        assert_eq!(config.include_dirs.len(), 1);
1362        assert_eq!(config.libraries.len(), 1);
1363    }
1364
1365    #[test]
1366    fn test_generate_template() {
1367        let temp_dir = env::temp_dir().join("torsh_test_template");
1368        let result = generate_extension_template("test_op", &temp_dir);
1369
1370        assert!(result.is_ok());
1371        assert!(temp_dir.join("test_op.h").exists());
1372        assert!(temp_dir.join("test_op.cpp").exists());
1373        assert!(temp_dir.join("build.rs").exists());
1374
1375        // Cleanup
1376        let _ = fs::remove_dir_all(temp_dir);
1377    }
1378
1379    #[test]
1380    fn test_cuda_detection() {
1381        // This test might fail on systems without CUDA
1382        let available = cuda_is_available();
1383        println!("CUDA available: {}", available);
1384
1385        if available {
1386            let archs = get_cuda_arch_list();
1387            assert!(!archs.is_empty());
1388        }
1389    }
1390
1391    #[test]
1392    fn test_cuda_kernel_compilation_options() {
1393        let default_options = CudaKernelCompilationOptions::default();
1394        assert_eq!(default_options.optimization_level, 2);
1395        assert!(!default_options.fast_math);
1396        assert!(default_options.use_cache);
1397        assert!(!default_options.debug_info);
1398
1399        let custom_options = CudaKernelCompilationOptions {
1400            optimization_level: 3,
1401            fast_math: true,
1402            max_registers: Some(64),
1403            debug_info: true,
1404            target_arch: Some("sm_80".to_string()),
1405            ..Default::default()
1406        };
1407
1408        assert_eq!(custom_options.optimization_level, 3);
1409        assert!(custom_options.fast_math);
1410        assert_eq!(custom_options.max_registers, Some(64));
1411        assert!(custom_options.debug_info);
1412    }
1413
1414    #[test]
1415    fn test_cuda_kernel_syntax_validation() {
1416        // Valid CUDA kernel
1417        let valid_kernel = r#"
1418        __global__ void test_kernel(float* input, float* output) {
1419            int idx = blockIdx.x * blockDim.x + threadIdx.x;
1420            output[idx] = input[idx] * 2.0f;
1421        }
1422        "#;
1423
1424        assert!(validate_cuda_kernel_syntax(valid_kernel, "test_kernel").is_ok());
1425
1426        // Invalid kernel (missing __global__)
1427        let invalid_kernel = r#"
1428        void test_kernel(float* input, float* output) {
1429            int idx = blockIdx.x * blockDim.x + threadIdx.x;
1430            output[idx] = input[idx] * 2.0f;
1431        }
1432        "#;
1433
1434        assert!(validate_cuda_kernel_syntax(invalid_kernel, "test_kernel").is_err());
1435
1436        // Invalid kernel (mismatched braces)
1437        let invalid_braces = r#"
1438        __global__ void test_kernel(float* input, float* output) {
1439            int idx = blockIdx.x * blockDim.x + threadIdx.x;
1440            output[idx] = input[idx] * 2.0f;
1441        // Missing closing brace
1442        "#;
1443
1444        assert!(validate_cuda_kernel_syntax(invalid_braces, "test_kernel").is_err());
1445    }
1446
1447    #[test]
1448    fn test_runtime_cuda_kernel_compilation() {
1449        let kernel_source = r#"
1450        __global__ void vector_add(float* a, float* b, float* c, int n) {
1451            int idx = blockIdx.x * blockDim.x + threadIdx.x;
1452            if (idx < n) {
1453                c[idx] = a[idx] + b[idx];
1454            }
1455        }
1456        "#;
1457
1458        let options = CudaKernelCompilationOptions::default();
1459
1460        // This should work even without CUDA (returns mock result)
1461        if cuda_is_available() {
1462            let result = compile_cuda_kernel_runtime(kernel_source, "vector_add", &options);
1463            if let Ok(kernel) = result {
1464                assert_eq!(kernel.name, "vector_add");
1465                assert!(!kernel.ptx_source.is_empty());
1466                assert!(kernel.module_handle.is_some());
1467                assert!(kernel.function_handle.is_some());
1468            }
1469        }
1470    }
1471
1472    #[test]
1473    fn test_cuda_launch_config_auto_tuning() {
1474        if cuda_is_available() {
1475            let kernel_source = r#"
1476            __global__ void simple_kernel(float* data) {
1477                int idx = blockIdx.x * blockDim.x + threadIdx.x;
1478                data[idx] *= 2.0f;
1479            }
1480            "#;
1481
1482            let options = CudaKernelCompilationOptions::default();
1483            let result = compile_cuda_kernel_runtime(kernel_source, "simple_kernel", &options);
1484
1485            if let Ok(mut kernel) = result {
1486                let input_sizes = vec![1024, 2048, 4096];
1487                let tuned_config = auto_tune_cuda_kernel(&mut kernel, &input_sizes);
1488
1489                if let Ok(config) = tuned_config {
1490                    assert!(config.grid_size[0] > 0);
1491                    assert!(config.block_size[0] > 0);
1492                    assert!(config.block_size[0] <= 1024); // Max threads per block
1493                }
1494            }
1495        }
1496    }
1497
1498    #[test]
1499    fn test_custom_op_with_cuda_jit() {
1500        let custom_op = CustomOpDefinition {
1501            name: "custom_relu".to_string(),
1502            op_type: CustomOpType::Forward,
1503            input_shapes: vec![None], // Dynamic shape
1504            output_shapes: vec![None],
1505            cpu_source: Some("return torch::relu(inputs[0]);".to_string()),
1506            cuda_source: Some(
1507                r#"
1508            __global__ void relu_kernel(float* input, float* output, int size) {
1509                int idx = blockIdx.x * blockDim.x + threadIdx.x;
1510                if (idx < size) {
1511                    output[idx] = fmaxf(0.0f, input[idx]);
1512                }
1513            }
1514            "#
1515                .to_string(),
1516            ),
1517            compile_flags: vec!["-O3".to_string()],
1518            schema: OpSchema {
1519                input_types: vec![TensorType {
1520                    dtype: "float32".to_string(),
1521                    min_dims: 1,
1522                    max_dims: None,
1523                    supports_sparse: false,
1524                }],
1525                output_types: vec![TensorType {
1526                    dtype: "float32".to_string(),
1527                    min_dims: 1,
1528                    max_dims: None,
1529                    supports_sparse: false,
1530                }],
1531                is_elementwise: true,
1532                is_deterministic: true,
1533                memory_requirement: MemoryRequirement::Linear,
1534            },
1535        };
1536
1537        let config = CppExtensionConfig::new("custom_relu_ext", vec![])
1538            .enable_cuda_jit()
1539            .custom_op(custom_op);
1540
1541        assert!(config.jit_config.cuda_jit);
1542        assert_eq!(config.custom_ops.len(), 1);
1543        assert_eq!(config.custom_ops[0].name, "custom_relu");
1544    }
1545}