1use std::collections::BTreeMap;
21
22use async_trait::async_trait;
23use eventsource_stream::Eventsource;
24use futures::StreamExt;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27
28use crate::error::ProviderError;
29use crate::message::{Content, Message, Role, StopReason, Usage};
30use crate::provider::{LlmProvider, Request, Response};
31use crate::stream::{ProviderEventStream, StreamEvent};
32
33const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
34
35pub struct OpenAICompatible {
36 api_key: String,
37 base_url: String,
38 client: reqwest::Client,
39}
40
41impl OpenAICompatible {
42 pub fn new(api_key: impl Into<String>) -> Self {
43 Self {
44 api_key: api_key.into(),
45 base_url: DEFAULT_BASE_URL.to_string(),
46 client: reqwest::Client::new(),
47 }
48 }
49
50 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
57 self.base_url = base_url.into();
58 self
59 }
60
61 pub fn from_env() -> Self {
63 let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY env var is required");
64 Self::new(api_key)
65 }
66}
67
68#[async_trait]
69impl LlmProvider for OpenAICompatible {
70 async fn stream(&self, request: Request) -> Result<ProviderEventStream, ProviderError> {
71 let mut body = build_request_body(&request);
72 body.stream = true;
73 body.stream_options = Some(StreamOptions {
74 include_usage: true,
75 });
76 let url = format!("{}/chat/completions", self.base_url);
77
78 let response = self
79 .client
80 .post(&url)
81 .bearer_auth(&self.api_key)
82 .header("content-type", "application/json")
83 .header("accept", "text/event-stream")
84 .json(&body)
85 .send()
86 .await?;
87
88 let status = response.status().as_u16();
89
90 if status >= 400 {
91 let retry_after_ms = parse_retry_after(response.headers());
92 let text = response.text().await.unwrap_or_default();
93 return Err(classify_error(status, text, retry_after_ms));
94 }
95
96 let event_stream = response.bytes_stream().eventsource();
97 Ok(Box::pin(openai_event_stream(event_stream)))
98 }
99
100 async fn complete(&self, request: Request) -> Result<Response, ProviderError> {
101 let body = build_request_body(&request);
102 let url = format!("{}/chat/completions", self.base_url);
103
104 let response = self
105 .client
106 .post(&url)
107 .bearer_auth(&self.api_key)
108 .header("content-type", "application/json")
109 .json(&body)
110 .send()
111 .await?;
112
113 let status = response.status().as_u16();
114
115 if status >= 400 {
116 let retry_after_ms = parse_retry_after(response.headers());
117 let text = response.text().await.unwrap_or_default();
118 return Err(classify_error(status, text, retry_after_ms));
119 }
120
121 let body = response.text().await?;
126 let api_response: ApiResponse = serde_json::from_str(&body)?;
127 convert_response(api_response)
128 }
129}
130
131fn classify_error(status: u16, message: String, retry_after_ms: Option<u64>) -> ProviderError {
132 match status {
133 429 => ProviderError::RateLimit { retry_after_ms },
134 503 => ProviderError::Overloaded { retry_after_ms },
135 500 | 502 | 504 => ProviderError::Api {
136 status,
137 message,
138 retryable: true,
139 },
140 s => ProviderError::Api {
141 status: s,
142 message,
143 retryable: (500..600).contains(&s),
144 },
145 }
146}
147
148fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<u64> {
149 let raw = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
150 raw.trim().parse::<u64>().ok().map(|s| s * 1_000)
151}
152
153#[derive(Serialize)]
156struct ApiRequest {
157 model: String,
158 messages: Vec<ApiMessage>,
159 #[serde(skip_serializing_if = "Vec::is_empty")]
160 tools: Vec<ApiTool>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 max_tokens: Option<u32>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 temperature: Option<f32>,
165 #[serde(skip_serializing_if = "std::ops::Not::not")]
167 stream: bool,
168 #[serde(skip_serializing_if = "Option::is_none")]
172 stream_options: Option<StreamOptions>,
173}
174
175#[derive(Serialize)]
176struct StreamOptions {
177 include_usage: bool,
178}
179
180#[derive(Serialize)]
181#[serde(untagged)]
182enum ApiMessage {
183 Simple { role: &'static str, content: String },
185 Assistant {
187 role: &'static str,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 content: Option<String>,
190 #[serde(skip_serializing_if = "Vec::is_empty")]
191 tool_calls: Vec<ApiToolCallOut>,
192 },
193 Tool {
195 role: &'static str,
196 tool_call_id: String,
197 content: String,
198 },
199}
200
201#[derive(Serialize)]
202struct ApiToolCallOut {
203 id: String,
204 #[serde(rename = "type")]
205 kind: &'static str,
206 function: ApiFunctionOut,
207}
208
209#[derive(Serialize)]
210struct ApiFunctionOut {
211 name: String,
212 arguments: String,
214}
215
216#[derive(Serialize)]
217struct ApiTool {
218 #[serde(rename = "type")]
219 kind: &'static str,
220 function: ApiFunctionDef,
221}
222
223#[derive(Serialize)]
224struct ApiFunctionDef {
225 name: String,
226 description: String,
227 parameters: Value,
228}
229
230#[derive(Deserialize)]
231struct ApiResponse {
232 choices: Vec<ApiChoice>,
233 #[serde(default)]
234 usage: Option<ApiUsage>,
235}
236
237#[derive(Deserialize)]
238struct ApiChoice {
239 message: ApiResponseMessage,
240 #[serde(default)]
241 finish_reason: Option<String>,
242}
243
244#[derive(Deserialize)]
245struct ApiResponseMessage {
246 #[serde(default)]
247 content: Option<String>,
248 #[serde(default)]
249 tool_calls: Vec<ApiToolCallIn>,
250}
251
252#[derive(Deserialize)]
253struct ApiToolCallIn {
254 id: String,
255 #[serde(default)]
256 function: ApiFunctionIn,
257}
258
259#[derive(Deserialize, Default)]
260struct ApiFunctionIn {
261 #[serde(default)]
262 name: String,
263 #[serde(default)]
264 arguments: String,
265}
266
267#[derive(Deserialize)]
268struct ApiUsage {
269 #[serde(default)]
270 prompt_tokens: u32,
271 #[serde(default)]
272 completion_tokens: u32,
273}
274
275fn build_request_body(request: &Request) -> ApiRequest {
278 let mut messages: Vec<ApiMessage> = Vec::new();
279
280 if let Some(blocks) = request.system.as_ref() {
285 if !blocks.is_empty() {
286 let joined = blocks
287 .iter()
288 .map(|b| b.text.as_str())
289 .collect::<Vec<_>>()
290 .join("\n\n");
291 messages.push(ApiMessage::Simple {
292 role: "system",
293 content: joined,
294 });
295 }
296 }
297
298 for msg in &request.messages {
299 extend_with_message(&mut messages, msg);
300 }
301
302 let tools = request
303 .tools
304 .iter()
305 .map(|t| ApiTool {
306 kind: "function",
307 function: ApiFunctionDef {
308 name: t.name.clone(),
309 description: t.description.clone(),
310 parameters: t.input_schema.clone(),
311 },
312 })
313 .collect();
314
315 ApiRequest {
316 model: request.model.clone(),
317 messages,
318 tools,
319 max_tokens: Some(request.max_tokens),
320 temperature: request.temperature,
321 stream: false,
322 stream_options: None,
323 }
324}
325
326fn extend_with_message(out: &mut Vec<ApiMessage>, msg: &Message) {
333 match msg.role {
334 Role::User => {
335 let mut text_buf = String::new();
336 for c in &msg.content {
337 match c {
338 Content::Text { text, .. } => {
339 if !text_buf.is_empty() {
340 text_buf.push('\n');
341 }
342 text_buf.push_str(text);
343 }
344 Content::ToolResult {
345 tool_use_id,
346 content,
347 is_error,
348 ..
349 } => {
350 if !text_buf.is_empty() {
352 out.push(ApiMessage::Simple {
353 role: "user",
354 content: std::mem::take(&mut text_buf),
355 });
356 }
357 let wire_content = if *is_error {
364 format!("[error] {content}")
365 } else {
366 content.clone()
367 };
368 out.push(ApiMessage::Tool {
369 role: "tool",
370 tool_call_id: tool_use_id.clone(),
371 content: wire_content,
372 });
373 }
374 Content::ToolUse { .. } | Content::Thinking { .. } => {
375 }
377 }
378 }
379 if !text_buf.is_empty() {
380 out.push(ApiMessage::Simple {
381 role: "user",
382 content: text_buf,
383 });
384 }
385 }
386 Role::Assistant => {
387 let mut text_parts: Vec<String> = Vec::new();
388 let mut tool_calls: Vec<ApiToolCallOut> = Vec::new();
389 for c in &msg.content {
390 match c {
391 Content::Text { text, .. } => text_parts.push(text.clone()),
392 Content::ToolUse { id, name, input } => {
393 tool_calls.push(ApiToolCallOut {
394 id: id.clone(),
395 kind: "function",
396 function: ApiFunctionOut {
397 name: name.clone(),
398 arguments: serde_json::to_string(input)
400 .unwrap_or_else(|_| "{}".to_string()),
401 },
402 });
403 }
404 Content::Thinking { .. } | Content::ToolResult { .. } => {
405 }
408 }
409 }
410 if text_parts.is_empty() && tool_calls.is_empty() {
417 return;
418 }
419 out.push(ApiMessage::Assistant {
420 role: "assistant",
421 content: if text_parts.is_empty() {
422 None
423 } else {
424 Some(text_parts.join(""))
425 },
426 tool_calls,
427 });
428 }
429 }
430}
431
432fn convert_response(api: ApiResponse) -> Result<Response, ProviderError> {
433 let choice = api
434 .choices
435 .into_iter()
436 .next()
437 .ok_or_else(|| ProviderError::Other("response had no choices".into()))?;
438
439 let mut content: Vec<Content> = Vec::new();
440 if let Some(text) = choice.message.content {
441 if !text.is_empty() {
442 content.push(Content::text(text));
443 }
444 }
445 for tc in choice.message.tool_calls {
446 let input = if tc.function.arguments.trim().is_empty() {
450 Value::Object(Default::default())
451 } else {
452 serde_json::from_str(&tc.function.arguments)
453 .unwrap_or(Value::Object(Default::default()))
454 };
455 content.push(Content::ToolUse {
456 id: tc.id,
457 name: tc.function.name,
458 input,
459 });
460 }
461
462 let has_tool_use = content.iter().any(|c| matches!(c, Content::ToolUse { .. }));
463
464 let stop_reason = match choice.finish_reason.as_deref() {
465 Some("stop") => StopReason::EndTurn,
466 Some("tool_calls") | Some("function_call") => StopReason::ToolUse,
467 Some("length") => StopReason::MaxTokens,
468 Some("content_filter") => StopReason::EndTurn,
469 Some("stop_sequence") => StopReason::StopSequence,
472 _ if has_tool_use => StopReason::ToolUse,
477 _ => StopReason::EndTurn,
478 };
479
480 let usage = api
481 .usage
482 .map(|u| Usage {
483 input_tokens: u.prompt_tokens,
484 output_tokens: u.completion_tokens,
485 cache_creation_input_tokens: 0,
486 cache_read_input_tokens: 0,
487 })
488 .unwrap_or_default();
489
490 Ok(Response {
491 content,
492 stop_reason,
493 usage,
494 })
495}
496
497#[derive(Deserialize)]
514struct ChatChunk {
515 #[serde(default)]
516 choices: Vec<ChatChoice>,
517 #[serde(default)]
518 usage: Option<ChunkUsage>,
519}
520
521#[derive(Deserialize)]
522struct ChatChoice {
523 #[serde(default)]
524 delta: ChatDelta,
525 #[serde(default)]
526 finish_reason: Option<String>,
527}
528
529#[derive(Deserialize, Default)]
530struct ChatDelta {
531 #[serde(default)]
532 content: Option<String>,
533 #[serde(default)]
534 tool_calls: Vec<ToolCallChunk>,
535}
536
537#[derive(Deserialize)]
538struct ToolCallChunk {
539 #[serde(default)]
542 index: Option<usize>,
543 #[serde(default)]
544 id: Option<String>,
545 #[serde(default)]
546 function: Option<ToolCallFunctionChunk>,
547}
548
549#[derive(Deserialize, Default)]
550struct ToolCallFunctionChunk {
551 #[serde(default)]
552 name: Option<String>,
553 #[serde(default)]
554 arguments: Option<String>,
555}
556
557#[derive(Deserialize)]
558struct ChunkUsage {
559 #[serde(default)]
560 prompt_tokens: u32,
561 #[serde(default)]
562 completion_tokens: u32,
563}
564
565#[derive(Default)]
566struct ToolSlot {
567 id: String,
568 name: String,
569 args_buf: String,
570}
571
572struct StreamState<S> {
575 sse: S,
576 slots: BTreeMap<usize, ToolSlot>,
579 pending_stop: Option<StopReason>,
581 buffer: std::collections::VecDeque<Result<StreamEvent, ProviderError>>,
583 emitted_done: bool,
586}
587
588fn openai_event_stream<S>(sse: S) -> impl futures::Stream<Item = Result<StreamEvent, ProviderError>>
589where
590 S: futures::Stream<
591 Item = Result<
592 eventsource_stream::Event,
593 eventsource_stream::EventStreamError<reqwest::Error>,
594 >,
595 > + Send
596 + Unpin
597 + 'static,
598{
599 use std::collections::VecDeque;
600
601 let initial = StreamState {
602 sse,
603 slots: BTreeMap::new(),
604 pending_stop: None,
605 buffer: VecDeque::new(),
606 emitted_done: false,
607 };
608
609 futures::stream::unfold(initial, |mut state| async move {
610 loop {
611 if let Some(ev) = state.buffer.pop_front() {
612 return Some((ev, state));
613 }
614
615 if state.emitted_done {
616 return None;
617 }
618
619 let next = state.sse.next().await;
620 let event = match next {
621 None => {
622 flush_terminal(&mut state.slots, &mut state.pending_stop, &mut state.buffer);
625 if state.buffer.is_empty() {
626 return None;
627 }
628 state.emitted_done = true;
629 continue;
630 }
631 Some(Ok(ev)) => ev,
632 Some(Err(e)) => {
633 let err = ProviderError::Other(format!("SSE read error: {e}"));
634 return Some((Err(err), state));
635 }
636 };
637
638 let data = event.data.trim();
639 if data == "[DONE]" {
640 flush_terminal(&mut state.slots, &mut state.pending_stop, &mut state.buffer);
641 state.emitted_done = true;
642 continue;
643 }
644 if data.is_empty() {
645 continue;
646 }
647
648 let chunk: ChatChunk = match serde_json::from_str(data) {
649 Ok(c) => c,
650 Err(_) => continue,
651 };
652
653 process_chunk(
654 chunk,
655 &mut state.slots,
656 &mut state.pending_stop,
657 &mut state.buffer,
658 );
659 }
660 })
661}
662
663fn process_chunk(
664 chunk: ChatChunk,
665 slots: &mut BTreeMap<usize, ToolSlot>,
666 pending_stop: &mut Option<StopReason>,
667 buffer: &mut std::collections::VecDeque<Result<StreamEvent, ProviderError>>,
668) {
669 if let Some(choice) = chunk.choices.into_iter().next() {
670 if let Some(text) = choice.delta.content {
671 if !text.is_empty() {
672 buffer.push_back(Ok(StreamEvent::ContentDelta(text)));
673 }
674 }
675 for tc in choice.delta.tool_calls {
676 let idx = tc.index.unwrap_or(slots.len());
677 let slot = slots.entry(idx).or_default();
678 if let Some(id) = tc.id {
679 slot.id = id;
680 }
681 if let Some(f) = tc.function {
682 if let Some(name) = f.name {
683 slot.name = name;
684 }
685 if let Some(args) = f.arguments {
686 slot.args_buf.push_str(&args);
687 }
688 }
689 }
690 if let Some(reason) = choice.finish_reason {
691 *pending_stop = Some(map_finish_reason(&reason));
692 }
693 }
694
695 if let Some(usage) = chunk.usage {
696 buffer.push_back(Ok(StreamEvent::Usage(Usage {
697 input_tokens: usage.prompt_tokens,
698 output_tokens: usage.completion_tokens,
699 cache_creation_input_tokens: 0,
700 cache_read_input_tokens: 0,
701 })));
702 }
703}
704
705fn flush_terminal(
708 slots: &mut BTreeMap<usize, ToolSlot>,
709 pending_stop: &mut Option<StopReason>,
710 buffer: &mut std::collections::VecDeque<Result<StreamEvent, ProviderError>>,
711) {
712 for (_, slot) in std::mem::take(slots) {
714 if slot.id.is_empty() && slot.name.is_empty() {
715 continue;
716 }
717 let input: Value = if slot.args_buf.trim().is_empty() {
718 Value::Object(Default::default())
719 } else {
720 serde_json::from_str(&slot.args_buf).unwrap_or(Value::Object(Default::default()))
721 };
722 buffer.push_back(Ok(StreamEvent::ToolUse {
723 id: slot.id,
724 name: slot.name,
725 input,
726 }));
727 }
728 if let Some(stop) = pending_stop.take() {
729 buffer.push_back(Ok(StreamEvent::MessageDelta { stop_reason: stop }));
730 }
731 buffer.push_back(Ok(StreamEvent::Done));
732}
733
734fn map_finish_reason(reason: &str) -> StopReason {
735 match reason {
736 "stop" => StopReason::EndTurn,
737 "tool_calls" | "function_call" => StopReason::ToolUse,
738 "length" => StopReason::MaxTokens,
739 "content_filter" => StopReason::EndTurn,
740 "stop_sequence" => StopReason::StopSequence,
741 _ => StopReason::EndTurn,
742 }
743}
744
745#[cfg(test)]
746mod tests {
747 use super::*;
748 use crate::message::CacheControl;
749 use crate::provider::SystemBlock;
750
751 #[test]
752 fn request_maps_system_and_user_text() {
753 let req = Request {
754 model: "gpt-4".into(),
755 system: Some(vec![SystemBlock::text("be brief")]),
756 messages: vec![Message::user_text("hi")],
757 tools: vec![],
758 max_tokens: 100,
759 temperature: Some(0.5),
760 thinking: None,
761 };
762 let body = build_request_body(&req);
763 let json = serde_json::to_value(&body).unwrap();
764 assert_eq!(json["model"], "gpt-4");
765 assert_eq!(json["messages"][0]["role"], "system");
766 assert_eq!(json["messages"][0]["content"], "be brief");
767 assert_eq!(json["messages"][1]["role"], "user");
768 assert_eq!(json["messages"][1]["content"], "hi");
769 assert_eq!(json["temperature"], 0.5);
770 assert_eq!(json["max_tokens"], 100);
771 }
772
773 #[test]
774 fn multiple_system_blocks_concatenate_with_double_newline() {
775 let req = Request {
776 model: "gpt-4".into(),
777 system: Some(vec![
778 SystemBlock::text("base instructions"),
779 SystemBlock::cached("long stable context"),
780 SystemBlock::text("final tail"),
781 ]),
782 messages: vec![Message::user_text("hi")],
783 tools: vec![],
784 max_tokens: 10,
785 temperature: None,
786 thinking: None,
787 };
788 let body = build_request_body(&req);
789 let json = serde_json::to_value(&body).unwrap();
790 assert_eq!(json["messages"][0]["role"], "system");
791 assert_eq!(
792 json["messages"][0]["content"],
793 "base instructions\n\nlong stable context\n\nfinal tail"
794 );
795 }
797
798 #[test]
799 fn empty_system_vec_emits_no_system_message() {
800 let req = Request {
801 model: "gpt-4".into(),
802 system: Some(vec![]),
803 messages: vec![Message::user_text("hi")],
804 tools: vec![],
805 max_tokens: 10,
806 temperature: None,
807 thinking: None,
808 };
809 let body = build_request_body(&req);
810 let json = serde_json::to_value(&body).unwrap();
811 assert_eq!(json["messages"][0]["role"], "user");
812 }
813
814 #[test]
815 fn tool_definition_cache_control_is_ignored_silently() {
816 use crate::provider::ToolDefinition;
820 let req = Request {
821 model: "gpt-4".into(),
822 system: None,
823 messages: vec![Message::user_text("hi")],
824 tools: vec![ToolDefinition {
825 name: "bash".into(),
826 description: "run a shell command".into(),
827 input_schema: serde_json::json!({"type": "object"}),
828 cache_control: Some(CacheControl::ephemeral()),
829 }],
830 max_tokens: 10,
831 temperature: None,
832 thinking: None,
833 };
834 let body = build_request_body(&req);
835 let json = serde_json::to_value(&body).unwrap();
836 let tool = &json["tools"][0];
837 assert!(tool.get("cache_control").is_none());
838 assert_eq!(tool["function"]["name"], "bash");
839 }
840
841 #[test]
842 fn request_fans_out_tool_results_to_separate_tool_messages() {
843 let req = Request {
844 model: "m".into(),
845 system: None,
846 messages: vec![Message::user(vec![
847 Content::tool_result("call_1", "ok", false),
848 Content::tool_result("call_2", "bad", true),
849 ])],
850 tools: vec![],
851 max_tokens: 10,
852 temperature: None,
853 thinking: None,
854 };
855 let body = build_request_body(&req);
856 let json = serde_json::to_value(&body).unwrap();
857 let msgs = json["messages"].as_array().unwrap();
858 assert_eq!(msgs.len(), 2);
859 assert_eq!(msgs[0]["role"], "tool");
860 assert_eq!(msgs[0]["tool_call_id"], "call_1");
861 assert_eq!(msgs[1]["tool_call_id"], "call_2");
862 }
863
864 #[test]
865 fn request_skips_thinking_only_messages() {
866 use crate::message::{ThinkingMetadata, ThinkingProvider};
867 let req = Request {
868 model: "m".into(),
869 system: None,
870 messages: vec![
871 Message::assistant(vec![Content::thinking(
872 "hidden",
873 ThinkingProvider::OpenAIResponses,
874 ThinkingMetadata::openai_responses(Some("rs_1".into()), None, 0, None),
875 )]),
876 Message::user_text("next"),
877 ],
878 tools: vec![],
879 max_tokens: 10,
880 temperature: None,
881 thinking: None,
882 };
883 let body = build_request_body(&req);
884 let json = serde_json::to_value(&body).unwrap();
885 let msgs = json["messages"].as_array().unwrap();
886
887 assert_eq!(msgs.len(), 1);
888 assert_eq!(msgs[0]["role"], "user");
889 assert_eq!(msgs[0]["content"], "next");
890 }
891
892 #[test]
893 fn request_drops_thinking_without_inserting_text_separator() {
894 use crate::message::{ThinkingMetadata, ThinkingProvider};
895 let req = Request {
896 model: "m".into(),
897 system: None,
898 messages: vec![Message::assistant(vec![
899 Content::text("Hello"),
900 Content::thinking(
901 "hidden",
902 ThinkingProvider::OpenAIResponses,
903 ThinkingMetadata::openai_responses(Some("rs_1".into()), None, 0, None),
904 ),
905 Content::text("world"),
906 ])],
907 tools: vec![],
908 max_tokens: 10,
909 temperature: None,
910 thinking: None,
911 };
912 let body = build_request_body(&req);
913 let json = serde_json::to_value(&body).unwrap();
914
915 assert_eq!(json["messages"][0]["role"], "assistant");
916 assert_eq!(json["messages"][0]["content"], "Helloworld");
917 }
918
919 #[test]
920 fn request_encodes_assistant_tool_use_as_tool_calls_with_string_arguments() {
921 let req = Request {
922 model: "m".into(),
923 system: None,
924 messages: vec![Message::assistant(vec![
925 Content::text("let me check"),
926 Content::ToolUse {
927 id: "call_x".into(),
928 name: "bash".into(),
929 input: serde_json::json!({"command": "ls"}),
930 },
931 ])],
932 tools: vec![],
933 max_tokens: 10,
934 temperature: None,
935 thinking: None,
936 };
937 let body = build_request_body(&req);
938 let json = serde_json::to_value(&body).unwrap();
939 let msg = &json["messages"][0];
940 assert_eq!(msg["role"], "assistant");
941 assert_eq!(msg["content"], "let me check");
942 let tc = &msg["tool_calls"][0];
943 assert_eq!(tc["id"], "call_x");
944 assert_eq!(tc["type"], "function");
945 assert_eq!(tc["function"]["name"], "bash");
946 let args_str = tc["function"]["arguments"].as_str().unwrap();
948 let parsed: Value = serde_json::from_str(args_str).unwrap();
949 assert_eq!(parsed["command"], "ls");
950 }
951
952 #[test]
953 fn response_ignores_non_standard_reasoning_fields() {
954 let raw = serde_json::json!({
955 "choices": [{
956 "message": {
957 "role": "assistant",
958 "content": "visible",
959 "reasoning": "not a standardized chat-completions field",
960 "reasoning_content": "not safe to expose by default",
961 "thinking": "also ignored by default"
962 },
963 "finish_reason": "stop"
964 }],
965 "usage": { "prompt_tokens": 2, "completion_tokens": 3 }
966 });
967 let api: ApiResponse = serde_json::from_value(raw).unwrap();
968 let resp = convert_response(api).unwrap();
969
970 assert_eq!(resp.stop_reason, StopReason::EndTurn);
971 assert_eq!(resp.content.len(), 1);
972 assert!(matches!(
973 &resp.content[0],
974 Content::Text { text, .. } if text == "visible"
975 ));
976 }
977
978 #[test]
979 fn streaming_ignores_non_standard_reasoning_fields() {
980 use std::collections::{BTreeMap, VecDeque};
981
982 let chunk: ChatChunk = serde_json::from_value(serde_json::json!({
983 "choices": [{
984 "delta": {
985 "content": "visible",
986 "reasoning": "not standardized",
987 "reasoning_content": "not safe",
988 "thinking": "not safe"
989 }
990 }]
991 }))
992 .unwrap();
993 let mut slots = BTreeMap::new();
994 let mut pending_stop = None;
995 let mut buffer = VecDeque::new();
996 process_chunk(chunk, &mut slots, &mut pending_stop, &mut buffer);
997
998 assert!(matches!(
999 buffer.pop_front().unwrap().unwrap(),
1000 StreamEvent::ContentDelta(text) if text == "visible"
1001 ));
1002 assert!(buffer.is_empty());
1003 }
1004
1005 #[test]
1006 fn streaming_reasoning_only_chunk_emits_no_thinking() {
1007 use std::collections::{BTreeMap, VecDeque};
1008
1009 let chunk: ChatChunk = serde_json::from_value(serde_json::json!({
1010 "choices": [{
1011 "delta": {
1012 "reasoning_content": "not safe to expose by default"
1013 }
1014 }]
1015 }))
1016 .unwrap();
1017 let mut slots = BTreeMap::new();
1018 let mut pending_stop = None;
1019 let mut buffer = VecDeque::new();
1020 process_chunk(chunk, &mut slots, &mut pending_stop, &mut buffer);
1021
1022 assert!(buffer.is_empty());
1023 assert!(pending_stop.is_none());
1024 assert!(slots.is_empty());
1025 }
1026
1027 #[test]
1028 fn response_decodes_text_and_tool_calls() {
1029 let raw = serde_json::json!({
1030 "choices": [{
1031 "message": {
1032 "role": "assistant",
1033 "content": "calling a tool",
1034 "tool_calls": [{
1035 "id": "call_1",
1036 "type": "function",
1037 "function": {
1038 "name": "bash",
1039 "arguments": "{\"command\":\"echo hi\"}"
1040 }
1041 }]
1042 },
1043 "finish_reason": "tool_calls"
1044 }],
1045 "usage": { "prompt_tokens": 10, "completion_tokens": 3 }
1046 });
1047 let api: ApiResponse = serde_json::from_value(raw).unwrap();
1048 let resp = convert_response(api).unwrap();
1049 assert_eq!(resp.stop_reason, StopReason::ToolUse);
1050 assert_eq!(resp.usage.input_tokens, 10);
1051 assert_eq!(resp.usage.output_tokens, 3);
1052 match &resp.content[0] {
1053 Content::Text { text, .. } => assert_eq!(text, "calling a tool"),
1054 _ => panic!("expected text"),
1055 }
1056 match &resp.content[1] {
1057 Content::ToolUse { id, name, input } => {
1058 assert_eq!(id, "call_1");
1059 assert_eq!(name, "bash");
1060 assert_eq!(input["command"], "echo hi");
1061 }
1062 _ => panic!("expected tool_use"),
1063 }
1064 }
1065
1066 #[test]
1067 fn response_maps_finish_reasons() {
1068 fn stop_for(reason: &str) -> StopReason {
1069 let raw = serde_json::json!({
1070 "choices": [{
1071 "message": {"role": "assistant", "content": ""},
1072 "finish_reason": reason
1073 }]
1074 });
1075 let api: ApiResponse = serde_json::from_value(raw).unwrap();
1076 convert_response(api).unwrap().stop_reason
1077 }
1078 assert_eq!(stop_for("stop"), StopReason::EndTurn);
1079 assert_eq!(stop_for("length"), StopReason::MaxTokens);
1080 assert_eq!(stop_for("tool_calls"), StopReason::ToolUse);
1081 assert_eq!(stop_for("content_filter"), StopReason::EndTurn);
1082 }
1083
1084 #[test]
1085 fn classify_maps_retryable_status_codes() {
1086 assert!(matches!(
1087 classify_error(429, "".into(), Some(1000)),
1088 ProviderError::RateLimit {
1089 retry_after_ms: Some(1000)
1090 }
1091 ));
1092 assert!(matches!(
1093 classify_error(503, "".into(), None),
1094 ProviderError::Overloaded {
1095 retry_after_ms: None
1096 }
1097 ));
1098 assert!(matches!(
1099 classify_error(500, "oops".into(), None),
1100 ProviderError::Api {
1101 retryable: true,
1102 ..
1103 }
1104 ));
1105 assert!(matches!(
1106 classify_error(400, "bad".into(), None),
1107 ProviderError::Api {
1108 retryable: false,
1109 ..
1110 }
1111 ));
1112 }
1113
1114 #[test]
1115 fn response_infers_tool_use_when_finish_reason_missing() {
1116 let raw = serde_json::json!({
1120 "choices": [{
1121 "message": {
1122 "role": "assistant",
1123 "tool_calls": [{
1124 "id": "call_1",
1125 "type": "function",
1126 "function": {"name": "bash", "arguments": "{}"}
1127 }]
1128 }
1129 }]
1131 });
1132 let api: ApiResponse = serde_json::from_value(raw).unwrap();
1133 let resp = convert_response(api).unwrap();
1134 assert_eq!(resp.stop_reason, StopReason::ToolUse);
1135 }
1136
1137 #[test]
1138 fn request_marks_error_tool_results_with_prefix() {
1139 let req = Request {
1143 model: "m".into(),
1144 system: None,
1145 messages: vec![Message::user(vec![
1146 Content::tool_result("call_ok", "all good", false),
1147 Content::tool_result("call_bad", "something broke", true),
1148 ])],
1149 tools: vec![],
1150 max_tokens: 10,
1151 temperature: None,
1152 thinking: None,
1153 };
1154 let body = build_request_body(&req);
1155 let json = serde_json::to_value(&body).unwrap();
1156 let msgs = json["messages"].as_array().unwrap();
1157 assert_eq!(msgs[0]["content"], "all good");
1158 assert_eq!(msgs[1]["content"], "[error] something broke");
1159 }
1160
1161 #[test]
1162 fn request_skips_empty_assistant_messages() {
1163 let req = Request {
1167 model: "m".into(),
1168 system: None,
1169 messages: vec![
1170 Message::user_text("hi"),
1171 Message::assistant(vec![]), Message::user_text("still there?"),
1173 ],
1174 tools: vec![],
1175 max_tokens: 10,
1176 temperature: None,
1177 thinking: None,
1178 };
1179 let body = build_request_body(&req);
1180 let json = serde_json::to_value(&body).unwrap();
1181 let msgs = json["messages"].as_array().unwrap();
1182 assert_eq!(msgs.len(), 2);
1184 assert_eq!(msgs[0]["role"], "user");
1185 assert_eq!(msgs[1]["role"], "user");
1186 }
1187
1188 #[test]
1189 fn request_thinking_is_ignored_silently() {
1190 use crate::provider::{ThinkingConfig, ThinkingEffort};
1196 let req = Request {
1197 model: "gpt-4o".into(),
1198 system: None,
1199 messages: vec![Message::user_text("hi")],
1200 tools: vec![],
1201 max_tokens: 10,
1202 temperature: None,
1203 thinking: Some(ThinkingConfig::Effort(ThinkingEffort::High)),
1204 };
1205 let body = build_request_body(&req);
1206 let json = serde_json::to_value(&body).unwrap();
1207 assert!(
1208 json.get("reasoning").is_none() && json.get("thinking").is_none(),
1209 "OpenAICompatible must not emit reasoning/thinking; got {json}"
1210 );
1211 }
1212}