1use std::ffi::{CStr, CString};
28use std::os::raw::c_char;
29use std::ptr;
30
31use tensorlogic_compiler::CompilerContext;
32
33use crate::executor::{Backend, ExecutionConfig};
34use crate::optimize::OptimizationLevel;
35use crate::parser::parse_expression;
36
37#[repr(C)]
39pub struct TLGraphResult {
40 pub graph_data: *mut c_char,
42 pub error_message: *mut c_char,
44 pub tensor_count: usize,
46 pub node_count: usize,
48}
49
50#[repr(C)]
52pub struct TLExecutionResult {
53 pub output_data: *mut c_char,
55 pub error_message: *mut c_char,
57 pub execution_time_us: u64,
59}
60
61#[repr(C)]
63pub struct TLOptimizationResult {
64 pub graph_data: *mut c_char,
66 pub error_message: *mut c_char,
68 pub tensors_removed: usize,
70 pub nodes_removed: usize,
72}
73
74#[repr(C)]
76pub struct TLBenchmarkResult {
77 pub error_message: *mut c_char,
79 pub mean_us: f64,
81 pub std_dev_us: f64,
83 pub min_us: u64,
85 pub max_us: u64,
87 pub iterations: usize,
89}
90
91fn to_c_string(s: String) -> *mut c_char {
93 match CString::new(s) {
94 Ok(cstr) => cstr.into_raw(),
95 Err(_) => ptr::null_mut(),
96 }
97}
98
99unsafe fn from_c_string(s: *const c_char) -> Result<String, String> {
101 if s.is_null() {
102 return Err("NULL pointer passed".to_string());
103 }
104
105 CStr::from_ptr(s)
106 .to_str()
107 .map(|s| s.to_string())
108 .map_err(|e| format!("Invalid UTF-8 string: {}", e))
109}
110
111#[no_mangle]
122pub unsafe extern "C" fn tl_compile_expr(expr: *const c_char) -> *mut TLGraphResult {
123 let result = Box::new(TLGraphResult {
124 graph_data: ptr::null_mut(),
125 error_message: ptr::null_mut(),
126 tensor_count: 0,
127 node_count: 0,
128 });
129
130 let expr_str = match from_c_string(expr) {
132 Ok(s) => s,
133 Err(e) => {
134 let mut result = result;
135 result.error_message = to_c_string(format!("Invalid expression: {}", e));
136 return Box::into_raw(result);
137 }
138 };
139
140 let tlexpr = match parse_expression(&expr_str) {
142 Ok(e) => e,
143 Err(e) => {
144 let mut result = result;
145 result.error_message = to_c_string(format!("Parse error: {}", e));
146 return Box::into_raw(result);
147 }
148 };
149
150 let mut context = CompilerContext::new();
152
153 let graph = match tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context) {
154 Ok(g) => g,
155 Err(e) => {
156 let mut result = result;
157 result.error_message = to_c_string(format!("Compilation error: {:?}", e));
158 return Box::into_raw(result);
159 }
160 };
161
162 let json = match serde_json::to_string_pretty(&graph) {
164 Ok(j) => j,
165 Err(e) => {
166 let mut result = result;
167 result.error_message = to_c_string(format!("Serialization error: {}", e));
168 return Box::into_raw(result);
169 }
170 };
171
172 let mut result = result;
173 result.graph_data = to_c_string(json);
174 result.tensor_count = graph.tensors.len();
175 result.node_count = graph.nodes.len();
176
177 Box::into_raw(result)
178}
179
180#[no_mangle]
192pub unsafe extern "C" fn tl_execute_graph(
193 graph_json: *const c_char,
194 backend: *const c_char,
195) -> *mut TLExecutionResult {
196 let result = Box::new(TLExecutionResult {
197 output_data: ptr::null_mut(),
198 error_message: ptr::null_mut(),
199 execution_time_us: 0,
200 });
201
202 let json_str = match from_c_string(graph_json) {
204 Ok(s) => s,
205 Err(e) => {
206 let mut result = result;
207 result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
208 return Box::into_raw(result);
209 }
210 };
211
212 let backend_str = match from_c_string(backend) {
213 Ok(s) => s,
214 Err(e) => {
215 let mut result = result;
216 result.error_message = to_c_string(format!("Invalid backend: {}", e));
217 return Box::into_raw(result);
218 }
219 };
220
221 let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
223 Ok(g) => g,
224 Err(e) => {
225 let mut result = result;
226 result.error_message = to_c_string(format!("JSON parse error: {}", e));
227 return Box::into_raw(result);
228 }
229 };
230
231 let backend_enum = match Backend::from_str(&backend_str) {
233 Ok(b) => b,
234 Err(e) => {
235 let mut result = result;
236 result.error_message = to_c_string(format!("Unknown backend: {}", e));
237 return Box::into_raw(result);
238 }
239 };
240
241 let config = ExecutionConfig {
243 backend: backend_enum,
244 device: tensorlogic_scirs_backend::DeviceType::Cpu,
245 show_metrics: false,
246 show_intermediates: false,
247 validate_shapes: true,
248 trace: false,
249 };
250
251 use crate::executor::CliExecutor;
252 let executor = match CliExecutor::new(config) {
253 Ok(e) => e,
254 Err(e) => {
255 let mut result = result;
256 result.error_message = to_c_string(format!("Executor creation error: {}", e));
257 return Box::into_raw(result);
258 }
259 };
260
261 let exec_result = match executor.execute(&graph) {
262 Ok(r) => r,
263 Err(e) => {
264 let mut result = result;
265 result.error_message = to_c_string(format!("Execution error: {}", e));
266 return Box::into_raw(result);
267 }
268 };
269
270 let output_json = match serde_json::to_string_pretty(&exec_result.output) {
272 Ok(j) => j,
273 Err(e) => {
274 let mut result = result;
275 result.error_message = to_c_string(format!("Serialization error: {}", e));
276 return Box::into_raw(result);
277 }
278 };
279
280 let mut result = result;
281 result.output_data = to_c_string(output_json);
282 result.execution_time_us = (exec_result.execution_time_ms * 1000.0) as u64;
283
284 Box::into_raw(result)
285}
286
287#[no_mangle]
299pub unsafe extern "C" fn tl_optimize_graph(
300 graph_json: *const c_char,
301 level: i32,
302) -> *mut TLOptimizationResult {
303 let result = Box::new(TLOptimizationResult {
304 graph_data: ptr::null_mut(),
305 error_message: ptr::null_mut(),
306 tensors_removed: 0,
307 nodes_removed: 0,
308 });
309
310 let json_str = match from_c_string(graph_json) {
312 Ok(s) => s,
313 Err(e) => {
314 let mut result = result;
315 result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
316 return Box::into_raw(result);
317 }
318 };
319
320 let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
322 Ok(g) => g,
323 Err(e) => {
324 let mut result = result;
325 result.error_message = to_c_string(format!("JSON parse error: {}", e));
326 return Box::into_raw(result);
327 }
328 };
329
330 let opt_level = match level {
332 0 => OptimizationLevel::None,
333 1 => OptimizationLevel::Basic,
334 2 => OptimizationLevel::Standard,
335 3 => OptimizationLevel::Aggressive,
336 _ => {
337 let mut result = result;
338 result.error_message = to_c_string(format!("Invalid optimization level: {}", level));
339 return Box::into_raw(result);
340 }
341 };
342
343 use crate::optimize::OptimizationConfig;
345 let config = OptimizationConfig {
346 level: opt_level,
347 enable_dce: true,
348 enable_cse: true,
349 enable_identity: true,
350 show_stats: false,
351 verbose: false,
352 };
353
354 let initial_nodes = graph.nodes.len();
355 let initial_tensors = graph.tensors.len();
356
357 let (optimized, _stats) = match crate::optimize::optimize_einsum_graph(graph, &config) {
358 Ok(r) => r,
359 Err(e) => {
360 let mut result = result;
361 result.error_message = to_c_string(format!("Optimization error: {}", e));
362 return Box::into_raw(result);
363 }
364 };
365
366 let output_json = match serde_json::to_string_pretty(&optimized) {
368 Ok(j) => j,
369 Err(e) => {
370 let mut result = result;
371 result.error_message = to_c_string(format!("Serialization error: {}", e));
372 return Box::into_raw(result);
373 }
374 };
375
376 let mut result = result;
377 result.graph_data = to_c_string(output_json);
378 result.tensors_removed = initial_tensors.saturating_sub(optimized.tensors.len());
379 result.nodes_removed = initial_nodes.saturating_sub(optimized.nodes.len());
380
381 Box::into_raw(result)
382}
383
384#[no_mangle]
396pub unsafe extern "C" fn tl_benchmark_compilation(
397 expr: *const c_char,
398 iterations: usize,
399) -> *mut TLBenchmarkResult {
400 let result = Box::new(TLBenchmarkResult {
401 error_message: ptr::null_mut(),
402 mean_us: 0.0,
403 std_dev_us: 0.0,
404 min_us: 0,
405 max_us: 0,
406 iterations: 0,
407 });
408
409 let expr_str = match from_c_string(expr) {
411 Ok(s) => s,
412 Err(e) => {
413 let mut result = result;
414 result.error_message = to_c_string(format!("Invalid expression: {}", e));
415 return Box::into_raw(result);
416 }
417 };
418
419 let tlexpr = match parse_expression(&expr_str) {
421 Ok(e) => e,
422 Err(e) => {
423 let mut result = result;
424 result.error_message = to_c_string(format!("Parse error: {}", e));
425 return Box::into_raw(result);
426 }
427 };
428
429 let mut timings = Vec::with_capacity(iterations);
431 for _ in 0..iterations {
432 let mut context = CompilerContext::new();
433 let start = std::time::Instant::now();
434 if tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context).is_ok() {
435 timings.push(start.elapsed());
436 } else {
437 let mut result = result;
438 result.error_message = to_c_string("Compilation failed during benchmark".to_string());
439 return Box::into_raw(result);
440 }
441 }
442
443 let mut sum_us = 0u64;
445 let mut min_us = u64::MAX;
446 let mut max_us = 0u64;
447
448 for timing in &timings {
449 let us = timing.as_micros() as u64;
450 sum_us += us;
451 min_us = min_us.min(us);
452 max_us = max_us.max(us);
453 }
454
455 let mean_us = sum_us as f64 / iterations as f64;
456
457 let mut variance_sum = 0.0;
459 for timing in &timings {
460 let us = timing.as_micros() as f64;
461 let diff = us - mean_us;
462 variance_sum += diff * diff;
463 }
464 let std_dev_us = (variance_sum / iterations as f64).sqrt();
465
466 let mut result = result;
467 result.mean_us = mean_us;
468 result.std_dev_us = std_dev_us;
469 result.min_us = min_us;
470 result.max_us = max_us;
471 result.iterations = iterations;
472
473 Box::into_raw(result)
474}
475
476#[no_mangle]
481pub unsafe extern "C" fn tl_free_string(s: *mut c_char) {
482 if !s.is_null() {
483 drop(CString::from_raw(s));
484 }
485}
486
487#[no_mangle]
492pub unsafe extern "C" fn tl_free_graph_result(result: *mut TLGraphResult) {
493 if !result.is_null() {
494 let result = Box::from_raw(result);
495 if !result.graph_data.is_null() {
496 tl_free_string(result.graph_data);
497 }
498 if !result.error_message.is_null() {
499 tl_free_string(result.error_message);
500 }
501 }
502}
503
504#[no_mangle]
509pub unsafe extern "C" fn tl_free_execution_result(result: *mut TLExecutionResult) {
510 if !result.is_null() {
511 let result = Box::from_raw(result);
512 if !result.output_data.is_null() {
513 tl_free_string(result.output_data);
514 }
515 if !result.error_message.is_null() {
516 tl_free_string(result.error_message);
517 }
518 }
519}
520
521#[no_mangle]
526pub unsafe extern "C" fn tl_free_optimization_result(result: *mut TLOptimizationResult) {
527 if !result.is_null() {
528 let result = Box::from_raw(result);
529 if !result.graph_data.is_null() {
530 tl_free_string(result.graph_data);
531 }
532 if !result.error_message.is_null() {
533 tl_free_string(result.error_message);
534 }
535 }
536}
537
538#[no_mangle]
543pub unsafe extern "C" fn tl_free_benchmark_result(result: *mut TLBenchmarkResult) {
544 if !result.is_null() {
545 let result = Box::from_raw(result);
546 if !result.error_message.is_null() {
547 tl_free_string(result.error_message);
548 }
549 }
550}
551
552#[no_mangle]
557pub extern "C" fn tl_version() -> *mut c_char {
558 to_c_string(env!("CARGO_PKG_VERSION").to_string())
559}
560
561#[no_mangle]
572pub unsafe extern "C" fn tl_is_backend_available(backend: *const c_char) -> i32 {
573 let backend_str = match from_c_string(backend) {
574 Ok(s) => s,
575 Err(_) => return 0,
576 };
577
578 match Backend::from_str(&backend_str) {
579 Ok(b) => {
580 if b.is_available() {
581 1
582 } else {
583 0
584 }
585 }
586 Err(_) => 0,
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use std::ffi::CString;
594
595 #[test]
596 fn test_compile_expr_success() {
597 let expr = CString::new("AND(pred1(x), pred2(x, y))").unwrap();
599
600 unsafe {
601 let result = tl_compile_expr(expr.as_ptr());
602 assert!(!result.is_null());
603 let result = Box::from_raw(result);
604
605 if !result.error_message.is_null() {
607 let err = CStr::from_ptr(result.error_message).to_str().unwrap();
608 println!("Compilation error: {}", err);
609 }
610 if !result.graph_data.is_null() {
611 let graph = CStr::from_ptr(result.graph_data).to_str().unwrap();
612 println!("Graph: {}", &graph[..graph.len().min(200)]);
613 println!(
614 "Tensors: {}, Nodes: {}",
615 result.tensor_count, result.node_count
616 );
617 }
618
619 assert!(result.error_message.is_null(), "Compilation should succeed");
620 assert!(!result.graph_data.is_null());
621 assert!(result.tensor_count > 0, "Should have at least one tensor");
622 if !result.graph_data.is_null() {
626 tl_free_string(result.graph_data);
627 }
628 if !result.error_message.is_null() {
629 tl_free_string(result.error_message);
630 }
631 }
632 }
633
634 #[test]
635 fn test_compile_expr_invalid_syntax() {
636 let expr = CString::new("AND(pred1(x), )").unwrap();
638
639 unsafe {
640 let result = tl_compile_expr(expr.as_ptr());
641 assert!(!result.is_null());
642 let result = Box::from_raw(result);
643
644 if !result.error_message.is_null() {
650 tl_free_string(result.error_message);
651 }
652 if !result.graph_data.is_null() {
653 tl_free_string(result.graph_data);
654 }
655 }
656 }
657
658 #[test]
659 fn test_compile_expr_with_error() {
660 let expr = CString::new("\"unclosed_string").unwrap();
662
663 unsafe {
664 let result = tl_compile_expr(expr.as_ptr());
665 assert!(!result.is_null());
666 let result = Box::from_raw(result);
667
668 if !result.error_message.is_null() {
673 tl_free_string(result.error_message);
674 }
675 if !result.graph_data.is_null() {
676 tl_free_string(result.graph_data);
677 }
678 }
679 }
680
681 #[test]
682 fn test_version() {
683 unsafe {
684 let version = tl_version();
685 assert!(!version.is_null());
686 let version_str = CStr::from_ptr(version).to_str().unwrap();
687 assert!(!version_str.is_empty());
688 tl_free_string(version);
689 }
690 }
691
692 #[test]
693 fn test_backend_availability() {
694 let cpu = CString::new("cpu").unwrap();
695 unsafe {
696 assert_eq!(tl_is_backend_available(cpu.as_ptr()), 1);
697 }
698
699 let invalid = CString::new("invalid_backend").unwrap();
700 unsafe {
701 assert_eq!(tl_is_backend_available(invalid.as_ptr()), 0);
702 }
703 }
704}