Skip to main content

sgr_agent/
llm.rs

1//! Llm — provider-agnostic LLM client.
2//!
3//! Public API: `LlmConfig` + `Llm`. No provider-specific types leak.
4//!
5//! Backend selection:
6//! - oxide (openai-oxide): primary, Responses API, works with OpenAI + OpenRouter + compatible
7//! - genai (optional): fallback for Vertex AI (project_id set)
8//!
9//! ```no_run
10//! use sgr_agent::{Llm, LlmConfig};
11//!
12//! let llm = Llm::new(&LlmConfig::auto("gpt-5.4"));
13//! let llm = Llm::new(&LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"));
14//! ```
15
16use crate::client::LlmClient;
17use crate::retry::RetryClient;
18use crate::schema::response_schema_for;
19use crate::tool::ToolDef;
20use crate::types::{LlmConfig, Message, SgrError, ToolCall};
21use schemars::JsonSchema;
22use serde::de::DeserializeOwned;
23use serde_json::Value;
24
25/// Backend dispatch — resolved at construction time.
26/// All network backends wrapped in RetryClient for automatic retry on transient errors.
27enum Backend {
28    Oxide(RetryClient<crate::oxide_client::OxideClient>),
29    OxideChat(RetryClient<crate::oxide_chat_client::OxideChatClient>),
30    #[cfg(feature = "genai")]
31    Genai(crate::genai_client::GenaiClient),
32    /// CLI subprocess (claude -p / gemini -p / codex exec).
33    Cli(crate::cli_client::CliClient),
34}
35
36/// Provider-agnostic LLM client. Construct via `Llm::new(&LlmConfig)`.
37pub struct Llm {
38    inner: Backend,
39}
40
41impl Llm {
42    /// Create from config. Backend auto-selected:
43    /// - genai when explicitly requested (`use_genai`) or for Vertex AI (project_id set)
44    /// - oxide-chat for Chat Completions compat endpoints
45    /// - oxide for all other models (primary)
46    pub fn new(config: &LlmConfig) -> Self {
47        // CLI subprocess backend (claude -p / gemini -p / codex exec)
48        if config.use_cli {
49            let backend = crate::cli_client::CliBackend::from_model(&config.model)
50                .unwrap_or(crate::cli_client::CliBackend::Claude);
51            let client = crate::cli_client::CliClient::new(backend).with_model(&config.model);
52            tracing::debug!(model = %config.model, backend = "cli", "Llm backend selected");
53            return Self {
54                inner: Backend::Cli(client),
55            };
56        }
57
58        // Explicit genai backend (e.g. Anthropic native API)
59        #[cfg(feature = "genai")]
60        if config.use_genai {
61            tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (explicit)");
62            return Self {
63                inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
64            };
65        }
66
67        // Vertex AI needs genai (gcloud ADC auth)
68        #[cfg(feature = "genai")]
69        if config.project_id.is_some() {
70            tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected");
71            return Self {
72                inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
73            };
74        }
75
76        // Chat Completions mode for compat endpoints (Cloudflare, OpenRouter compat, etc.)
77        if config.use_chat_api
78            && let Ok(client) = crate::oxide_chat_client::OxideChatClient::from_config(config)
79        {
80            tracing::debug!(model = %config.model, backend = "oxide-chat", "Llm backend selected (Chat Completions)");
81            return Self {
82                inner: Backend::OxideChat(RetryClient::new(client)),
83            };
84        }
85
86        if let Ok(client) = crate::oxide_client::OxideClient::from_config(config) {
87            tracing::debug!(model = %config.model, backend = "oxide", "Llm backend selected");
88            Self {
89                inner: Backend::Oxide(RetryClient::new(client)),
90            }
91        } else {
92            #[cfg(feature = "genai")]
93            {
94                tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (oxide fallback)");
95                return Self {
96                    inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
97                };
98            }
99            #[cfg(not(feature = "genai"))]
100            panic!("OxideClient::from_config failed and genai feature not enabled");
101        }
102    }
103
104    /// Get a reference to the inner LlmClient.
105    fn client(&self) -> &dyn LlmClient {
106        match &self.inner {
107            Backend::Oxide(c) => c,
108            Backend::OxideChat(c) => c,
109            #[cfg(feature = "genai")]
110            Backend::Genai(c) => c,
111            Backend::Cli(c) => c,
112        }
113    }
114
115    /// Upgrade to WebSocket mode for lower latency (oxide backend only).
116    /// No-op for genai.
117    pub async fn connect_ws(&self) -> Result<(), SgrError> {
118        #[cfg(feature = "oxide-ws")]
119        if let Backend::Oxide(c) = &self.inner {
120            return c.inner().connect_ws().await;
121        }
122        Ok(())
123    }
124
125    /// Stream text completion, calling `on_token` for each chunk.
126    pub async fn stream_complete<F>(
127        &self,
128        messages: &[Message],
129        mut on_token: F,
130    ) -> Result<String, SgrError>
131    where
132        F: FnMut(&str),
133    {
134        match &self.inner {
135            #[cfg(feature = "genai")]
136            Backend::Genai(c) => c.stream_complete(messages, on_token).await,
137            Backend::Oxide(_) | Backend::OxideChat(_) | Backend::Cli(_) => {
138                // Non-streaming backends — generate full text,
139                // then invoke on_token so callers (e.g. TTS, TUI) get the content.
140                let text = self.generate(messages).await?;
141                on_token(&text);
142                Ok(text)
143            }
144        }
145    }
146
147    /// Non-streaming text completion.
148    pub async fn generate(&self, messages: &[Message]) -> Result<String, SgrError> {
149        self.client().complete(messages).await
150    }
151
152    /// Function calling with stateful session support (Responses API).
153    /// Delegates to the trait method — each backend implements its own version.
154    pub async fn tools_call_stateful(
155        &self,
156        messages: &[Message],
157        tools: &[ToolDef],
158        previous_response_id: Option<&str>,
159    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
160        self.client()
161            .tools_call_stateful(messages, tools, previous_response_id)
162            .await
163    }
164
165    /// Structured output — generates JSON schema from `T`, parses result.
166    pub async fn structured<T: JsonSchema + DeserializeOwned>(
167        &self,
168        messages: &[Message],
169    ) -> Result<T, SgrError> {
170        let schema = response_schema_for::<T>();
171        let (parsed, _tool_calls, raw_text) =
172            self.client().structured_call(messages, &schema).await?;
173        match parsed {
174            Some(value) => serde_json::from_value::<T>(value)
175                .map_err(|e| SgrError::Schema(format!("Parse error: {e}\nRaw: {raw_text}"))),
176            None => Err(SgrError::EmptyResponse),
177        }
178    }
179
180    /// Which backend is active.
181    pub fn backend_name(&self) -> &'static str {
182        match &self.inner {
183            Backend::Oxide(_) => "oxide",
184            Backend::OxideChat(_) => "oxide-chat",
185            #[cfg(feature = "genai")]
186            Backend::Genai(_) => "genai",
187            Backend::Cli(_) => "cli",
188        }
189    }
190}
191
192#[async_trait::async_trait]
193impl LlmClient for Llm {
194    async fn structured_call(
195        &self,
196        messages: &[Message],
197        schema: &Value,
198    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
199        self.client().structured_call(messages, schema).await
200    }
201
202    async fn tools_call(
203        &self,
204        messages: &[Message],
205        tools: &[ToolDef],
206    ) -> Result<Vec<ToolCall>, SgrError> {
207        self.client().tools_call(messages, tools).await
208    }
209
210    async fn tools_call_stateful(
211        &self,
212        messages: &[Message],
213        tools: &[ToolDef],
214        previous_response_id: Option<&str>,
215    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
216        self.client()
217            .tools_call_stateful(messages, tools, previous_response_id)
218            .await
219    }
220
221    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
222        self.client().complete(messages).await
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn llm_from_auto_config() {
232        // OxideClient::from_config needs an API key — use config-based key
233        let config = LlmConfig::endpoint("sk-test-dummy", "https://api.openai.com/v1", "gpt-5.4");
234        let llm = Llm::new(&config);
235        assert_eq!(llm.backend_name(), "oxide");
236    }
237
238    #[test]
239    fn llm_custom_endpoint_uses_oxide() {
240        let config = LlmConfig::endpoint("sk-test", "https://openrouter.ai/api/v1", "gpt-5.4");
241        let llm = Llm::new(&config);
242        assert_eq!(llm.backend_name(), "oxide");
243    }
244
245    #[test]
246    fn llm_config_serde_roundtrip() {
247        let config = LlmConfig::endpoint("key", "https://example.com/v1", "model")
248            .temperature(0.9)
249            .max_tokens(1000);
250        let json = serde_json::to_string(&config).unwrap();
251        let back: LlmConfig = serde_json::from_str(&json).unwrap();
252        assert_eq!(back.model, "model");
253        assert_eq!(back.api_key.as_deref(), Some("key"));
254        assert_eq!(back.base_url.as_deref(), Some("https://example.com/v1"));
255        assert_eq!(back.temp, 0.9);
256        assert_eq!(back.max_tokens, Some(1000));
257    }
258
259    #[test]
260    fn llm_config_auto_minimal_json() {
261        let json = r#"{"model": "gpt-4o"}"#;
262        let config: LlmConfig = serde_json::from_str(json).unwrap();
263        assert_eq!(config.model, "gpt-4o");
264        assert!(config.api_key.is_none());
265        assert_eq!(config.temp, 0.7);
266    }
267}