Skip to main content

wesichain_core/
tool.rs

1use crate::Value;
2use thiserror::Error;
3
4pub use tokio_util::sync::CancellationToken;
5
6#[derive(Debug, Error)]
7pub enum ToolError {
8    #[error("invalid input: {0}")]
9    InvalidInput(String),
10    #[error("execution failed: {0}")]
11    ExecutionFailed(String),
12    #[error("json error: {0}")]
13    Json(#[from] serde_json::Error),
14    #[error("io error: {0}")]
15    Io(#[from] std::io::Error),
16}
17
18#[async_trait::async_trait]
19pub trait Tool: Send + Sync {
20    fn name(&self) -> &str;
21    fn description(&self) -> &str;
22    fn schema(&self) -> Value;
23    async fn invoke(&self, args: Value) -> Result<Value, ToolError>;
24}
25
26#[derive(Clone, Debug)]
27pub struct ToolContext {
28    pub correlation_id: String,
29    pub step_id: u32,
30    pub cancellation: CancellationToken,
31    /// Optional channel for streaming tool output to the agent event loop in real time.
32    ///
33    /// Tools that produce incremental output (e.g. `BashExecTool`) should send
34    /// `StreamEvent::ContentChunk` items here so the host can display progress
35    /// without waiting for the full result.
36    ///
37    /// Set to `None` if no streaming consumer is attached.
38    pub stream_tx: Option<tokio::sync::mpsc::UnboundedSender<crate::StreamEvent>>,
39}
40
41impl ToolContext {
42    /// Convenience constructor that leaves `stream_tx` disconnected.
43    pub fn new(
44        correlation_id: impl Into<String>,
45        step_id: u32,
46        cancellation: CancellationToken,
47    ) -> Self {
48        Self {
49            correlation_id: correlation_id.into(),
50            step_id,
51            cancellation,
52            stream_tx: None,
53        }
54    }
55}
56
57#[async_trait::async_trait]
58pub trait TypedTool: Send + Sync {
59    type Args: serde::de::DeserializeOwned + schemars::JsonSchema + Send;
60    type Output: serde::Serialize + schemars::JsonSchema + Send;
61    const NAME: &'static str;
62    async fn run(&self, args: Self::Args, ctx: ToolContext) -> Result<Self::Output, ToolError>;
63}