Skip to main content

tirea_contract/runtime/inference/
transform.rs

1//! Generic inference request transformation.
2//!
3//! Plugins register [`InferenceRequestTransform`] implementations to modify
4//! the message list and set request-level hints before inference. The core
5//! loop applies all registered transforms without knowing their semantics.
6
7use crate::runtime::tool_call::ToolDescriptor;
8use crate::thread::Message;
9use std::sync::Arc;
10
11/// Output of an inference request transform.
12pub struct InferenceTransformOutput {
13    /// Transformed message list (system + history).
14    pub messages: Vec<Message>,
15    /// Whether to enable prompt caching for this request.
16    pub enable_prompt_cache: bool,
17}
18
19/// Trait for plugins that transform the inference request before it is sent
20/// to the LLM. The core loop calls all registered transforms in order,
21/// piping the output messages of one into the input of the next.
22///
23/// Implementing this trait is how plugins inject behaviors like context
24/// window truncation, message summarization, or content filtering — without
25/// the core loop needing domain-specific knowledge.
26pub trait InferenceRequestTransform: Send + Sync {
27    /// Transform the message list. `tool_descriptors` are provided so
28    /// transforms can account for tool token overhead.
29    fn transform(
30        &self,
31        messages: Vec<Message>,
32        tool_descriptors: &[ToolDescriptor],
33    ) -> InferenceTransformOutput;
34}
35
36/// Apply a chain of transforms, piping messages through each in order.
37///
38/// Returns the final messages and whether any transform requested prompt caching.
39pub fn apply_request_transforms(
40    mut messages: Vec<Message>,
41    tool_descriptors: &[ToolDescriptor],
42    transforms: &[Arc<dyn InferenceRequestTransform>],
43) -> InferenceTransformOutput {
44    let mut enable_prompt_cache = false;
45    for transform in transforms {
46        let output = transform.transform(messages, tool_descriptors);
47        messages = output.messages;
48        enable_prompt_cache |= output.enable_prompt_cache;
49    }
50    InferenceTransformOutput {
51        messages,
52        enable_prompt_cache,
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::thread::Message;
60
61    /// A transform that prepends a system message.
62    struct PrependSystem(String);
63
64    impl InferenceRequestTransform for PrependSystem {
65        fn transform(
66            &self,
67            mut messages: Vec<Message>,
68            _tool_descriptors: &[ToolDescriptor],
69        ) -> InferenceTransformOutput {
70            messages.insert(0, Message::system(&self.0));
71            InferenceTransformOutput {
72                messages,
73                enable_prompt_cache: false,
74            }
75        }
76    }
77
78    /// A transform that enables prompt caching without modifying messages.
79    struct EnableCache;
80
81    impl InferenceRequestTransform for EnableCache {
82        fn transform(
83            &self,
84            messages: Vec<Message>,
85            _tool_descriptors: &[ToolDescriptor],
86        ) -> InferenceTransformOutput {
87            InferenceTransformOutput {
88                messages,
89                enable_prompt_cache: true,
90            }
91        }
92    }
93
94    /// A transform that drops messages exceeding a count limit.
95    struct LimitMessages(usize);
96
97    impl InferenceRequestTransform for LimitMessages {
98        fn transform(
99            &self,
100            messages: Vec<Message>,
101            _tool_descriptors: &[ToolDescriptor],
102        ) -> InferenceTransformOutput {
103            let kept: Vec<Message> = messages.into_iter().take(self.0).collect();
104            InferenceTransformOutput {
105                messages: kept,
106                enable_prompt_cache: false,
107            }
108        }
109    }
110
111    #[test]
112    fn empty_transforms_is_passthrough() {
113        let messages = vec![Message::user("Hello"), Message::assistant("Hi")];
114        let output = apply_request_transforms(messages.clone(), &[], &[]);
115        assert_eq!(output.messages.len(), 2);
116        assert_eq!(output.messages[0].content, "Hello");
117        assert_eq!(output.messages[1].content, "Hi");
118        assert!(!output.enable_prompt_cache);
119    }
120
121    #[test]
122    fn single_transform_applied() {
123        let messages = vec![Message::user("Hello")];
124        let transforms: Vec<Arc<dyn InferenceRequestTransform>> =
125            vec![Arc::new(PrependSystem("System prompt".into()))];
126        let output = apply_request_transforms(messages, &[], &transforms);
127        assert_eq!(output.messages.len(), 2);
128        assert_eq!(output.messages[0].content, "System prompt");
129        assert_eq!(output.messages[1].content, "Hello");
130    }
131
132    #[test]
133    fn transforms_chain_pipes_output_to_next() {
134        let messages = vec![Message::user("Hello")];
135        let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
136            Arc::new(PrependSystem("First".into())),
137            Arc::new(PrependSystem("Second".into())),
138        ];
139        let output = apply_request_transforms(messages, &[], &transforms);
140        // First transform: [System("First"), User("Hello")]
141        // Second transform: [System("Second"), System("First"), User("Hello")]
142        assert_eq!(output.messages.len(), 3);
143        assert_eq!(output.messages[0].content, "Second");
144        assert_eq!(output.messages[1].content, "First");
145        assert_eq!(output.messages[2].content, "Hello");
146    }
147
148    #[test]
149    fn enable_prompt_cache_or_aggregated() {
150        let messages = vec![Message::user("Hello")];
151        // First transform: cache=false, Second: cache=true
152        let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
153            Arc::new(PrependSystem("sys".into())), // cache=false
154            Arc::new(EnableCache),                 // cache=true
155        ];
156        let output = apply_request_transforms(messages, &[], &transforms);
157        assert!(
158            output.enable_prompt_cache,
159            "should be true via OR aggregation"
160        );
161    }
162
163    #[test]
164    fn enable_prompt_cache_stays_false_when_none_request() {
165        let messages = vec![Message::user("Hello")];
166        let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
167            Arc::new(PrependSystem("a".into())),
168            Arc::new(PrependSystem("b".into())),
169        ];
170        let output = apply_request_transforms(messages, &[], &transforms);
171        assert!(!output.enable_prompt_cache);
172    }
173
174    #[test]
175    fn chain_with_limiting_transform() {
176        let messages = vec![Message::user("1"), Message::user("2"), Message::user("3")];
177        let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
178            Arc::new(PrependSystem("sys".into())), // 4 messages
179            Arc::new(LimitMessages(2)),            // keep first 2
180        ];
181        let output = apply_request_transforms(messages, &[], &transforms);
182        assert_eq!(output.messages.len(), 2);
183        assert_eq!(output.messages[0].content, "sys");
184        assert_eq!(output.messages[1].content, "1");
185    }
186}