Skip to main content

simple_agents_ffi/
lib.rs

1//! C-compatible FFI bindings for SimpleAgents.
2
3use serde::Serialize;
4use serde_json::Value as JsonValue;
5use simple_agent_type::coercion::{CoercionFlag, CoercionResult};
6use simple_agent_type::message::{Message, Role};
7use simple_agent_type::prelude::{ApiKey, CompletionRequest, Provider, Result, SimpleAgentsError};
8use simple_agent_type::response::{CompletionResponse, FinishReason, Usage};
9use simple_agent_type::tool::{ToolCall, ToolType};
10use simple_agents_core::{
11    CompletionMode, CompletionOptions, CompletionOutcome, HealedJsonResponse, HealedSchemaResponse,
12    SimpleAgentsClient, SimpleAgentsClientBuilder,
13};
14use simple_agents_healing::schema::{Field as SchemaField, ObjectSchema, Schema, StreamAnnotation};
15use simple_agents_providers::anthropic::AnthropicProvider;
16use simple_agents_providers::openai::OpenAIProvider;
17use simple_agents_providers::openrouter::OpenRouterProvider;
18use std::cell::RefCell;
19use std::ffi::{CStr, CString};
20use std::os::raw::c_char;
21use std::panic::{catch_unwind, AssertUnwindSafe};
22use std::sync::{Arc, Mutex};
23
24// Keep runtime ownership in the FFI layer so each client is self-contained.
25type Runtime = tokio::runtime::Runtime;
26
27struct FfiClient {
28    runtime: Mutex<Runtime>,
29    client: SimpleAgentsClient,
30}
31
32#[repr(C)]
33pub struct SAClient {
34    inner: FfiClient,
35}
36
37#[repr(C)]
38pub struct SAMessage {
39    pub role: *const c_char,
40    pub content: *const c_char,
41    pub name: *const c_char,
42    pub tool_call_id: *const c_char,
43}
44
45#[derive(Serialize)]
46struct FfiToolCallFunction {
47    name: String,
48    arguments: String,
49}
50
51#[derive(Serialize)]
52struct FfiToolCall {
53    id: String,
54    tool_type: String,
55    function: FfiToolCallFunction,
56}
57
58#[derive(Serialize)]
59struct FfiUsage {
60    prompt_tokens: u32,
61    completion_tokens: u32,
62    total_tokens: u32,
63}
64
65#[derive(Serialize)]
66struct FfiHealingData {
67    value: JsonValue,
68    flags: Vec<CoercionFlag>,
69    confidence: f32,
70}
71
72#[derive(Serialize)]
73struct FfiCompletionResult {
74    id: String,
75    model: String,
76    role: String,
77    content: Option<String>,
78    tool_calls: Option<Vec<FfiToolCall>>,
79    finish_reason: Option<String>,
80    usage: FfiUsage,
81    raw: Option<String>,
82    healed: Option<FfiHealingData>,
83    coerced: Option<FfiHealingData>,
84}
85
86thread_local! {
87    static LAST_ERROR: RefCell<Option<String>> = const { RefCell::new(None) };
88}
89
90fn set_last_error(message: impl Into<String>) {
91    LAST_ERROR.with(|slot| {
92        *slot.borrow_mut() = Some(message.into());
93    });
94}
95
96fn clear_last_error() {
97    LAST_ERROR.with(|slot| {
98        *slot.borrow_mut() = None;
99    });
100}
101
102fn take_last_error() -> Option<String> {
103    LAST_ERROR.with(|slot| slot.borrow_mut().take())
104}
105
106fn build_runtime() -> Result<Runtime> {
107    Runtime::new().map_err(|e| SimpleAgentsError::Config(format!("Failed to build runtime: {e}")))
108}
109
110fn provider_from_env(provider_name: &str) -> Result<Arc<dyn Provider>> {
111    match provider_name {
112        "openai" => Ok(Arc::new(OpenAIProvider::from_env()?)),
113        "anthropic" => Ok(Arc::new(AnthropicProvider::from_env()?)),
114        "openrouter" => Ok(Arc::new(openrouter_from_env()?)),
115        _ => Err(SimpleAgentsError::Config(format!(
116            "Unknown provider '{provider_name}'"
117        ))),
118    }
119}
120
121fn openrouter_from_env() -> Result<OpenRouterProvider> {
122    let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
123        SimpleAgentsError::Config("OPENROUTER_API_KEY environment variable is required".to_string())
124    })?;
125    let api_key = ApiKey::new(api_key)?;
126    let base_url = std::env::var("OPENROUTER_API_BASE")
127        .unwrap_or_else(|_| OpenRouterProvider::DEFAULT_BASE_URL.to_string());
128    OpenRouterProvider::with_base_url(api_key, base_url)
129}
130
131unsafe fn cstr_to_string(ptr: *const c_char, field: &str) -> Result<String> {
132    if ptr.is_null() {
133        return Err(SimpleAgentsError::Config(format!("{field} cannot be null")));
134    }
135
136    let c_str = CStr::from_ptr(ptr);
137    let value = c_str
138        .to_str()
139        .map_err(|_| SimpleAgentsError::Config(format!("{field} must be valid UTF-8")))?;
140    if value.is_empty() {
141        return Err(SimpleAgentsError::Config(format!(
142            "{field} cannot be empty"
143        )));
144    }
145
146    Ok(value.to_string())
147}
148
149unsafe fn cstr_to_optional_string(ptr: *const c_char, field: &str) -> Result<Option<String>> {
150    if ptr.is_null() {
151        return Ok(None);
152    }
153    let c_str = CStr::from_ptr(ptr);
154    let value = c_str
155        .to_str()
156        .map_err(|_| SimpleAgentsError::Config(format!("{field} must be valid UTF-8")))?;
157    if value.is_empty() {
158        return Ok(None);
159    }
160    Ok(Some(value.to_string()))
161}
162
163fn build_client(provider: Arc<dyn Provider>) -> Result<SimpleAgentsClient> {
164    SimpleAgentsClientBuilder::new()
165        .with_provider(provider)
166        .build()
167}
168
169fn build_request_from_messages(
170    model: &str,
171    messages: Vec<Message>,
172    max_tokens: i32,
173    temperature: f32,
174    top_p: f32,
175) -> Result<CompletionRequest> {
176    let mut builder = CompletionRequest::builder().model(model).messages(messages);
177
178    if max_tokens > 0 {
179        builder = builder.max_tokens(max_tokens as u32);
180    }
181
182    if temperature >= 0.0 {
183        builder = builder.temperature(temperature);
184    }
185
186    if top_p >= 0.0 {
187        builder = builder.top_p(top_p);
188    }
189
190    builder.build()
191}
192
193fn build_request(
194    model: &str,
195    prompt: &str,
196    max_tokens: i32,
197    temperature: f32,
198) -> Result<CompletionRequest> {
199    build_request_from_messages(
200        model,
201        vec![Message::user(prompt)],
202        max_tokens,
203        temperature,
204        -1.0,
205    )
206}
207
208fn schema_aliases(value: Option<&JsonValue>) -> Vec<String> {
209    value
210        .and_then(JsonValue::as_array)
211        .map(|arr| {
212            arr.iter()
213                .filter_map(|v| v.as_str().map(str::to_string))
214                .collect()
215        })
216        .unwrap_or_default()
217}
218
219fn parse_schema_field(value: &JsonValue) -> Result<SchemaField> {
220    let name = value
221        .get("name")
222        .and_then(JsonValue::as_str)
223        .ok_or_else(|| SimpleAgentsError::Config("schema field missing `name`".to_string()))?;
224    let schema_value = value.get("schema").ok_or_else(|| {
225        SimpleAgentsError::Config(format!("schema field `{name}` missing `schema`"))
226    })?;
227
228    Ok(SchemaField {
229        name: name.to_string(),
230        schema: parse_schema(schema_value)?,
231        required: value
232            .get("required")
233            .and_then(JsonValue::as_bool)
234            .unwrap_or(true),
235        aliases: schema_aliases(value.get("aliases")),
236        default: None,
237        description: None,
238        stream_annotation: StreamAnnotation::Normal,
239    })
240}
241
242fn parse_schema(value: &JsonValue) -> Result<Schema> {
243    let kind = value
244        .get("kind")
245        .and_then(JsonValue::as_str)
246        .ok_or_else(|| SimpleAgentsError::Config("schema requires `kind`".to_string()))?
247        .to_lowercase();
248
249    match kind.as_str() {
250        "string" => Ok(Schema::String),
251        "int" => Ok(Schema::Int),
252        "uint" => Ok(Schema::UInt),
253        "float" => Ok(Schema::Float),
254        "bool" => Ok(Schema::Bool),
255        "null" => Ok(Schema::Null),
256        "any" => Ok(Schema::Any),
257        "array" => {
258            let elements = value.get("elements").ok_or_else(|| {
259                SimpleAgentsError::Config("array schema requires `elements`".to_string())
260            })?;
261            Ok(Schema::array(parse_schema(elements)?))
262        }
263        "union" => {
264            let variants = value
265                .get("variants")
266                .and_then(JsonValue::as_array)
267                .ok_or_else(|| {
268                    SimpleAgentsError::Config("union schema requires `variants` array".to_string())
269                })?;
270            let schemas = variants
271                .iter()
272                .map(parse_schema)
273                .collect::<Result<Vec<_>>>()?;
274            Ok(Schema::union(schemas))
275        }
276        "object" => {
277            let fields = value
278                .get("fields")
279                .and_then(JsonValue::as_array)
280                .ok_or_else(|| {
281                    SimpleAgentsError::Config("object schema requires `fields` array".to_string())
282                })?;
283            let converted = fields
284                .iter()
285                .map(parse_schema_field)
286                .collect::<Result<Vec<_>>>()?;
287            Ok(Schema::Object(ObjectSchema {
288                fields: converted,
289                allow_additional_fields: value
290                    .get("allow_additional_fields")
291                    .and_then(JsonValue::as_bool)
292                    .unwrap_or(false),
293            }))
294        }
295        other => Err(SimpleAgentsError::Config(format!(
296            "unsupported schema kind `{other}`"
297        ))),
298    }
299}
300
301fn completion_options(mode: Option<&str>, schema_json: Option<&str>) -> Result<CompletionOptions> {
302    let mode = match mode.map(|m| m.to_ascii_lowercase()) {
303        None => CompletionMode::Standard,
304        Some(m) if m.is_empty() || m == "standard" => CompletionMode::Standard,
305        Some(m) if m == "healed_json" => CompletionMode::HealedJson,
306        Some(m) if m == "schema" => {
307            let raw_schema = schema_json.ok_or_else(|| {
308                SimpleAgentsError::Config("mode `schema` requires `schema_json`".to_string())
309            })?;
310            let value: JsonValue = serde_json::from_str(raw_schema)
311                .map_err(|e| SimpleAgentsError::Config(format!("invalid `schema_json`: {e}")))?;
312            CompletionMode::CoercedSchema(parse_schema(&value)?)
313        }
314        Some(other) => {
315            return Err(SimpleAgentsError::Config(format!(
316                "unknown mode `{other}` (expected standard|healed_json|schema)"
317            )))
318        }
319    };
320
321    Ok(CompletionOptions { mode })
322}
323
324fn role_to_string(role: Role) -> String {
325    match role {
326        Role::User => "user".to_string(),
327        Role::Assistant => "assistant".to_string(),
328        Role::System => "system".to_string(),
329        Role::Tool => "tool".to_string(),
330    }
331}
332
333fn finish_reason_to_string(finish_reason: FinishReason) -> String {
334    match finish_reason {
335        FinishReason::Stop => "stop".to_string(),
336        FinishReason::Length => "length".to_string(),
337        FinishReason::ContentFilter => "content_filter".to_string(),
338        FinishReason::ToolCalls => "tool_calls".to_string(),
339    }
340}
341
342fn tool_type_to_string(tool_type: ToolType) -> String {
343    match tool_type {
344        ToolType::Function => "function".to_string(),
345    }
346}
347
348fn usage_to_ffi(usage: Usage) -> FfiUsage {
349    FfiUsage {
350        prompt_tokens: usage.prompt_tokens,
351        completion_tokens: usage.completion_tokens,
352        total_tokens: usage.total_tokens,
353    }
354}
355
356fn map_tool_calls(tool_calls: Option<Vec<ToolCall>>) -> Option<Vec<FfiToolCall>> {
357    tool_calls.map(|calls| {
358        calls
359            .into_iter()
360            .map(|call| FfiToolCall {
361                id: call.id,
362                tool_type: tool_type_to_string(call.tool_type),
363                function: FfiToolCallFunction {
364                    name: call.function.name,
365                    arguments: call.function.arguments,
366                },
367            })
368            .collect()
369    })
370}
371
372fn healing_data_from(result: CoercionResult<JsonValue>) -> FfiHealingData {
373    FfiHealingData {
374        value: result.value,
375        flags: result.flags,
376        confidence: result.confidence,
377    }
378}
379
380fn completion_result_from_response(
381    response: CompletionResponse,
382    healed: Option<FfiHealingData>,
383    coerced: Option<FfiHealingData>,
384) -> FfiCompletionResult {
385    let content = response.content().map(str::to_string);
386    let choice = response.choices.first();
387    let role = choice
388        .map(|c| role_to_string(c.message.role))
389        .unwrap_or_else(|| "assistant".to_string());
390    let finish_reason = choice.map(|c| finish_reason_to_string(c.finish_reason));
391    let tool_calls = choice.and_then(|c| c.message.tool_calls.clone());
392    let usage = response.usage;
393
394    FfiCompletionResult {
395        id: response.id.clone(),
396        model: response.model.clone(),
397        role: role.clone(),
398        content: content.clone(),
399        tool_calls: map_tool_calls(tool_calls),
400        finish_reason,
401        usage: usage_to_ffi(usage),
402        raw: content,
403        healed,
404        coerced,
405    }
406}
407
408fn parse_messages(messages: *const SAMessage, messages_len: usize) -> Result<Vec<Message>> {
409    if messages.is_null() {
410        return Err(SimpleAgentsError::Config(
411            "messages cannot be null".to_string(),
412        ));
413    }
414    if messages_len == 0 {
415        return Err(SimpleAgentsError::Config(
416            "messages cannot be empty".to_string(),
417        ));
418    }
419
420    let input = unsafe { std::slice::from_raw_parts(messages, messages_len) };
421    input
422        .iter()
423        .enumerate()
424        .map(|(idx, msg)| {
425            let role = unsafe { cstr_to_string(msg.role, &format!("messages[{idx}].role"))? }
426                .to_ascii_lowercase();
427            let content =
428                unsafe { cstr_to_string(msg.content, &format!("messages[{idx}].content"))? };
429            let name =
430                unsafe { cstr_to_optional_string(msg.name, &format!("messages[{idx}].name"))? };
431            let tool_call_id = unsafe {
432                cstr_to_optional_string(msg.tool_call_id, &format!("messages[{idx}].tool_call_id"))?
433            };
434
435            let parsed = match role.as_str() {
436                "user" => Message::user(content),
437                "assistant" => Message::assistant(content),
438                "system" => Message::system(content),
439                "tool" => {
440                    let call_id = tool_call_id.ok_or_else(|| {
441                        SimpleAgentsError::Config(format!(
442                            "messages[{idx}].tool_call_id is required for tool role"
443                        ))
444                    })?;
445                    Message::tool(content, call_id)
446                }
447                _ => {
448                    return Err(SimpleAgentsError::Config(format!(
449                        "messages[{idx}].role must be one of user|assistant|system|tool"
450                    )))
451                }
452            };
453
454            Ok(match name {
455                Some(name) => parsed.with_name(name),
456                None => parsed,
457            })
458        })
459        .collect()
460}
461
462fn ffi_result_string(result: Result<String>) -> *mut c_char {
463    match result {
464        Ok(value) => match CString::new(value) {
465            Ok(c_string) => {
466                clear_last_error();
467                c_string.into_raw()
468            }
469            Err(_) => {
470                set_last_error("Response contained an interior null byte".to_string());
471                std::ptr::null_mut()
472            }
473        },
474        Err(error) => {
475            set_last_error(error.to_string());
476            std::ptr::null_mut()
477        }
478    }
479}
480
481fn ffi_guard<T>(action: impl FnOnce() -> Result<T>) -> *mut c_char
482where
483    T: Into<String>,
484{
485    let result = catch_unwind(AssertUnwindSafe(action));
486    match result {
487        Ok(inner) => ffi_result_string(inner.map(Into::into)),
488        Err(_) => {
489            set_last_error("Panic occurred in FFI call".to_string());
490            std::ptr::null_mut()
491        }
492    }
493}
494
495/// Create a client from environment variables for a provider.
496///
497/// `provider_name` must be one of: "openai", "anthropic", "openrouter".
498///
499/// # Safety
500///
501/// The `provider_name` pointer must be a valid null-terminated C string or null.
502/// The returned pointer must be freed with `sa_client_free`.
503#[no_mangle]
504pub unsafe extern "C" fn sa_client_new_from_env(provider_name: *const c_char) -> *mut SAClient {
505    let result = catch_unwind(AssertUnwindSafe(|| -> Result<Box<SAClient>> {
506        let provider = cstr_to_string(provider_name, "provider_name")?;
507        let provider = provider_from_env(&provider)?;
508        let client = build_client(provider)?;
509        let runtime = build_runtime()?;
510
511        Ok(Box::new(SAClient {
512            inner: FfiClient {
513                runtime: Mutex::new(runtime),
514                client,
515            },
516        }))
517    }));
518
519    match result {
520        Ok(Ok(client)) => {
521            clear_last_error();
522            Box::into_raw(client)
523        }
524        Ok(Err(error)) => {
525            set_last_error(error.to_string());
526            std::ptr::null_mut()
527        }
528        Err(_) => {
529            set_last_error("Panic occurred in sa_client_new_from_env".to_string());
530            std::ptr::null_mut()
531        }
532    }
533}
534
535/// Free a client created by `sa_client_new_from_env`.
536///
537/// # Safety
538///
539/// The `client` pointer must be null or a valid pointer returned by `sa_client_new_from_env`.
540/// After calling this function, the pointer is no longer valid and must not be used.
541#[no_mangle]
542pub unsafe extern "C" fn sa_client_free(client: *mut SAClient) {
543    if client.is_null() {
544        return;
545    }
546
547    drop(Box::from_raw(client));
548}
549
550/// Execute a completion request with a single user prompt.
551///
552/// Use `max_tokens <= 0` to omit, and `temperature < 0.0` to omit.
553///
554/// # Safety
555///
556/// The `client` pointer must be a valid pointer returned by `sa_client_new_from_env`.
557/// The `model` and `prompt` pointers must be valid null-terminated C strings.
558/// The returned pointer must be freed with `sa_string_free`.
559#[no_mangle]
560pub unsafe extern "C" fn sa_complete(
561    client: *mut SAClient,
562    model: *const c_char,
563    prompt: *const c_char,
564    max_tokens: i32,
565    temperature: f32,
566) -> *mut c_char {
567    if client.is_null() {
568        set_last_error("client cannot be null".to_string());
569        return std::ptr::null_mut();
570    }
571
572    ffi_guard(|| {
573        let model = cstr_to_string(model, "model")?;
574        let prompt = cstr_to_string(prompt, "prompt")?;
575        let request = build_request(&model, &prompt, max_tokens, temperature)?;
576
577        let client = &(*client).inner;
578        let runtime = client
579            .runtime
580            .lock()
581            .map_err(|_| SimpleAgentsError::Config("runtime lock poisoned".to_string()))?;
582        let outcome = runtime.block_on(
583            client
584                .client
585                .complete(&request, CompletionOptions::default()),
586        )?;
587        let response = match outcome {
588            CompletionOutcome::Response(response) => response,
589            CompletionOutcome::Stream(_) => {
590                return Err(SimpleAgentsError::Config(
591                    "streaming response returned from complete".to_string(),
592                ))
593            }
594            CompletionOutcome::HealedJson(_) => {
595                return Err(SimpleAgentsError::Config(
596                    "healed json response returned from complete".to_string(),
597                ))
598            }
599            CompletionOutcome::CoercedSchema(_) => {
600                return Err(SimpleAgentsError::Config(
601                    "schema response returned from complete".to_string(),
602                ))
603            }
604        };
605
606        Ok(response.content().unwrap_or_default().to_string())
607    })
608}
609
610/// Execute a completion request with full message input and return a structured JSON payload.
611///
612/// Use `max_tokens <= 0`, `temperature < 0.0`, or `top_p < 0.0` to omit those options.
613/// `mode` supports `standard`, `healed_json`, and `schema`; when mode is `schema`, `schema_json`
614/// must be a JSON object with the internal schema shape.
615///
616/// # Safety
617///
618/// - `client` must be a pointer returned by `sa_client_new_from_env`.
619/// - `model` must be a valid null-terminated C string.
620/// - `messages` must point to `messages_len` valid `SAMessage` values.
621/// - Returned string must be freed with `sa_string_free`.
622#[no_mangle]
623pub unsafe extern "C" fn sa_complete_messages_json(
624    client: *mut SAClient,
625    model: *const c_char,
626    messages: *const SAMessage,
627    messages_len: usize,
628    max_tokens: i32,
629    temperature: f32,
630    top_p: f32,
631    mode: *const c_char,
632    schema_json: *const c_char,
633) -> *mut c_char {
634    if client.is_null() {
635        set_last_error("client cannot be null".to_string());
636        return std::ptr::null_mut();
637    }
638
639    ffi_guard(|| {
640        let model = cstr_to_string(model, "model")?;
641        let messages = parse_messages(messages, messages_len)?;
642        let request =
643            build_request_from_messages(&model, messages, max_tokens, temperature, top_p)?;
644
645        let mode = cstr_to_optional_string(mode, "mode")?;
646        let schema_json = cstr_to_optional_string(schema_json, "schema_json")?;
647        let options = completion_options(mode.as_deref(), schema_json.as_deref())?;
648
649        let client = &(*client).inner;
650        let runtime = client
651            .runtime
652            .lock()
653            .map_err(|_| SimpleAgentsError::Config("runtime lock poisoned".to_string()))?;
654        let outcome = runtime.block_on(client.client.complete(&request, options))?;
655
656        let payload = match outcome {
657            CompletionOutcome::Response(response) => {
658                completion_result_from_response(response, None, None)
659            }
660            CompletionOutcome::HealedJson(HealedJsonResponse { response, parsed }) => {
661                completion_result_from_response(response, Some(healing_data_from(parsed)), None)
662            }
663            CompletionOutcome::CoercedSchema(HealedSchemaResponse {
664                response,
665                parsed,
666                coerced,
667            }) => completion_result_from_response(
668                response,
669                Some(healing_data_from(parsed)),
670                Some(healing_data_from(coerced)),
671            ),
672            CompletionOutcome::Stream(_) => {
673                return Err(SimpleAgentsError::Config(
674                    "streaming mode is not supported via sa_complete_messages_json".to_string(),
675                ))
676            }
677        };
678
679        serde_json::to_string(&payload)
680            .map_err(|e| SimpleAgentsError::Config(format!("failed to serialize result: {e}")))
681    })
682}
683
684/// Get the last error message for the current thread.
685///
686/// Returns null if there is no error. Caller must free the string.
687#[no_mangle]
688pub extern "C" fn sa_last_error_message() -> *mut c_char {
689    match take_last_error() {
690        Some(message) => match CString::new(message) {
691            Ok(c_string) => c_string.into_raw(),
692            Err(_) => std::ptr::null_mut(),
693        },
694        None => std::ptr::null_mut(),
695    }
696}
697
698/// Free a string returned by SimpleAgents FFI.
699///
700/// # Safety
701///
702/// The `value` pointer must be null or a valid pointer returned by a SimpleAgents FFI function.
703/// After calling this function, the pointer is no longer valid and must not be used.
704#[no_mangle]
705pub unsafe extern "C" fn sa_string_free(value: *mut c_char) {
706    if value.is_null() {
707        return;
708    }
709
710    drop(CString::from_raw(value));
711}