Skip to main content

walrus_model/local/
provider.rs

1//! Model trait implementation for the Local provider.
2
3use super::Local;
4use anyhow::Result;
5use async_stream::try_stream;
6use compact_str::CompactString;
7use futures_core::Stream;
8use std::collections::HashMap;
9use wcore::model::{
10    Choice, CompletionMeta, Delta, FunctionCall, Model, Response, Role, StreamChunk, ToolCall,
11    Usage,
12};
13
14impl Model for Local {
15    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
16        let mr_request = build_request(request);
17        let resp = self.model.send_chat_request(mr_request).await?;
18        Ok(to_response(resp))
19    }
20
21    fn stream(
22        &self,
23        request: wcore::model::Request,
24    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
25        let model = self.model.clone();
26        try_stream! {
27            let mr_request = build_request(&request);
28            let mut stream = model.stream_chat_request(mr_request).await?;
29            while let Some(resp) = stream.next().await {
30                match resp {
31                    mistralrs::Response::Chunk(chunk) => {
32                        yield to_stream_chunk(chunk);
33                    }
34                    mistralrs::Response::Done(_) => break,
35                    mistralrs::Response::InternalError(e)
36                    | mistralrs::Response::ValidationError(e) => {
37                        Err(anyhow::anyhow!("{e}"))?;
38                    }
39                    mistralrs::Response::ModelError(msg, _) => {
40                        Err(anyhow::anyhow!("model error: {msg}"))?;
41                    }
42                    _ => {}
43                }
44            }
45        }
46    }
47
48    fn context_limit(&self, model: &str) -> usize {
49        self.context_length(model)
50            .unwrap_or_else(|| wcore::model::default_context_limit(model))
51    }
52
53    fn active_model(&self) -> CompactString {
54        CompactString::from("local")
55    }
56}
57
58/// Build a mistralrs `RequestBuilder` from a walrus `Request`.
59fn build_request(request: &wcore::model::Request) -> mistralrs::RequestBuilder {
60    let mut builder = mistralrs::RequestBuilder::new();
61
62    for msg in &request.messages {
63        match msg.role {
64            Role::System => {
65                builder = builder.add_message(mistralrs::TextMessageRole::System, &msg.content);
66            }
67            Role::User => {
68                builder = builder.add_message(mistralrs::TextMessageRole::User, &msg.content);
69            }
70            Role::Assistant => {
71                if msg.tool_calls.is_empty() {
72                    builder =
73                        builder.add_message(mistralrs::TextMessageRole::Assistant, &msg.content);
74                } else {
75                    let tool_calls = msg
76                        .tool_calls
77                        .iter()
78                        .map(|tc| mistralrs::ToolCallResponse {
79                            id: tc.id.to_string(),
80                            tp: mistralrs::ToolCallType::Function,
81                            function: mistralrs::CalledFunction {
82                                name: tc.function.name.to_string(),
83                                arguments: tc.function.arguments.clone(),
84                            },
85                            index: tc.index as usize,
86                        })
87                        .collect();
88                    builder = builder.add_message_with_tool_call(
89                        mistralrs::TextMessageRole::Assistant,
90                        &msg.content,
91                        tool_calls,
92                    );
93                }
94            }
95            Role::Tool => {
96                builder = builder.add_tool_message(&msg.content, &msg.tool_call_id);
97            }
98        }
99    }
100
101    if let Some(tools) = &request.tools {
102        let mr_tools = tools
103            .iter()
104            .map(|t| {
105                let params: HashMap<String, serde_json::Value> =
106                    serde_json::from_value(serde_json::to_value(&t.parameters).unwrap_or_default())
107                        .unwrap_or_default();
108                mistralrs::Tool {
109                    tp: mistralrs::ToolType::Function,
110                    function: mistralrs::Function {
111                        description: Some(t.description.clone()),
112                        name: t.name.to_string(),
113                        parameters: Some(params),
114                    },
115                }
116            })
117            .collect();
118        builder = builder.set_tools(mr_tools);
119    }
120
121    if let Some(tool_choice) = &request.tool_choice {
122        let mr_choice = match tool_choice {
123            wcore::model::ToolChoice::None => mistralrs::ToolChoice::None,
124            wcore::model::ToolChoice::Auto | wcore::model::ToolChoice::Required => {
125                mistralrs::ToolChoice::Auto
126            }
127            wcore::model::ToolChoice::Function(name) => {
128                mistralrs::ToolChoice::Tool(mistralrs::Tool {
129                    tp: mistralrs::ToolType::Function,
130                    function: mistralrs::Function {
131                        description: None,
132                        name: name.to_string(),
133                        parameters: None,
134                    },
135                })
136            }
137        };
138        builder = builder.set_tool_choice(mr_choice);
139    }
140
141    builder
142}
143
144/// Convert a mistralrs `ChatCompletionResponse` to a walrus `Response`.
145fn to_response(resp: mistralrs::ChatCompletionResponse) -> Response {
146    let choices = resp
147        .choices
148        .into_iter()
149        .map(|c| Choice {
150            index: c.index as u32,
151            delta: Delta {
152                role: Some(Role::Assistant),
153                content: c.message.content,
154                reasoning_content: c.message.reasoning_content,
155                tool_calls: c
156                    .message
157                    .tool_calls
158                    .map(|tcs| tcs.into_iter().map(convert_tool_call).collect()),
159            },
160            finish_reason: parse_finish_reason(&c.finish_reason),
161            logprobs: None,
162        })
163        .collect();
164
165    Response {
166        meta: CompletionMeta {
167            id: CompactString::from(&resp.id),
168            object: CompactString::from(&resp.object),
169            created: resp.created,
170            model: CompactString::from(&resp.model),
171            system_fingerprint: Some(CompactString::from(&resp.system_fingerprint)),
172        },
173        choices,
174        usage: convert_usage(&resp.usage),
175    }
176}
177
178/// Convert a mistralrs `ChatCompletionChunkResponse` to a walrus `StreamChunk`.
179fn to_stream_chunk(chunk: mistralrs::ChatCompletionChunkResponse) -> StreamChunk {
180    let choices = chunk
181        .choices
182        .into_iter()
183        .map(|c| Choice {
184            index: c.index as u32,
185            delta: Delta {
186                role: Some(Role::Assistant),
187                content: c.delta.content,
188                reasoning_content: c.delta.reasoning_content,
189                tool_calls: c
190                    .delta
191                    .tool_calls
192                    .map(|tcs| tcs.into_iter().map(convert_tool_call).collect()),
193            },
194            finish_reason: c
195                .finish_reason
196                .as_ref()
197                .and_then(|r| parse_finish_reason(r)),
198            logprobs: None,
199        })
200        .collect();
201
202    StreamChunk {
203        meta: CompletionMeta {
204            id: CompactString::from(&chunk.id),
205            object: CompactString::from(&chunk.object),
206            created: chunk.created as u64,
207            model: CompactString::from(&chunk.model),
208            system_fingerprint: Some(CompactString::from(&chunk.system_fingerprint)),
209        },
210        choices,
211        usage: chunk.usage.as_ref().map(convert_usage),
212    }
213}
214
215/// Convert a mistralrs `ToolCallResponse` to a walrus `ToolCall`.
216fn convert_tool_call(tc: mistralrs::ToolCallResponse) -> ToolCall {
217    ToolCall {
218        id: CompactString::from(&tc.id),
219        index: tc.index as u32,
220        call_type: CompactString::from("function"),
221        function: FunctionCall {
222            name: CompactString::from(&tc.function.name),
223            arguments: tc.function.arguments,
224        },
225    }
226}
227
228/// Convert a mistralrs `Usage` to a walrus `Usage`.
229fn convert_usage(u: &mistralrs::Usage) -> Usage {
230    Usage {
231        prompt_tokens: u.prompt_tokens as u32,
232        completion_tokens: u.completion_tokens as u32,
233        total_tokens: u.total_tokens as u32,
234        prompt_cache_hit_tokens: None,
235        prompt_cache_miss_tokens: None,
236        completion_tokens_details: None,
237    }
238}
239
240/// Parse a finish reason string into a walrus `FinishReason`.
241fn parse_finish_reason(reason: &str) -> Option<wcore::model::FinishReason> {
242    match reason {
243        "stop" => Some(wcore::model::FinishReason::Stop),
244        "length" => Some(wcore::model::FinishReason::Length),
245        "content_filter" => Some(wcore::model::FinishReason::ContentFilter),
246        "tool_calls" => Some(wcore::model::FinishReason::ToolCalls),
247        _ => None,
248    }
249}