pub mod context;
pub mod prompt;
pub mod tools;
pub mod models;
pub mod chains;
use std::fmt::Debug;
use std::str::FromStr;
use std::sync::{Arc, Weak};
use clap::builder::PossibleValue;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::chains::{Chain, Message, MultiStepOODAChain, SingleStepOODAChain};
use crate::context::{ChatEntry, ContextDump};
use crate::models::openai::OpenAI;
use crate::models::{ModelRef, ModelResponse, Role, Usage};
use crate::tools::invocation::InvocationError;
use crate::tools::toolbox::{InvokeResult, Toolbox};
use crate::tools::{TerminationMessage, ToolUseError};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Failed to add to the chat history: {0}")]
ChatHistoryError(#[from] context::Error),
#[error("Model evaluation error: {0}")]
ModelEvaluationError(#[from] models::Error),
#[error("Maximal number of steps reached")]
MaxStepsReached,
#[error("The response is too long: {0}")]
ActionResponseTooLong(String),
#[error("Chain error: {0}")]
ChainError(#[from] chains::Error),
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ChainType {
#[default]
SingleStepOODA,
MultiStepOODA,
}
impl FromStr for ChainType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"single-step-ooda" => Ok(ChainType::SingleStepOODA),
"multi-step-ooda" => Ok(ChainType::MultiStepOODA),
_ => Err(format!("Unknown chain type: {}", s)),
}
}
}
#[cfg(feature = "clap")]
impl clap::ValueEnum for ChainType {
fn value_variants<'a>() -> &'a [Self] {
&[ChainType::SingleStepOODA, ChainType::MultiStepOODA]
}
fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
ChainType::SingleStepOODA => Some(PossibleValue::new("single-step-ooda")),
ChainType::MultiStepOODA => Some(PossibleValue::new("multi-step-ooda")),
}
}
}
#[derive(Clone)]
pub struct SapiensConfig {
pub model: ModelRef,
pub max_steps: usize,
pub chain_type: ChainType,
pub min_tokens_for_completion: usize,
pub max_tokens: Option<usize>,
}
impl Debug for SapiensConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Config")
.field("max_steps", &self.max_steps)
.field("chain_type", &self.chain_type)
.field("min_tokens_for_completion", &self.min_tokens_for_completion)
.field("max_tokens", &self.max_tokens)
.finish()
}
}
impl Default for SapiensConfig {
fn default() -> Self {
Self {
model: Arc::new(Box::<OpenAI>::default()),
max_steps: 10,
chain_type: ChainType::SingleStepOODA,
min_tokens_for_completion: 256,
max_tokens: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelNotification {
pub chat_entry: ChatEntry,
pub usage: Option<Usage>,
}
impl From<ModelResponse> for ModelNotification {
fn from(res: ModelResponse) -> Self {
Self {
chat_entry: ChatEntry {
role: Role::Assistant,
msg: res.msg,
},
usage: res.usage,
}
}
}
#[derive(Debug, Clone)]
pub struct MessageNotification {
pub message: Message,
}
impl From<Message> for MessageNotification {
fn from(message: Message) -> Self {
Self { message }
}
}
pub enum InvocationResultNotification {
InvocationSuccess(InvocationSuccessNotification),
InvocationFailure(InvocationFailureNotification),
InvalidInvocation(InvalidInvocationNotification),
}
impl From<InvokeResult> for InvocationResultNotification {
fn from(res: InvokeResult) -> Self {
match res {
InvokeResult::NoInvocationsFound { e } => {
InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
e,
invocation_count: 0,
})
}
InvokeResult::NoValidInvocationsFound {
e,
invocation_count,
} => InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
e,
invocation_count,
}),
InvokeResult::Success {
invocation_count,
tool_name,
extracted_input,
result,
} => InvocationResultNotification::InvocationSuccess(InvocationSuccessNotification {
invocation_count,
tool_name,
extracted_input,
result,
}),
InvokeResult::Error {
invocation_count,
tool_name,
extracted_input,
e,
} => InvocationResultNotification::InvocationFailure(InvocationFailureNotification {
invocation_count,
tool_name,
extracted_input,
e,
}),
}
}
}
pub struct InvocationSuccessNotification {
pub invocation_count: usize,
pub tool_name: String,
pub extracted_input: String,
pub result: String,
}
pub struct InvocationFailureNotification {
pub invocation_count: usize,
pub tool_name: String,
pub extracted_input: String,
pub e: ToolUseError,
}
pub struct InvalidInvocationNotification {
pub e: InvocationError,
pub invocation_count: usize,
}
pub struct TerminationNotification {
pub messages: Vec<TerminationMessage>,
}
#[async_trait::async_trait]
pub trait RuntimeObserver: Send {
async fn on_task(&mut self, _task: &str) {}
async fn on_start(&mut self, _context: ContextDump) {}
async fn on_model_update(&mut self, _event: ModelNotification) {}
async fn on_message(&mut self, _event: MessageNotification) {}
async fn on_invocation_result(&mut self, _event: InvocationResultNotification) {}
async fn on_termination(&mut self, _event: TerminationNotification) {}
}
pub fn wrap_observer<O: RuntimeObserver + 'static>(observer: O) -> StrongRuntimeObserver<O> {
Arc::new(Mutex::new(observer))
}
pub type StrongRuntimeObserver<O> = Arc<Mutex<O>>;
pub type WeakRuntimeObserver = Weak<Mutex<dyn RuntimeObserver>>;
pub struct VoidTaskProgressUpdateObserver;
#[cfg(test)]
pub(crate) fn void_observer() -> StrongRuntimeObserver<VoidTaskProgressUpdateObserver> {
wrap_observer(VoidTaskProgressUpdateObserver)
}
#[async_trait::async_trait]
impl RuntimeObserver for VoidTaskProgressUpdateObserver {}
pub struct Step {
task_chain: Box<dyn Chain>,
observer: WeakRuntimeObserver,
}
impl Step {
async fn step(mut self) -> Result<TaskState, Error> {
let termination_messages = self.task_chain.step().await?;
if !termination_messages.is_empty() {
if let Some(observer) = self.observer.upgrade() {
observer
.lock()
.await
.on_termination(TerminationNotification {
messages: termination_messages.clone(),
})
.await;
}
return Ok(TaskState::Stop {
stop: Stop {
termination_messages,
},
});
}
Ok(TaskState::Step { step: self })
}
}
pub struct Stop {
pub termination_messages: Vec<TerminationMessage>,
}
pub enum TaskState {
Step {
step: Step,
},
Stop {
stop: Stop,
},
}
impl TaskState {
pub async fn new(config: SapiensConfig, toolbox: Toolbox, task: String) -> Result<Self, Error> {
let observer = wrap_observer(VoidTaskProgressUpdateObserver {});
let observer = Arc::downgrade(&observer);
TaskState::with_observer(config, toolbox, task, observer).await
}
pub async fn with_observer(
config: SapiensConfig,
toolbox: Toolbox,
task: String,
observer: WeakRuntimeObserver,
) -> Result<Self, Error> {
if let Some(observer) = observer.upgrade() {
observer.lock().await.on_task(&task).await;
}
let task_chain = match config.chain_type {
ChainType::SingleStepOODA => {
let chain = SingleStepOODAChain::new(config, toolbox, observer.clone())
.await?
.with_task(task);
Box::new(chain) as Box<dyn Chain>
}
ChainType::MultiStepOODA => {
let chain = MultiStepOODAChain::new(config, toolbox, observer.clone())
.await?
.with_task(task);
Box::new(chain) as Box<dyn Chain>
}
};
if let Some(observer) = observer.upgrade() {
observer.lock().await.on_start(task_chain.dump()).await;
}
Ok(TaskState::Step {
step: Step {
task_chain,
observer,
},
})
}
pub async fn run(mut self) -> Result<Stop, Error> {
loop {
match self {
TaskState::Step { step } => {
self = step.step().await?;
}
TaskState::Stop { stop } => {
return Ok(stop);
}
}
}
}
pub async fn step(self) -> Result<Self, Error> {
match self {
TaskState::Step { step } => step.step().await,
TaskState::Stop { stop } => Ok(TaskState::Stop { stop }),
}
}
pub fn is_done(&self) -> Option<Vec<TerminationMessage>> {
match self {
TaskState::Step { step: _ } => None,
TaskState::Stop { stop } => Some(stop.termination_messages.clone()),
}
}
}
#[tracing::instrument(skip(toolbox, observer, config))]
pub async fn run_to_the_end(
config: SapiensConfig,
toolbox: Toolbox,
task: String,
observer: WeakRuntimeObserver,
) -> Result<Vec<TerminationMessage>, Error> {
let task_state = TaskState::with_observer(config, toolbox, task, observer).await?;
let stop = task_state.run().await?;
Ok(stop.termination_messages)
}