1use crate::client::LlmClient;
8use crate::tool::ToolDef;
9use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
10use openai_oxide::OpenAI;
11use openai_oxide::config::ClientConfig;
12use openai_oxide::types::responses::*;
13use serde_json::Value;
14
15#[cfg(feature = "telemetry")]
17fn record_otel_usage(response: &Response, model: &str, messages: &[Message]) {
18 let pt = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0);
19 let ct = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0);
20 let cached = response.usage.as_ref()
21 .and_then(|u| u.input_tokens_details.as_ref())
22 .and_then(|d| d.cached_tokens).unwrap_or(0);
23
24 let input = last_user_content(messages, 500);
25 let output_text = response.output_text();
26 let output = truncate_str(&output_text, 500);
27 let tool_calls: Vec<(String, String)> = response.function_calls().iter()
28 .map(|fc| (fc.name.clone(), fc.arguments.to_string()))
29 .collect();
30
31 crate::telemetry::record_llm_span(
32 "oxide.responses.api",
33 model,
34 &input,
35 &output,
36 &tool_calls,
37 &crate::telemetry::LlmUsage { prompt_tokens: pt, completion_tokens: ct, cached_tokens: cached, response_model: response.model.clone() },
38 );
39}
40
41#[cfg(not(feature = "telemetry"))]
42fn record_otel_usage(_response: &Response, _model: &str, _messages: &[Message]) {}
43
44#[cfg(feature = "telemetry")]
45fn last_user_content(messages: &[Message], max_len: usize) -> String {
46 messages.iter().rev()
47 .find(|m| matches!(m.role, Role::User | Role::Tool))
48 .map(|m| truncate_str(&m.content, max_len))
49 .unwrap_or_default()
50}
51
52#[cfg(feature = "telemetry")]
53fn truncate_str(s: &str, max_len: usize) -> String {
54 if s.len() > max_len { format!("{}...", &s[..max_len]) } else { s.to_string() }
55}
56
57pub struct OxideClient {
62 client: OpenAI,
63 pub(crate) model: String,
64 pub(crate) temperature: Option<f64>,
65 pub(crate) max_tokens: Option<u32>,
66 #[cfg(feature = "oxide-ws")]
68 ws: tokio::sync::Mutex<Option<openai_oxide::websocket::WsSession>>,
69 #[cfg(feature = "oxide-ws")]
71 ws_enabled: std::sync::atomic::AtomicBool,
72}
73
74impl OxideClient {
75 pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
77 let api_key = config
78 .api_key
79 .clone()
80 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
81 .unwrap_or_else(|| {
82 if config.base_url.is_some() {
83 "dummy_key".into()
84 } else {
85 "".into()
86 }
87 });
88
89 if api_key.is_empty() {
90 return Err(SgrError::Schema("No API key for oxide client".into()));
91 }
92
93 let mut client_config = ClientConfig::new(&api_key);
94 if let Some(ref url) = config.base_url {
95 client_config = client_config.base_url(url.clone());
96 }
97 config.apply_headers(&mut client_config);
98
99 Ok(Self {
100 client: OpenAI::with_config(client_config),
101 model: config.model.clone(),
102 temperature: Some(config.temp),
103 max_tokens: config.max_tokens,
104 #[cfg(feature = "oxide-ws")]
105 ws: tokio::sync::Mutex::new(None),
106 #[cfg(feature = "oxide-ws")]
107 ws_enabled: std::sync::atomic::AtomicBool::new(false),
108 })
109 }
110
111 #[cfg(feature = "oxide-ws")]
119 pub async fn connect_ws(&self) -> Result<(), SgrError> {
120 self.ws_enabled
121 .store(true, std::sync::atomic::Ordering::Relaxed);
122 tracing::info!(model = %self.model, "oxide WebSocket enabled (lazy connect)");
123 Ok(())
124 }
125
126 async fn send_request_auto(
128 &self,
129 request: ResponseCreateRequest,
130 ) -> Result<Response, SgrError> {
131 #[cfg(feature = "oxide-ws")]
132 if self.ws_enabled.load(std::sync::atomic::Ordering::Relaxed) {
133 let mut ws_guard = self.ws.lock().await;
134
135 if ws_guard.is_none() {
137 match self.client.ws_session().await {
138 Ok(session) => {
139 tracing::info!(model = %self.model, "oxide WS connected (lazy)");
140 *ws_guard = Some(session);
141 }
142 Err(e) => {
143 tracing::warn!("oxide WS connect failed, using HTTP: {e}");
144 self.ws_enabled
145 .store(false, std::sync::atomic::Ordering::Relaxed);
146 }
147 }
148 }
149
150 if let Some(ref mut session) = *ws_guard {
151 match session.send(request.clone()).await {
152 Ok(response) => return Ok(response),
153 Err(e) => {
154 tracing::warn!("oxide WS send failed, falling back to HTTP: {e}");
155 *ws_guard = None;
156 }
157 }
158 }
159 }
160
161 self.client
163 .responses()
164 .create(request)
165 .await
166 .map_err(|e| SgrError::Api {
167 status: 0,
168 body: e.to_string(),
169 })
170 }
171
172 pub(crate) fn build_request(
179 &self,
180 messages: &[Message],
181 schema: Option<&Value>,
182 previous_response_id: Option<&str>,
183 ) -> ResponseCreateRequest {
184 if previous_response_id.is_some() {
185 return self.build_request_items(messages, previous_response_id);
189 }
190
191 let mut input_items = Vec::new();
193
194 for msg in messages {
195 match msg.role {
196 Role::System => {
197 input_items.push(ResponseInputItem {
198 role: openai_oxide::types::common::Role::System,
199 content: Value::String(msg.content.clone()),
200 });
201 }
202 Role::User => {
203 input_items.push(ResponseInputItem {
204 role: openai_oxide::types::common::Role::User,
205 content: Value::String(msg.content.clone()),
206 });
207 }
208 Role::Assistant => {
209 let mut content = msg.content.clone();
212 if !msg.tool_calls.is_empty() {
213 for tc in &msg.tool_calls {
214 let args = tc.arguments.to_string();
215 let preview = if args.len() > 200 {
216 &args[..200]
217 } else {
218 &args
219 };
220 content.push_str(&format!("\n→ {}({})", tc.name, preview));
221 }
222 }
223 input_items.push(ResponseInputItem {
224 role: openai_oxide::types::common::Role::Assistant,
225 content: Value::String(content),
226 });
227 }
228 Role::Tool => {
229 input_items.push(ResponseInputItem {
232 role: openai_oxide::types::common::Role::User,
233 content: Value::String(msg.content.clone()),
234 });
235 }
236 }
237 }
238
239 let mut req = ResponseCreateRequest::new(&self.model);
240
241 if input_items.len() == 1 && input_items[0].role == openai_oxide::types::common::Role::User
243 {
244 if let Some(text) = input_items[0].content.as_str() {
245 req = req.input(text);
246 } else {
247 req.input = Some(ResponseInput::Messages(input_items));
248 }
249 } else if !input_items.is_empty() {
250 req.input = Some(ResponseInput::Messages(input_items));
251 }
252
253 if let Some(temp) = self.temperature
255 && (temp - 1.0).abs() > f64::EPSILON
256 {
257 req = req.temperature(temp);
258 }
259
260 if let Some(max) = self.max_tokens {
262 req = req.max_output_tokens(max as i64);
263 }
264
265 if let Some(schema_val) = schema {
267 req = req.text(ResponseTextConfig {
268 format: Some(ResponseTextFormat::JsonSchema {
269 name: "sgr_response".into(),
270 description: None,
271 schema: Some(schema_val.clone()),
272 strict: Some(true),
273 }),
274 verbosity: None,
275 });
276 }
277
278 req
279 }
280
281 fn build_request_items(
283 &self,
284 messages: &[Message],
285 previous_response_id: Option<&str>,
286 ) -> ResponseCreateRequest {
287 use openai_oxide::types::responses::ResponseInput;
288
289 let mut items: Vec<Value> = Vec::new();
290
291 for msg in messages {
292 match msg.role {
293 Role::Tool => {
294 if let Some(ref call_id) = msg.tool_call_id {
295 items.push(serde_json::json!({
296 "type": "function_call_output",
297 "call_id": call_id,
298 "output": msg.content
299 }));
300 }
301 }
302 Role::System => {
303 items.push(serde_json::json!({
304 "type": "message",
305 "role": "system",
306 "content": msg.content
307 }));
308 }
309 Role::User => {
310 items.push(serde_json::json!({
311 "type": "message",
312 "role": "user",
313 "content": msg.content
314 }));
315 }
316 Role::Assistant => {
317 items.push(serde_json::json!({
318 "type": "message",
319 "role": "assistant",
320 "content": msg.content
321 }));
322 }
323 }
324 }
325
326 let mut req = ResponseCreateRequest::new(&self.model);
327 if !items.is_empty() {
328 req.input = Some(ResponseInput::Items(items));
329 }
330
331 if let Some(temp) = self.temperature
333 && (temp - 1.0).abs() > f64::EPSILON
334 {
335 req = req.temperature(temp);
336 }
337 if let Some(max) = self.max_tokens {
338 req = req.max_output_tokens(max as i64);
339 }
340
341 if let Some(prev_id) = previous_response_id {
342 req = req.previous_response_id(prev_id);
343 }
344
345 req
346 }
347
348 async fn tools_call_stateful_impl(
360 &self,
361 messages: &[Message],
362 tools: &[ToolDef],
363 previous_response_id: Option<&str>,
364 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
365 let mut req = self.build_request(messages, None, previous_response_id);
366 req = req.store(true);
368
369 let response_tools: Vec<ResponseTool> = tools
374 .iter()
375 .map(|t| {
376 let mut params = t.parameters.clone();
377 openai_oxide::parsing::ensure_strict(&mut params);
378 ResponseTool::Function {
379 name: t.name.clone(),
380 description: if t.description.is_empty() {
381 None
382 } else {
383 Some(t.description.clone())
384 },
385 parameters: Some(params),
386 strict: Some(true),
387 }
388 })
389 .collect();
390 req = req.tools(response_tools);
391
392 let response = self.send_request_auto(req).await?;
393
394 let response_id = response.id.clone();
395 record_otel_usage(&response, &self.model, messages);
397
398 let input_tokens = response
399 .usage
400 .as_ref()
401 .and_then(|u| u.input_tokens)
402 .unwrap_or(0);
403 let cached_tokens = response
404 .usage
405 .as_ref()
406 .and_then(|u| u.input_tokens_details.as_ref())
407 .and_then(|d| d.cached_tokens)
408 .unwrap_or(0);
409
410 let chained = previous_response_id.is_some();
411 let cache_pct = if input_tokens > 0 {
412 (cached_tokens * 100) / input_tokens
413 } else {
414 0
415 };
416
417 tracing::info!(
418 model = %response.model,
419 response_id = %response_id,
420 input_tokens,
421 cached_tokens,
422 cache_pct,
423 chained,
424 "oxide.tools_call_stateful"
425 );
426
427 if cached_tokens > 0 {
428 eprintln!(
429 " 💰 {}in/{}out (cached: {}, {}%)",
430 input_tokens,
431 response
432 .usage
433 .as_ref()
434 .and_then(|u| u.output_tokens)
435 .unwrap_or(0),
436 cached_tokens,
437 cache_pct
438 );
439 } else {
440 eprintln!(
441 " 💰 {}in/{}out",
442 input_tokens,
443 response
444 .usage
445 .as_ref()
446 .and_then(|u| u.output_tokens)
447 .unwrap_or(0)
448 );
449 }
450
451 Self::check_truncation(&response)?;
452 Ok((Self::extract_tool_calls(&response), Some(response_id)))
453 }
454
455 fn check_truncation(response: &Response) -> Result<(), SgrError> {
458 let is_incomplete = response
459 .status
460 .as_deref()
461 .is_some_and(|s| s == "incomplete");
462 let is_max_tokens = response
463 .incomplete_details
464 .as_ref()
465 .and_then(|d| d.reason.as_deref())
466 .is_some_and(|r| r == "max_output_tokens");
467
468 if is_incomplete && is_max_tokens {
469 return Err(SgrError::MaxOutputTokens {
470 partial_content: response.output_text(),
471 });
472 }
473 Ok(())
474 }
475
476 fn extract_tool_calls(response: &Response) -> Vec<ToolCall> {
478 response
479 .function_calls()
480 .into_iter()
481 .map(|fc| ToolCall {
482 id: fc.call_id,
483 name: fc.name,
484 arguments: fc.arguments,
485 })
486 .collect()
487 }
488}
489
490#[async_trait::async_trait]
491impl LlmClient for OxideClient {
492 async fn structured_call(
493 &self,
494 messages: &[Message],
495 schema: &Value,
496 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
497 let strict_schema =
500 if schema.get("additionalProperties").and_then(|v| v.as_bool()) == Some(false) {
501 schema.clone()
503 } else {
504 let mut s = schema.clone();
505 openai_oxide::parsing::ensure_strict(&mut s);
506 s
507 };
508
509 let mut req = self.build_request(messages, Some(&strict_schema), None);
512 req = req.store(true);
513
514 let span = tracing::info_span!(
515 "oxide.responses.create",
516 model = %self.model,
517 method = "structured_call",
518 );
519 let _enter = span.enter();
520
521 if std::env::var("SGR_DEBUG_SCHEMA").is_ok()
523 && let Some(ref text_cfg) = req.text
524 {
525 eprintln!(
526 "[sgr] Schema: {}",
527 serde_json::to_string(text_cfg).unwrap_or_default()
528 );
529 }
530
531 let response = self.send_request_auto(req).await?;
532
533 record_otel_usage(&response, &self.model, messages);
535
536 Self::check_truncation(&response)?;
537
538 let raw_text = response.output_text();
539 if std::env::var("SGR_DEBUG").is_ok() {
540 eprintln!(
541 "[sgr] Raw response: {}",
542 &raw_text[..raw_text.len().min(500)]
543 );
544 }
545 let tool_calls = Self::extract_tool_calls(&response);
546 let parsed = serde_json::from_str::<Value>(&raw_text).ok();
547
548 let input_tokens = response
549 .usage
550 .as_ref()
551 .and_then(|u| u.input_tokens)
552 .unwrap_or(0);
553 let cached_tokens = response
554 .usage
555 .as_ref()
556 .and_then(|u| u.input_tokens_details.as_ref())
557 .and_then(|d| d.cached_tokens)
558 .unwrap_or(0);
559 let cache_pct = if input_tokens > 0 {
560 (cached_tokens * 100) / input_tokens
561 } else {
562 0
563 };
564
565 {
566 let output_tokens = response
567 .usage
568 .as_ref()
569 .and_then(|u| u.output_tokens)
570 .unwrap_or(0);
571 if cached_tokens > 0 {
572 eprintln!(
573 " 💰 {}in/{}out (cached: {}, {}%)",
574 input_tokens, output_tokens, cached_tokens, cache_pct
575 );
576 } else {
577 eprintln!(" 💰 {}in/{}out", input_tokens, output_tokens);
578 }
579 }
580
581 Ok((parsed, tool_calls, raw_text))
582 }
583
584 async fn tools_call(
585 &self,
586 messages: &[Message],
587 tools: &[ToolDef],
588 ) -> Result<Vec<ToolCall>, SgrError> {
589 let mut req = self.build_request(messages, None, None);
593 req = req.store(true);
594
595 let response_tools: Vec<ResponseTool> = tools
597 .iter()
598 .map(|t| ResponseTool::Function {
599 name: t.name.clone(),
600 description: if t.description.is_empty() {
601 None
602 } else {
603 Some(t.description.clone())
604 },
605 parameters: Some(t.parameters.clone()),
606 strict: None,
607 })
608 .collect();
609 req = req.tools(response_tools);
610
611 req = req.tool_choice(openai_oxide::types::responses::ResponseToolChoice::Mode(
614 "required".into(),
615 ));
616
617 let response = self.send_request_auto(req).await?;
618
619 record_otel_usage(&response, &self.model, messages);
620 Self::check_truncation(&response)?;
621
622 let input_tokens = response
623 .usage
624 .as_ref()
625 .and_then(|u| u.input_tokens)
626 .unwrap_or(0);
627 let cached_tokens = response
628 .usage
629 .as_ref()
630 .and_then(|u| u.input_tokens_details.as_ref())
631 .and_then(|d| d.cached_tokens)
632 .unwrap_or(0);
633 let cache_pct = if input_tokens > 0 {
634 (cached_tokens * 100) / input_tokens
635 } else {
636 0
637 };
638
639 if cached_tokens > 0 {
640 eprintln!(
641 " 💰 {}in/{}out (cached: {}, {}%)",
642 input_tokens,
643 response
644 .usage
645 .as_ref()
646 .and_then(|u| u.output_tokens)
647 .unwrap_or(0),
648 cached_tokens,
649 cache_pct
650 );
651 } else {
652 eprintln!(
653 " 💰 {}in/{}out",
654 input_tokens,
655 response
656 .usage
657 .as_ref()
658 .and_then(|u| u.output_tokens)
659 .unwrap_or(0)
660 );
661 }
662
663 let calls = Self::extract_tool_calls(&response);
664 Ok(calls)
665 }
666
667 async fn tools_call_stateful(
668 &self,
669 messages: &[Message],
670 tools: &[ToolDef],
671 previous_response_id: Option<&str>,
672 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
673 self.tools_call_stateful_impl(messages, tools, previous_response_id)
674 .await
675 }
676
677 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
678 let mut req = self.build_request(messages, None, None);
679 req = req.store(true);
680
681 let response = self.send_request_auto(req).await?;
682
683 record_otel_usage(&response, &self.model, messages);
684 Self::check_truncation(&response)?;
685
686 let text = response.output_text();
687 if text.is_empty() {
688 return Err(SgrError::EmptyResponse);
689 }
690
691 tracing::info!(
692 model = %response.model,
693 response_id = %response.id,
694 input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
695 output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
696 "oxide.complete"
697 );
698
699 Ok(text)
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
708 fn oxide_client_from_config() {
709 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
711 let client = OxideClient::from_config(&config).unwrap();
712 assert_eq!(client.model, "gpt-5.4");
713 }
714
715 #[test]
716 fn build_request_simple() {
717 let config = LlmConfig::with_key("sk-test", "gpt-5.4").temperature(0.5);
718 let client = OxideClient::from_config(&config).unwrap();
719 let messages = vec![Message::system("Be helpful."), Message::user("Hello")];
720 let req = client.build_request(&messages, None, None);
721 assert_eq!(req.model, "gpt-5.4");
722 assert!(req.instructions.is_none());
723 assert!(req.input.is_some());
724 assert_eq!(req.temperature, Some(0.5));
725 }
726
727 #[test]
728 fn build_request_with_schema() {
729 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
730 let client = OxideClient::from_config(&config).unwrap();
731 let schema = serde_json::json!({
732 "type": "object",
733 "properties": {"answer": {"type": "string"}},
734 "required": ["answer"]
735 });
736 let req = client.build_request(&[Message::user("Hi")], Some(&schema), None);
737 assert!(req.text.is_some());
738 }
739
740 #[test]
741 fn build_request_stateless_no_previous_response_id() {
742 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
743 let client = OxideClient::from_config(&config).unwrap();
744
745 let req = client.build_request(&[Message::user("Hi")], None, None);
746 assert!(
747 req.previous_response_id.is_none(),
748 "build_request must be stateless when no explicit ID"
749 );
750 }
751
752 #[test]
753 fn build_request_explicit_chaining() {
754 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
755 let client = OxideClient::from_config(&config).unwrap();
756
757 let req = client.build_request(&[Message::user("Hi")], None, Some("resp_xyz"));
759 assert_eq!(
760 req.previous_response_id.as_deref(),
761 Some("resp_xyz"),
762 "build_request should chain with explicit previous_response_id"
763 );
764 }
765
766 #[test]
767 fn build_request_tool_outputs_chaining() {
768 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
769 let client = OxideClient::from_config(&config).unwrap();
770
771 let messages = vec![Message::tool("call_1", "result data")];
773 let req = client.build_request(&messages, None, Some("resp_123"));
774 assert_eq!(req.previous_response_id.as_deref(), Some("resp_123"));
775
776 let req = client.build_request(&messages, None, None);
778 assert!(
779 req.previous_response_id.is_none(),
780 "build_request must be stateless when no explicit ID"
781 );
782 }
783}