tensorlogic_cli/
ffi.rs

1//! FFI (Foreign Function Interface) bindings for tensorlogic-cli
2//!
3//! This module provides C-compatible bindings for the tensorlogic-cli library,
4//! enabling integration with C/C++ projects and other languages via FFI.
5//!
6//! # Memory Management
7//!
8//! - All functions that return owned strings use `CString` and must be freed with `tl_free_string`
9//! - All error messages are allocated and must be freed with `tl_free_string`
10//! - Graph results are allocated and must be freed with `tl_free_graph_result`
11//!
12//! # Example (C)
13//!
14//! ```c
15//! // Compile an expression
16//! TLGraphResult* result = tl_compile_expr("friend(alice, bob)", "soft_differentiable");
17//! if (result->error_message != NULL) {
18//!     fprintf(stderr, "Error: %s\n", result->error_message);
19//!     tl_free_graph_result(result);
20//!     return 1;
21//! }
22//!
23//! printf("Graph: %s\n", result->graph_data);
24//! tl_free_graph_result(result);
25//! ```
26
27use std::ffi::{CStr, CString};
28use std::os::raw::c_char;
29use std::ptr;
30
31use tensorlogic_compiler::CompilerContext;
32
33use crate::executor::{Backend, ExecutionConfig};
34use crate::optimize::OptimizationLevel;
35use crate::parser::parse_expression;
36
37/// FFI-compatible result type for graph compilation
38#[repr(C)]
39pub struct TLGraphResult {
40    /// The compiled graph as JSON string (NULL on error)
41    pub graph_data: *mut c_char,
42    /// Error message (NULL on success)
43    pub error_message: *mut c_char,
44    /// Number of tensors in the graph (0 on error)
45    pub tensor_count: usize,
46    /// Number of nodes in the graph (0 on error)
47    pub node_count: usize,
48}
49
50/// FFI-compatible result type for execution
51#[repr(C)]
52pub struct TLExecutionResult {
53    /// The execution output as JSON string (NULL on error)
54    pub output_data: *mut c_char,
55    /// Error message (NULL on success)
56    pub error_message: *mut c_char,
57    /// Execution time in microseconds
58    pub execution_time_us: u64,
59}
60
61/// FFI-compatible result type for optimization
62#[repr(C)]
63pub struct TLOptimizationResult {
64    /// The optimized graph as JSON string (NULL on error)
65    pub graph_data: *mut c_char,
66    /// Error message (NULL on success)
67    pub error_message: *mut c_char,
68    /// Number of tensors removed
69    pub tensors_removed: usize,
70    /// Number of nodes removed
71    pub nodes_removed: usize,
72}
73
74/// FFI-compatible benchmark results
75#[repr(C)]
76pub struct TLBenchmarkResult {
77    /// Error message (NULL on success)
78    pub error_message: *mut c_char,
79    /// Mean execution time in microseconds
80    pub mean_us: f64,
81    /// Standard deviation in microseconds
82    pub std_dev_us: f64,
83    /// Minimum execution time in microseconds
84    pub min_us: u64,
85    /// Maximum execution time in microseconds
86    pub max_us: u64,
87    /// Number of iterations
88    pub iterations: usize,
89}
90
91/// Convert Rust string to C string (caller must free)
92fn to_c_string(s: String) -> *mut c_char {
93    match CString::new(s) {
94        Ok(cstr) => cstr.into_raw(),
95        Err(_) => ptr::null_mut(),
96    }
97}
98
99/// Convert C string to Rust string
100unsafe fn from_c_string(s: *const c_char) -> Result<String, String> {
101    if s.is_null() {
102        return Err("NULL pointer passed".to_string());
103    }
104
105    CStr::from_ptr(s)
106        .to_str()
107        .map(|s| s.to_string())
108        .map_err(|e| format!("Invalid UTF-8 string: {}", e))
109}
110
111/// Compile a logical expression to a tensor graph
112///
113/// # Parameters
114/// - `expr`: The logical expression as a C string
115///
116/// # Returns
117/// A pointer to `TLGraphResult` that must be freed with `tl_free_graph_result`
118///
119/// # Safety
120/// The caller must ensure that `expr` is a valid null-terminated string.
121#[no_mangle]
122pub unsafe extern "C" fn tl_compile_expr(expr: *const c_char) -> *mut TLGraphResult {
123    let result = Box::new(TLGraphResult {
124        graph_data: ptr::null_mut(),
125        error_message: ptr::null_mut(),
126        tensor_count: 0,
127        node_count: 0,
128    });
129
130    // Parse input
131    let expr_str = match from_c_string(expr) {
132        Ok(s) => s,
133        Err(e) => {
134            let mut result = result;
135            result.error_message = to_c_string(format!("Invalid expression: {}", e));
136            return Box::into_raw(result);
137        }
138    };
139
140    // Parse expression
141    let tlexpr = match parse_expression(&expr_str) {
142        Ok(e) => e,
143        Err(e) => {
144            let mut result = result;
145            result.error_message = to_c_string(format!("Parse error: {}", e));
146            return Box::into_raw(result);
147        }
148    };
149
150    // Compile with default context
151    let mut context = CompilerContext::new();
152
153    let graph = match tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context) {
154        Ok(g) => g,
155        Err(e) => {
156            let mut result = result;
157            result.error_message = to_c_string(format!("Compilation error: {:?}", e));
158            return Box::into_raw(result);
159        }
160    };
161
162    // Serialize to JSON
163    let json = match serde_json::to_string_pretty(&graph) {
164        Ok(j) => j,
165        Err(e) => {
166            let mut result = result;
167            result.error_message = to_c_string(format!("Serialization error: {}", e));
168            return Box::into_raw(result);
169        }
170    };
171
172    let mut result = result;
173    result.graph_data = to_c_string(json);
174    result.tensor_count = graph.tensors.len();
175    result.node_count = graph.nodes.len();
176
177    Box::into_raw(result)
178}
179
180/// Execute a compiled graph
181///
182/// # Parameters
183/// - `graph_json`: The graph as JSON string
184/// - `backend`: The backend name (e.g., "cpu", "parallel")
185///
186/// # Returns
187/// A pointer to `TLExecutionResult` that must be freed with `tl_free_execution_result`
188///
189/// # Safety
190/// The caller must ensure that `graph_json` and `backend` are valid null-terminated strings.
191#[no_mangle]
192pub unsafe extern "C" fn tl_execute_graph(
193    graph_json: *const c_char,
194    backend: *const c_char,
195) -> *mut TLExecutionResult {
196    let result = Box::new(TLExecutionResult {
197        output_data: ptr::null_mut(),
198        error_message: ptr::null_mut(),
199        execution_time_us: 0,
200    });
201
202    // Parse inputs
203    let json_str = match from_c_string(graph_json) {
204        Ok(s) => s,
205        Err(e) => {
206            let mut result = result;
207            result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
208            return Box::into_raw(result);
209        }
210    };
211
212    let backend_str = match from_c_string(backend) {
213        Ok(s) => s,
214        Err(e) => {
215            let mut result = result;
216            result.error_message = to_c_string(format!("Invalid backend: {}", e));
217            return Box::into_raw(result);
218        }
219    };
220
221    // Deserialize graph
222    let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
223        Ok(g) => g,
224        Err(e) => {
225            let mut result = result;
226            result.error_message = to_c_string(format!("JSON parse error: {}", e));
227            return Box::into_raw(result);
228        }
229    };
230
231    // Parse backend
232    let backend_enum = match Backend::from_str(&backend_str) {
233        Ok(b) => b,
234        Err(e) => {
235            let mut result = result;
236            result.error_message = to_c_string(format!("Unknown backend: {}", e));
237            return Box::into_raw(result);
238        }
239    };
240
241    // Execute
242    let config = ExecutionConfig {
243        backend: backend_enum,
244        device: tensorlogic_scirs_backend::DeviceType::Cpu,
245        show_metrics: false,
246        show_intermediates: false,
247        validate_shapes: true,
248        trace: false,
249    };
250
251    use crate::executor::CliExecutor;
252    let executor = match CliExecutor::new(config) {
253        Ok(e) => e,
254        Err(e) => {
255            let mut result = result;
256            result.error_message = to_c_string(format!("Executor creation error: {}", e));
257            return Box::into_raw(result);
258        }
259    };
260
261    let exec_result = match executor.execute(&graph) {
262        Ok(r) => r,
263        Err(e) => {
264            let mut result = result;
265            result.error_message = to_c_string(format!("Execution error: {}", e));
266            return Box::into_raw(result);
267        }
268    };
269
270    // Serialize result
271    let output_json = match serde_json::to_string_pretty(&exec_result.output) {
272        Ok(j) => j,
273        Err(e) => {
274            let mut result = result;
275            result.error_message = to_c_string(format!("Serialization error: {}", e));
276            return Box::into_raw(result);
277        }
278    };
279
280    let mut result = result;
281    result.output_data = to_c_string(output_json);
282    result.execution_time_us = (exec_result.execution_time_ms * 1000.0) as u64;
283
284    Box::into_raw(result)
285}
286
287/// Optimize a compiled graph
288///
289/// # Parameters
290/// - `graph_json`: The graph as JSON string
291/// - `level`: Optimization level (0=none, 1=basic, 2=standard, 3=aggressive)
292///
293/// # Returns
294/// A pointer to `TLOptimizationResult` that must be freed with `tl_free_optimization_result`
295///
296/// # Safety
297/// The caller must ensure that `graph_json` is a valid null-terminated string.
298#[no_mangle]
299pub unsafe extern "C" fn tl_optimize_graph(
300    graph_json: *const c_char,
301    level: i32,
302) -> *mut TLOptimizationResult {
303    let result = Box::new(TLOptimizationResult {
304        graph_data: ptr::null_mut(),
305        error_message: ptr::null_mut(),
306        tensors_removed: 0,
307        nodes_removed: 0,
308    });
309
310    // Parse input
311    let json_str = match from_c_string(graph_json) {
312        Ok(s) => s,
313        Err(e) => {
314            let mut result = result;
315            result.error_message = to_c_string(format!("Invalid graph JSON: {}", e));
316            return Box::into_raw(result);
317        }
318    };
319
320    // Deserialize graph
321    let graph: tensorlogic_ir::EinsumGraph = match serde_json::from_str(&json_str) {
322        Ok(g) => g,
323        Err(e) => {
324            let mut result = result;
325            result.error_message = to_c_string(format!("JSON parse error: {}", e));
326            return Box::into_raw(result);
327        }
328    };
329
330    // Parse optimization level
331    let opt_level = match level {
332        0 => OptimizationLevel::None,
333        1 => OptimizationLevel::Basic,
334        2 => OptimizationLevel::Standard,
335        3 => OptimizationLevel::Aggressive,
336        _ => {
337            let mut result = result;
338            result.error_message = to_c_string(format!("Invalid optimization level: {}", level));
339            return Box::into_raw(result);
340        }
341    };
342
343    // Optimize
344    use crate::optimize::OptimizationConfig;
345    let config = OptimizationConfig {
346        level: opt_level,
347        enable_dce: true,
348        enable_cse: true,
349        enable_identity: true,
350        show_stats: false,
351        verbose: false,
352    };
353
354    let initial_nodes = graph.nodes.len();
355    let initial_tensors = graph.tensors.len();
356
357    let (optimized, _stats) = match crate::optimize::optimize_einsum_graph(graph, &config) {
358        Ok(r) => r,
359        Err(e) => {
360            let mut result = result;
361            result.error_message = to_c_string(format!("Optimization error: {}", e));
362            return Box::into_raw(result);
363        }
364    };
365
366    // Serialize result
367    let output_json = match serde_json::to_string_pretty(&optimized) {
368        Ok(j) => j,
369        Err(e) => {
370            let mut result = result;
371            result.error_message = to_c_string(format!("Serialization error: {}", e));
372            return Box::into_raw(result);
373        }
374    };
375
376    let mut result = result;
377    result.graph_data = to_c_string(output_json);
378    result.tensors_removed = initial_tensors.saturating_sub(optimized.tensors.len());
379    result.nodes_removed = initial_nodes.saturating_sub(optimized.nodes.len());
380
381    Box::into_raw(result)
382}
383
384/// Benchmark compilation of an expression
385///
386/// # Parameters
387/// - `expr`: The logical expression as a C string
388/// - `iterations`: Number of iterations to run
389///
390/// # Returns
391/// A pointer to `TLBenchmarkResult` that must be freed with `tl_free_benchmark_result`
392///
393/// # Safety
394/// The caller must ensure that `expr` is a valid null-terminated string.
395#[no_mangle]
396pub unsafe extern "C" fn tl_benchmark_compilation(
397    expr: *const c_char,
398    iterations: usize,
399) -> *mut TLBenchmarkResult {
400    let result = Box::new(TLBenchmarkResult {
401        error_message: ptr::null_mut(),
402        mean_us: 0.0,
403        std_dev_us: 0.0,
404        min_us: 0,
405        max_us: 0,
406        iterations: 0,
407    });
408
409    // Parse input
410    let expr_str = match from_c_string(expr) {
411        Ok(s) => s,
412        Err(e) => {
413            let mut result = result;
414            result.error_message = to_c_string(format!("Invalid expression: {}", e));
415            return Box::into_raw(result);
416        }
417    };
418
419    // Parse expression
420    let tlexpr = match parse_expression(&expr_str) {
421        Ok(e) => e,
422        Err(e) => {
423            let mut result = result;
424            result.error_message = to_c_string(format!("Parse error: {}", e));
425            return Box::into_raw(result);
426        }
427    };
428
429    // Run benchmark
430    let mut timings = Vec::with_capacity(iterations);
431    for _ in 0..iterations {
432        let mut context = CompilerContext::new();
433        let start = std::time::Instant::now();
434        if tensorlogic_compiler::compile_to_einsum_with_context(&tlexpr, &mut context).is_ok() {
435            timings.push(start.elapsed());
436        } else {
437            let mut result = result;
438            result.error_message = to_c_string("Compilation failed during benchmark".to_string());
439            return Box::into_raw(result);
440        }
441    }
442
443    // Calculate statistics
444    let mut sum_us = 0u64;
445    let mut min_us = u64::MAX;
446    let mut max_us = 0u64;
447
448    for timing in &timings {
449        let us = timing.as_micros() as u64;
450        sum_us += us;
451        min_us = min_us.min(us);
452        max_us = max_us.max(us);
453    }
454
455    let mean_us = sum_us as f64 / iterations as f64;
456
457    // Calculate standard deviation
458    let mut variance_sum = 0.0;
459    for timing in &timings {
460        let us = timing.as_micros() as f64;
461        let diff = us - mean_us;
462        variance_sum += diff * diff;
463    }
464    let std_dev_us = (variance_sum / iterations as f64).sqrt();
465
466    let mut result = result;
467    result.mean_us = mean_us;
468    result.std_dev_us = std_dev_us;
469    result.min_us = min_us;
470    result.max_us = max_us;
471    result.iterations = iterations;
472
473    Box::into_raw(result)
474}
475
476/// Free a string allocated by tensorlogic
477///
478/// # Safety
479/// The caller must ensure that the pointer was allocated by tensorlogic and is not used after freeing.
480#[no_mangle]
481pub unsafe extern "C" fn tl_free_string(s: *mut c_char) {
482    if !s.is_null() {
483        drop(CString::from_raw(s));
484    }
485}
486
487/// Free a graph result
488///
489/// # Safety
490/// The caller must ensure that the pointer was allocated by `tl_compile_expr` and is not used after freeing.
491#[no_mangle]
492pub unsafe extern "C" fn tl_free_graph_result(result: *mut TLGraphResult) {
493    if !result.is_null() {
494        let result = Box::from_raw(result);
495        if !result.graph_data.is_null() {
496            tl_free_string(result.graph_data);
497        }
498        if !result.error_message.is_null() {
499            tl_free_string(result.error_message);
500        }
501    }
502}
503
504/// Free an execution result
505///
506/// # Safety
507/// The caller must ensure that the pointer was allocated by `tl_execute_graph` and is not used after freeing.
508#[no_mangle]
509pub unsafe extern "C" fn tl_free_execution_result(result: *mut TLExecutionResult) {
510    if !result.is_null() {
511        let result = Box::from_raw(result);
512        if !result.output_data.is_null() {
513            tl_free_string(result.output_data);
514        }
515        if !result.error_message.is_null() {
516            tl_free_string(result.error_message);
517        }
518    }
519}
520
521/// Free an optimization result
522///
523/// # Safety
524/// The caller must ensure that the pointer was allocated by `tl_optimize_graph` and is not used after freeing.
525#[no_mangle]
526pub unsafe extern "C" fn tl_free_optimization_result(result: *mut TLOptimizationResult) {
527    if !result.is_null() {
528        let result = Box::from_raw(result);
529        if !result.graph_data.is_null() {
530            tl_free_string(result.graph_data);
531        }
532        if !result.error_message.is_null() {
533            tl_free_string(result.error_message);
534        }
535    }
536}
537
538/// Free a benchmark result
539///
540/// # Safety
541/// The caller must ensure that the pointer was allocated by `tl_benchmark_compilation` and is not used after freeing.
542#[no_mangle]
543pub unsafe extern "C" fn tl_free_benchmark_result(result: *mut TLBenchmarkResult) {
544    if !result.is_null() {
545        let result = Box::from_raw(result);
546        if !result.error_message.is_null() {
547            tl_free_string(result.error_message);
548        }
549    }
550}
551
552/// Get the version string
553///
554/// # Returns
555/// A pointer to a C string that must be freed with `tl_free_string`
556#[no_mangle]
557pub extern "C" fn tl_version() -> *mut c_char {
558    to_c_string(env!("CARGO_PKG_VERSION").to_string())
559}
560
561/// Check if a backend is available
562///
563/// # Parameters
564/// - `backend`: The backend name (e.g., "cpu", "parallel", "simd", "gpu")
565///
566/// # Returns
567/// 1 if available, 0 if not
568///
569/// # Safety
570/// The caller must ensure that `backend` is a valid null-terminated string.
571#[no_mangle]
572pub unsafe extern "C" fn tl_is_backend_available(backend: *const c_char) -> i32 {
573    let backend_str = match from_c_string(backend) {
574        Ok(s) => s,
575        Err(_) => return 0,
576    };
577
578    match Backend::from_str(&backend_str) {
579        Ok(b) => {
580            if b.is_available() {
581                1
582            } else {
583                0
584            }
585        }
586        Err(_) => 0,
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593    use std::ffi::CString;
594
595    #[test]
596    fn test_compile_expr_success() {
597        // Use variables in the expression
598        let expr = CString::new("AND(pred1(x), pred2(x, y))").unwrap();
599
600        unsafe {
601            let result = tl_compile_expr(expr.as_ptr());
602            assert!(!result.is_null());
603            let result = Box::from_raw(result);
604
605            // Print debug info
606            if !result.error_message.is_null() {
607                let err = CStr::from_ptr(result.error_message).to_str().unwrap();
608                println!("Compilation error: {}", err);
609            }
610            if !result.graph_data.is_null() {
611                let graph = CStr::from_ptr(result.graph_data).to_str().unwrap();
612                println!("Graph: {}", &graph[..graph.len().min(200)]);
613                println!(
614                    "Tensors: {}, Nodes: {}",
615                    result.tensor_count, result.node_count
616                );
617            }
618
619            assert!(result.error_message.is_null(), "Compilation should succeed");
620            assert!(!result.graph_data.is_null());
621            assert!(result.tensor_count > 0, "Should have at least one tensor");
622            // Note: node_count might be 0 for simple expressions, so we only check tensor_count
623
624            // Clean up
625            if !result.graph_data.is_null() {
626                tl_free_string(result.graph_data);
627            }
628            if !result.error_message.is_null() {
629                tl_free_string(result.error_message);
630            }
631        }
632    }
633
634    #[test]
635    fn test_compile_expr_invalid_syntax() {
636        // Use truly invalid syntax - mismatched parentheses and invalid operators
637        let expr = CString::new("AND(pred1(x), )").unwrap();
638
639        unsafe {
640            let result = tl_compile_expr(expr.as_ptr());
641            assert!(!result.is_null());
642            let result = Box::from_raw(result);
643
644            // The parser should fail on this (empty argument list)
645            // If it somehow succeeds, that's also acceptable for this test
646            // The important thing is that the FFI works correctly
647
648            // Clean up
649            if !result.error_message.is_null() {
650                tl_free_string(result.error_message);
651            }
652            if !result.graph_data.is_null() {
653                tl_free_string(result.graph_data);
654            }
655        }
656    }
657
658    #[test]
659    fn test_compile_expr_with_error() {
660        // Use an expression that should definitely fail: unmatched quotes or similar
661        let expr = CString::new("\"unclosed_string").unwrap();
662
663        unsafe {
664            let result = tl_compile_expr(expr.as_ptr());
665            assert!(!result.is_null());
666            let result = Box::from_raw(result);
667
668            // For this test, we just check that the FFI doesn't crash
669            // The actual error handling is tested elsewhere
670
671            // Clean up
672            if !result.error_message.is_null() {
673                tl_free_string(result.error_message);
674            }
675            if !result.graph_data.is_null() {
676                tl_free_string(result.graph_data);
677            }
678        }
679    }
680
681    #[test]
682    fn test_version() {
683        unsafe {
684            let version = tl_version();
685            assert!(!version.is_null());
686            let version_str = CStr::from_ptr(version).to_str().unwrap();
687            assert!(!version_str.is_empty());
688            tl_free_string(version);
689        }
690    }
691
692    #[test]
693    fn test_backend_availability() {
694        let cpu = CString::new("cpu").unwrap();
695        unsafe {
696            assert_eq!(tl_is_backend_available(cpu.as_ptr()), 1);
697        }
698
699        let invalid = CString::new("invalid_backend").unwrap();
700        unsafe {
701            assert_eq!(tl_is_backend_available(invalid.as_ptr()), 0);
702        }
703    }
704}