1use crate::client::LlmClient;
8use crate::tool::ToolDef;
9use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
10use openai_oxide::OpenAI;
11use openai_oxide::config::ClientConfig;
12use openai_oxide::types::responses::*;
13use serde_json::Value;
14
15#[cfg(feature = "telemetry")]
17fn record_otel_usage(response: &Response, model: &str, messages: &[Message]) {
18 let pt = response
19 .usage
20 .as_ref()
21 .and_then(|u| u.input_tokens)
22 .unwrap_or(0);
23 let ct = response
24 .usage
25 .as_ref()
26 .and_then(|u| u.output_tokens)
27 .unwrap_or(0);
28 let cached = response
29 .usage
30 .as_ref()
31 .and_then(|u| u.input_tokens_details.as_ref())
32 .and_then(|d| d.cached_tokens)
33 .unwrap_or(0);
34
35 let input = last_user_content(messages, 500);
36 let output_text = response.output_text();
37 let output = truncate_str(&output_text, 500);
38 let tool_calls: Vec<(String, String)> = response
39 .function_calls()
40 .iter()
41 .map(|fc| (fc.name.clone(), fc.arguments.to_string()))
42 .collect();
43
44 crate::telemetry::record_llm_span(
45 "oxide.responses.api",
46 model,
47 &input,
48 &output,
49 &tool_calls,
50 &crate::telemetry::LlmUsage {
51 prompt_tokens: pt,
52 completion_tokens: ct,
53 cached_tokens: cached,
54 response_model: response.model.clone(),
55 },
56 );
57}
58
59#[cfg(not(feature = "telemetry"))]
60fn record_otel_usage(_response: &Response, _model: &str, _messages: &[Message]) {}
61
62#[cfg(feature = "telemetry")]
63fn last_user_content(messages: &[Message], max_len: usize) -> String {
64 messages
65 .iter()
66 .rev()
67 .find(|m| matches!(m.role, Role::User | Role::Tool))
68 .map(|m| truncate_str(&m.content, max_len))
69 .unwrap_or_default()
70}
71
72#[cfg(feature = "telemetry")]
73fn truncate_str(s: &str, max_len: usize) -> String {
74 if s.len() > max_len {
75 format!("{}...", &s[..max_len])
76 } else {
77 s.to_string()
78 }
79}
80
81pub struct OxideClient {
86 client: OpenAI,
87 pub(crate) model: String,
88 pub(crate) temperature: Option<f64>,
89 pub(crate) max_tokens: Option<u32>,
90 #[cfg(feature = "oxide-ws")]
92 ws: tokio::sync::Mutex<Option<openai_oxide::websocket::WsSession>>,
93 #[cfg(feature = "oxide-ws")]
95 ws_enabled: std::sync::atomic::AtomicBool,
96}
97
98impl OxideClient {
99 pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
101 let api_key = config
102 .api_key
103 .clone()
104 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
105 .unwrap_or_else(|| {
106 if config.base_url.is_some() {
107 "dummy_key".into()
108 } else {
109 "".into()
110 }
111 });
112
113 if api_key.is_empty() {
114 return Err(SgrError::Schema("No API key for oxide client".into()));
115 }
116
117 let mut client_config = ClientConfig::new(&api_key);
118 if let Some(ref url) = config.base_url {
119 client_config = client_config.base_url(url.clone());
120 }
121 config.apply_headers(&mut client_config);
122
123 Ok(Self {
124 client: OpenAI::with_config(client_config),
125 model: config.model.clone(),
126 temperature: Some(config.temp),
127 max_tokens: config.max_tokens,
128 #[cfg(feature = "oxide-ws")]
129 ws: tokio::sync::Mutex::new(None),
130 #[cfg(feature = "oxide-ws")]
131 ws_enabled: std::sync::atomic::AtomicBool::new(false),
132 })
133 }
134
135 #[cfg(feature = "oxide-ws")]
143 pub async fn connect_ws(&self) -> Result<(), SgrError> {
144 self.ws_enabled
145 .store(true, std::sync::atomic::Ordering::Relaxed);
146 tracing::info!(model = %self.model, "oxide WebSocket enabled (lazy connect)");
147 Ok(())
148 }
149
150 async fn send_request_auto(
152 &self,
153 request: ResponseCreateRequest,
154 ) -> Result<Response, SgrError> {
155 #[cfg(feature = "oxide-ws")]
156 if self.ws_enabled.load(std::sync::atomic::Ordering::Relaxed) {
157 let mut ws_guard = self.ws.lock().await;
158
159 if ws_guard.is_none() {
161 match self.client.ws_session().await {
162 Ok(session) => {
163 tracing::info!(model = %self.model, "oxide WS connected (lazy)");
164 *ws_guard = Some(session);
165 }
166 Err(e) => {
167 tracing::warn!("oxide WS connect failed, using HTTP: {e}");
168 self.ws_enabled
169 .store(false, std::sync::atomic::Ordering::Relaxed);
170 }
171 }
172 }
173
174 if let Some(ref mut session) = *ws_guard {
175 match session.send(request.clone()).await {
176 Ok(response) => return Ok(response),
177 Err(e) => {
178 tracing::warn!("oxide WS send failed, falling back to HTTP: {e}");
179 *ws_guard = None;
180 }
181 }
182 }
183 }
184
185 self.client
187 .responses()
188 .create(request)
189 .await
190 .map_err(|e| SgrError::Api {
191 status: 0,
192 body: e.to_string(),
193 })
194 }
195
196 pub(crate) fn build_request(
203 &self,
204 messages: &[Message],
205 schema: Option<&Value>,
206 previous_response_id: Option<&str>,
207 ) -> ResponseCreateRequest {
208 if previous_response_id.is_some() {
209 return self.build_request_items(messages, previous_response_id);
213 }
214
215 let mut input_items = Vec::new();
217
218 for msg in messages {
219 match msg.role {
220 Role::System => {
221 input_items.push(ResponseInputItem {
222 role: openai_oxide::types::common::Role::System,
223 content: Value::String(msg.content.clone()),
224 });
225 }
226 Role::User => {
227 input_items.push(ResponseInputItem {
228 role: openai_oxide::types::common::Role::User,
229 content: Value::String(msg.content.clone()),
230 });
231 }
232 Role::Assistant => {
233 let mut content = msg.content.clone();
236 if !msg.tool_calls.is_empty() {
237 for tc in &msg.tool_calls {
238 let args = tc.arguments.to_string();
239 let preview = if args.len() > 200 {
240 &args[..200]
241 } else {
242 &args
243 };
244 content.push_str(&format!("\n→ {}({})", tc.name, preview));
245 }
246 }
247 input_items.push(ResponseInputItem {
248 role: openai_oxide::types::common::Role::Assistant,
249 content: Value::String(content),
250 });
251 }
252 Role::Tool => {
253 input_items.push(ResponseInputItem {
256 role: openai_oxide::types::common::Role::User,
257 content: Value::String(msg.content.clone()),
258 });
259 }
260 }
261 }
262
263 let mut req = ResponseCreateRequest::new(&self.model);
264
265 if input_items.len() == 1 && input_items[0].role == openai_oxide::types::common::Role::User
267 {
268 if let Some(text) = input_items[0].content.as_str() {
269 req = req.input(text);
270 } else {
271 req.input = Some(ResponseInput::Messages(input_items));
272 }
273 } else if !input_items.is_empty() {
274 req.input = Some(ResponseInput::Messages(input_items));
275 }
276
277 if let Some(temp) = self.temperature
279 && (temp - 1.0).abs() > f64::EPSILON
280 {
281 req = req.temperature(temp);
282 }
283
284 if let Some(max) = self.max_tokens {
286 req = req.max_output_tokens(max as i64);
287 }
288
289 if let Some(schema_val) = schema {
291 req = req.text(ResponseTextConfig {
292 format: Some(ResponseTextFormat::JsonSchema {
293 name: "sgr_response".into(),
294 description: None,
295 schema: Some(schema_val.clone()),
296 strict: Some(true),
297 }),
298 verbosity: None,
299 });
300 }
301
302 req
303 }
304
305 fn build_request_items(
307 &self,
308 messages: &[Message],
309 previous_response_id: Option<&str>,
310 ) -> ResponseCreateRequest {
311 use openai_oxide::types::responses::ResponseInput;
312
313 let mut items: Vec<Value> = Vec::new();
314
315 for msg in messages {
316 match msg.role {
317 Role::Tool => {
318 if let Some(ref call_id) = msg.tool_call_id {
319 items.push(serde_json::json!({
320 "type": "function_call_output",
321 "call_id": call_id,
322 "output": msg.content
323 }));
324 }
325 }
326 Role::System => {
327 items.push(serde_json::json!({
328 "type": "message",
329 "role": "system",
330 "content": msg.content
331 }));
332 }
333 Role::User => {
334 items.push(serde_json::json!({
335 "type": "message",
336 "role": "user",
337 "content": msg.content
338 }));
339 }
340 Role::Assistant => {
341 items.push(serde_json::json!({
342 "type": "message",
343 "role": "assistant",
344 "content": msg.content
345 }));
346 }
347 }
348 }
349
350 let mut req = ResponseCreateRequest::new(&self.model);
351 if !items.is_empty() {
352 req.input = Some(ResponseInput::Items(items));
353 }
354
355 if let Some(temp) = self.temperature
357 && (temp - 1.0).abs() > f64::EPSILON
358 {
359 req = req.temperature(temp);
360 }
361 if let Some(max) = self.max_tokens {
362 req = req.max_output_tokens(max as i64);
363 }
364
365 if let Some(prev_id) = previous_response_id {
366 req = req.previous_response_id(prev_id);
367 }
368
369 req
370 }
371
372 async fn tools_call_stateful_impl(
384 &self,
385 messages: &[Message],
386 tools: &[ToolDef],
387 previous_response_id: Option<&str>,
388 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
389 let mut req = self.build_request(messages, None, previous_response_id);
390 req = req.store(true);
392
393 let response_tools: Vec<ResponseTool> = tools
398 .iter()
399 .map(|t| {
400 let mut params = t.parameters.clone();
401 openai_oxide::parsing::ensure_strict(&mut params);
402 ResponseTool::Function {
403 name: t.name.clone(),
404 description: if t.description.is_empty() {
405 None
406 } else {
407 Some(t.description.clone())
408 },
409 parameters: Some(params),
410 strict: Some(true),
411 }
412 })
413 .collect();
414 req = req.tools(response_tools);
415
416 let response = self.send_request_auto(req).await?;
417
418 let response_id = response.id.clone();
419 record_otel_usage(&response, &self.model, messages);
421
422 let input_tokens = response
423 .usage
424 .as_ref()
425 .and_then(|u| u.input_tokens)
426 .unwrap_or(0);
427 let cached_tokens = response
428 .usage
429 .as_ref()
430 .and_then(|u| u.input_tokens_details.as_ref())
431 .and_then(|d| d.cached_tokens)
432 .unwrap_or(0);
433
434 let chained = previous_response_id.is_some();
435 let cache_pct = if input_tokens > 0 {
436 (cached_tokens * 100) / input_tokens
437 } else {
438 0
439 };
440
441 tracing::info!(
442 model = %response.model,
443 response_id = %response_id,
444 input_tokens,
445 cached_tokens,
446 cache_pct,
447 chained,
448 "oxide.tools_call_stateful"
449 );
450
451 if cached_tokens > 0 {
452 eprintln!(
453 " 💰 {}in/{}out (cached: {}, {}%)",
454 input_tokens,
455 response
456 .usage
457 .as_ref()
458 .and_then(|u| u.output_tokens)
459 .unwrap_or(0),
460 cached_tokens,
461 cache_pct
462 );
463 } else {
464 eprintln!(
465 " 💰 {}in/{}out",
466 input_tokens,
467 response
468 .usage
469 .as_ref()
470 .and_then(|u| u.output_tokens)
471 .unwrap_or(0)
472 );
473 }
474
475 Self::check_truncation(&response)?;
476 Ok((Self::extract_tool_calls(&response), Some(response_id)))
477 }
478
479 fn check_truncation(response: &Response) -> Result<(), SgrError> {
482 let is_incomplete = response
483 .status
484 .as_deref()
485 .is_some_and(|s| s == "incomplete");
486 let is_max_tokens = response
487 .incomplete_details
488 .as_ref()
489 .and_then(|d| d.reason.as_deref())
490 .is_some_and(|r| r == "max_output_tokens");
491
492 if is_incomplete && is_max_tokens {
493 return Err(SgrError::MaxOutputTokens {
494 partial_content: response.output_text(),
495 });
496 }
497 Ok(())
498 }
499
500 fn extract_tool_calls(response: &Response) -> Vec<ToolCall> {
502 response
503 .function_calls()
504 .into_iter()
505 .map(|fc| ToolCall {
506 id: fc.call_id,
507 name: fc.name,
508 arguments: fc.arguments,
509 })
510 .collect()
511 }
512}
513
514#[async_trait::async_trait]
515impl LlmClient for OxideClient {
516 async fn structured_call(
517 &self,
518 messages: &[Message],
519 schema: &Value,
520 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
521 let strict_schema =
524 if schema.get("additionalProperties").and_then(|v| v.as_bool()) == Some(false) {
525 schema.clone()
527 } else {
528 let mut s = schema.clone();
529 openai_oxide::parsing::ensure_strict(&mut s);
530 s
531 };
532
533 let mut req = self.build_request(messages, Some(&strict_schema), None);
536 req = req.store(true);
537
538 let span = tracing::info_span!(
539 "oxide.responses.create",
540 model = %self.model,
541 method = "structured_call",
542 );
543 let _enter = span.enter();
544
545 if std::env::var("SGR_DEBUG_SCHEMA").is_ok()
547 && let Some(ref text_cfg) = req.text
548 {
549 eprintln!(
550 "[sgr] Schema: {}",
551 serde_json::to_string(text_cfg).unwrap_or_default()
552 );
553 }
554
555 let response = self.send_request_auto(req).await?;
556
557 record_otel_usage(&response, &self.model, messages);
559
560 Self::check_truncation(&response)?;
561
562 let raw_text = response.output_text();
563 if std::env::var("SGR_DEBUG").is_ok() {
564 eprintln!(
565 "[sgr] Raw response: {}",
566 &raw_text[..raw_text.len().min(500)]
567 );
568 }
569 let tool_calls = Self::extract_tool_calls(&response);
570 let parsed = serde_json::from_str::<Value>(&raw_text).ok();
571
572 let input_tokens = response
573 .usage
574 .as_ref()
575 .and_then(|u| u.input_tokens)
576 .unwrap_or(0);
577 let cached_tokens = response
578 .usage
579 .as_ref()
580 .and_then(|u| u.input_tokens_details.as_ref())
581 .and_then(|d| d.cached_tokens)
582 .unwrap_or(0);
583 let cache_pct = if input_tokens > 0 {
584 (cached_tokens * 100) / input_tokens
585 } else {
586 0
587 };
588
589 {
590 let output_tokens = response
591 .usage
592 .as_ref()
593 .and_then(|u| u.output_tokens)
594 .unwrap_or(0);
595 if cached_tokens > 0 {
596 eprintln!(
597 " 💰 {}in/{}out (cached: {}, {}%)",
598 input_tokens, output_tokens, cached_tokens, cache_pct
599 );
600 } else {
601 eprintln!(" 💰 {}in/{}out", input_tokens, output_tokens);
602 }
603 }
604
605 Ok((parsed, tool_calls, raw_text))
606 }
607
608 async fn tools_call(
609 &self,
610 messages: &[Message],
611 tools: &[ToolDef],
612 ) -> Result<Vec<ToolCall>, SgrError> {
613 let mut req = self.build_request(messages, None, None);
617 req = req.store(true);
618
619 let response_tools: Vec<ResponseTool> = tools
621 .iter()
622 .map(|t| ResponseTool::Function {
623 name: t.name.clone(),
624 description: if t.description.is_empty() {
625 None
626 } else {
627 Some(t.description.clone())
628 },
629 parameters: Some(t.parameters.clone()),
630 strict: None,
631 })
632 .collect();
633 req = req.tools(response_tools);
634
635 req = req.tool_choice(openai_oxide::types::responses::ResponseToolChoice::Mode(
638 "required".into(),
639 ));
640
641 let response = self.send_request_auto(req).await?;
642
643 record_otel_usage(&response, &self.model, messages);
644 Self::check_truncation(&response)?;
645
646 let input_tokens = response
647 .usage
648 .as_ref()
649 .and_then(|u| u.input_tokens)
650 .unwrap_or(0);
651 let cached_tokens = response
652 .usage
653 .as_ref()
654 .and_then(|u| u.input_tokens_details.as_ref())
655 .and_then(|d| d.cached_tokens)
656 .unwrap_or(0);
657 let cache_pct = if input_tokens > 0 {
658 (cached_tokens * 100) / input_tokens
659 } else {
660 0
661 };
662
663 if cached_tokens > 0 {
664 eprintln!(
665 " 💰 {}in/{}out (cached: {}, {}%)",
666 input_tokens,
667 response
668 .usage
669 .as_ref()
670 .and_then(|u| u.output_tokens)
671 .unwrap_or(0),
672 cached_tokens,
673 cache_pct
674 );
675 } else {
676 eprintln!(
677 " 💰 {}in/{}out",
678 input_tokens,
679 response
680 .usage
681 .as_ref()
682 .and_then(|u| u.output_tokens)
683 .unwrap_or(0)
684 );
685 }
686
687 let calls = Self::extract_tool_calls(&response);
688 Ok(calls)
689 }
690
691 async fn tools_call_stateful(
692 &self,
693 messages: &[Message],
694 tools: &[ToolDef],
695 previous_response_id: Option<&str>,
696 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
697 self.tools_call_stateful_impl(messages, tools, previous_response_id)
698 .await
699 }
700
701 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
702 let mut req = self.build_request(messages, None, None);
703 req = req.store(true);
704
705 let response = self.send_request_auto(req).await?;
706
707 record_otel_usage(&response, &self.model, messages);
708 Self::check_truncation(&response)?;
709
710 let text = response.output_text();
711 if text.is_empty() {
712 return Err(SgrError::EmptyResponse);
713 }
714
715 tracing::info!(
716 model = %response.model,
717 response_id = %response.id,
718 input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
719 output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
720 "oxide.complete"
721 );
722
723 Ok(text)
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730
731 #[test]
732 fn oxide_client_from_config() {
733 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
735 let client = OxideClient::from_config(&config).unwrap();
736 assert_eq!(client.model, "gpt-5.4");
737 }
738
739 #[test]
740 fn build_request_simple() {
741 let config = LlmConfig::with_key("sk-test", "gpt-5.4").temperature(0.5);
742 let client = OxideClient::from_config(&config).unwrap();
743 let messages = vec![Message::system("Be helpful."), Message::user("Hello")];
744 let req = client.build_request(&messages, None, None);
745 assert_eq!(req.model, "gpt-5.4");
746 assert!(req.instructions.is_none());
747 assert!(req.input.is_some());
748 assert_eq!(req.temperature, Some(0.5));
749 }
750
751 #[test]
752 fn build_request_with_schema() {
753 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
754 let client = OxideClient::from_config(&config).unwrap();
755 let schema = serde_json::json!({
756 "type": "object",
757 "properties": {"answer": {"type": "string"}},
758 "required": ["answer"]
759 });
760 let req = client.build_request(&[Message::user("Hi")], Some(&schema), None);
761 assert!(req.text.is_some());
762 }
763
764 #[test]
765 fn build_request_stateless_no_previous_response_id() {
766 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
767 let client = OxideClient::from_config(&config).unwrap();
768
769 let req = client.build_request(&[Message::user("Hi")], None, None);
770 assert!(
771 req.previous_response_id.is_none(),
772 "build_request must be stateless when no explicit ID"
773 );
774 }
775
776 #[test]
777 fn build_request_explicit_chaining() {
778 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
779 let client = OxideClient::from_config(&config).unwrap();
780
781 let req = client.build_request(&[Message::user("Hi")], None, Some("resp_xyz"));
783 assert_eq!(
784 req.previous_response_id.as_deref(),
785 Some("resp_xyz"),
786 "build_request should chain with explicit previous_response_id"
787 );
788 }
789
790 #[test]
791 fn build_request_tool_outputs_chaining() {
792 let config = LlmConfig::with_key("sk-test", "gpt-5.4");
793 let client = OxideClient::from_config(&config).unwrap();
794
795 let messages = vec![Message::tool("call_1", "result data")];
797 let req = client.build_request(&messages, None, Some("resp_123"));
798 assert_eq!(req.previous_response_id.as_deref(), Some("resp_123"));
799
800 let req = client.build_request(&messages, None, None);
802 assert!(
803 req.previous_response_id.is_none(),
804 "build_request must be stateless when no explicit ID"
805 );
806 }
807}