Skip to main content

sgr_agent/
oxide_client.rs

1//! OxideClient — LlmClient adapter for `openai-oxide` crate.
2//!
3//! Uses the **Responses API** (`POST /responses`) instead of Chat Completions.
4//! With `oxide-ws` feature: persistent WebSocket connection for -20-25% latency.
5//! Supports: structured output (json_schema), function calling, multi-turn (previous_response_id).
6
7use crate::client::LlmClient;
8use crate::tool::ToolDef;
9use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
10use openai_oxide::OpenAI;
11use openai_oxide::config::ClientConfig;
12use openai_oxide::types::responses::*;
13use serde_json::Value;
14
15/// Record OTEL attributes on the current span for Phoenix/OpenInference.
16#[cfg(feature = "telemetry")]
17fn record_otel_usage(response: &Response, model: &str) {
18    use opentelemetry::trace::{Span, Tracer, TracerProvider};
19
20    let provider = opentelemetry::global::tracer_provider();
21    let tracer = provider.tracer("sgr-agent");
22    let mut otel_span = tracer.start("oxide.responses.api");
23
24    let pt = response
25        .usage
26        .as_ref()
27        .and_then(|u| u.input_tokens)
28        .unwrap_or(0);
29    let ct = response
30        .usage
31        .as_ref()
32        .and_then(|u| u.output_tokens)
33        .unwrap_or(0);
34
35    let cached = response
36        .usage
37        .as_ref()
38        .and_then(|u| u.input_tokens_details.as_ref())
39        .and_then(|d| d.cached_tokens)
40        .unwrap_or(0);
41
42    // OpenInference conventions (Phoenix)
43    otel_span.set_attribute(opentelemetry::KeyValue::new(
44        "openinference.span.kind",
45        "LLM",
46    ));
47    otel_span.set_attribute(opentelemetry::KeyValue::new(
48        "llm.model_name",
49        model.to_string(),
50    ));
51    otel_span.set_attribute(opentelemetry::KeyValue::new("llm.token_count.prompt", pt));
52    otel_span.set_attribute(opentelemetry::KeyValue::new(
53        "llm.token_count.completion",
54        ct,
55    ));
56    otel_span.set_attribute(opentelemetry::KeyValue::new(
57        "llm.token_count.total",
58        pt + ct,
59    ));
60    otel_span.set_attribute(opentelemetry::KeyValue::new(
61        "llm.token_count.cached",
62        cached,
63    ));
64
65    // GenAI conventions (LangSmith)
66    otel_span.set_attribute(opentelemetry::KeyValue::new("langsmith.span.kind", "LLM"));
67    otel_span.set_attribute(opentelemetry::KeyValue::new(
68        "gen_ai.request.model",
69        model.to_string(),
70    ));
71    otel_span.set_attribute(opentelemetry::KeyValue::new(
72        "gen_ai.response.model",
73        response.model.clone(),
74    ));
75    otel_span.set_attribute(opentelemetry::KeyValue::new(
76        "gen_ai.usage.prompt_tokens",
77        pt,
78    ));
79    otel_span.set_attribute(opentelemetry::KeyValue::new(
80        "gen_ai.usage.completion_tokens",
81        ct,
82    ));
83    otel_span.set_attribute(opentelemetry::KeyValue::new(
84        "gen_ai.usage.cached_tokens",
85        cached,
86    ));
87
88    // Output text
89    let output = response.output_text();
90    if !output.is_empty() {
91        otel_span.set_attribute(opentelemetry::KeyValue::new(
92            "gen_ai.completion.0.content",
93            if output.len() > 4000 {
94                format!("{}...", &output[..4000])
95            } else {
96                output
97            },
98        ));
99    }
100
101    otel_span.end();
102}
103
104#[cfg(not(feature = "telemetry"))]
105fn record_otel_usage(_response: &Response, _model: &str) {}
106
107/// LlmClient backed by openai-oxide (Responses API).
108///
109/// With `oxide-ws` feature: call `connect_ws()` to upgrade to WebSocket mode.
110/// All subsequent calls go over persistent wss:// connection (-20-25% latency).
111pub struct OxideClient {
112    client: OpenAI,
113    pub(crate) model: String,
114    pub(crate) temperature: Option<f64>,
115    pub(crate) max_tokens: Option<u32>,
116    /// WebSocket session (when oxide-ws feature is enabled and connected).
117    #[cfg(feature = "oxide-ws")]
118    ws: tokio::sync::Mutex<Option<openai_oxide::websocket::WsSession>>,
119    /// Lazy WS: true = connect on first request, false = HTTP only.
120    #[cfg(feature = "oxide-ws")]
121    ws_enabled: std::sync::atomic::AtomicBool,
122}
123
124impl OxideClient {
125    /// Create from LlmConfig.
126    pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
127        let api_key = config
128            .api_key
129            .clone()
130            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
131            .unwrap_or_else(|| {
132                if config.base_url.is_some() {
133                    "dummy_key".into()
134                } else {
135                    "".into()
136                }
137            });
138
139        if api_key.is_empty() {
140            return Err(SgrError::Schema("No API key for oxide client".into()));
141        }
142
143        let mut client_config = ClientConfig::new(&api_key);
144        if let Some(ref url) = config.base_url {
145            client_config = client_config.base_url(url.clone());
146        }
147
148        Ok(Self {
149            client: OpenAI::with_config(client_config),
150            model: config.model.clone(),
151            temperature: Some(config.temp),
152            max_tokens: config.max_tokens,
153            #[cfg(feature = "oxide-ws")]
154            ws: tokio::sync::Mutex::new(None),
155            #[cfg(feature = "oxide-ws")]
156            ws_enabled: std::sync::atomic::AtomicBool::new(false),
157        })
158    }
159
160    /// Enable WebSocket mode — lazy connect on first request.
161    ///
162    /// Does NOT open a connection immediately. The WS connection is established
163    /// on the first `send_request_auto()` call, eliminating idle timeout issues.
164    /// Falls back to HTTP automatically if WS fails.
165    ///
166    /// Requires `oxide-ws` feature.
167    #[cfg(feature = "oxide-ws")]
168    pub async fn connect_ws(&self) -> Result<(), SgrError> {
169        self.ws_enabled
170            .store(true, std::sync::atomic::Ordering::Relaxed);
171        tracing::info!(model = %self.model, "oxide WebSocket enabled (lazy connect)");
172        Ok(())
173    }
174
175    /// Send request — lazy WS connect + send, falls back to HTTP on any WS error.
176    async fn send_request_auto(
177        &self,
178        request: ResponseCreateRequest,
179    ) -> Result<Response, SgrError> {
180        #[cfg(feature = "oxide-ws")]
181        if self.ws_enabled.load(std::sync::atomic::Ordering::Relaxed) {
182            let mut ws_guard = self.ws.lock().await;
183
184            // Lazy connect
185            if ws_guard.is_none() {
186                match self.client.ws_session().await {
187                    Ok(session) => {
188                        tracing::info!(model = %self.model, "oxide WS connected (lazy)");
189                        *ws_guard = Some(session);
190                    }
191                    Err(e) => {
192                        tracing::warn!("oxide WS connect failed, using HTTP: {e}");
193                        self.ws_enabled
194                            .store(false, std::sync::atomic::Ordering::Relaxed);
195                    }
196                }
197            }
198
199            if let Some(ref mut session) = *ws_guard {
200                match session.send(request.clone()).await {
201                    Ok(response) => return Ok(response),
202                    Err(e) => {
203                        tracing::warn!("oxide WS send failed, falling back to HTTP: {e}");
204                        *ws_guard = None;
205                    }
206                }
207            }
208        }
209
210        // HTTP fallback
211        self.client
212            .responses()
213            .create(request)
214            .await
215            .map_err(|e| SgrError::Api {
216                status: 0,
217                body: e.to_string(),
218            })
219    }
220
221    /// Build a ResponseCreateRequest from messages + optional schema + optional chaining.
222    ///
223    /// - `previous_response_id` is None: full history as Messages format
224    /// - `previous_response_id` is Some: Items format with function_call_output
225    ///   (required for chaining after tool calls via Responses API)
226    /// - `schema`: optional structured output json_schema config
227    pub(crate) fn build_request(
228        &self,
229        messages: &[Message],
230        schema: Option<&Value>,
231        previous_response_id: Option<&str>,
232    ) -> ResponseCreateRequest {
233        if previous_response_id.is_some() {
234            // Items format: messages + function_call_output items.
235            // HTTP API accepts Messages format (without type), but WS API requires it.
236            // Using Items consistently ensures both HTTP and WS work.
237            return self.build_request_items(messages, previous_response_id);
238        }
239
240        // Messages format: standard request with optional structured output
241        let mut input_items = Vec::new();
242
243        for msg in messages {
244            match msg.role {
245                Role::System => {
246                    input_items.push(ResponseInputItem {
247                        role: openai_oxide::types::common::Role::System,
248                        content: Value::String(msg.content.clone()),
249                    });
250                }
251                Role::User => {
252                    input_items.push(ResponseInputItem {
253                        role: openai_oxide::types::common::Role::User,
254                        content: Value::String(msg.content.clone()),
255                    });
256                }
257                Role::Assistant => {
258                    // Include tool call info so structured_call context shows
259                    // what action was taken.
260                    let mut content = msg.content.clone();
261                    if !msg.tool_calls.is_empty() {
262                        for tc in &msg.tool_calls {
263                            let args = tc.arguments.to_string();
264                            let preview = if args.len() > 200 {
265                                &args[..200]
266                            } else {
267                                &args
268                            };
269                            content.push_str(&format!("\n→ {}({})", tc.name, preview));
270                        }
271                    }
272                    input_items.push(ResponseInputItem {
273                        role: openai_oxide::types::common::Role::Assistant,
274                        content: Value::String(content),
275                    });
276                }
277                Role::Tool => {
278                    // Clean format — no "[Tool result for ...]" prefix.
279                    // The assistant message above already has the action name.
280                    input_items.push(ResponseInputItem {
281                        role: openai_oxide::types::common::Role::User,
282                        content: Value::String(msg.content.clone()),
283                    });
284                }
285            }
286        }
287
288        let mut req = ResponseCreateRequest::new(&self.model);
289
290        // Set input — prefer simple text when single user message (fewer tokens)
291        if input_items.len() == 1 && input_items[0].role == openai_oxide::types::common::Role::User
292        {
293            if let Some(text) = input_items[0].content.as_str() {
294                req = req.input(text);
295            } else {
296                req.input = Some(ResponseInput::Messages(input_items));
297            }
298        } else if !input_items.is_empty() {
299            req.input = Some(ResponseInput::Messages(input_items));
300        }
301
302        // Temperature — skip default to reduce payload
303        if let Some(temp) = self.temperature
304            && (temp - 1.0).abs() > f64::EPSILON
305        {
306            req = req.temperature(temp);
307        }
308
309        // Max tokens
310        if let Some(max) = self.max_tokens {
311            req = req.max_output_tokens(max as i64);
312        }
313
314        // Structured output via json_schema
315        if let Some(schema_val) = schema {
316            req = req.text(ResponseTextConfig {
317                format: Some(ResponseTextFormat::JsonSchema {
318                    name: "sgr_response".into(),
319                    description: None,
320                    schema: Some(schema_val.clone()),
321                    strict: Some(true),
322                }),
323                verbosity: None,
324            });
325        }
326
327        req
328    }
329
330    /// Build Items-format request for stateful chaining with previous_response_id.
331    fn build_request_items(
332        &self,
333        messages: &[Message],
334        previous_response_id: Option<&str>,
335    ) -> ResponseCreateRequest {
336        use openai_oxide::types::responses::ResponseInput;
337
338        let mut items: Vec<Value> = Vec::new();
339
340        for msg in messages {
341            match msg.role {
342                Role::Tool => {
343                    if let Some(ref call_id) = msg.tool_call_id {
344                        items.push(serde_json::json!({
345                            "type": "function_call_output",
346                            "call_id": call_id,
347                            "output": msg.content
348                        }));
349                    }
350                }
351                Role::System => {
352                    items.push(serde_json::json!({
353                        "type": "message",
354                        "role": "system",
355                        "content": msg.content
356                    }));
357                }
358                Role::User => {
359                    items.push(serde_json::json!({
360                        "type": "message",
361                        "role": "user",
362                        "content": msg.content
363                    }));
364                }
365                Role::Assistant => {
366                    items.push(serde_json::json!({
367                        "type": "message",
368                        "role": "assistant",
369                        "content": msg.content
370                    }));
371                }
372            }
373        }
374
375        let mut req = ResponseCreateRequest::new(&self.model);
376        if !items.is_empty() {
377            req.input = Some(ResponseInput::Items(items));
378        }
379
380        // Temperature
381        if let Some(temp) = self.temperature
382            && (temp - 1.0).abs() > f64::EPSILON
383        {
384            req = req.temperature(temp);
385        }
386        if let Some(max) = self.max_tokens {
387            req = req.max_output_tokens(max as i64);
388        }
389
390        if let Some(prev_id) = previous_response_id {
391            req = req.previous_response_id(prev_id);
392        }
393
394        req
395    }
396
397    /// Function calling with explicit previous_response_id.
398    /// Returns tool calls + new response_id for chaining.
399    ///
400    /// Always sets `store(true)` so responses can be referenced by subsequent calls.
401    /// When `previous_response_id` is provided, only delta messages need to be sent
402    /// (server has full history from previous stored response).
403    ///
404    /// Tool messages (role=Tool with tool_call_id) are converted to Responses API
405    /// `function_call_output` items — required for chaining with previous_response_id.
406    ///
407    /// This method does NOT use the Mutex — all state is explicit via parameters/return.
408    async fn tools_call_stateful_impl(
409        &self,
410        messages: &[Message],
411        tools: &[ToolDef],
412        previous_response_id: Option<&str>,
413    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
414        let mut req = self.build_request(messages, None, previous_response_id);
415        // Always store so next call can chain via previous_response_id
416        req = req.store(true);
417
418        // Convert ToolDefs to ResponseTools with strict mode.
419        // strict: true guarantees LLM output matches schema exactly (no parse errors).
420        // oxide ensure_strict() handles: additionalProperties, all-required,
421        // nullable→anyOf, allOf inlining, oneOf→anyOf.
422        let response_tools: Vec<ResponseTool> = tools
423            .iter()
424            .map(|t| {
425                let mut params = t.parameters.clone();
426                openai_oxide::parsing::ensure_strict(&mut params);
427                ResponseTool::Function {
428                    name: t.name.clone(),
429                    description: if t.description.is_empty() {
430                        None
431                    } else {
432                        Some(t.description.clone())
433                    },
434                    parameters: Some(params),
435                    strict: Some(true),
436                }
437            })
438            .collect();
439        req = req.tools(response_tools);
440
441        let response = self.send_request_auto(req).await?;
442
443        let response_id = response.id.clone();
444        // No Mutex save — caller owns the response_id
445        record_otel_usage(&response, &self.model);
446
447        let input_tokens = response
448            .usage
449            .as_ref()
450            .and_then(|u| u.input_tokens)
451            .unwrap_or(0);
452        let cached_tokens = response
453            .usage
454            .as_ref()
455            .and_then(|u| u.input_tokens_details.as_ref())
456            .and_then(|d| d.cached_tokens)
457            .unwrap_or(0);
458
459        let chained = previous_response_id.is_some();
460        let cache_pct = if input_tokens > 0 {
461            (cached_tokens * 100) / input_tokens
462        } else {
463            0
464        };
465
466        tracing::info!(
467            model = %response.model,
468            response_id = %response_id,
469            input_tokens,
470            cached_tokens,
471            cache_pct,
472            chained,
473            "oxide.tools_call_stateful"
474        );
475
476        if std::env::var("SGR_DEBUG").is_ok() {
477            eprintln!(
478                "[sgr] stateful: input={} cached={} ({}%) chained={}",
479                input_tokens, cached_tokens, cache_pct, chained
480            );
481        }
482
483        Ok((Self::extract_tool_calls(&response), Some(response_id)))
484    }
485
486    /// Extract tool calls from Responses API output items.
487    fn extract_tool_calls(response: &Response) -> Vec<ToolCall> {
488        response
489            .function_calls()
490            .into_iter()
491            .map(|fc| ToolCall {
492                id: fc.call_id,
493                name: fc.name,
494                arguments: fc.arguments,
495            })
496            .collect()
497    }
498}
499
500#[async_trait::async_trait]
501impl LlmClient for OxideClient {
502    async fn structured_call(
503        &self,
504        messages: &[Message],
505        schema: &Value,
506    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
507        // Make schema OpenAI-strict — UNLESS it's already strict
508        // (build_action_schema produces pre-strict schemas that ensure_strict would break)
509        let strict_schema =
510            if schema.get("additionalProperties").and_then(|v| v.as_bool()) == Some(false) {
511                // Already strict-compatible (e.g., from build_action_schema)
512                schema.clone()
513            } else {
514                let mut s = schema.clone();
515                openai_oxide::parsing::ensure_strict(&mut s);
516                s
517            };
518
519        // Stateless — build request with full message history, no chaining
520        let req = self.build_request(messages, Some(&strict_schema), None);
521
522        let span = tracing::info_span!(
523            "oxide.responses.create",
524            model = %self.model,
525            method = "structured_call",
526        );
527        let _enter = span.enter();
528
529        // Debug: dump schema on first call
530        if std::env::var("SGR_DEBUG_SCHEMA").is_ok()
531            && let Some(ref text_cfg) = req.text
532        {
533            eprintln!(
534                "[sgr] Schema: {}",
535                serde_json::to_string(text_cfg).unwrap_or_default()
536            );
537        }
538
539        let response = self.send_request_auto(req).await?;
540
541        // No Mutex save — structured_call is stateless
542        record_otel_usage(&response, &self.model);
543
544        let raw_text = response.output_text();
545        if std::env::var("SGR_DEBUG").is_ok() {
546            eprintln!(
547                "[sgr] Raw response: {}",
548                &raw_text[..raw_text.len().min(500)]
549            );
550        }
551        let tool_calls = Self::extract_tool_calls(&response);
552        let parsed = serde_json::from_str::<Value>(&raw_text).ok();
553
554        tracing::info!(
555            model = %response.model,
556            response_id = %response.id,
557            input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
558            output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
559            "oxide.structured_call"
560        );
561
562        Ok((parsed, tool_calls, raw_text))
563    }
564
565    async fn tools_call(
566        &self,
567        messages: &[Message],
568        tools: &[ToolDef],
569    ) -> Result<Vec<ToolCall>, SgrError> {
570        // Stateless — no previous_response_id, full message history
571        let mut req = self.build_request(messages, None, None);
572
573        // Convert ToolDefs to ResponseTools — no strict mode (faster server-side)
574        let response_tools: Vec<ResponseTool> = tools
575            .iter()
576            .map(|t| ResponseTool::Function {
577                name: t.name.clone(),
578                description: if t.description.is_empty() {
579                    None
580                } else {
581                    Some(t.description.clone())
582                },
583                parameters: Some(t.parameters.clone()),
584                strict: None,
585            })
586            .collect();
587        req = req.tools(response_tools);
588
589        // Force model to always call a tool — prevents text-only responses
590        // that lose answer content (tools_call only returns Vec<ToolCall>).
591        req = req.tool_choice(openai_oxide::types::responses::ResponseToolChoice::Mode(
592            "required".into(),
593        ));
594
595        let response = self.send_request_auto(req).await?;
596
597        // No Mutex save — tools_call is stateless
598        record_otel_usage(&response, &self.model);
599
600        tracing::info!(
601            model = %response.model,
602            response_id = %response.id,
603            "oxide.tools_call"
604        );
605
606        let calls = Self::extract_tool_calls(&response);
607        Ok(calls)
608    }
609
610    async fn tools_call_stateful(
611        &self,
612        messages: &[Message],
613        tools: &[ToolDef],
614        previous_response_id: Option<&str>,
615    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
616        self.tools_call_stateful_impl(messages, tools, previous_response_id)
617            .await
618    }
619
620    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
621        let req = self.build_request(messages, None, None);
622
623        let response = self.send_request_auto(req).await?;
624
625        // No Mutex save — complete is stateless
626        record_otel_usage(&response, &self.model);
627
628        let text = response.output_text();
629        if text.is_empty() {
630            return Err(SgrError::EmptyResponse);
631        }
632
633        tracing::info!(
634            model = %response.model,
635            response_id = %response.id,
636            input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
637            output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
638            "oxide.complete"
639        );
640
641        Ok(text)
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn oxide_client_from_config() {
651        // Just test construction doesn't panic
652        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
653        let client = OxideClient::from_config(&config).unwrap();
654        assert_eq!(client.model, "gpt-5.4");
655    }
656
657    #[test]
658    fn build_request_simple() {
659        let config = LlmConfig::with_key("sk-test", "gpt-5.4").temperature(0.5);
660        let client = OxideClient::from_config(&config).unwrap();
661        let messages = vec![Message::system("Be helpful."), Message::user("Hello")];
662        let req = client.build_request(&messages, None, None);
663        assert_eq!(req.model, "gpt-5.4");
664        assert!(req.instructions.is_none());
665        assert!(req.input.is_some());
666        assert_eq!(req.temperature, Some(0.5));
667    }
668
669    #[test]
670    fn build_request_with_schema() {
671        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
672        let client = OxideClient::from_config(&config).unwrap();
673        let schema = serde_json::json!({
674            "type": "object",
675            "properties": {"answer": {"type": "string"}},
676            "required": ["answer"]
677        });
678        let req = client.build_request(&[Message::user("Hi")], Some(&schema), None);
679        assert!(req.text.is_some());
680    }
681
682    #[test]
683    fn build_request_stateless_no_previous_response_id() {
684        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
685        let client = OxideClient::from_config(&config).unwrap();
686
687        let req = client.build_request(&[Message::user("Hi")], None, None);
688        assert!(
689            req.previous_response_id.is_none(),
690            "build_request must be stateless when no explicit ID"
691        );
692    }
693
694    #[test]
695    fn build_request_explicit_chaining() {
696        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
697        let client = OxideClient::from_config(&config).unwrap();
698
699        // With previous_response_id — uses Items format for chaining
700        let req = client.build_request(&[Message::user("Hi")], None, Some("resp_xyz"));
701        assert_eq!(
702            req.previous_response_id.as_deref(),
703            Some("resp_xyz"),
704            "build_request should chain with explicit previous_response_id"
705        );
706    }
707
708    #[test]
709    fn build_request_tool_outputs_chaining() {
710        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
711        let client = OxideClient::from_config(&config).unwrap();
712
713        // With previous_response_id — tool outputs as function_call_output items
714        let messages = vec![Message::tool("call_1", "result data")];
715        let req = client.build_request(&messages, None, Some("resp_123"));
716        assert_eq!(req.previous_response_id.as_deref(), Some("resp_123"));
717
718        // Without previous_response_id
719        let req = client.build_request(&messages, None, None);
720        assert!(
721            req.previous_response_id.is_none(),
722            "build_request must be stateless when no explicit ID"
723        );
724    }
725}