tycode_core/ai/
provider.rs1use std::collections::HashSet;
2use std::pin::Pin;
3
4use tokio_stream::Stream;
5
6use crate::ai::tweaks::ModelTweaks;
7use crate::ai::{error::AiError, model::Model, types::*};
8
9#[async_trait::async_trait]
10pub trait AiProvider: Send + Sync {
11 fn name(&self) -> &'static str;
12
13 fn supported_models(&self) -> HashSet<Model>;
14
15 async fn converse(&self, request: ConversationRequest)
16 -> Result<ConversationResponse, AiError>;
17
18 fn get_cost(&self, model: &Model) -> Cost;
19
20 async fn converse_stream(
21 &self,
22 request: ConversationRequest,
23 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AiError>> + Send>>, AiError> {
24 let response = self.converse(request).await?;
25 Ok(Box::pin(tokio_stream::once(Ok(
26 StreamEvent::MessageComplete { response },
27 ))))
28 }
29
30 fn tweaks(&self) -> ModelTweaks {
31 ModelTweaks::default()
32 }
33}