Skip to main content

tycode_core/ai/
provider.rs

1use std::collections::HashSet;
2use std::pin::Pin;
3
4use tokio_stream::Stream;
5
6use crate::ai::tweaks::ModelTweaks;
7use crate::ai::{error::AiError, model::Model, types::*};
8
9#[async_trait::async_trait]
10pub trait AiProvider: Send + Sync {
11    fn name(&self) -> &'static str;
12
13    fn supported_models(&self) -> HashSet<Model>;
14
15    async fn converse(&self, request: ConversationRequest)
16        -> Result<ConversationResponse, AiError>;
17
18    fn get_cost(&self, model: &Model) -> Cost;
19
20    async fn converse_stream(
21        &self,
22        request: ConversationRequest,
23    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AiError>> + Send>>, AiError> {
24        let response = self.converse(request).await?;
25        Ok(Box::pin(tokio_stream::once(Ok(
26            StreamEvent::MessageComplete { response },
27        ))))
28    }
29
30    fn tweaks(&self) -> ModelTweaks {
31        ModelTweaks::default()
32    }
33}