1use crate::client::LlmClient;
17use crate::retry::RetryClient;
18use crate::schema::response_schema_for;
19use crate::tool::ToolDef;
20use crate::types::{LlmConfig, Message, SgrError, ToolCall};
21use schemars::JsonSchema;
22use serde::de::DeserializeOwned;
23use serde_json::Value;
24
25enum Backend {
28 Oxide(RetryClient<crate::oxide_client::OxideClient>),
29 OxideChat(RetryClient<crate::oxide_chat_client::OxideChatClient>),
30 #[cfg(feature = "genai")]
31 Genai(crate::genai_client::GenaiClient),
32 Cli(crate::cli_client::CliClient),
34}
35
36pub struct Llm {
38 inner: Backend,
39}
40
41impl Llm {
42 pub fn new(config: &LlmConfig) -> Self {
47 if config.use_cli {
49 let backend = crate::cli_client::CliBackend::from_model(&config.model)
50 .unwrap_or(crate::cli_client::CliBackend::Claude);
51 let client = crate::cli_client::CliClient::new(backend).with_model(&config.model);
52 tracing::debug!(model = %config.model, backend = "cli", "Llm backend selected");
53 return Self {
54 inner: Backend::Cli(client),
55 };
56 }
57
58 #[cfg(feature = "genai")]
60 if config.use_genai {
61 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (explicit)");
62 return Self {
63 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
64 };
65 }
66
67 #[cfg(feature = "genai")]
69 if config.project_id.is_some() {
70 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected");
71 return Self {
72 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
73 };
74 }
75
76 if config.use_chat_api
78 && let Ok(client) = crate::oxide_chat_client::OxideChatClient::from_config(config)
79 {
80 tracing::debug!(model = %config.model, backend = "oxide-chat", "Llm backend selected (Chat Completions)");
81 return Self {
82 inner: Backend::OxideChat(RetryClient::new(client)),
83 };
84 }
85
86 if let Ok(client) = crate::oxide_client::OxideClient::from_config(config) {
87 tracing::debug!(model = %config.model, backend = "oxide", "Llm backend selected");
88 Self {
89 inner: Backend::Oxide(RetryClient::new(client)),
90 }
91 } else {
92 #[cfg(feature = "genai")]
93 {
94 tracing::debug!(model = %config.model, backend = "genai", "Llm backend selected (oxide fallback)");
95 return Self {
96 inner: Backend::Genai(crate::genai_client::GenaiClient::from_config(config)),
97 };
98 }
99 #[cfg(not(feature = "genai"))]
100 panic!("OxideClient::from_config failed and genai feature not enabled");
101 }
102 }
103
104 fn client(&self) -> &dyn LlmClient {
106 match &self.inner {
107 Backend::Oxide(c) => c,
108 Backend::OxideChat(c) => c,
109 #[cfg(feature = "genai")]
110 Backend::Genai(c) => c,
111 Backend::Cli(c) => c,
112 }
113 }
114
115 pub async fn connect_ws(&self) -> Result<(), SgrError> {
118 #[cfg(feature = "oxide-ws")]
119 if let Backend::Oxide(c) = &self.inner {
120 return c.inner().connect_ws().await;
121 }
122 Ok(())
123 }
124
125 pub async fn stream_complete<F>(
127 &self,
128 messages: &[Message],
129 mut on_token: F,
130 ) -> Result<String, SgrError>
131 where
132 F: FnMut(&str),
133 {
134 match &self.inner {
135 #[cfg(feature = "genai")]
136 Backend::Genai(c) => c.stream_complete(messages, on_token).await,
137 Backend::Oxide(_) | Backend::OxideChat(_) | Backend::Cli(_) => {
138 let text = self.generate(messages).await?;
141 on_token(&text);
142 Ok(text)
143 }
144 }
145 }
146
147 pub async fn generate(&self, messages: &[Message]) -> Result<String, SgrError> {
149 self.client().complete(messages).await
150 }
151
152 pub async fn tools_call_stateful(
155 &self,
156 messages: &[Message],
157 tools: &[ToolDef],
158 previous_response_id: Option<&str>,
159 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
160 self.client()
161 .tools_call_stateful(messages, tools, previous_response_id)
162 .await
163 }
164
165 pub async fn structured<T: JsonSchema + DeserializeOwned>(
167 &self,
168 messages: &[Message],
169 ) -> Result<T, SgrError> {
170 let schema = response_schema_for::<T>();
171 let (parsed, _tool_calls, raw_text) =
172 self.client().structured_call(messages, &schema).await?;
173 match parsed {
174 Some(value) => serde_json::from_value::<T>(value)
175 .map_err(|e| SgrError::Schema(format!("Parse error: {e}\nRaw: {raw_text}"))),
176 None => Err(SgrError::EmptyResponse),
177 }
178 }
179
180 pub fn backend_name(&self) -> &'static str {
182 match &self.inner {
183 Backend::Oxide(_) => "oxide",
184 Backend::OxideChat(_) => "oxide-chat",
185 #[cfg(feature = "genai")]
186 Backend::Genai(_) => "genai",
187 Backend::Cli(_) => "cli",
188 }
189 }
190}
191
192#[async_trait::async_trait]
193impl LlmClient for Llm {
194 async fn structured_call(
195 &self,
196 messages: &[Message],
197 schema: &Value,
198 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
199 self.client().structured_call(messages, schema).await
200 }
201
202 async fn tools_call(
203 &self,
204 messages: &[Message],
205 tools: &[ToolDef],
206 ) -> Result<Vec<ToolCall>, SgrError> {
207 self.client().tools_call(messages, tools).await
208 }
209
210 async fn tools_call_stateful(
211 &self,
212 messages: &[Message],
213 tools: &[ToolDef],
214 previous_response_id: Option<&str>,
215 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
216 self.client()
217 .tools_call_stateful(messages, tools, previous_response_id)
218 .await
219 }
220
221 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
222 self.client().complete(messages).await
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn llm_from_auto_config() {
232 let config = LlmConfig::endpoint("sk-test-dummy", "https://api.openai.com/v1", "gpt-5.4");
234 let llm = Llm::new(&config);
235 assert_eq!(llm.backend_name(), "oxide");
236 }
237
238 #[test]
239 fn llm_custom_endpoint_uses_oxide() {
240 let config = LlmConfig::endpoint("sk-test", "https://openrouter.ai/api/v1", "gpt-5.4");
241 let llm = Llm::new(&config);
242 assert_eq!(llm.backend_name(), "oxide");
243 }
244
245 #[test]
246 fn llm_config_serde_roundtrip() {
247 let config = LlmConfig::endpoint("key", "https://example.com/v1", "model")
248 .temperature(0.9)
249 .max_tokens(1000);
250 let json = serde_json::to_string(&config).unwrap();
251 let back: LlmConfig = serde_json::from_str(&json).unwrap();
252 assert_eq!(back.model, "model");
253 assert_eq!(back.api_key.as_deref(), Some("key"));
254 assert_eq!(back.base_url.as_deref(), Some("https://example.com/v1"));
255 assert_eq!(back.temp, 0.9);
256 assert_eq!(back.max_tokens, Some(1000));
257 }
258
259 #[test]
260 fn llm_config_auto_minimal_json() {
261 let json = r#"{"model": "gpt-4o"}"#;
262 let config: LlmConfig = serde_json::from_str(json).unwrap();
263 assert_eq!(config.model, "gpt-4o");
264 assert!(config.api_key.is_none());
265 assert_eq!(config.temp, 0.7);
266 }
267}