use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
pub use llm_chain::parsing::{find_yaml, ExtractionError};
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct FieldFormat {
pub name: String,
pub r#type: String,
pub optional: bool,
pub description: String,
}
pub trait Describe {
fn describe() -> Format;
}
#[derive(Debug, Clone)]
pub struct Format {
pub fields: Vec<FieldFormat>,
}
impl Serialize for Format {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let n = self.fields.len();
let mut map = serializer.serialize_map(Some(n))?;
for field in &self.fields {
let description = if field.optional {
format!("<{}> {} (optional)", field.r#type, field.description)
} else {
format!("<{}> {}", field.r#type, field.description)
};
map.serialize_entry(&field.name, &description)?;
}
map.end()
}
}
impl From<Vec<FieldFormat>> for Format {
fn from(fields: Vec<FieldFormat>) -> Self {
Format { fields }
}
}
#[derive(Debug, Serialize, Clone)]
pub struct ToolDescription {
pub name: String,
pub description: String,
pub input_format: Format,
pub output_format: Format,
}
impl ToolDescription {
pub fn new(name: &str, description: &str, input_format: Format, output_format: Format) -> Self {
ToolDescription {
name: name.to_string(),
description: description.to_string(),
input_format,
output_format,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ToolUseError {
#[error("Tool not found: {0}")]
ToolNotFound(String),
#[error("Tool invocation failed: {0}")]
ToolInvocationFailed(String),
#[error("Invalid YAML: {0}")]
InvalidYaml(#[from] serde_yaml::Error),
#[error("Invalid input: {0}")]
InvalidInput(#[from] ExtractionError),
}
#[derive(Serialize, Deserialize, Debug)]
struct ToolInvocationInput {
command: String,
input: serde_yaml::Value,
#[serde(skip_serializing_if = "HashMap::is_empty", flatten)]
junk: HashMap<String, serde_yaml::Value>,
}
pub trait ProtoToolDescribe {
fn description(&self) -> ToolDescription;
}
#[async_trait::async_trait]
pub trait ProtoToolInvoke {
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError>;
}
#[async_trait::async_trait]
pub trait Tool: Sync + Send {
fn description(&self) -> ToolDescription;
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError>;
}
#[async_trait::async_trait]
impl<T: Sync + Send> Tool for T
where
T: ProtoToolDescribe + ProtoToolInvoke,
{
fn description(&self) -> ToolDescription {
self.description()
}
async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, ToolUseError> {
self.invoke(input).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TerminationMessage {
pub conclusion: String,
pub original_question: String,
}
#[async_trait::async_trait]
pub trait TerminalTool: Tool + Sync + Send {
async fn is_done(&self) -> bool {
false
}
async fn take_done(&self) -> Option<TerminationMessage> {
None
}
}
#[async_trait::async_trait]
pub trait AdvancedTool: Tool {
async fn invoke_with_toolbox(
&self,
toolbox: Toolbox,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError>;
}
#[derive(Default, Clone)]
pub struct Toolbox {
terminal_tools: Arc<RwLock<HashMap<String, Box<dyn TerminalTool>>>>,
tools: Arc<RwLock<HashMap<String, Box<dyn Tool>>>>,
advanced_tools: Arc<RwLock<HashMap<String, Box<dyn AdvancedTool>>>>,
}
impl Debug for Toolbox {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Toolbox").finish()
}
}
impl Toolbox {
pub async fn termination_messages(&self) -> Vec<TerminationMessage> {
let mut messages = Vec::new();
for tool in self.terminal_tools.read().await.values() {
if let Some(message) = tool.take_done().await {
messages.push(message);
}
}
messages
}
pub async fn add_terminal_tool(&mut self, tool: impl TerminalTool + 'static) {
let name = tool.description().name;
self.terminal_tools
.write()
.await
.insert(name, Box::new(tool));
}
pub async fn add_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.description().name;
self.tools.write().await.insert(name, Box::new(tool));
}
pub async fn add_advanced_tool(&mut self, tool: impl AdvancedTool + 'static) {
let name = tool.description().name;
self.advanced_tools
.write()
.await
.insert(name, Box::new(tool));
}
pub async fn describe(&self) -> HashMap<String, ToolDescription> {
let mut descriptions = HashMap::new();
for (name, tool) in self.terminal_tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
for (name, tool) in self.tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
for (name, tool) in self.advanced_tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
descriptions
}
}
pub async fn invoke_from_toolbox(
toolbox: Toolbox,
name: &str,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError> {
if let Some(tool) = toolbox.clone().advanced_tools.read().await.get(name) {
return tool.invoke_with_toolbox(toolbox, input).await;
}
{
let guard = toolbox.terminal_tools.read().await;
if let Some(tool) = guard.get(name) {
return tool.invoke(input).await;
}
}
let guard = toolbox.tools.read().await;
let tool = guard
.get(name)
.ok_or(ToolUseError::ToolNotFound(name.to_string()))?;
tool.invoke(input).await
}
pub async fn invoke_simple_from_toolbox(
toolbox: Toolbox,
name: &str,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError> {
{
let guard = toolbox.terminal_tools.read().await;
if let Some(tool) = guard.get(name) {
return tool.invoke(input).await;
}
}
let guard = toolbox.tools.read().await;
let tool = guard
.get(name)
.ok_or(ToolUseError::ToolNotFound(name.to_string()))?;
tool.invoke(input).await
}
#[tracing::instrument]
pub async fn invoke_tool(toolbox: Toolbox, data: &str) -> (String, Result<String, ToolUseError>) {
let tool_invocations = find_yaml::<ToolInvocationInput>(data);
match tool_invocations {
Ok(tool_invocations) => {
if tool_invocations.is_empty() {
return (
"unknown".to_string(),
Err(ToolUseError::ToolInvocationFailed(
"No Action found".to_string(),
)),
);
}
for invocation in tool_invocations.iter() {
if !invocation.junk.is_empty() {
let junk_keys = invocation
.junk
.keys()
.cloned()
.collect::<Vec<String>>()
.join(", ");
return (
"unknown".to_string(),
Err(ToolUseError::ToolInvocationFailed(
format!("The Action cannot have fields: {}. Only `command` and `input` are allowed.", junk_keys),
)),
);
}
}
let invocation_input = &tool_invocations.last().unwrap();
let tool_name = invocation_input.command.clone();
let input = invocation_input.input.clone();
match invoke_from_toolbox(toolbox, &invocation_input.command, input).await {
Ok(o) => (tool_name, Ok(serde_yaml::to_string(&o).unwrap())),
Err(e) => (tool_name, Err(e)),
}
}
Err(e) => ("unknown".to_string(), Err(ToolUseError::InvalidInput(e))),
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use indoc::indoc;
use insta::assert_display_snapshot;
use serde::{Deserialize, Serialize};
use serde_yaml::Number;
#[tokio::test]
async fn test_extraction_of_one_yaml() {
let data = indoc! {r#"# Some text
```yaml
command: Search
input:
q: Marcel Deneuve
excluded_terms: Resident Evil
num_results: 10
```
Some other text
"#};
let tool_invocations = super::find_yaml::<super::ToolInvocationInput>(data).unwrap();
assert_eq!(tool_invocations.len(), 1);
let invocation = &tool_invocations[0];
assert_eq!(invocation.command, "Search");
}
#[tokio::test]
async fn test_extraction_of_one_yaml_with_output() {
let data = indoc! {r#"# Some text
```yaml
command: Search
input:
q: Marcel Deneuve
excluded_terms: Resident Evil
num_results: 10
output:
something: |
Marcel Deneuve is a character in the Resident Evil film series, playing a minor role in Resident Evil: Apocalypse and a much larger role in Resident Evil: Extinction. Explore historical records and family tree profiles about Marcel Deneuve on MyHeritage, the world's largest family network.
```
Some other text
"#};
let tool_invocations = super::find_yaml::<super::ToolInvocationInput>(data).unwrap();
assert_eq!(tool_invocations.len(), 1);
let invocation = &tool_invocations[0];
assert_eq!(invocation.command, "Search");
assert_eq!(invocation.input.get("q").unwrap(), "Marcel Deneuve");
assert_eq!(
invocation.input.get("excluded_terms").unwrap(),
"Resident Evil"
);
assert_eq!(
invocation.input.get("num_results").unwrap(),
&serde_yaml::Value::Number(Number::from(10))
);
assert!(!invocation.junk.is_empty());
assert!(invocation.junk.get("output").is_some());
}
#[tokio::test]
async fn test_extraction_of_three_yaml_with_output() {
let data = indoc! {r#"# Some text
```yaml
command: Search
input:
q: Marcel Deneuve
excluded_terms: Resident Evil
num_results: 10
output:
something: |
Marcel Deneuve is a character in the Resident Evil film series, playing a minor role in Resident Evil: Apocalypse and a much larger role in Resident Evil: Extinction. Explore historical records and family tree profiles about Marcel Deneuve on MyHeritage, the world's largest family network.
```
Some other text
```yaml
command: Erf
input:
q: Marcel Prouse
excluded_terms: La Recherche du Temps Perdu
num_results: 10
```
Some other other text
```yaml
command: Plaff
input:
q: Marcel et son Orchestre
excluded_terms: Les Vaches
num_results: 10
```
That's all folks!
"#};
let tool_invocations = super::find_yaml::<super::ToolInvocationInput>(data).unwrap();
assert_eq!(tool_invocations.len(), 3);
let invocation = &tool_invocations[0];
assert_eq!(invocation.command, "Plaff");
let invocation = &tool_invocations[1];
assert_eq!(invocation.command, "Erf");
let invocation = &tool_invocations[2];
assert_eq!(invocation.command, "Search");
}
#[derive(Debug, Serialize, Deserialize)]
struct FakeToolInput {
q: String,
excluded_terms: Option<String>,
num_results: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FakeToolOutput {
items: Vec<String>,
}
#[tokio::test]
async fn test_serializing_tool_invocation() {
let input = FakeToolInput {
q: "Marcel Deneuve".to_string(),
excluded_terms: Some("Resident Evil".to_string()),
num_results: Some(10),
};
let output = FakeToolOutput {
items: vec![
"Marcel Deneuve is a character in the Resident Evil film series,".to_string(),
"playing a minor role in Resident Evil: Apocalypse and a much larger".to_string(),
" role in Resident Evil: Extinction. Explore historical records and ".to_string(),
"family tree profiles about Marcel Deneuve on MyHeritage, the world's largest family network.".to_string()
]
};
let junk = vec![("output".to_string(), serde_yaml::to_value(output).unwrap())];
let invocation = super::ToolInvocationInput {
command: "Search".to_string(),
input: serde_yaml::to_value(input).unwrap(),
junk: HashMap::from_iter(junk.into_iter()),
};
let serialized = serde_yaml::to_string(&invocation).unwrap();
assert_display_snapshot!(serialized);
}
}