1use super::Local;
4use anyhow::Result;
5use async_stream::try_stream;
6use compact_str::CompactString;
7use futures_core::Stream;
8use std::collections::HashMap;
9use wcore::model::{
10 Choice, CompletionMeta, Delta, FunctionCall, Model, Response, Role, StreamChunk, ToolCall,
11 Usage,
12};
13
14impl Model for Local {
15 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
16 let mr_request = build_request(request);
17 let resp = self.model.send_chat_request(mr_request).await?;
18 Ok(to_response(resp))
19 }
20
21 fn stream(
22 &self,
23 request: wcore::model::Request,
24 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
25 let model = self.model.clone();
26 try_stream! {
27 let mr_request = build_request(&request);
28 let mut stream = model.stream_chat_request(mr_request).await?;
29 while let Some(resp) = stream.next().await {
30 match resp {
31 mistralrs::Response::Chunk(chunk) => {
32 yield to_stream_chunk(chunk);
33 }
34 mistralrs::Response::Done(_) => break,
35 mistralrs::Response::InternalError(e)
36 | mistralrs::Response::ValidationError(e) => {
37 Err(anyhow::anyhow!("{e}"))?;
38 }
39 mistralrs::Response::ModelError(msg, _) => {
40 Err(anyhow::anyhow!("model error: {msg}"))?;
41 }
42 _ => {}
43 }
44 }
45 }
46 }
47
48 fn context_limit(&self, model: &str) -> usize {
49 self.context_length(model)
50 .unwrap_or_else(|| wcore::model::default_context_limit(model))
51 }
52
53 fn active_model(&self) -> CompactString {
54 CompactString::from("local")
55 }
56}
57
58fn build_request(request: &wcore::model::Request) -> mistralrs::RequestBuilder {
60 let mut builder = mistralrs::RequestBuilder::new();
61
62 for msg in &request.messages {
63 match msg.role {
64 Role::System => {
65 builder = builder.add_message(mistralrs::TextMessageRole::System, &msg.content);
66 }
67 Role::User => {
68 builder = builder.add_message(mistralrs::TextMessageRole::User, &msg.content);
69 }
70 Role::Assistant => {
71 if msg.tool_calls.is_empty() {
72 builder =
73 builder.add_message(mistralrs::TextMessageRole::Assistant, &msg.content);
74 } else {
75 let tool_calls = msg
76 .tool_calls
77 .iter()
78 .map(|tc| mistralrs::ToolCallResponse {
79 id: tc.id.to_string(),
80 tp: mistralrs::ToolCallType::Function,
81 function: mistralrs::CalledFunction {
82 name: tc.function.name.to_string(),
83 arguments: tc.function.arguments.clone(),
84 },
85 index: tc.index as usize,
86 })
87 .collect();
88 builder = builder.add_message_with_tool_call(
89 mistralrs::TextMessageRole::Assistant,
90 &msg.content,
91 tool_calls,
92 );
93 }
94 }
95 Role::Tool => {
96 builder = builder.add_tool_message(&msg.content, &msg.tool_call_id);
97 }
98 }
99 }
100
101 if let Some(tools) = &request.tools {
102 let mr_tools = tools
103 .iter()
104 .map(|t| {
105 let params: HashMap<String, serde_json::Value> =
106 serde_json::from_value(serde_json::to_value(&t.parameters).unwrap_or_default())
107 .unwrap_or_default();
108 mistralrs::Tool {
109 tp: mistralrs::ToolType::Function,
110 function: mistralrs::Function {
111 description: Some(t.description.clone()),
112 name: t.name.to_string(),
113 parameters: Some(params),
114 },
115 }
116 })
117 .collect();
118 builder = builder.set_tools(mr_tools);
119 }
120
121 if let Some(tool_choice) = &request.tool_choice {
122 let mr_choice = match tool_choice {
123 wcore::model::ToolChoice::None => mistralrs::ToolChoice::None,
124 wcore::model::ToolChoice::Auto | wcore::model::ToolChoice::Required => {
125 mistralrs::ToolChoice::Auto
126 }
127 wcore::model::ToolChoice::Function(name) => {
128 mistralrs::ToolChoice::Tool(mistralrs::Tool {
129 tp: mistralrs::ToolType::Function,
130 function: mistralrs::Function {
131 description: None,
132 name: name.to_string(),
133 parameters: None,
134 },
135 })
136 }
137 };
138 builder = builder.set_tool_choice(mr_choice);
139 }
140
141 builder
142}
143
144fn to_response(resp: mistralrs::ChatCompletionResponse) -> Response {
146 let choices = resp
147 .choices
148 .into_iter()
149 .map(|c| Choice {
150 index: c.index as u32,
151 delta: Delta {
152 role: Some(Role::Assistant),
153 content: c.message.content,
154 reasoning_content: c.message.reasoning_content,
155 tool_calls: c
156 .message
157 .tool_calls
158 .map(|tcs| tcs.into_iter().map(convert_tool_call).collect()),
159 },
160 finish_reason: parse_finish_reason(&c.finish_reason),
161 logprobs: None,
162 })
163 .collect();
164
165 Response {
166 meta: CompletionMeta {
167 id: CompactString::from(&resp.id),
168 object: CompactString::from(&resp.object),
169 created: resp.created,
170 model: CompactString::from(&resp.model),
171 system_fingerprint: Some(CompactString::from(&resp.system_fingerprint)),
172 },
173 choices,
174 usage: convert_usage(&resp.usage),
175 }
176}
177
178fn to_stream_chunk(chunk: mistralrs::ChatCompletionChunkResponse) -> StreamChunk {
180 let choices = chunk
181 .choices
182 .into_iter()
183 .map(|c| Choice {
184 index: c.index as u32,
185 delta: Delta {
186 role: Some(Role::Assistant),
187 content: c.delta.content,
188 reasoning_content: c.delta.reasoning_content,
189 tool_calls: c
190 .delta
191 .tool_calls
192 .map(|tcs| tcs.into_iter().map(convert_tool_call).collect()),
193 },
194 finish_reason: c
195 .finish_reason
196 .as_ref()
197 .and_then(|r| parse_finish_reason(r)),
198 logprobs: None,
199 })
200 .collect();
201
202 StreamChunk {
203 meta: CompletionMeta {
204 id: CompactString::from(&chunk.id),
205 object: CompactString::from(&chunk.object),
206 created: chunk.created as u64,
207 model: CompactString::from(&chunk.model),
208 system_fingerprint: Some(CompactString::from(&chunk.system_fingerprint)),
209 },
210 choices,
211 usage: chunk.usage.as_ref().map(convert_usage),
212 }
213}
214
215fn convert_tool_call(tc: mistralrs::ToolCallResponse) -> ToolCall {
217 ToolCall {
218 id: CompactString::from(&tc.id),
219 index: tc.index as u32,
220 call_type: CompactString::from("function"),
221 function: FunctionCall {
222 name: CompactString::from(&tc.function.name),
223 arguments: tc.function.arguments,
224 },
225 }
226}
227
228fn convert_usage(u: &mistralrs::Usage) -> Usage {
230 Usage {
231 prompt_tokens: u.prompt_tokens as u32,
232 completion_tokens: u.completion_tokens as u32,
233 total_tokens: u.total_tokens as u32,
234 prompt_cache_hit_tokens: None,
235 prompt_cache_miss_tokens: None,
236 completion_tokens_details: None,
237 }
238}
239
240fn parse_finish_reason(reason: &str) -> Option<wcore::model::FinishReason> {
242 match reason {
243 "stop" => Some(wcore::model::FinishReason::Stop),
244 "length" => Some(wcore::model::FinishReason::Length),
245 "content_filter" => Some(wcore::model::FinishReason::ContentFilter),
246 "tool_calls" => Some(wcore::model::FinishReason::ToolCalls),
247 _ => None,
248 }
249}