1use std::collections::HashMap;
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12#[derive(Debug, Clone, Default)]
14pub struct JitCompilationConfig {
15 pub enabled: bool,
17 pub cache_enabled: bool,
19 pub cache_dir: Option<PathBuf>,
21 pub optimization_level: u8,
23 pub cuda_jit: bool,
25 pub cuda_cache_size: usize,
27 pub cuda_max_registers: Option<u32>,
29}
30
31#[derive(Debug, Clone)]
33pub struct CustomOpDefinition {
34 pub name: String,
36 pub op_type: CustomOpType,
38 pub input_shapes: Vec<Option<Vec<usize>>>,
40 pub output_shapes: Vec<Option<Vec<usize>>>,
42 pub cpu_source: Option<String>,
44 pub cuda_source: Option<String>,
46 pub compile_flags: Vec<String>,
48 pub schema: OpSchema,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub enum CustomOpType {
55 Forward,
56 Backward,
57 ForwardBackward,
58}
59
60#[derive(Debug, Clone, Default)]
62pub struct OpSchema {
63 pub input_types: Vec<TensorType>,
65 pub output_types: Vec<TensorType>,
67 pub is_elementwise: bool,
69 pub is_deterministic: bool,
71 pub memory_requirement: MemoryRequirement,
73}
74
75#[derive(Debug, Clone)]
77pub struct TensorType {
78 pub dtype: String,
80 pub min_dims: usize,
82 pub max_dims: Option<usize>,
84 pub supports_sparse: bool,
86}
87
88#[derive(Debug, Clone, Default)]
90pub enum MemoryRequirement {
91 #[default]
92 Unknown,
93 Constant,
95 Linear,
97 Quadratic,
99 Custom(String),
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct CrossPlatformConfig {
106 pub target_platforms: Vec<TargetPlatform>,
108 pub windows: WindowsConfig,
110 pub macos: MacOsConfig,
112 pub linux: LinuxConfig,
114 pub cross_compile: bool,
116 pub use_docker: bool,
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Hash)]
122pub enum TargetPlatform {
123 WindowsX64,
124 WindowsX86,
125 MacOsX64,
126 MacOsArm64,
127 LinuxX64,
128 LinuxArm64,
129 LinuxAarch64,
130}
131
132#[derive(Debug, Clone, Default)]
134pub struct WindowsConfig {
135 pub vs_version: Option<String>,
137 pub sdk_version: Option<String>,
139 pub use_clang: bool,
141 pub enable_simd: bool,
143}
144
145#[derive(Debug, Clone, Default)]
147pub struct MacOsConfig {
148 pub min_version: Option<String>,
150 pub xcode_version: Option<String>,
152 pub enable_mps: bool,
154 pub universal_binary: bool,
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct LinuxConfig {
161 pub compiler_preference: CompilerPreference,
163 pub enable_mkl: bool,
165 pub enable_openmp: bool,
167 pub distro_packages: Vec<String>,
169}
170
171#[derive(Debug, Clone, Default, PartialEq, Eq)]
173pub enum CompilerPreference {
174 #[default]
175 Auto,
176 Gcc,
177 Clang,
178 Intel,
179}
180
181#[derive(Debug, Clone)]
183pub struct CppExtensionConfig {
184 pub name: String,
186 pub sources: Vec<PathBuf>,
188 pub include_dirs: Vec<PathBuf>,
190 pub library_dirs: Vec<PathBuf>,
192 pub libraries: Vec<String>,
194 pub extra_compile_args: Vec<String>,
196 pub extra_link_args: Vec<String>,
198 pub with_cuda: bool,
200 pub cuda_archs: Vec<String>,
202 pub debug: bool,
204 pub build_dir: PathBuf,
206 pub jit_config: JitCompilationConfig,
208 pub custom_ops: Vec<CustomOpDefinition>,
210 pub cross_platform: CrossPlatformConfig,
212}
213
214impl CppExtensionConfig {
215 pub fn new(name: impl Into<String>, sources: Vec<PathBuf>) -> Self {
217 let name = name.into();
218 let build_dir = env::temp_dir().join("torsh_cpp_extensions").join(&name);
219
220 Self {
221 name,
222 sources,
223 include_dirs: vec![],
224 library_dirs: vec![],
225 libraries: vec![],
226 extra_compile_args: vec![],
227 extra_link_args: vec![],
228 with_cuda: false,
229 cuda_archs: vec![
230 "sm_70".to_string(),
231 "sm_75".to_string(),
232 "sm_80".to_string(),
233 "sm_86".to_string(),
234 "sm_89".to_string(),
235 ],
236 debug: false,
237 build_dir,
238 jit_config: JitCompilationConfig::default(),
239 custom_ops: vec![],
240 cross_platform: CrossPlatformConfig::default(),
241 }
242 }
243
244 pub fn include_dir(mut self, dir: impl AsRef<Path>) -> Self {
246 self.include_dirs.push(dir.as_ref().to_path_buf());
247 self
248 }
249
250 pub fn library_dir(mut self, dir: impl AsRef<Path>) -> Self {
252 self.library_dirs.push(dir.as_ref().to_path_buf());
253 self
254 }
255
256 pub fn library(mut self, lib: impl Into<String>) -> Self {
258 self.libraries.push(lib.into());
259 self
260 }
261
262 pub fn extra_compile_arg(mut self, arg: impl Into<String>) -> Self {
264 self.extra_compile_args.push(arg.into());
265 self
266 }
267
268 pub fn extra_link_arg(mut self, arg: impl Into<String>) -> Self {
270 self.extra_link_args.push(arg.into());
271 self
272 }
273
274 pub fn cuda(mut self, cuda_archs: Vec<String>) -> Self {
276 self.with_cuda = true;
277 self.cuda_archs = cuda_archs;
278 self
279 }
280
281 pub fn debug(mut self) -> Self {
283 self.debug = true;
284 self
285 }
286
287 pub fn build_dir(mut self, dir: impl AsRef<Path>) -> Self {
289 self.build_dir = dir.as_ref().to_path_buf();
290 self
291 }
292
293 pub fn jit(mut self, config: JitCompilationConfig) -> Self {
295 self.jit_config = config;
296 self
297 }
298
299 pub fn custom_op(mut self, op: CustomOpDefinition) -> Self {
301 self.custom_ops.push(op);
302 self
303 }
304
305 pub fn cross_platform(mut self, config: CrossPlatformConfig) -> Self {
307 self.cross_platform = config;
308 self
309 }
310
311 pub fn enable_jit(mut self) -> Self {
313 self.jit_config.enabled = true;
314 self.jit_config.cache_enabled = true;
315 self.jit_config.optimization_level = 2;
316 self
317 }
318
319 pub fn enable_cuda_jit(mut self) -> Self {
321 self.jit_config.cuda_jit = true;
322 self.jit_config.cuda_cache_size = 256; self
324 }
325}
326
327#[derive(Debug)]
329pub struct BuildResult {
330 pub library_path: PathBuf,
332 pub include_dirs: Vec<PathBuf>,
334 pub jit_info: Option<JitBuildInfo>,
336 pub compiled_ops: Vec<String>,
338 pub platform_artifacts: HashMap<TargetPlatform, PathBuf>,
340}
341
342#[derive(Debug)]
344pub struct JitBuildInfo {
345 pub cache_dir: PathBuf,
347 pub kernel_count: usize,
349 pub cuda_info: Option<CudaJitInfo>,
351}
352
353#[derive(Debug)]
355pub struct CudaJitInfo {
356 pub ptx_cache_size: usize,
358 pub kernel_count: usize,
360 pub compute_capability: Vec<String>,
362 pub cache_hits: usize,
364 pub cache_misses: usize,
366 pub compilation_time_ms: f64,
368}
369
370#[derive(Debug, Clone)]
372pub struct CudaDeviceInfo {
373 pub device_id: u32,
375 pub name: String,
377 pub compute_capability: String,
379 pub total_memory: usize,
381 pub max_threads_per_block: u32,
383 pub max_grid_size: [u32; 3],
385 pub max_block_size: [u32; 3],
387 pub warp_size: u32,
389 pub multiprocessor_count: u32,
391 pub shared_memory_per_block: usize,
393}
394
395#[derive(Debug, Clone)]
397pub struct CudaKernelCompilationOptions {
398 pub optimization_level: u8,
400 pub fast_math: bool,
402 pub max_registers: Option<u32>,
404 pub use_cache: bool,
406 pub debug_info: bool,
408 pub target_arch: Option<String>,
410 pub custom_flags: Vec<String>,
412}
413
414impl Default for CudaKernelCompilationOptions {
415 fn default() -> Self {
416 Self {
417 optimization_level: 2,
418 fast_math: false,
419 max_registers: None,
420 use_cache: true,
421 debug_info: false,
422 target_arch: None,
423 custom_flags: vec![],
424 }
425 }
426}
427
428#[derive(Debug)]
430pub struct RuntimeCudaKernel {
431 pub name: String,
433 pub ptx_source: String,
435 pub module_handle: Option<usize>,
437 pub function_handle: Option<usize>,
439 pub compilation_options: CudaKernelCompilationOptions,
441 pub launch_config: CudaLaunchConfig,
443}
444
445#[derive(Debug, Clone)]
447pub struct CudaLaunchConfig {
448 pub grid_size: [u32; 3],
450 pub block_size: [u32; 3],
452 pub shared_memory_size: usize,
454 pub stream: Option<usize>,
456}
457
458pub fn build_cpp_extension(config: &CppExtensionConfig) -> Result<BuildResult, String> {
460 fs::create_dir_all(&config.build_dir)
462 .map_err(|e| format!("Failed to create build directory: {}", e))?;
463
464 let jit_info = if config.jit_config.enabled {
466 Some(setup_jit_compilation(config)?)
467 } else {
468 None
469 };
470
471 let mut generated_sources = vec![];
473 let mut compiled_ops = vec![];
474
475 for custom_op in &config.custom_ops {
476 let generated_source = generate_custom_op_source(custom_op)?;
477 generated_sources.push(generated_source);
478 compiled_ops.push(custom_op.name.clone());
479 }
480
481 let mut platform_artifacts = HashMap::new();
483
484 if config.cross_platform.target_platforms.is_empty() {
485 let artifact = build_for_platform(config, None, &generated_sources, &jit_info)?;
487 platform_artifacts.insert(detect_current_platform(), artifact);
488 } else {
489 for platform in &config.cross_platform.target_platforms {
491 let artifact =
492 build_for_platform(config, Some(platform), &generated_sources, &jit_info)?;
493 platform_artifacts.insert(platform.clone(), artifact);
494 }
495 }
496
497 let main_artifact = platform_artifacts
499 .get(&detect_current_platform())
500 .or_else(|| platform_artifacts.values().next())
501 .ok_or("No artifacts built")?
502 .clone();
503
504 Ok(BuildResult {
505 library_path: main_artifact,
506 include_dirs: config.include_dirs.clone(),
507 jit_info,
508 compiled_ops,
509 platform_artifacts,
510 })
511}
512
513fn setup_jit_compilation(config: &CppExtensionConfig) -> Result<JitBuildInfo, String> {
515 let cache_dir = config
516 .jit_config
517 .cache_dir
518 .clone()
519 .unwrap_or_else(|| config.build_dir.join("jit_cache"));
520
521 fs::create_dir_all(&cache_dir)
522 .map_err(|e| format!("Failed to create JIT cache directory: {}", e))?;
523
524 let cuda_info = if config.jit_config.cuda_jit && config.with_cuda {
525 Some(setup_cuda_jit(config, &cache_dir)?)
526 } else {
527 None
528 };
529
530 Ok(JitBuildInfo {
531 cache_dir,
532 kernel_count: config.custom_ops.len(),
533 cuda_info,
534 })
535}
536
537fn setup_cuda_jit(config: &CppExtensionConfig, cache_dir: &Path) -> Result<CudaJitInfo, String> {
539 let cuda_cache_dir = cache_dir.join("cuda");
540 fs::create_dir_all(&cuda_cache_dir)
541 .map_err(|e| format!("Failed to create CUDA cache directory: {}", e))?;
542
543 let device_info = query_cuda_devices()?;
545 let available_archs = device_info
546 .iter()
547 .map(|dev| format!("sm_{}", dev.compute_capability.replace(".", "")))
548 .collect::<Vec<_>>();
549
550 let ptx_cache_dir = cuda_cache_dir.join("ptx");
552 let cubin_cache_dir = cuda_cache_dir.join("cubin");
553 fs::create_dir_all(&ptx_cache_dir)
554 .map_err(|e| format!("Failed to create PTX cache directory: {}", e))?;
555 fs::create_dir_all(&cubin_cache_dir)
556 .map_err(|e| format!("Failed to create CUBIN cache directory: {}", e))?;
557
558 configure_cuda_jit_options(config)?;
560
561 for op in &config.custom_ops {
563 if let Some(cuda_source) = &op.cuda_source {
564 validate_cuda_kernel_syntax(cuda_source, &op.name)?;
565 }
566 }
567
568 Ok(CudaJitInfo {
569 ptx_cache_size: config.jit_config.cuda_cache_size * 1024 * 1024, kernel_count: config
571 .custom_ops
572 .iter()
573 .filter(|op| op.cuda_source.is_some())
574 .count(),
575 compute_capability: available_archs,
576 cache_hits: 0,
577 cache_misses: 0,
578 compilation_time_ms: 0.0,
579 })
580}
581
582fn generate_custom_op_source(op: &CustomOpDefinition) -> Result<PathBuf, String> {
584 let source_content = match &op.op_type {
586 CustomOpType::Forward => generate_forward_op(&op.name, &op.cpu_source, &op.cuda_source)?,
587 CustomOpType::Backward => generate_backward_op(&op.name, &op.cpu_source, &op.cuda_source)?,
588 CustomOpType::ForwardBackward => {
589 generate_forward_backward_op(&op.name, &op.cpu_source, &op.cuda_source)?
590 }
591 };
592
593 let temp_file = env::temp_dir().join(format!("{}_custom_op.cpp", op.name));
595 fs::write(&temp_file, source_content)
596 .map_err(|e| format!("Failed to write custom op source: {}", e))?;
597
598 Ok(temp_file)
599}
600
601fn generate_forward_op(
603 name: &str,
604 cpu_source: &Option<String>,
605 cuda_source: &Option<String>,
606) -> Result<String, String> {
607 let mut source = format!(
608 r#"// Generated custom operation: {}
609#include <torsh/tensor.h>
610#include <torsh/autograd.h>
611
612namespace torsh {{
613namespace ops {{
614
615"#,
616 name
617 );
618
619 if let Some(cpu_impl) = cpu_source {
621 source.push_str(&format!(
622 r#"
623// CPU implementation
624Tensor {}_cpu_forward(const std::vector<Tensor>& inputs) {{
625 {}
626}}
627"#,
628 name, cpu_impl
629 ));
630 }
631
632 if let Some(cuda_impl) = cuda_source {
634 source.push_str(&format!(
635 r#"
636#ifdef TORSH_USE_CUDA
637// CUDA implementation
638Tensor {}_cuda_forward(const std::vector<Tensor>& inputs) {{
639 {}
640}}
641#endif
642"#,
643 name, cuda_impl
644 ));
645 }
646
647 source.push_str(&format!(
649 r#"
650// Operation dispatcher
651Tensor {}_forward(const std::vector<Tensor>& inputs) {{
652#ifdef TORSH_USE_CUDA
653 if (inputs[0].is_cuda()) {{
654 return {}_cuda_forward(inputs);
655 }}
656#endif
657 return {}_cpu_forward(inputs);
658}}
659
660// Register operation
661TORSH_REGISTER_OP("{}", {}_forward);
662
663}} // namespace ops
664}} // namespace torsh
665"#,
666 name, name, name, name, name
667 ));
668
669 Ok(source)
670}
671
672fn generate_backward_op(
674 name: &str,
675 cpu_source: &Option<String>,
676 cuda_source: &Option<String>,
677) -> Result<String, String> {
678 let mut source = format!(
680 r#"// Generated custom backward operation: {}
681#include <torsh/tensor.h>
682#include <torsh/autograd.h>
683
684namespace torsh {{
685namespace ops {{
686"#,
687 name
688 );
689
690 if let Some(cpu_impl) = cpu_source {
691 source.push_str(&format!(
692 r#"
693std::vector<Tensor> {}_cpu_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
694 {}
695}}
696"#,
697 name, cpu_impl
698 ));
699 }
700
701 if let Some(cuda_impl) = cuda_source {
702 source.push_str(&format!(
703 r#"
704#ifdef TORSH_USE_CUDA
705std::vector<Tensor> {}_cuda_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
706 {}
707}}
708#endif
709"#,
710 name, cuda_impl
711 ));
712 }
713
714 source.push_str(&format!(
715 r#"
716std::vector<Tensor> {}_backward(const std::vector<Tensor>& grad_outputs, const std::vector<Tensor>& inputs) {{
717#ifdef TORSH_USE_CUDA
718 if (inputs[0].is_cuda()) {{
719 return {}_cuda_backward(grad_outputs, inputs);
720 }}
721#endif
722 return {}_cpu_backward(grad_outputs, inputs);
723}}
724
725TORSH_REGISTER_BACKWARD_OP("{}", {}_backward);
726
727}} // namespace ops
728}} // namespace torsh
729"#,
730 name, name, name, name, name
731 ));
732
733 Ok(source)
734}
735
736fn generate_forward_backward_op(
738 name: &str,
739 cpu_source: &Option<String>,
740 cuda_source: &Option<String>,
741) -> Result<String, String> {
742 let forward_source = generate_forward_op(name, cpu_source, cuda_source)?;
744 let backward_source =
745 generate_backward_op(&format!("{}_backward", name), cpu_source, cuda_source)?;
746
747 Ok(format!("{}\n\n{}", forward_source, backward_source))
748}
749
750fn build_for_platform(
752 config: &CppExtensionConfig,
753 target_platform: Option<&TargetPlatform>,
754 generated_sources: &[PathBuf],
755 _jit_info: &Option<JitBuildInfo>,
756) -> Result<PathBuf, String> {
757 let (compiler, extra_flags) =
759 match target_platform {
760 Some(TargetPlatform::WindowsX64) | Some(TargetPlatform::WindowsX86) => {
761 if config.cross_platform.windows.use_clang {
762 (
763 "clang++".to_string(),
764 vec![
765 "-target".to_string(),
766 get_windows_target(target_platform.expect(
767 "target_platform should be Some for Windows platform branch",
768 )),
769 ],
770 )
771 } else {
772 ("cl.exe".to_string(), vec!["/std:c++17".to_string()])
773 }
774 }
775 Some(TargetPlatform::MacOsX64) | Some(TargetPlatform::MacOsArm64) => {
776 let target = match target_platform
777 .expect("target_platform should be Some for macOS platform branch")
778 {
779 TargetPlatform::MacOsX64 => "x86_64-apple-darwin",
780 TargetPlatform::MacOsArm64 => "arm64-apple-darwin",
781 _ => unreachable!(),
782 };
783 (
784 "clang++".to_string(),
785 vec!["-target".to_string(), target.to_string()],
786 )
787 }
788 Some(TargetPlatform::LinuxX64)
789 | Some(TargetPlatform::LinuxArm64)
790 | Some(TargetPlatform::LinuxAarch64) => {
791 match config.cross_platform.linux.compiler_preference {
792 CompilerPreference::Clang => ("clang++".to_string(), vec![]),
793 CompilerPreference::Gcc => ("g++".to_string(), vec![]),
794 CompilerPreference::Intel => ("icpc".to_string(), vec![]),
795 CompilerPreference::Auto => (
796 env::var("CXX").unwrap_or_else(|_| "g++".to_string()),
797 vec![],
798 ),
799 }
800 }
801 None => {
802 if config.with_cuda {
804 ("nvcc".to_string(), vec![])
805 } else {
806 (
807 env::var("CXX").unwrap_or_else(|_| "c++".to_string()),
808 vec![],
809 )
810 }
811 }
812 };
813
814 let mut cmd = Command::new(&compiler);
816
817 cmd.args(&extra_flags);
819
820 for include_dir in &config.include_dirs {
822 cmd.arg(format!("-I{}", include_dir.display()));
823 }
824
825 if let Ok(torsh_include) = env::var("TORSH_INCLUDE_DIR") {
827 cmd.arg(format!("-I{}", torsh_include));
828 }
829
830 if compiler.contains("cl.exe") {
832 cmd.arg("/std:c++17");
834 if !config.debug {
835 cmd.arg("/O2");
836 cmd.arg("/DNDEBUG");
837 } else {
838 cmd.arg("/Od");
839 cmd.arg("/Zi");
840 }
841 } else {
842 cmd.arg("-std=c++17");
844 cmd.arg("-fPIC");
845 if !config.debug {
846 cmd.arg("-O3");
847 cmd.arg("-DNDEBUG");
848 } else {
849 cmd.arg("-g");
850 cmd.arg("-O0");
851 }
852 }
853
854 if config.with_cuda && compiler.contains("nvcc") {
856 for arch in &config.cuda_archs {
857 cmd.arg(format!(
858 "-gencode=arch=compute_{},code={}",
859 &arch[3..],
860 arch
861 ));
862 }
863 cmd.arg("-x").arg("cu");
864 }
865
866 for arg in &config.extra_compile_args {
868 cmd.arg(arg);
869 }
870
871 for source in &config.sources {
873 cmd.arg(source);
874 }
875 for source in generated_sources {
876 cmd.arg(source);
877 }
878
879 let platform_suffix = target_platform
881 .map(|p| format!("_{:?}", p))
882 .unwrap_or_default();
883 let output_file = config
884 .build_dir
885 .join(format!("lib{}{}.so", config.name, platform_suffix));
886
887 if compiler.contains("cl.exe") {
888 cmd.arg("/Fe:").arg(&output_file);
889 cmd.arg("/LD"); } else {
891 cmd.arg("-shared");
892 cmd.arg("-o").arg(&output_file);
893 }
894
895 for lib_dir in &config.library_dirs {
897 if compiler.contains("cl.exe") {
898 cmd.arg(format!("/LIBPATH:{}", lib_dir.display()));
899 } else {
900 cmd.arg(format!("-L{}", lib_dir.display()));
901 }
902 }
903
904 for lib in &config.libraries {
906 if compiler.contains("cl.exe") {
907 cmd.arg(format!("{}.lib", lib));
908 } else {
909 cmd.arg(format!("-l{}", lib));
910 }
911 }
912
913 for arg in &config.extra_link_args {
915 cmd.arg(arg);
916 }
917
918 let output = cmd
920 .output()
921 .map_err(|e| format!("Failed to execute compiler {}: {}", compiler, e))?;
922
923 if !output.status.success() {
924 let stderr = String::from_utf8_lossy(&output.stderr);
925 return Err(format!(
926 "Compilation failed for platform {:?}:\n{}",
927 target_platform, stderr
928 ));
929 }
930
931 Ok(output_file)
932}
933
934fn detect_current_platform() -> TargetPlatform {
936 match env::consts::OS {
937 "windows" => match env::consts::ARCH {
938 "x86_64" => TargetPlatform::WindowsX64,
939 "x86" => TargetPlatform::WindowsX86,
940 _ => TargetPlatform::WindowsX64, },
942 "macos" => match env::consts::ARCH {
943 "aarch64" => TargetPlatform::MacOsArm64,
944 _ => TargetPlatform::MacOsX64,
945 },
946 "linux" => match env::consts::ARCH {
947 "aarch64" => TargetPlatform::LinuxAarch64,
948 "arm64" => TargetPlatform::LinuxArm64,
949 _ => TargetPlatform::LinuxX64,
950 },
951 _ => TargetPlatform::LinuxX64, }
953}
954
955fn get_windows_target(platform: &TargetPlatform) -> String {
957 match platform {
958 TargetPlatform::WindowsX64 => "x86_64-pc-windows-msvc".to_string(),
959 TargetPlatform::WindowsX86 => "i686-pc-windows-msvc".to_string(),
960 _ => "x86_64-pc-windows-msvc".to_string(), }
962}
963
964pub fn load_cpp_extension(library_path: &Path) -> Result<(), String> {
966 if !library_path.exists() {
969 return Err(format!("Library not found: {}", library_path.display()));
970 }
971
972 Ok(())
978}
979
980pub fn generate_extension_template(name: &str, output_dir: &Path) -> Result<(), String> {
982 fs::create_dir_all(output_dir)
983 .map_err(|e| format!("Failed to create output directory: {}", e))?;
984
985 let header_content = format!(
987 r#"#pragma once
988
989#include <torsh/tensor.h>
990#include <torsh/module.h>
991
992namespace torsh {{
993namespace ops {{
994
995// Example custom operation
996Tensor {}_forward(const Tensor& input);
997
998}} // namespace ops
999}} // namespace torsh
1000"#,
1001 name
1002 );
1003
1004 let header_path = output_dir.join(format!("{}.h", name));
1005 fs::write(&header_path, header_content)
1006 .map_err(|e| format!("Failed to write header file: {}", e))?;
1007
1008 let source_content = format!(
1010 r#"#include "{}.h"
1011#include <torsh/autograd.h>
1012#include <iostream>
1013
1014namespace torsh {{
1015namespace ops {{
1016
1017Tensor {}_forward(const Tensor& input) {{
1018 // Example implementation
1019 auto output = input.clone();
1020
1021 // Perform custom operation
1022 // This is where you would implement your custom logic
1023
1024 return output;
1025}}
1026
1027// Register the operation
1028TORSH_LIBRARY(TORCH_EXTENSION_NAME, m) {{
1029 m.def("{}_forward", &{}_forward);
1030}}
1031
1032}} // namespace ops
1033}} // namespace torsh
1034"#,
1035 name, name, name, name
1036 );
1037
1038 let source_path = output_dir.join(format!("{}.cpp", name));
1039 fs::write(&source_path, source_content)
1040 .map_err(|e| format!("Failed to write source file: {}", e))?;
1041
1042 let setup_content = format!(
1044 r#"use torsh_utils::cpp_extension::{{CppExtensionConfig, build_cpp_extension}};
1045use std::path::PathBuf;
1046
1047fn main() {{
1048 let config = CppExtensionConfig::new("{}", vec![
1049 PathBuf::from("{}.cpp"),
1050 ])
1051 .include_dir(".")
1052 .extra_compile_arg("-Wall")
1053 .extra_compile_arg("-Wextra");
1054
1055 match build_cpp_extension(&config) {{
1056 Ok(result) => {{
1057 println!("Extension built successfully!");
1058 println!("Library: {{:?}}", result.library_path);
1059 }}
1060 Err(e) => {{
1061 eprintln!("Build failed: {{}}", e);
1062 std::process::exit(1);
1063 }}
1064 }}
1065}}
1066"#,
1067 name, name
1068 );
1069
1070 let setup_path = output_dir.join("build.rs");
1071 fs::write(&setup_path, setup_content)
1072 .map_err(|e| format!("Failed to write setup script: {}", e))?;
1073
1074 Ok(())
1075}
1076
1077pub fn cuda_is_available() -> bool {
1079 Command::new("nvcc")
1080 .arg("--version")
1081 .output()
1082 .map(|output| output.status.success())
1083 .unwrap_or(false)
1084}
1085
1086pub fn get_cuda_arch_list() -> Vec<String> {
1088 vec![
1091 "sm_70".to_string(), "sm_75".to_string(), "sm_80".to_string(), "sm_86".to_string(), "sm_89".to_string(), ]
1097}
1098
1099fn query_cuda_devices() -> Result<Vec<CudaDeviceInfo>, String> {
1101 if !cuda_is_available() {
1104 return Err("CUDA is not available on this system".to_string());
1105 }
1106
1107 let mock_device = CudaDeviceInfo {
1109 device_id: 0,
1110 name: "NVIDIA GPU".to_string(),
1111 compute_capability: "8.0".to_string(),
1112 total_memory: 8 * 1024 * 1024 * 1024, max_threads_per_block: 1024,
1114 max_grid_size: [65535, 65535, 65535],
1115 max_block_size: [1024, 1024, 64],
1116 warp_size: 32,
1117 multiprocessor_count: 80,
1118 shared_memory_per_block: 48 * 1024, };
1120
1121 Ok(vec![mock_device])
1122}
1123
1124fn configure_cuda_jit_options(config: &CppExtensionConfig) -> Result<(), String> {
1126 if config.jit_config.cuda_jit {
1134 if config.jit_config.cuda_cache_size == 0 {
1136 return Err("CUDA JIT cache size must be greater than 0".to_string());
1137 }
1138
1139 if config.jit_config.optimization_level > 3 {
1140 return Err("CUDA JIT optimization level must be 0-3".to_string());
1141 }
1142
1143 }
1150
1151 Ok(())
1152}
1153
1154fn validate_cuda_kernel_syntax(cuda_source: &str, op_name: &str) -> Result<(), String> {
1156 let required_patterns = [
1158 "__global__", "__device__", "__host__", ];
1162
1163 let has_cuda_pattern = required_patterns
1165 .iter()
1166 .any(|pattern| cuda_source.contains(pattern));
1167
1168 if !has_cuda_pattern {
1169 return Err(format!(
1170 "CUDA source for operation '{}' does not contain valid CUDA kernel markers (__global__, __device__, or __host__)",
1171 op_name
1172 ));
1173 }
1174
1175 let brackets_open = cuda_source.chars().filter(|&c| c == '{').count();
1177 let brackets_close = cuda_source.chars().filter(|&c| c == '}').count();
1178
1179 if brackets_open != brackets_close {
1180 return Err(format!(
1181 "CUDA source for operation '{}' has mismatched braces ({{ and }})",
1182 op_name
1183 ));
1184 }
1185
1186 let lines: Vec<&str> = cuda_source.lines().collect();
1188 for (i, line) in lines.iter().enumerate() {
1189 let trimmed = line.trim();
1190 if !trimmed.is_empty()
1191 && !trimmed.starts_with("//")
1192 && !trimmed.starts_with("/*")
1193 && !trimmed.ends_with('{')
1194 && !trimmed.ends_with('}')
1195 && !trimmed.ends_with(';')
1196 && !trimmed.starts_with('#')
1197 {
1198 return Err(format!(
1199 "CUDA source for operation '{}' line {} may be missing semicolon: '{}'",
1200 op_name,
1201 i + 1,
1202 trimmed
1203 ));
1204 }
1205 }
1206
1207 Ok(())
1208}
1209
1210pub fn compile_cuda_kernel_runtime(
1212 kernel_source: &str,
1213 kernel_name: &str,
1214 options: &CudaKernelCompilationOptions,
1215) -> Result<RuntimeCudaKernel, String> {
1216 if !cuda_is_available() {
1218 return Err("CUDA is not available for runtime compilation".to_string());
1219 }
1220
1221 validate_cuda_kernel_syntax(kernel_source, kernel_name)?;
1223
1224 let ptx_source = format!(
1232 r#"
1233.version 8.0
1234.target sm_80
1235.address_size 64
1236
1237.visible .entry {}(
1238 .param .u64 param_0
1239)
1240{{
1241 // Generated PTX code would go here
1242 ret;
1243}}
1244"#,
1245 kernel_name
1246 );
1247
1248 let launch_config = CudaLaunchConfig {
1250 grid_size: [1, 1, 1],
1251 block_size: [256, 1, 1],
1252 shared_memory_size: 0,
1253 stream: None,
1254 };
1255
1256 Ok(RuntimeCudaKernel {
1257 name: kernel_name.to_string(),
1258 ptx_source,
1259 module_handle: Some(1), function_handle: Some(1), compilation_options: options.clone(),
1262 launch_config,
1263 })
1264}
1265
1266pub fn launch_cuda_kernel(
1268 kernel: &RuntimeCudaKernel,
1269 args: &[*mut std::ffi::c_void],
1270) -> Result<(), String> {
1271 if kernel.module_handle.is_none() || kernel.function_handle.is_none() {
1278 return Err(format!("Kernel '{}' is not properly loaded", kernel.name));
1279 }
1280
1281 if kernel.launch_config.grid_size[0] == 0 || kernel.launch_config.block_size[0] == 0 {
1283 return Err(format!(
1284 "Invalid launch configuration for kernel '{}'",
1285 kernel.name
1286 ));
1287 }
1288
1289 println!(
1291 "Launching CUDA kernel '{}' with grid {:?} and block {:?}",
1292 kernel.name, kernel.launch_config.grid_size, kernel.launch_config.block_size
1293 );
1294
1295 if args.is_empty() {
1297 return Err(format!(
1298 "No arguments provided for kernel '{}'",
1299 kernel.name
1300 ));
1301 }
1302
1303 Ok(())
1304}
1305
1306pub fn auto_tune_cuda_kernel(
1308 kernel: &mut RuntimeCudaKernel,
1309 input_sizes: &[usize],
1310) -> Result<CudaLaunchConfig, String> {
1311 let devices = query_cuda_devices()?;
1313 let device = devices
1314 .first()
1315 .ok_or("No CUDA devices available for auto-tuning")?;
1316
1317 let optimal_block_size = if input_sizes.iter().any(|&size| size > 10000) {
1319 device.max_threads_per_block.min(512)
1321 } else {
1322 device.max_threads_per_block.min(256)
1324 };
1325
1326 let total_elements = input_sizes.iter().max().copied().unwrap_or(1);
1328 let optimal_grid_size =
1329 (total_elements + optimal_block_size as usize - 1) / optimal_block_size as usize;
1330
1331 let clamped_grid_size = (optimal_grid_size as u32).min(device.max_grid_size[0]);
1333
1334 let optimized_config = CudaLaunchConfig {
1335 grid_size: [clamped_grid_size, 1, 1],
1336 block_size: [optimal_block_size, 1, 1],
1337 shared_memory_size: 0, stream: kernel.launch_config.stream,
1339 };
1340
1341 kernel.launch_config = optimized_config.clone();
1343
1344 Ok(optimized_config)
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350 use std::env;
1351
1352 #[test]
1353 fn test_cpp_extension_config() {
1354 let config = CppExtensionConfig::new("test_ext", vec![PathBuf::from("test.cpp")])
1355 .include_dir("/usr/include")
1356 .library("torsh")
1357 .extra_compile_arg("-std=c++17");
1358
1359 assert_eq!(config.name, "test_ext");
1360 assert_eq!(config.sources.len(), 1);
1361 assert_eq!(config.include_dirs.len(), 1);
1362 assert_eq!(config.libraries.len(), 1);
1363 }
1364
1365 #[test]
1366 fn test_generate_template() {
1367 let temp_dir = env::temp_dir().join("torsh_test_template");
1368 let result = generate_extension_template("test_op", &temp_dir);
1369
1370 assert!(result.is_ok());
1371 assert!(temp_dir.join("test_op.h").exists());
1372 assert!(temp_dir.join("test_op.cpp").exists());
1373 assert!(temp_dir.join("build.rs").exists());
1374
1375 let _ = fs::remove_dir_all(temp_dir);
1377 }
1378
1379 #[test]
1380 fn test_cuda_detection() {
1381 let available = cuda_is_available();
1383 println!("CUDA available: {}", available);
1384
1385 if available {
1386 let archs = get_cuda_arch_list();
1387 assert!(!archs.is_empty());
1388 }
1389 }
1390
1391 #[test]
1392 fn test_cuda_kernel_compilation_options() {
1393 let default_options = CudaKernelCompilationOptions::default();
1394 assert_eq!(default_options.optimization_level, 2);
1395 assert!(!default_options.fast_math);
1396 assert!(default_options.use_cache);
1397 assert!(!default_options.debug_info);
1398
1399 let custom_options = CudaKernelCompilationOptions {
1400 optimization_level: 3,
1401 fast_math: true,
1402 max_registers: Some(64),
1403 debug_info: true,
1404 target_arch: Some("sm_80".to_string()),
1405 ..Default::default()
1406 };
1407
1408 assert_eq!(custom_options.optimization_level, 3);
1409 assert!(custom_options.fast_math);
1410 assert_eq!(custom_options.max_registers, Some(64));
1411 assert!(custom_options.debug_info);
1412 }
1413
1414 #[test]
1415 fn test_cuda_kernel_syntax_validation() {
1416 let valid_kernel = r#"
1418 __global__ void test_kernel(float* input, float* output) {
1419 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1420 output[idx] = input[idx] * 2.0f;
1421 }
1422 "#;
1423
1424 assert!(validate_cuda_kernel_syntax(valid_kernel, "test_kernel").is_ok());
1425
1426 let invalid_kernel = r#"
1428 void test_kernel(float* input, float* output) {
1429 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1430 output[idx] = input[idx] * 2.0f;
1431 }
1432 "#;
1433
1434 assert!(validate_cuda_kernel_syntax(invalid_kernel, "test_kernel").is_err());
1435
1436 let invalid_braces = r#"
1438 __global__ void test_kernel(float* input, float* output) {
1439 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1440 output[idx] = input[idx] * 2.0f;
1441 // Missing closing brace
1442 "#;
1443
1444 assert!(validate_cuda_kernel_syntax(invalid_braces, "test_kernel").is_err());
1445 }
1446
1447 #[test]
1448 fn test_runtime_cuda_kernel_compilation() {
1449 let kernel_source = r#"
1450 __global__ void vector_add(float* a, float* b, float* c, int n) {
1451 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1452 if (idx < n) {
1453 c[idx] = a[idx] + b[idx];
1454 }
1455 }
1456 "#;
1457
1458 let options = CudaKernelCompilationOptions::default();
1459
1460 if cuda_is_available() {
1462 let result = compile_cuda_kernel_runtime(kernel_source, "vector_add", &options);
1463 if let Ok(kernel) = result {
1464 assert_eq!(kernel.name, "vector_add");
1465 assert!(!kernel.ptx_source.is_empty());
1466 assert!(kernel.module_handle.is_some());
1467 assert!(kernel.function_handle.is_some());
1468 }
1469 }
1470 }
1471
1472 #[test]
1473 fn test_cuda_launch_config_auto_tuning() {
1474 if cuda_is_available() {
1475 let kernel_source = r#"
1476 __global__ void simple_kernel(float* data) {
1477 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1478 data[idx] *= 2.0f;
1479 }
1480 "#;
1481
1482 let options = CudaKernelCompilationOptions::default();
1483 let result = compile_cuda_kernel_runtime(kernel_source, "simple_kernel", &options);
1484
1485 if let Ok(mut kernel) = result {
1486 let input_sizes = vec![1024, 2048, 4096];
1487 let tuned_config = auto_tune_cuda_kernel(&mut kernel, &input_sizes);
1488
1489 if let Ok(config) = tuned_config {
1490 assert!(config.grid_size[0] > 0);
1491 assert!(config.block_size[0] > 0);
1492 assert!(config.block_size[0] <= 1024); }
1494 }
1495 }
1496 }
1497
1498 #[test]
1499 fn test_custom_op_with_cuda_jit() {
1500 let custom_op = CustomOpDefinition {
1501 name: "custom_relu".to_string(),
1502 op_type: CustomOpType::Forward,
1503 input_shapes: vec![None], output_shapes: vec![None],
1505 cpu_source: Some("return torch::relu(inputs[0]);".to_string()),
1506 cuda_source: Some(
1507 r#"
1508 __global__ void relu_kernel(float* input, float* output, int size) {
1509 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1510 if (idx < size) {
1511 output[idx] = fmaxf(0.0f, input[idx]);
1512 }
1513 }
1514 "#
1515 .to_string(),
1516 ),
1517 compile_flags: vec!["-O3".to_string()],
1518 schema: OpSchema {
1519 input_types: vec![TensorType {
1520 dtype: "float32".to_string(),
1521 min_dims: 1,
1522 max_dims: None,
1523 supports_sparse: false,
1524 }],
1525 output_types: vec![TensorType {
1526 dtype: "float32".to_string(),
1527 min_dims: 1,
1528 max_dims: None,
1529 supports_sparse: false,
1530 }],
1531 is_elementwise: true,
1532 is_deterministic: true,
1533 memory_requirement: MemoryRequirement::Linear,
1534 },
1535 };
1536
1537 let config = CppExtensionConfig::new("custom_relu_ext", vec![])
1538 .enable_cuda_jit()
1539 .custom_op(custom_op);
1540
1541 assert!(config.jit_config.cuda_jit);
1542 assert_eq!(config.custom_ops.len(), 1);
1543 assert_eq!(config.custom_ops[0].name, "custom_relu");
1544 }
1545}