tool_parser/parsers/
llama.rs1use async_trait::async_trait;
2use openai_protocol::common::Tool;
3use serde_json::Value;
4
5use crate::{
6 errors::{ParserError, ParserResult},
7 parsers::helpers,
8 partial_json::PartialJson,
9 traits::ToolParser,
10 types::{FunctionCall, StreamingParseResult, ToolCall},
11};
12
13pub struct LlamaParser {
20 partial_json: PartialJson,
22
23 buffer: String,
25
26 prev_tool_call_arr: Vec<Value>,
28
29 current_tool_id: i32,
31
32 current_tool_name_sent: bool,
34
35 streamed_args_for_tool: Vec<String>,
37
38 bot_token: &'static str,
40 tool_call_separator: &'static str,
41}
42
43impl LlamaParser {
44 pub fn new() -> Self {
46 Self {
47 partial_json: PartialJson::default(),
48 buffer: String::new(),
49 prev_tool_call_arr: Vec::new(),
50 current_tool_id: -1,
51 current_tool_name_sent: false,
52 streamed_args_for_tool: Vec::new(),
53 bot_token: "<|python_tag|>",
54 tool_call_separator: ";",
55 }
56 }
57
58 fn extract_content_after_python_tag(&self, text: &str) -> Option<(String, String)> {
60 const PYTHON_TAG: &str = "<|python_tag|>";
61
62 if let Some(tag_pos) = text.find(PYTHON_TAG) {
63 let normal_text = text[..tag_pos].to_string();
64 let json_content = text[tag_pos + PYTHON_TAG.len()..].to_string();
65 Some((normal_text, json_content))
66 } else {
67 None
68 }
69 }
70
71 fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
73 let name = obj.get("name").and_then(|v| v.as_str());
75
76 if let Some(name) = name {
77 let empty_obj = Value::Object(serde_json::Map::new());
79 let parameters = obj.get("parameters").unwrap_or(&empty_obj);
80
81 let arguments = serde_json::to_string(parameters)
83 .map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
84
85 Ok(Some(ToolCall {
86 function: FunctionCall {
87 name: name.to_string(),
88 arguments,
89 },
90 }))
91 } else {
92 Ok(None)
93 }
94 }
95
96 fn parse_semicolon_separated(&self, content: &str) -> ParserResult<Vec<ToolCall>> {
98 let mut all_tools = Vec::new();
99
100 for part in content.split(';') {
102 let trimmed = part.trim();
103 if trimmed.is_empty() {
104 continue;
105 }
106
107 match serde_json::from_str::<Value>(trimmed) {
109 Ok(value) => {
110 if let Some(tool) = self.parse_single_object(&value)? {
111 all_tools.push(tool);
112 }
113 }
114 Err(e) => {
115 tracing::debug!("Failed to parse tool call: {}", e);
117 }
118 }
119 }
120
121 Ok(all_tools)
122 }
123}
124
125impl Default for LlamaParser {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[async_trait]
132impl ToolParser for LlamaParser {
133 async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
134 let (normal_text, json_content) =
136 if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
137 (normal, json)
138 } else if text.trim_start().starts_with('{') {
139 (String::new(), text.to_string())
140 } else {
141 return Ok((text.to_string(), vec![]));
143 };
144
145 let tools = if json_content.contains(';') {
147 self.parse_semicolon_separated(&json_content)?
148 } else {
149 let parsed = serde_json::from_str::<Value>(json_content.trim())
151 .map_err(|e| ParserError::ParsingFailed(e.to_string()))
152 .and_then(|v| {
153 self.parse_single_object(&v)
154 .map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
155 });
156
157 parsed.unwrap_or_else(|e| {
158 tracing::debug!("Failed to parse tool call: {:?}", e);
159 vec![]
160 })
161 };
162
163 if tools.is_empty() {
165 return Ok((text.to_string(), vec![]));
166 }
167
168 Ok((normal_text, tools))
169 }
170
171 async fn parse_incremental(
172 &mut self,
173 chunk: &str,
174 tools: &[Tool],
175 ) -> ParserResult<StreamingParseResult> {
176 self.buffer.push_str(chunk);
178 let current_text = &self.buffer.clone();
179
180 let has_tool_start = self.has_tool_markers(current_text)
182 || (self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator));
183
184 if !has_tool_start {
185 if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
187 let normal_text = self.buffer.clone();
188 self.buffer.clear();
189
190 return Ok(StreamingParseResult {
191 normal_text,
192 calls: vec![],
193 });
194 } else {
195 return Ok(StreamingParseResult::default());
197 }
198 }
199
200 let tool_indices = helpers::get_tool_indices(tools);
202
203 let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
205 pos + self.bot_token.len()
206 } else if self.current_tool_id > 0 && current_text.starts_with(self.tool_call_separator) {
207 self.tool_call_separator.len()
208 } else {
209 0
210 };
211
212 helpers::handle_json_tool_streaming(
213 current_text,
214 start_idx,
215 &mut self.partial_json,
216 &tool_indices,
217 &mut self.buffer,
218 &mut self.current_tool_id,
219 &mut self.current_tool_name_sent,
220 &mut self.streamed_args_for_tool,
221 &mut self.prev_tool_call_arr,
222 )
223 }
224
225 fn has_tool_markers(&self, text: &str) -> bool {
226 text.contains("<|python_tag|>") || text.trim_start().starts_with('{')
228 }
229
230 fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::types::ToolCallItem>> {
231 helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
232 }
233
234 fn reset(&mut self) {
235 helpers::reset_parser_state(
236 &mut self.buffer,
237 &mut self.prev_tool_call_arr,
238 &mut self.current_tool_id,
239 &mut self.current_tool_name_sent,
240 &mut self.streamed_args_for_tool,
241 );
242 }
243}