1use simple_agent_type::message::Message;
4use simple_agent_type::prelude::{ApiKey, CompletionRequest, Provider, Result, SimpleAgentsError};
5use simple_agents_core::{
6 CompletionOptions, CompletionOutcome, SimpleAgentsClient, SimpleAgentsClientBuilder,
7};
8use simple_agents_providers::anthropic::AnthropicProvider;
9use simple_agents_providers::openai::OpenAIProvider;
10use simple_agents_providers::openrouter::OpenRouterProvider;
11use std::cell::RefCell;
12use std::ffi::{CStr, CString};
13use std::os::raw::c_char;
14use std::panic::{catch_unwind, AssertUnwindSafe};
15use std::sync::{Arc, Mutex};
16
17type Runtime = tokio::runtime::Runtime;
18
19struct FfiClient {
20 runtime: Mutex<Runtime>,
21 client: SimpleAgentsClient,
22}
23
24#[repr(C)]
25pub struct SAClient {
26 inner: FfiClient,
27}
28
29thread_local! {
30 static LAST_ERROR: RefCell<Option<String>> = const { RefCell::new(None) };
31}
32
33fn set_last_error(message: impl Into<String>) {
34 LAST_ERROR.with(|slot| {
35 *slot.borrow_mut() = Some(message.into());
36 });
37}
38
39fn clear_last_error() {
40 LAST_ERROR.with(|slot| {
41 *slot.borrow_mut() = None;
42 });
43}
44
45fn take_last_error() -> Option<String> {
46 LAST_ERROR.with(|slot| slot.borrow_mut().take())
47}
48
49fn build_runtime() -> Result<Runtime> {
50 Runtime::new().map_err(|e| SimpleAgentsError::Config(format!("Failed to build runtime: {e}")))
51}
52
53fn provider_from_env(provider_name: &str) -> Result<Arc<dyn Provider>> {
54 match provider_name {
55 "openai" => Ok(Arc::new(OpenAIProvider::from_env()?)),
56 "anthropic" => Ok(Arc::new(AnthropicProvider::from_env()?)),
57 "openrouter" => Ok(Arc::new(openrouter_from_env()?)),
58 _ => Err(SimpleAgentsError::Config(format!(
59 "Unknown provider '{provider_name}'"
60 ))),
61 }
62}
63
64fn openrouter_from_env() -> Result<OpenRouterProvider> {
65 let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
66 SimpleAgentsError::Config("OPENROUTER_API_KEY environment variable is required".to_string())
67 })?;
68 let api_key = ApiKey::new(api_key)?;
69 let base_url = std::env::var("OPENROUTER_API_BASE")
70 .unwrap_or_else(|_| OpenRouterProvider::DEFAULT_BASE_URL.to_string());
71 OpenRouterProvider::with_base_url(api_key, base_url)
72}
73
74unsafe fn cstr_to_string(ptr: *const c_char, field: &str) -> Result<String> {
75 if ptr.is_null() {
76 return Err(SimpleAgentsError::Config(format!("{field} cannot be null")));
77 }
78
79 let c_str = CStr::from_ptr(ptr);
80 let value = c_str
81 .to_str()
82 .map_err(|_| SimpleAgentsError::Config(format!("{field} must be valid UTF-8")))?;
83 if value.is_empty() {
84 return Err(SimpleAgentsError::Config(format!(
85 "{field} cannot be empty"
86 )));
87 }
88
89 Ok(value.to_string())
90}
91
92fn build_client(provider: Arc<dyn Provider>) -> Result<SimpleAgentsClient> {
93 SimpleAgentsClientBuilder::new()
94 .with_provider(provider)
95 .build()
96}
97
98fn build_request(
99 model: &str,
100 prompt: &str,
101 max_tokens: i32,
102 temperature: f32,
103) -> Result<CompletionRequest> {
104 let mut builder = CompletionRequest::builder()
105 .model(model)
106 .message(Message::user(prompt));
107
108 if max_tokens > 0 {
109 builder = builder.max_tokens(max_tokens as u32);
110 }
111
112 if temperature >= 0.0 {
113 builder = builder.temperature(temperature);
114 }
115
116 builder.build()
117}
118
119fn ffi_result_string(result: Result<String>) -> *mut c_char {
120 match result {
121 Ok(value) => match CString::new(value) {
122 Ok(c_string) => {
123 clear_last_error();
124 c_string.into_raw()
125 }
126 Err(_) => {
127 set_last_error("Response contained an interior null byte".to_string());
128 std::ptr::null_mut()
129 }
130 },
131 Err(error) => {
132 set_last_error(error.to_string());
133 std::ptr::null_mut()
134 }
135 }
136}
137
138fn ffi_guard<T>(action: impl FnOnce() -> Result<T>) -> *mut c_char
139where
140 T: Into<String>,
141{
142 let result = catch_unwind(AssertUnwindSafe(action));
143 match result {
144 Ok(inner) => ffi_result_string(inner.map(Into::into)),
145 Err(_) => {
146 set_last_error("Panic occurred in FFI call".to_string());
147 std::ptr::null_mut()
148 }
149 }
150}
151
152#[no_mangle]
161pub unsafe extern "C" fn sa_client_new_from_env(provider_name: *const c_char) -> *mut SAClient {
162 let result = catch_unwind(AssertUnwindSafe(|| -> Result<Box<SAClient>> {
163 let provider = cstr_to_string(provider_name, "provider_name")?;
164 let provider = provider_from_env(&provider)?;
165 let client = build_client(provider)?;
166 let runtime = build_runtime()?;
167
168 Ok(Box::new(SAClient {
169 inner: FfiClient {
170 runtime: Mutex::new(runtime),
171 client,
172 },
173 }))
174 }));
175
176 match result {
177 Ok(Ok(client)) => {
178 clear_last_error();
179 Box::into_raw(client)
180 }
181 Ok(Err(error)) => {
182 set_last_error(error.to_string());
183 std::ptr::null_mut()
184 }
185 Err(_) => {
186 set_last_error("Panic occurred in sa_client_new_from_env".to_string());
187 std::ptr::null_mut()
188 }
189 }
190}
191
192#[no_mangle]
199pub unsafe extern "C" fn sa_client_free(client: *mut SAClient) {
200 if client.is_null() {
201 return;
202 }
203
204 drop(Box::from_raw(client));
205}
206
207#[no_mangle]
217pub unsafe extern "C" fn sa_complete(
218 client: *mut SAClient,
219 model: *const c_char,
220 prompt: *const c_char,
221 max_tokens: i32,
222 temperature: f32,
223) -> *mut c_char {
224 if client.is_null() {
225 set_last_error("client cannot be null".to_string());
226 return std::ptr::null_mut();
227 }
228
229 ffi_guard(|| {
230 let model = cstr_to_string(model, "model")?;
231 let prompt = cstr_to_string(prompt, "prompt")?;
232 let request = build_request(&model, &prompt, max_tokens, temperature)?;
233
234 let client = &(*client).inner;
235 let runtime = client
236 .runtime
237 .lock()
238 .map_err(|_| SimpleAgentsError::Config("runtime lock poisoned".to_string()))?;
239 let outcome = runtime.block_on(
240 client
241 .client
242 .complete(&request, CompletionOptions::default()),
243 )?;
244 let response = match outcome {
245 CompletionOutcome::Response(response) => response,
246 CompletionOutcome::Stream(_) => {
247 return Err(SimpleAgentsError::Config(
248 "streaming response returned from complete".to_string(),
249 ))
250 }
251 CompletionOutcome::HealedJson(_) => {
252 return Err(SimpleAgentsError::Config(
253 "healed json response returned from complete".to_string(),
254 ))
255 }
256 CompletionOutcome::CoercedSchema(_) => {
257 return Err(SimpleAgentsError::Config(
258 "schema response returned from complete".to_string(),
259 ))
260 }
261 };
262
263 Ok(response.content().unwrap_or_default().to_string())
264 })
265}
266
267#[no_mangle]
271pub extern "C" fn sa_last_error_message() -> *mut c_char {
272 match take_last_error() {
273 Some(message) => match CString::new(message) {
274 Ok(c_string) => c_string.into_raw(),
275 Err(_) => std::ptr::null_mut(),
276 },
277 None => std::ptr::null_mut(),
278 }
279}
280
281#[no_mangle]
288pub unsafe extern "C" fn sa_string_free(value: *mut c_char) {
289 if value.is_null() {
290 return;
291 }
292
293 drop(CString::from_raw(value));
294}