tirea_contract/runtime/inference/
transform.rs1use crate::runtime::tool_call::ToolDescriptor;
8use crate::thread::Message;
9use std::sync::Arc;
10
11pub struct InferenceTransformOutput {
13 pub messages: Vec<Message>,
15 pub enable_prompt_cache: bool,
17}
18
19pub trait InferenceRequestTransform: Send + Sync {
27 fn transform(
30 &self,
31 messages: Vec<Message>,
32 tool_descriptors: &[ToolDescriptor],
33 ) -> InferenceTransformOutput;
34}
35
36pub fn apply_request_transforms(
40 mut messages: Vec<Message>,
41 tool_descriptors: &[ToolDescriptor],
42 transforms: &[Arc<dyn InferenceRequestTransform>],
43) -> InferenceTransformOutput {
44 let mut enable_prompt_cache = false;
45 for transform in transforms {
46 let output = transform.transform(messages, tool_descriptors);
47 messages = output.messages;
48 enable_prompt_cache |= output.enable_prompt_cache;
49 }
50 InferenceTransformOutput {
51 messages,
52 enable_prompt_cache,
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use crate::thread::Message;
60
61 struct PrependSystem(String);
63
64 impl InferenceRequestTransform for PrependSystem {
65 fn transform(
66 &self,
67 mut messages: Vec<Message>,
68 _tool_descriptors: &[ToolDescriptor],
69 ) -> InferenceTransformOutput {
70 messages.insert(0, Message::system(&self.0));
71 InferenceTransformOutput {
72 messages,
73 enable_prompt_cache: false,
74 }
75 }
76 }
77
78 struct EnableCache;
80
81 impl InferenceRequestTransform for EnableCache {
82 fn transform(
83 &self,
84 messages: Vec<Message>,
85 _tool_descriptors: &[ToolDescriptor],
86 ) -> InferenceTransformOutput {
87 InferenceTransformOutput {
88 messages,
89 enable_prompt_cache: true,
90 }
91 }
92 }
93
94 struct LimitMessages(usize);
96
97 impl InferenceRequestTransform for LimitMessages {
98 fn transform(
99 &self,
100 messages: Vec<Message>,
101 _tool_descriptors: &[ToolDescriptor],
102 ) -> InferenceTransformOutput {
103 let kept: Vec<Message> = messages.into_iter().take(self.0).collect();
104 InferenceTransformOutput {
105 messages: kept,
106 enable_prompt_cache: false,
107 }
108 }
109 }
110
111 #[test]
112 fn empty_transforms_is_passthrough() {
113 let messages = vec![Message::user("Hello"), Message::assistant("Hi")];
114 let output = apply_request_transforms(messages.clone(), &[], &[]);
115 assert_eq!(output.messages.len(), 2);
116 assert_eq!(output.messages[0].content, "Hello");
117 assert_eq!(output.messages[1].content, "Hi");
118 assert!(!output.enable_prompt_cache);
119 }
120
121 #[test]
122 fn single_transform_applied() {
123 let messages = vec![Message::user("Hello")];
124 let transforms: Vec<Arc<dyn InferenceRequestTransform>> =
125 vec![Arc::new(PrependSystem("System prompt".into()))];
126 let output = apply_request_transforms(messages, &[], &transforms);
127 assert_eq!(output.messages.len(), 2);
128 assert_eq!(output.messages[0].content, "System prompt");
129 assert_eq!(output.messages[1].content, "Hello");
130 }
131
132 #[test]
133 fn transforms_chain_pipes_output_to_next() {
134 let messages = vec![Message::user("Hello")];
135 let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
136 Arc::new(PrependSystem("First".into())),
137 Arc::new(PrependSystem("Second".into())),
138 ];
139 let output = apply_request_transforms(messages, &[], &transforms);
140 assert_eq!(output.messages.len(), 3);
143 assert_eq!(output.messages[0].content, "Second");
144 assert_eq!(output.messages[1].content, "First");
145 assert_eq!(output.messages[2].content, "Hello");
146 }
147
148 #[test]
149 fn enable_prompt_cache_or_aggregated() {
150 let messages = vec![Message::user("Hello")];
151 let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
153 Arc::new(PrependSystem("sys".into())), Arc::new(EnableCache), ];
156 let output = apply_request_transforms(messages, &[], &transforms);
157 assert!(
158 output.enable_prompt_cache,
159 "should be true via OR aggregation"
160 );
161 }
162
163 #[test]
164 fn enable_prompt_cache_stays_false_when_none_request() {
165 let messages = vec![Message::user("Hello")];
166 let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
167 Arc::new(PrependSystem("a".into())),
168 Arc::new(PrependSystem("b".into())),
169 ];
170 let output = apply_request_transforms(messages, &[], &transforms);
171 assert!(!output.enable_prompt_cache);
172 }
173
174 #[test]
175 fn chain_with_limiting_transform() {
176 let messages = vec![Message::user("1"), Message::user("2"), Message::user("3")];
177 let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
178 Arc::new(PrependSystem("sys".into())), Arc::new(LimitMessages(2)), ];
181 let output = apply_request_transforms(messages, &[], &transforms);
182 assert_eq!(output.messages.len(), 2);
183 assert_eq!(output.messages[0].content, "sys");
184 assert_eq!(output.messages[1].content, "1");
185 }
186}