1use std::{collections::HashSet, time::Duration};
15
16use crate::{
17 config::{ToolMode, ToolsConfig},
18 openai::chat::{
19 ChatCompletionRequest, ChatRequestError, ChatToolChoice, ChatToolDefinition,
20 NormalizedChatMessage,
21 },
22 vllm_tool_parser::{
23 Glm47MoeToolParser, HermesToolParser, Qwen3XmlToolParser, Result as ToolParserResult, Tool,
24 ToolCallDelta, ToolParseResult, ToolParser,
25 },
26};
27use serde_json::{Map, Value};
28use thiserror::Error;
29use tracing::warn;
30
31const TOOL_CALL_START: &str = "<tool_call>";
33const TOOL_CALL_END: &str = "</tool_call>";
34
35pub fn generate_tool_call_id() -> String {
37 format!("call_{}", uuid::Uuid::new_v4().simple())
38}
39
40const CORRECTION_INVALID_OUTPUT_MAX_BYTES: usize = 4_096;
44
45#[derive(Debug, Clone)]
47pub struct ToolEmulationContext {
48 config: ToolsConfig,
49 tools: Vec<ChatToolDefinition>,
50 tool_schemas_json: String,
51 require_tool_call: bool,
52 prompt_format: ToolPromptFormat,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57enum ToolPromptFormat {
58 HermesJson,
59 GlmXml,
60 QwenXml,
61}
62
63impl ToolPromptFormat {
64 fn for_model(model: &str) -> Self {
66 let model = model.to_ascii_lowercase();
67 if model.contains("glm") {
68 Self::GlmXml
69 } else if model.contains("qwen") {
70 Self::QwenXml
71 } else {
72 Self::HermesJson
77 }
78 }
79}
80
81impl ToolEmulationContext {
82 pub fn from_request(
84 config: &ToolsConfig,
85 request: &ChatCompletionRequest,
86 ) -> Result<Option<Self>, ChatRequestError> {
87 if !config.enabled || config.mode == ToolMode::None {
88 return Ok(None);
89 }
90
91 if matches!(request.tool_choice, ChatToolChoice::None) {
92 return Ok(None);
93 }
94
95 if request.tools.is_empty() {
96 if matches!(
97 request.tool_choice,
98 ChatToolChoice::Required | ChatToolChoice::Function { .. }
99 ) {
100 return Err(ChatRequestError::invalid_field(
101 "tool_choice",
102 "tool_choice requires at least one function tool",
103 ));
104 }
105 return Ok(None);
106 }
107
108 let mut seen_names = HashSet::new();
109 for tool in &request.tools {
110 if !seen_names.insert(tool.name()) {
111 return Err(ChatRequestError::invalid_field(
112 "tools",
113 format!("duplicate function tool name {:?}", tool.name()),
114 ));
115 }
116
117 if config.validate_json_schema
118 && let Some(schema) = tool.parameters_schema()
119 {
120 validate_schema_shape(schema).map_err(|message| {
121 ChatRequestError::invalid_field(
122 "tools",
123 format!(
124 "tool {:?} has an unsupported or invalid parameters schema: {message}",
125 tool.name()
126 ),
127 )
128 })?;
129 }
130 }
131
132 let (tools, require_tool_call) = match &request.tool_choice {
133 ChatToolChoice::Auto => (request.tools.clone(), false),
134 ChatToolChoice::Required => (request.tools.clone(), true),
135 ChatToolChoice::Function { name } => {
136 let selected = request
137 .tools
138 .iter()
139 .find(|tool| tool.name() == name)
140 .cloned()
141 .ok_or_else(|| {
142 ChatRequestError::invalid_field(
143 "tool_choice",
144 format!("requested function tool {name:?} is not present in tools"),
145 )
146 })?;
147 (vec![selected], true)
148 }
149 ChatToolChoice::None => unreachable!("tool_choice none returned above"),
150 };
151
152 let tool_schemas_json = serde_json::to_string(&tools).map_err(|source| {
153 ChatRequestError::invalid_field(
154 "tools",
155 format!("tool schemas could not be serialized for the controller prompt: {source}"),
156 )
157 })?;
158
159 Ok(Some(Self {
160 config: config.clone(),
161 tools,
162 tool_schemas_json,
163 require_tool_call,
164 prompt_format: ToolPromptFormat::for_model(&request.model),
165 }))
166 }
167
168 pub fn config(&self) -> &ToolsConfig {
170 &self.config
171 }
172
173 pub fn max_retries(&self) -> u32 {
175 self.config.max_retries
176 }
177
178 pub fn marker_timeout(&self) -> Duration {
180 self.config.tool_call_marker_timeout
181 }
182
183 pub fn create_parser(&self) -> Result<Box<dyn ToolParser>, ToolCallValidationError> {
185 self.create_parser_for_format(self.prompt_format)
186 }
187
188 fn create_parser_for_format(
190 &self,
191 format: ToolPromptFormat,
192 ) -> Result<Box<dyn ToolParser>, ToolCallValidationError> {
193 let parser = match format {
194 ToolPromptFormat::HermesJson => LenientToolParser::create(&[]),
195 ToolPromptFormat::GlmXml => Glm47MoeToolParser::create(&self.vllm_tools()),
196 ToolPromptFormat::QwenXml => Qwen3XmlToolParser::create(&[]),
197 };
198 parser.map_err(|error| {
199 ToolCallValidationError::new(format!("tool parser could not be created: {error}"))
200 })
201 }
202
203 fn vllm_tools(&self) -> Vec<Tool> {
205 self.tools
206 .iter()
207 .map(|tool| {
208 let function = tool.function();
209 Tool {
210 name: function.name.clone(),
211 description: function.description.clone(),
212 parameters: function
213 .parameters
214 .as_ref()
215 .map(|schema| Value::Object(schema.as_map().clone()))
216 .unwrap_or_else(|| Value::Object(Map::new())),
217 strict: None,
218 }
219 })
220 .collect()
221 }
222
223 pub fn controller_message(&self) -> NormalizedChatMessage {
225 let requirement = if self.require_tool_call {
226 "You must call at least one tool. Do not answer the user directly. Output each tool call using this format and nothing else:"
227 } else {
228 "If tools are required, do not answer the user directly. Output each tool call using this format and nothing else:"
229 };
230 let optional_rule = if self.require_tool_call {
231 String::new()
232 } else {
233 format!("\n- If no tool is needed, answer normally and do not use {TOOL_CALL_START}.")
234 };
235
236 let content = match self.prompt_format {
237 ToolPromptFormat::HermesJson => {
238 self.hermes_controller_content(requirement, &optional_rule)
239 }
240 ToolPromptFormat::GlmXml => {
241 self.glm_xml_controller_content(requirement, &optional_rule)
242 }
243 ToolPromptFormat::QwenXml => {
244 self.qwen_xml_controller_content(requirement, &optional_rule)
245 }
246 };
247
248 NormalizedChatMessage::new("user", content)
252 }
253
254 pub fn correction_message(
256 &self,
257 validation_error: &str,
258 invalid_output: &str,
259 ) -> NormalizedChatMessage {
260 let invalid_output =
261 truncate_at_char_boundary(invalid_output, CORRECTION_INVALID_OUTPUT_MAX_BYTES);
262 let content = match self.prompt_format {
263 ToolPromptFormat::HermesJson => {
264 self.hermes_correction_content(validation_error, &invalid_output)
265 }
266 ToolPromptFormat::GlmXml => {
267 self.glm_xml_correction_content(validation_error, &invalid_output)
268 }
269 ToolPromptFormat::QwenXml => {
270 self.qwen_xml_correction_content(validation_error, &invalid_output)
271 }
272 };
273 NormalizedChatMessage::new("system", content)
274 }
275
276 fn hermes_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
278 format!(
279 "You have access to tools.\n\n{requirement}\n\nRequired tool-call format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid single-call example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nValid multi-call example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n- {{\"tool\":\"TOOL_NAME\",\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not put arguments directly after the tool name.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be valid JSON and must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
280 r#"{"name":"TOOL_NAME","arguments":{...}}"#,
281 r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
282 r#"{"name":"TOOL_NAME_1","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
283 r#"{"name":"TOOL_NAME_2","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
284 self.tool_schemas_json,
285 )
286 }
287
288 fn qwen_xml_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
290 format!(
291 "You have access to tools.\n\n{requirement}\n\nRequired Qwen XML-wrapped JSON tool-call format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nThere MUST be a newline immediately after {TOOL_CALL_START}. Inside each block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be valid JSON and must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
292 r#"{"name":"TOOL_NAME","arguments":{...}}"#,
293 r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
294 self.tool_schemas_json,
295 )
296 }
297
298 fn glm_xml_controller_content(&self, requirement: &str, optional_rule: &str) -> String {
300 format!(
301 "You have access to tools.\n\n{requirement}\n\nRequired GLM XML tool-call format:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block:\n- Start with the exact tool name as plain text.\n- Then output one <arg_key>/<arg_value> pair for each argument.\n- Put only the raw argument name inside <arg_key>.\n- Put only the raw argument value inside <arg_value>.\n- If an argument value is an object or array, put compact valid JSON inside <arg_value>.\n\nValid single-call example:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nValid multi-call example:\n\n{TOOL_CALL_START}TOOL_NAME_1\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n{TOOL_CALL_START}TOOL_NAME_2\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- {TOOL_CALL_START}{{\"name\":\"TOOL_NAME\",\"arguments\":{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one available tool name.\n- Do not output JSON inside {TOOL_CALL_START} except for object/array values inside <arg_value>.\n- Do not use function-call syntax like TOOL_NAME(...).\n- Do not use the Hermes JSON format with \"name\" and \"arguments\" keys.\n- Argument names and values must satisfy the tool schema.\n- Emit one marker block per tool call.\n- Do not include markdown fences.\n- Do not include explanations.{optional_rule}\n\nAvailable tools:\n{}",
302 self.tool_schemas_json,
303 )
304 }
305
306 fn hermes_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
308 format!(
309 "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block, output ONLY one valid JSON object with exactly these top-level keys:\n- \"name\": the tool name as a JSON string.\n- \"arguments\": a JSON object containing the tool arguments.\n\nValid example:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one of the available tools.\n- Always put the tool name in the JSON \"name\" field.\n- Always put tool arguments inside the JSON \"arguments\" object.\n- Do not put arguments directly after the tool name.\n- Do not use function-call syntax like TOOL_NAME(...).\n- arguments must be a JSON object.\n- arguments must satisfy the tool schema.\n- Do not include markdown fences.\n- Do not include explanations.\n- Do not answer the user directly.\n\nAvailable tools:\n{}",
310 r#"{"name":"TOOL_NAME","arguments":{...}}"#,
311 r#"{"name":"TOOL_NAME","arguments":{"ARGUMENT_NAME":"ARGUMENT_VALUE"}}"#,
312 self.tool_schemas_json,
313 )
314 }
315
316 fn qwen_xml_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
318 format!(
319 "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact Qwen XML-wrapped JSON format:\n\n{TOOL_CALL_START}\n{}\n{TOOL_CALL_END}\n\nThere MUST be a newline immediately after {TOOL_CALL_START}. Inside each block, output ONLY one valid JSON object with \"name\" and \"arguments\" top-level keys.\n\nAvailable tools:\n{}",
320 r#"{"name":"TOOL_NAME","arguments":{...}}"#, self.tool_schemas_json,
321 )
322 }
323
324 fn glm_xml_correction_content(&self, validation_error: &str, invalid_output: &str) -> String {
326 format!(
327 "Your previous response attempted a tool call, but it was invalid.\n\nValidation error:\n{validation_error}\n\nInvalid output:\n{invalid_output}\n\nYou must now return only valid tool calls and nothing else.\n\nUse this exact GLM XML format:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInside each {TOOL_CALL_START} block:\n- Start with the exact tool name as plain text.\n- Then output one <arg_key>/<arg_value> pair for each argument.\n- Put only the raw argument name inside <arg_key>.\n- Put only the raw argument value inside <arg_value>.\n\nValid example:\n\n{TOOL_CALL_START}TOOL_NAME\n<arg_key>ARGUMENT_NAME</arg_key>\n<arg_value>ARGUMENT_VALUE</arg_value>\n{TOOL_CALL_END}\n\nInvalid formats. NEVER use these:\n- {TOOL_CALL_START}TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}){TOOL_CALL_END}\n- {TOOL_CALL_START}TOOL_NAME{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}{TOOL_CALL_END}\n- {TOOL_CALL_START}{{\"name\":\"TOOL_NAME\",\"arguments\":{{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}}}}{TOOL_CALL_END}\n- TOOL_NAME({{\"ARGUMENT_NAME\":\"ARGUMENT_VALUE\"}})\n\nRules:\n- TOOL_NAME must exactly match one of the available tools.\n- Do not output JSON inside {TOOL_CALL_START} except for object/array values inside <arg_value>.\n- Do not use function-call syntax like TOOL_NAME(...).\n- Do not use the Hermes JSON format with \"name\" and \"arguments\" keys.\n- Argument names and values must satisfy the tool schema.\n- Do not include markdown fences.\n- Do not include explanations.\n- Do not answer the user directly.\n\nAvailable tools:\n{}",
328 self.tool_schemas_json,
329 )
330 }
331
332 pub fn classify_assistant_output(&self, output: &str) -> ToolOutputClassification {
338 if output.len() > self.config.tool_call_max_bytes {
339 return ToolOutputClassification::InvalidToolCall {
340 error: ToolCallValidationError::new(format!(
341 "assistant output exceeded the tool call max size of {} bytes",
342 self.config.tool_call_max_bytes
343 )),
344 invalid_output: output.to_owned(),
345 };
346 }
347
348 let result = self.parse_tool_calls(output);
349
350 match result {
351 Ok(tool_calls) if tool_calls.is_empty() => {
352 if self.require_tool_call {
353 ToolOutputClassification::InvalidToolCall {
354 error: ToolCallValidationError::new(
355 "expected the assistant response to include a tool call",
356 ),
357 invalid_output: output.to_owned(),
358 }
359 } else {
360 ToolOutputClassification::NormalText
361 }
362 }
363 Ok(tool_calls) => ToolOutputClassification::ToolCalls(tool_calls),
364 Err(error) => ToolOutputClassification::InvalidToolCall {
365 error,
366 invalid_output: output.to_owned(),
367 },
368 }
369 }
370
371 fn parse_tool_calls(
373 &self,
374 output: &str,
375 ) -> Result<Vec<ValidatedToolCall>, ToolCallValidationError> {
376 let result = self.parse_tool_calls_with_format(self.prompt_format, output);
377 if self.prompt_format == ToolPromptFormat::HermesJson {
378 return result;
379 }
380
381 match result {
382 Ok(tool_calls) if tool_calls.is_empty() && output.contains(TOOL_CALL_START) => {
383 if let Some(fallback_calls) = self.hermes_fallback_tool_calls(output) {
384 Ok(fallback_calls)
385 } else {
386 Ok(tool_calls)
387 }
388 }
389 Err(error) => {
390 if let Some(fallback_calls) = self.hermes_fallback_tool_calls(output) {
391 Ok(fallback_calls)
392 } else {
393 Err(error)
394 }
395 }
396 result => result,
397 }
398 }
399
400 fn hermes_fallback_tool_calls(&self, output: &str) -> Option<Vec<ValidatedToolCall>> {
402 match self.parse_tool_calls_with_format(ToolPromptFormat::HermesJson, output) {
403 Ok(tool_calls) if !tool_calls.is_empty() => Some(tool_calls),
404 _ => None,
405 }
406 }
407
408 fn parse_tool_calls_with_format(
410 &self,
411 format: ToolPromptFormat,
412 output: &str,
413 ) -> Result<Vec<ValidatedToolCall>, ToolCallValidationError> {
414 self.create_parser_for_format(format)
415 .and_then(|mut parser| {
416 parser.parse_complete(output).map_err(|error| {
417 ToolCallValidationError::new(format!("tool call parsing failed: {error}"))
418 })
419 })
420 .and_then(|result| {
421 result
422 .calls
423 .iter()
424 .map(|call| self.validate_tool_call(call))
425 .collect::<Result<Vec<_>, _>>()
426 })
427 }
428
429 fn validate_tool_call(
431 &self,
432 call: &ToolCallDelta,
433 ) -> Result<ValidatedToolCall, ToolCallValidationError> {
434 let name = call.name.as_deref().unwrap_or_default();
435 if name.trim().is_empty() {
436 return Err(ToolCallValidationError::new(
437 "tool call name must not be empty",
438 ));
439 }
440 let tool = self
441 .tools
442 .iter()
443 .find(|tool| tool.name() == name)
444 .ok_or_else(|| ToolCallValidationError::new(format!("unknown tool name {name:?}")))?;
445
446 let arguments: Value = serde_json::from_str(&call.arguments).map_err(|source| {
447 ToolCallValidationError::new(format!("tool call arguments JSON is invalid: {source}"))
448 })?;
449 if !arguments.is_object() {
450 return Err(ToolCallValidationError::new(
451 "tool call arguments must be a JSON object",
452 ));
453 }
454
455 if self.config.validate_json_schema
456 && let Some(schema) = tool.parameters_schema()
457 {
458 validate_value_against_schema(&arguments, schema, "arguments").map_err(|message| {
459 ToolCallValidationError::new(format!(
460 "tool call arguments do not satisfy schema: {message}"
461 ))
462 })?;
463 }
464
465 let arguments_json = serde_json::to_string(&arguments).map_err(|source| {
466 ToolCallValidationError::new(format!(
467 "tool call arguments could not be serialized as JSON: {source}"
468 ))
469 })?;
470
471 Ok(ValidatedToolCall {
472 id: generate_tool_call_id(),
473 name: name.to_owned(),
474 arguments_json,
475 })
476 }
477}
478
479fn truncate_at_char_boundary(text: &str, max_bytes: usize) -> std::borrow::Cow<'_, str> {
481 if text.len() <= max_bytes {
482 return std::borrow::Cow::Borrowed(text);
483 }
484 let mut end = max_bytes;
485 while !text.is_char_boundary(end) {
486 end -= 1;
487 }
488 std::borrow::Cow::Owned(format!("{} [output truncated]", &text[..end]))
489}
490
491struct LenientToolParser {
502 inner: Box<dyn ToolParser>,
503 args_scanner: ArgsCompletenessScanner,
506 drained: bool,
508}
509
510impl ToolParser for LenientToolParser {
511 fn create(tools: &[Tool]) -> ToolParserResult<Box<dyn ToolParser>> {
513 Ok(Box::new(Self {
514 inner: HermesToolParser::create(tools)?,
515 args_scanner: ArgsCompletenessScanner::default(),
516 drained: false,
517 }))
518 }
519
520 fn parse_into(&mut self, chunk: &str, output: &mut ToolParseResult) -> ToolParserResult<()> {
522 output.append(self.push(chunk)?);
523 Ok(())
524 }
525
526 fn push(&mut self, chunk: &str) -> ToolParserResult<ToolParseResult> {
528 let mut merged = ToolParseResult::default();
529 if self.drained {
530 return Ok(merged);
531 }
532 for piece in split_before_tag_starts(chunk) {
533 match self.inner.push(piece) {
534 Ok(result) => {
535 self.args_scanner.track(&result);
536 merged.normal_text.push_str(&result.normal_text);
537 merged.calls.extend(result.calls);
538 }
539 Err(error) => {
540 if !self.args_scanner.complete() {
541 return Err(error);
542 }
543 warn!(%error, "ignoring trailing output after a complete tool call");
547 self.drained = true;
548 break;
549 }
550 }
551 }
552 Ok(merged)
553 }
554
555 fn finish(&mut self) -> ToolParserResult<ToolParseResult> {
557 if self.drained {
558 return Ok(ToolParseResult::default());
559 }
560 let error = match self.inner.finish() {
561 Ok(result) => return Ok(result),
562 Err(error) => error,
563 };
564 let Ok(mut recovered) = self.inner.push(TOOL_CALL_END) else {
567 return Err(error);
568 };
569 let Ok(finished) = self.inner.finish() else {
570 return Err(error);
571 };
572 recovered.normal_text.push_str(&finished.normal_text);
573 recovered.calls.extend(finished.calls);
574 Ok(recovered)
575 }
576
577 fn reset(&mut self) -> String {
579 self.args_scanner = ArgsCompletenessScanner::default();
580 self.drained = false;
581 self.inner.reset()
582 }
583}
584
585fn split_before_tag_starts(text: &str) -> Vec<&str> {
588 let mut pieces = Vec::new();
589 let mut start = 0;
590 for (index, _) in text.match_indices('<') {
591 if index > start {
592 pieces.push(&text[start..index]);
593 }
594 start = index;
595 }
596 if start < text.len() {
597 pieces.push(&text[start..]);
598 }
599 pieces
600}
601
602#[derive(Debug, Default)]
606struct ArgsCompletenessScanner {
607 started: bool,
608 depth: u32,
609 in_string: bool,
610 escaped: bool,
611}
612
613impl ArgsCompletenessScanner {
614 fn track(&mut self, result: &ToolParseResult) {
616 for call in &result.calls {
617 if call.name.is_some() {
618 *self = Self::default();
619 }
620 self.feed(&call.arguments);
621 }
622 }
623
624 fn feed(&mut self, fragment: &str) {
626 for ch in fragment.chars() {
627 if self.escaped {
628 self.escaped = false;
629 continue;
630 }
631 if self.in_string {
632 match ch {
633 '\\' => self.escaped = true,
634 '"' => self.in_string = false,
635 _ => {}
636 }
637 continue;
638 }
639 match ch {
640 '"' => {
641 self.in_string = true;
642 self.started = true;
643 }
644 '{' | '[' => {
645 self.depth += 1;
646 self.started = true;
647 }
648 '}' | ']' => self.depth = self.depth.saturating_sub(1),
649 ch if !ch.is_whitespace() => self.started = true,
650 _ => {}
651 }
652 }
653 }
654
655 fn complete(&self) -> bool {
657 self.started && self.depth == 0 && !self.in_string
658 }
659}
660
661#[derive(Debug, Clone, PartialEq, Eq)]
663pub enum ToolOutputClassification {
664 NormalText,
665 ToolCalls(Vec<ValidatedToolCall>),
666 InvalidToolCall {
667 error: ToolCallValidationError,
668 invalid_output: String,
669 },
670}
671
672#[derive(Debug, Clone, PartialEq, Eq)]
674pub struct ValidatedToolCall {
675 pub id: String,
676 pub name: String,
677 pub arguments_json: String,
678}
679
680impl ValidatedToolCall {
681 pub fn to_openai_value(&self) -> Value {
683 serde_json::json!({
684 "id": self.id,
685 "type": "function",
686 "function": {
687 "name": self.name,
688 "arguments": self.arguments_json,
689 },
690 })
691 }
692}
693
694#[derive(Debug, Clone, PartialEq, Eq, Error)]
696#[error("{message}")]
697pub struct ToolCallValidationError {
698 message: String,
699}
700
701impl ToolCallValidationError {
702 fn new(message: impl Into<String>) -> Self {
704 Self {
705 message: message.into(),
706 }
707 }
708
709 pub fn message(&self) -> &str {
711 &self.message
712 }
713}
714
715fn validate_schema_shape(schema: &Map<String, Value>) -> Result<(), String> {
717 validate_schema_object_shape(schema, "schema")
718}
719
720fn validate_schema_object_shape(object: &Map<String, Value>, path: &str) -> Result<(), String> {
722 if let Some(kind) = object.get("type") {
723 validate_schema_type_shape(kind, &format!("{path}.type"))?;
724 }
725 if let Some(required) = object.get("required") {
726 let required = required
727 .as_array()
728 .ok_or_else(|| format!("{path}.required must be an array"))?;
729 if required.iter().any(|value| !value.is_string()) {
730 return Err(format!("{path}.required must contain only strings"));
731 }
732 }
733 if let Some(properties) = object.get("properties") {
734 let properties = properties
735 .as_object()
736 .ok_or_else(|| format!("{path}.properties must be an object"))?;
737 for (name, schema) in properties {
738 let schema = schema
739 .as_object()
740 .ok_or_else(|| format!("{path}.properties.{name} must be an object"))?;
741 validate_schema_object_shape(schema, &format!("{path}.properties.{name}"))?;
742 }
743 }
744 if let Some(items) = object.get("items") {
745 let items = items
746 .as_object()
747 .ok_or_else(|| format!("{path}.items must be an object"))?;
748 validate_schema_object_shape(items, &format!("{path}.items"))?;
749 }
750 if let Some(additional) = object.get("additionalProperties") {
751 match additional {
752 Value::Bool(_) => {}
753 Value::Object(additional) => {
754 validate_schema_object_shape(additional, &format!("{path}.additionalProperties"))?
755 }
756 _ => {
757 return Err(format!(
758 "{path}.additionalProperties must be a boolean or object"
759 ));
760 }
761 }
762 }
763 if let Some(enum_values) = object.get("enum")
764 && !enum_values.is_array()
765 {
766 return Err(format!("{path}.enum must be an array"));
767 }
768 Ok(())
769}
770
771fn validate_schema_type_shape(value: &Value, path: &str) -> Result<(), String> {
773 match value {
774 Value::String(kind) => validate_schema_type_name(kind, path),
775 Value::Array(kinds) => {
776 if kinds.is_empty() {
777 return Err(format!("{path} must not be an empty array"));
778 }
779 for kind in kinds {
780 let kind = kind
781 .as_str()
782 .ok_or_else(|| format!("{path} array must contain only strings"))?;
783 validate_schema_type_name(kind, path)?;
784 }
785 Ok(())
786 }
787 _ => Err(format!("{path} must be a string or array of strings")),
788 }
789}
790
791fn validate_schema_type_name(kind: &str, path: &str) -> Result<(), String> {
793 match kind {
794 "object" | "array" | "string" | "integer" | "number" | "boolean" | "null" => Ok(()),
795 other => Err(format!(
796 "{path} contains unsupported JSON schema type {other:?}"
797 )),
798 }
799}
800
801fn validate_value_against_schema(
803 value: &Value,
804 schema: &Map<String, Value>,
805 path: &str,
806) -> Result<(), String> {
807 if let Some(enum_values) = schema.get("enum").and_then(Value::as_array)
808 && !enum_values.iter().any(|enum_value| enum_value == value)
809 {
810 return Err(format!("{path} is not one of the allowed enum values"));
811 }
812
813 if let Some(kind) = schema.get("type")
814 && !schema_type_matches(value, kind)
815 {
816 return Err(format!(
817 "{path} expected type {}, got {}",
818 schema_type_description(kind),
819 value_kind(value)
820 ));
821 }
822
823 if schema_implies_object(schema) {
824 let object = value
825 .as_object()
826 .ok_or_else(|| format!("{path} expected object, got {}", value_kind(value)))?;
827 if let Some(required) = schema.get("required").and_then(Value::as_array) {
828 for field in required.iter().filter_map(Value::as_str) {
829 if !object.contains_key(field) {
830 return Err(format!("{path}.{field} is required"));
831 }
832 }
833 }
834 let properties = schema.get("properties").and_then(Value::as_object);
835 if let Some(properties) = properties {
836 for (field, property_schema) in properties {
837 if let Some(property_value) = object.get(field) {
838 let property_path = format!("{path}.{field}");
839 let property_schema = schema_value_as_object(property_schema, &property_path)?;
840 validate_value_against_schema(property_value, property_schema, &property_path)?;
841 }
842 }
843 }
844 if let Some(additional) = schema.get("additionalProperties") {
845 match additional {
846 Value::Bool(false) => {
847 for field in object.keys() {
848 if properties.is_none_or(|properties| !properties.contains_key(field)) {
849 return Err(format!("{path}.{field} is not allowed by schema"));
850 }
851 }
852 }
853 Value::Object(additional_schema) => {
854 for (field, additional_value) in object {
855 if properties.is_none_or(|properties| !properties.contains_key(field)) {
856 validate_value_against_schema(
857 additional_value,
858 additional_schema,
859 &format!("{path}.{field}"),
860 )?;
861 }
862 }
863 }
864 _ => {}
865 }
866 }
867 }
868
869 if schema_implies_array(schema) {
870 let array = value
871 .as_array()
872 .ok_or_else(|| format!("{path} expected array, got {}", value_kind(value)))?;
873 if let Some(items_schema) = schema.get("items") {
874 for (index, item) in array.iter().enumerate() {
875 let item_path = format!("{path}[{index}]");
876 let items_schema = schema_value_as_object(items_schema, &item_path)?;
877 validate_value_against_schema(item, items_schema, &item_path)?;
878 }
879 }
880 }
881
882 Ok(())
883}
884
885fn schema_value_as_object<'a>(
887 schema: &'a Value,
888 path: &str,
889) -> Result<&'a Map<String, Value>, String> {
890 schema
891 .as_object()
892 .ok_or_else(|| format!("{path} schema must be an object"))
893}
894
895fn schema_implies_object(schema: &Map<String, Value>) -> bool {
897 schema
898 .get("type")
899 .is_some_and(|kind| schema_type_includes(kind, "object"))
900 || schema.contains_key("properties")
901 || schema.contains_key("required")
902 || schema.contains_key("additionalProperties")
903}
904
905fn schema_implies_array(schema: &Map<String, Value>) -> bool {
907 schema
908 .get("type")
909 .is_some_and(|kind| schema_type_includes(kind, "array"))
910 || schema.contains_key("items")
911}
912
913fn schema_type_matches(value: &Value, kind: &Value) -> bool {
915 match kind {
916 Value::String(kind) => value_matches_schema_type(value, kind),
917 Value::Array(kinds) => kinds
918 .iter()
919 .filter_map(Value::as_str)
920 .any(|kind| value_matches_schema_type(value, kind)),
921 _ => true,
922 }
923}
924
925fn schema_type_includes(kind: &Value, expected: &str) -> bool {
927 match kind {
928 Value::String(kind) => kind == expected,
929 Value::Array(kinds) => kinds
930 .iter()
931 .filter_map(Value::as_str)
932 .any(|kind| kind == expected),
933 _ => false,
934 }
935}
936
937fn value_matches_schema_type(value: &Value, kind: &str) -> bool {
939 match kind {
940 "object" => value.is_object(),
941 "array" => value.is_array(),
942 "string" => value.is_string(),
943 "integer" => value.as_i64().is_some() || value.as_u64().is_some(),
944 "number" => value.is_number(),
945 "boolean" => value.is_boolean(),
946 "null" => value.is_null(),
947 _ => true,
948 }
949}
950
951fn schema_type_description(kind: &Value) -> String {
953 match kind {
954 Value::String(kind) => kind.clone(),
955 Value::Array(kinds) => kinds
956 .iter()
957 .filter_map(Value::as_str)
958 .collect::<Vec<_>>()
959 .join(" or "),
960 _ => "unknown".to_owned(),
961 }
962}
963
964fn value_kind(value: &Value) -> &'static str {
966 match value {
967 Value::Null => "null",
968 Value::Bool(_) => "boolean",
969 Value::Number(_) => "number",
970 Value::String(_) => "string",
971 Value::Array(_) => "array",
972 Value::Object(_) => "object",
973 }
974}
975
976#[cfg(test)]
977mod tests {
978 use serde_json::json;
979
980 use super::*;
981 use crate::config::ToolsConfig;
982
983 fn request_with_tool(arguments_schema: Value) -> ChatCompletionRequest {
984 request_with_tool_for_model("e2ee-test", arguments_schema)
985 }
986
987 fn request_with_tool_for_model(model: &str, arguments_schema: Value) -> ChatCompletionRequest {
988 ChatCompletionRequest::parse(&json!({
989 "model": model,
990 "messages": [{"role":"user", "content":"hi"}],
991 "tools": [{
992 "type": "function",
993 "function": {
994 "name": "search_web",
995 "description": "Search the web",
996 "parameters": arguments_schema
997 }
998 }]
999 }))
1000 .expect("request should parse")
1001 }
1002
1003 fn context_for_request(request: &ChatCompletionRequest) -> ToolEmulationContext {
1004 ToolEmulationContext::from_request(&ToolsConfig::default(), request)
1005 .expect("tool context should build")
1006 .expect("tools should activate")
1007 }
1008
1009 #[test]
1010 fn classifies_valid_hermes_tool_call() {
1011 let request = request_with_tool(json!({
1012 "type": "object",
1013 "properties": {"query": {"type": "string"}},
1014 "required": ["query"],
1015 "additionalProperties": false
1016 }));
1017 let context = context_for_request(&request);
1018
1019 let classification = context.classify_assistant_output(
1020 "\n<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"Venice\"}}\n</tool_call>\n",
1021 );
1022
1023 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1024 panic!("expected valid tool call");
1025 };
1026 assert_eq!(tool_calls.len(), 1);
1027 assert!(tool_calls[0].id.starts_with("call_"));
1028 assert_eq!(tool_calls[0].name, "search_web");
1029 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1030 }
1031
1032 #[test]
1033 fn classifies_glm_xml_tool_call_for_glm_models() {
1034 let request = request_with_tool_for_model(
1035 "e2ee-glm-5-1",
1036 json!({
1037 "type": "object",
1038 "properties": {"query": {"type": "string"}},
1039 "required": ["query"],
1040 "additionalProperties": false
1041 }),
1042 );
1043 let context = context_for_request(&request);
1044
1045 let classification = context.classify_assistant_output(
1046 "<tool_call>search_web\n<arg_key>query</arg_key><arg_value>Venice</arg_value></tool_call>",
1047 );
1048
1049 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1050 panic!("expected valid GLM XML tool call");
1051 };
1052 assert_eq!(tool_calls.len(), 1);
1053 assert_eq!(tool_calls[0].name, "search_web");
1054 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1055 assert!(context.controller_message().content.contains("<arg_key>"));
1056 }
1057
1058 #[test]
1059 fn classifies_qwen_xml_tool_call_for_qwen_models() {
1060 let request = request_with_tool_for_model(
1061 "e2ee-qwen3-30b-a3b-p",
1062 json!({
1063 "type": "object",
1064 "properties": {"query": {"type": "string"}},
1065 "required": ["query"],
1066 "additionalProperties": false
1067 }),
1068 );
1069 let context = context_for_request(&request);
1070
1071 let classification = context.classify_assistant_output(
1072 "<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"Venice\"}}\n</tool_call>",
1073 );
1074
1075 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1076 panic!("expected valid Qwen XML-wrapped JSON tool call");
1077 };
1078 assert_eq!(tool_calls.len(), 1);
1079 assert_eq!(tool_calls[0].name, "search_web");
1080 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"Venice\"}");
1081 assert!(
1082 context
1083 .controller_message()
1084 .content
1085 .contains("Qwen XML-wrapped JSON")
1086 );
1087 }
1088
1089 #[test]
1090 fn rejects_invalid_json_unknown_tool_and_schema_mismatch() {
1091 let request = request_with_tool(json!({
1092 "type": "object",
1093 "properties": {"query": {"type": "string"}},
1094 "required": ["query"],
1095 "additionalProperties": false
1096 }));
1097 let context = context_for_request(&request);
1098
1099 let ToolOutputClassification::InvalidToolCall { error, .. } = context
1102 .classify_assistant_output(
1103 "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"x\",}}</tool_call>",
1104 )
1105 else {
1106 panic!("expected invalid JSON to be rejected");
1107 };
1108 assert!(error.message().contains("JSON is invalid"));
1109
1110 let ToolOutputClassification::InvalidToolCall { error, .. } = context
1111 .classify_assistant_output(
1112 "<tool_call>{\"name\":\"unknown\",\"arguments\":{\"query\":\"x\"}}</tool_call>",
1113 )
1114 else {
1115 panic!("expected unknown tool to be rejected");
1116 };
1117 assert!(error.message().contains("unknown tool name"));
1118
1119 let ToolOutputClassification::InvalidToolCall { error, .. } = context
1120 .classify_assistant_output(
1121 "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"q\":\"x\"}}</tool_call>",
1122 )
1123 else {
1124 panic!("expected schema mismatch to be rejected");
1125 };
1126 assert!(error.message().contains("arguments.query is required"));
1127 }
1128
1129 #[test]
1130 fn recovers_tool_call_with_truncated_closing_marker() {
1131 let request = request_with_tool(json!({"type": "object"}));
1135 let context = context_for_request(&request);
1136
1137 let classification = context.classify_assistant_output(
1138 "<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}\n",
1139 );
1140
1141 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1142 panic!("expected truncated closing marker to be recovered, got {classification:?}");
1143 };
1144 assert_eq!(tool_calls.len(), 1);
1145 assert_eq!(tool_calls[0].name, "search_web");
1146 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1147 }
1148
1149 #[test]
1150 fn ignores_trailing_garbage_after_complete_tool_call() {
1151 let request = request_with_tool_for_model("e2ee-glm-5-1", json!({"type": "object"}));
1154 let context = context_for_request(&request);
1155
1156 for output in [
1157 "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</arg_value>",
1158 "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</arg_value></tool_call>",
1159 ] {
1160 let classification = context.classify_assistant_output(output);
1161 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1162 panic!(
1163 "expected trailing garbage to be ignored for {output:?}, got {classification:?}"
1164 );
1165 };
1166 assert_eq!(tool_calls.len(), 1);
1167 assert_eq!(tool_calls[0].name, "search_web");
1168 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1169 }
1170 }
1171
1172 #[test]
1173 fn classifies_output_truncated_mid_json_as_invalid_tool_call() {
1174 let request = request_with_tool(json!({"type": "object"}));
1175 let context = context_for_request(&request);
1176
1177 let classification =
1178 context.classify_assistant_output("<tool_call>{\"name\":\"search_web\",\"argu");
1179
1180 let ToolOutputClassification::InvalidToolCall { error, .. } = classification else {
1181 panic!("expected mid-JSON truncation to be invalid, got {classification:?}");
1182 };
1183 assert!(error.message().contains("tool call parsing failed"));
1184 }
1185
1186 #[test]
1187 fn classifies_plain_text_and_enforces_required_tool_call() {
1188 let request = request_with_tool(json!({"type": "object"}));
1189 let context = context_for_request(&request);
1190 assert_eq!(
1191 context.classify_assistant_output("Hello, world!"),
1192 ToolOutputClassification::NormalText
1193 );
1194
1195 let request = ChatCompletionRequest::parse(&json!({
1196 "model": "e2ee-test",
1197 "messages": [{"role":"user", "content":"hi"}],
1198 "tool_choice": "required",
1199 "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1200 }))
1201 .expect("request should parse");
1202 let context = context_for_request(&request);
1203
1204 let ToolOutputClassification::InvalidToolCall { error, .. } =
1205 context.classify_assistant_output("Hello, world!")
1206 else {
1207 panic!("expected missing required tool call to be invalid");
1208 };
1209 assert!(error.message().contains("expected the assistant response"));
1210 }
1211
1212 #[test]
1213 fn classifies_mixed_text_and_tool_call_as_tool_calls() {
1214 let request = request_with_tool(json!({"type": "object"}));
1215 let context = context_for_request(&request);
1216
1217 let classification = context.classify_assistant_output(
1218 "Let me check.\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</tool_call>",
1219 );
1220
1221 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1222 panic!("expected mixed output to classify as tool calls");
1223 };
1224 assert_eq!(tool_calls.len(), 1);
1225 }
1226
1227 #[test]
1228 fn classifies_multiple_tool_calls_regardless_of_parallel_tool_calls() {
1229 let request = ChatCompletionRequest::parse(&json!({
1230 "model": "e2ee-test",
1231 "messages": [{"role":"user", "content":"hi"}],
1232 "parallel_tool_calls": false,
1233 "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1234 }))
1235 .expect("request should parse");
1236 let context = context_for_request(&request);
1237
1238 let classification = context.classify_assistant_output(
1241 "<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"a\"}}</tool_call>\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"b\"}}</tool_call>",
1242 );
1243 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1244 panic!("expected two valid tool calls");
1245 };
1246 assert_eq!(tool_calls.len(), 2);
1247 assert_eq!(tool_calls[0].arguments_json, "{\"query\":\"a\"}");
1248 assert_eq!(tool_calls[1].arguments_json, "{\"query\":\"b\"}");
1249 assert_ne!(tool_calls[0].id, tool_calls[1].id);
1250 }
1251
1252 #[test]
1253 fn rejects_oversized_assistant_output() {
1254 let request = request_with_tool(json!({"type": "object"}));
1255 let config = ToolsConfig {
1256 tool_call_max_bytes: 32,
1257 ..ToolsConfig::default()
1258 };
1259 let context = ToolEmulationContext::from_request(&config, &request)
1260 .expect("tool context should build")
1261 .expect("tools should activate");
1262
1263 let ToolOutputClassification::InvalidToolCall { error, .. } =
1264 context.classify_assistant_output(&"x".repeat(33))
1265 else {
1266 panic!("expected oversized output to be invalid");
1267 };
1268 assert!(error.message().contains("max size of 32 bytes"));
1269 }
1270
1271 #[test]
1272 fn can_disable_schema_validation_explicitly() {
1273 let request = request_with_tool(json!({
1274 "type": "object",
1275 "required": ["query"]
1276 }));
1277 let config = ToolsConfig {
1278 validate_json_schema: false,
1279 ..ToolsConfig::default()
1280 };
1281 let context = ToolEmulationContext::from_request(&config, &request)
1282 .expect("tool context should build")
1283 .expect("tools should activate");
1284
1285 let classification = context.classify_assistant_output(
1286 "<tool_call>{\"name\":\"search_web\",\"arguments\":{}}</tool_call>",
1287 );
1288 let ToolOutputClassification::ToolCalls(tool_calls) = classification else {
1289 panic!("schema mismatch should be allowed when validation is disabled");
1290 };
1291 assert_eq!(tool_calls[0].arguments_json, "{}");
1292 }
1293
1294 #[test]
1295 fn rejects_non_object_arguments() {
1296 let request = request_with_tool(json!({"type": "object"}));
1297 let context = context_for_request(&request);
1298
1299 let ToolOutputClassification::InvalidToolCall { error, .. } = context
1303 .classify_assistant_output(
1304 "<tool_call>{\"name\":\"search_web\",\"arguments\":[]}</tool_call>",
1305 )
1306 else {
1307 panic!("expected non-object arguments to be rejected");
1308 };
1309 assert!(error.message().contains("tool call parsing failed"));
1310
1311 let error = context
1314 .validate_tool_call(&ToolCallDelta {
1315 tool_index: 0,
1316 name: Some("search_web".to_owned()),
1317 arguments: "[]".to_owned(),
1318 })
1319 .unwrap_err();
1320 assert!(error.message().contains("arguments must be a JSON object"));
1321 }
1322
1323 #[test]
1324 fn builds_controller_and_retry_prompts() {
1325 let request = ChatCompletionRequest::parse(&json!({
1326 "model": "e2ee-test",
1327 "messages": [{"role":"user", "content":"hi"}],
1328 "tool_choice": "required",
1329 "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1330 }))
1331 .expect("request should parse");
1332 let context = context_for_request(&request);
1333
1334 let controller = context.controller_message();
1335 assert_eq!(controller.role, "user");
1336 assert!(
1337 controller
1338 .content
1339 .contains("You must call at least one tool")
1340 );
1341 assert!(
1342 controller
1343 .content
1344 .contains("Emit one marker block per tool call")
1345 );
1346 assert!(controller.content.contains("<tool_call>"));
1347 assert!(controller.content.contains("search_web"));
1348
1349 let correction = context.correction_message("bad name", "<tool_call>{}</tool_call>");
1350 assert_eq!(correction.role, "system");
1351 assert!(correction.content.contains("Validation error:\nbad name"));
1352 assert!(
1353 correction
1354 .content
1355 .contains("Invalid output:\n<tool_call>{}</tool_call>")
1356 );
1357 assert!(
1358 correction
1359 .content
1360 .contains("You must now return only valid tool calls")
1361 );
1362
1363 let optional_request = ChatCompletionRequest::parse(&json!({
1364 "model": "e2ee-test",
1365 "messages": [{"role":"user", "content":"hi"}],
1366 "tools": [{"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}}]
1367 }))
1368 .expect("request should parse");
1369 let optional = context_for_request(&optional_request);
1370 assert!(
1371 optional
1372 .controller_message()
1373 .content
1374 .contains("If no tool is needed, answer normally")
1375 );
1376 }
1377
1378 #[test]
1379 fn correction_prompt_truncates_oversized_invalid_output() {
1380 let request = request_with_tool(json!({"type": "object"}));
1381 let context = context_for_request(&request);
1382
1383 let oversized = "x".repeat(CORRECTION_INVALID_OUTPUT_MAX_BYTES + 1);
1384 let correction = context.correction_message("error", &oversized);
1385 assert!(correction.content.contains("[output truncated]"));
1386 assert!(!correction.content.contains(&oversized));
1387
1388 let short = context.correction_message("error", "<tool_call>{}</tool_call>");
1389 assert!(!short.content.contains("[output truncated]"));
1390 }
1391
1392 #[test]
1393 fn specific_tool_choice_filters_available_tools() {
1394 let request = ChatCompletionRequest::parse(&json!({
1395 "model": "e2ee-test",
1396 "messages": [{"role":"user", "content":"hi"}],
1397 "tool_choice": {"type":"function", "function":{"name":"search_web"}},
1398 "tools": [
1399 {"type":"function", "function":{"name":"search_web", "parameters":{"type":"object"}}},
1400 {"type":"function", "function":{"name":"other", "parameters":{"type":"object"}}}
1401 ]
1402 }))
1403 .expect("request should parse");
1404 let context = context_for_request(&request);
1405
1406 assert!(context.controller_message().content.contains("search_web"));
1407 assert!(!context.controller_message().content.contains("other"));
1408 }
1409}