Skip to main content

temporal_agent_rs/
state.rs

1//! Serializable agent state stored in workflow history.
2//!
3//! These types deliberately mirror — rather than re-export — AutoAgents'
4//! `ChatMessage` shape so the workflow's persisted history stays decoupled
5//! from upstream version churn. Conversions happen at the activity boundary
6//! in [`crate::llm`].
7
8use serde::{Deserialize, Serialize};
9
10/// Initial input handed to a new `AgentWorkflow` run.
11///
12/// The model is configured on the [`LLMProvider`] at worker-build time, not
13/// per workflow run; that's why there's no `model` field here.
14///
15/// [`LLMProvider`]: autoagents_llm::LLMProvider
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct AgentInput {
18    pub system_prompt: String,
19    pub user_message: String,
20    /// Hard cap on reasoning turns before the workflow returns.
21    pub max_turns: u32,
22}
23
24/// Reason the agent stopped looping.
25#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
26pub enum StopReason {
27    /// The model emitted a final answer.
28    FinalAnswer,
29    /// Reached `max_turns` before the model finalized.
30    MaxTurnsReached,
31}
32
33/// Final result of an `AgentWorkflow` run.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AgentOutput {
36    pub final_answer: String,
37    pub stop_reason: StopReason,
38    pub turns_used: u32,
39    pub tool_calls: u32,
40}
41
42/// Role of a message in the conversation history.
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44#[serde(rename_all = "snake_case")]
45pub enum Role {
46    System,
47    User,
48    Assistant,
49    Tool,
50}
51
52/// A single tool invocation requested by the model.
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54pub struct ToolCall {
55    pub id: String,
56    pub name: String,
57    pub args: serde_json::Value,
58}
59
60/// Result of executing a single tool call.
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62pub struct ToolResult {
63    pub call_id: String,
64    pub output: serde_json::Value,
65    /// Populated when the tool itself errored. The agent observes this and
66    /// can recover; Temporal does NOT retry tool errors automatically.
67    pub error: Option<String>,
68}
69
70/// A single entry in the conversation history.
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub struct Message {
73    pub role: Role,
74    #[serde(default)]
75    pub content: String,
76    /// Populated on `Role::Assistant` messages that requested tool calls.
77    #[serde(default, skip_serializing_if = "Vec::is_empty")]
78    pub tool_calls: Vec<ToolCall>,
79    /// Populated on `Role::Tool` messages — correlates with a prior `ToolCall::id`.
80    #[serde(default, skip_serializing_if = "Option::is_none")]
81    pub tool_call_id: Option<String>,
82}
83
84impl Message {
85    pub fn system(content: impl Into<String>) -> Self {
86        Self {
87            role: Role::System,
88            content: content.into(),
89            tool_calls: vec![],
90            tool_call_id: None,
91        }
92    }
93
94    pub fn user(content: impl Into<String>) -> Self {
95        Self {
96            role: Role::User,
97            content: content.into(),
98            tool_calls: vec![],
99            tool_call_id: None,
100        }
101    }
102
103    pub fn assistant_text(content: impl Into<String>) -> Self {
104        Self {
105            role: Role::Assistant,
106            content: content.into(),
107            tool_calls: vec![],
108            tool_call_id: None,
109        }
110    }
111
112    pub fn assistant_with_tools(calls: Vec<ToolCall>) -> Self {
113        Self {
114            role: Role::Assistant,
115            content: String::new(),
116            tool_calls: calls,
117            tool_call_id: None,
118        }
119    }
120
121    pub fn tool_result(result: &ToolResult) -> Self {
122        let content = match &result.error {
123            Some(err) => format!("ERROR: {err}"),
124            None => result.output.to_string(),
125        };
126        Self {
127            role: Role::Tool,
128            content,
129            tool_calls: vec![],
130            tool_call_id: Some(result.call_id.clone()),
131        }
132    }
133}
134
135/// Live state of an in-flight agent run. Lives inside the workflow.
136#[derive(Debug, Clone, Default, Serialize, Deserialize)]
137pub struct AgentState {
138    pub input: AgentInput,
139    pub history: Vec<Message>,
140    pub turn: u32,
141    pub tool_calls_executed: u32,
142    /// Messages enqueued via the `add_user_message` signal that haven't been
143    /// folded into history yet. Useful for user-initiated mid-conversation
144    /// nudges; tools that need to *block* waiting on the user should
145    /// implement their own answer channel rather than relying on this.
146    #[serde(default)]
147    pub pending_user_messages: Vec<String>,
148}
149
150impl AgentState {
151    pub fn new(input: AgentInput) -> Self {
152        let history = vec![
153            Message::system(&input.system_prompt),
154            Message::user(&input.user_message),
155        ];
156        Self {
157            input,
158            history,
159            turn: 0,
160            tool_calls_executed: 0,
161            pending_user_messages: vec![],
162        }
163    }
164}
165
166/// The LLM's response on a single turn.
167///
168/// Human-in-the-loop intentionally does NOT have its own variant — that is
169/// handled by user-registered tools whose `execute()` blocks waiting on an
170/// external answer mechanism. The library treats every tool uniformly.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172#[serde(tag = "kind", rename_all = "snake_case")]
173pub enum LlmResponse {
174    /// Model produced a final natural-language answer; the loop should exit.
175    Final { answer: String },
176    /// Model wants to invoke one or more tools before reasoning further.
177    UseTools { calls: Vec<ToolCall> },
178}
179
180/// Input passed to the `llm_chat` activity each turn.
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct LlmChatInput {
183    pub messages: Vec<Message>,
184    pub tools: Vec<ToolSchema>,
185}
186
187/// Description of a tool sent to the LLM so it knows what it can call.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct ToolSchema {
190    pub name: String,
191    pub description: String,
192    pub args_schema: serde_json::Value,
193}
194
195/// Compact long histories so `continue_as_new` doesn't grow the event history
196/// unbounded. Keeps the system prompt, a synthetic summary marker, and the
197/// most recent `keep_recent` messages.
198pub fn compact(state: &AgentState, keep_recent: usize) -> AgentInput {
199    let mut summary_lines = Vec::new();
200    let total = state.history.len();
201    let drop_until = total.saturating_sub(keep_recent);
202
203    for msg in state.history.iter().take(drop_until) {
204        let line = match msg.role {
205            Role::System if summary_lines.is_empty() => continue,
206            Role::User => format!("user: {}", truncate(&msg.content, 200)),
207            Role::Assistant if !msg.tool_calls.is_empty() => {
208                let names: Vec<&str> = msg.tool_calls.iter().map(|c| c.name.as_str()).collect();
209                format!("assistant: called tools [{}]", names.join(", "))
210            }
211            Role::Assistant => format!("assistant: {}", truncate(&msg.content, 200)),
212            Role::Tool => format!("tool: {}", truncate(&msg.content, 120)),
213            Role::System => continue,
214        };
215        summary_lines.push(line);
216    }
217
218    let summary = if summary_lines.is_empty() {
219        String::new()
220    } else {
221        format!(
222            "\n\n[Prior conversation summary, {} messages dropped]\n{}",
223            drop_until,
224            summary_lines.join("\n")
225        )
226    };
227
228    let recent_user = state
229        .history
230        .iter()
231        .rev()
232        .find(|m| m.role == Role::User)
233        .map(|m| m.content.clone())
234        .unwrap_or_default();
235
236    AgentInput {
237        system_prompt: format!("{}{}", state.input.system_prompt, summary),
238        user_message: recent_user,
239        max_turns: state.input.max_turns,
240    }
241}
242
243fn truncate(s: &str, max: usize) -> String {
244    if s.len() <= max {
245        return s.to_string();
246    }
247    let mut boundary = max;
248    while boundary > 0 && !s.is_char_boundary(boundary) {
249        boundary -= 1;
250    }
251    format!("{}…", &s[..boundary])
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn agent_state_seeds_system_and_user() {
260        let s = AgentState::new(AgentInput {
261            system_prompt: "be helpful".into(),
262            user_message: "hi".into(),
263            max_turns: 5,
264        });
265        assert_eq!(s.history.len(), 2);
266        assert_eq!(s.history[0].role, Role::System);
267        assert_eq!(s.history[1].role, Role::User);
268        assert_eq!(s.turn, 0);
269    }
270
271    #[test]
272    fn compact_keeps_system_and_recent() {
273        let mut state = AgentState::new(AgentInput {
274            system_prompt: "sys".into(),
275            user_message: "u0".into(),
276            max_turns: 50,
277        });
278        for i in 1..30 {
279            state.history.push(Message::user(format!("u{i}")));
280            state.history.push(Message::assistant_text(format!("a{i}")));
281        }
282        let compacted = compact(&state, 10);
283        assert!(compacted.system_prompt.starts_with("sys"));
284        assert!(
285            compacted
286                .system_prompt
287                .contains("Prior conversation summary")
288        );
289        assert_eq!(compacted.max_turns, state.input.max_turns);
290    }
291
292    #[test]
293    fn truncate_respects_utf8_char_boundary() {
294        // 'é' is 2 bytes; slicing at byte 2 would split it and panic.
295        let t = truncate("héllo world", 2);
296        assert_eq!(t, "h…");
297        // Already short — returned unchanged.
298        assert_eq!(truncate("hi", 10), "hi");
299        // Emoji boundary (4 bytes for 🦀).
300        assert_eq!(truncate("🦀rust", 2), "…");
301    }
302
303    #[test]
304    fn message_roundtrips_through_json() {
305        let m = Message::assistant_with_tools(vec![ToolCall {
306            id: "c1".into(),
307            name: "add".into(),
308            args: serde_json::json!({"a": 1, "b": 2}),
309        }]);
310        let s = serde_json::to_string(&m).unwrap();
311        let back: Message = serde_json::from_str(&s).unwrap();
312        assert_eq!(m, back);
313    }
314}