Skip to main content

tiy_core/transform/
messages.rs

1//! Message transformation for cross-provider compatibility.
2
3use crate::types::*;
4
5/// Transform messages for cross-provider compatibility.
6///
7/// This handles:
8/// - Thinking block conversion between providers
9/// - ToolCall ID normalization
10/// - Orphan tool call handling
11pub fn transform_messages(
12    messages: &[Message],
13    target_model: &Model,
14    normalize_tool_call_id: Option<fn(&str) -> String>,
15) -> Vec<Message> {
16    let mut result = Vec::new();
17
18    for msg in messages {
19        match msg {
20            Message::User(user_msg) => {
21                result.push(Message::User(user_msg.clone()));
22            }
23            Message::Assistant(assistant_msg) => {
24                // Skip error/aborted messages
25                if assistant_msg.stop_reason == StopReason::Error
26                    || assistant_msg.stop_reason == StopReason::Aborted
27                {
28                    continue;
29                }
30
31                let transformed = transform_assistant_message(
32                    assistant_msg,
33                    target_model,
34                    normalize_tool_call_id,
35                );
36                result.push(Message::Assistant(transformed));
37            }
38            Message::ToolResult(tool_result) => {
39                let mut result_msg = tool_result.clone();
40                if let Some(normalize) = normalize_tool_call_id {
41                    result_msg.tool_call_id = normalize(&result_msg.tool_call_id);
42                }
43                result.push(Message::ToolResult(result_msg));
44            }
45        }
46    }
47
48    // Handle orphan tool calls
49    handle_orphan_tool_calls(&mut result);
50
51    result
52}
53
54fn transform_assistant_message(
55    msg: &AssistantMessage,
56    target_model: &Model,
57    normalize_tool_call_id: Option<fn(&str) -> String>,
58) -> AssistantMessage {
59    let mut new_msg = msg.clone();
60
61    // Transform content blocks
62    new_msg.content = msg
63        .content
64        .iter()
65        .map(|block| match block {
66            ContentBlock::Thinking(thinking) => {
67                // Convert thinking blocks based on target provider
68                transform_thinking_block(thinking, &msg.provider, &msg.model, target_model)
69            }
70            ContentBlock::ToolCall(tc) => {
71                let mut new_tc = tc.clone();
72                if let Some(normalize) = normalize_tool_call_id {
73                    new_tc.id = normalize(&new_tc.id);
74                }
75                ContentBlock::ToolCall(new_tc)
76            }
77            _ => block.clone(),
78        })
79        .collect();
80
81    if let Some(api) = &target_model.api {
82        new_msg.api = api.clone();
83    }
84    new_msg.provider = target_model.provider.clone();
85    new_msg.model = target_model.id.clone();
86
87    new_msg
88}
89
90fn transform_thinking_block(
91    thinking: &ThinkingContent,
92    source_provider: &Provider,
93    source_model: &str,
94    target_model: &Model,
95) -> ContentBlock {
96    // If same provider and model, keep thinking block
97    if *source_provider == target_model.provider && source_model == target_model.id {
98        return ContentBlock::Thinking(thinking.clone());
99    }
100
101    // For different providers, convert to text
102    // This avoids having the model mimic thinking tags
103    if thinking.thinking.trim().is_empty() {
104        ContentBlock::Text(TextContent::new(""))
105    } else {
106        ContentBlock::Text(TextContent::new(format!(
107            "[Reasoning]\n{}\n[/Reasoning]",
108            thinking.thinking
109        )))
110    }
111}
112
113fn handle_orphan_tool_calls(messages: &mut Vec<Message>) {
114    // Find tool calls without corresponding results
115    let mut tool_call_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
116    let mut result_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
117
118    for msg in messages.iter() {
119        match msg {
120            Message::Assistant(assistant) => {
121                for block in &assistant.content {
122                    if let ContentBlock::ToolCall(tc) = block {
123                        tool_call_ids.insert(tc.id.clone());
124                    }
125                }
126            }
127            Message::ToolResult(result) => {
128                result_ids.insert(result.tool_call_id.clone());
129            }
130            _ => {}
131        }
132    }
133
134    // Find orphan IDs
135    let orphan_ids: std::collections::HashSet<String> =
136        tool_call_ids.difference(&result_ids).cloned().collect();
137
138    if orphan_ids.is_empty() {
139        return;
140    }
141
142    // Insert synthetic error results for orphan tool calls
143    let mut new_messages = Vec::new();
144    for msg in messages.iter() {
145        new_messages.push(msg.clone());
146        if let Message::Assistant(assistant) = msg {
147            let orphan_calls: Vec<_> = assistant
148                .content
149                .iter()
150                .filter_map(|b| b.as_tool_call())
151                .filter(|tc| orphan_ids.contains(&tc.id))
152                .collect();
153
154            for tc in orphan_calls {
155                let error_result = ToolResultMessage::error(
156                    tc.id.clone(),
157                    tc.name.clone(),
158                    "Tool call was not executed (orphaned)",
159                );
160                new_messages.push(Message::ToolResult(error_result));
161            }
162        }
163    }
164
165    *messages = new_messages;
166}