tokio_openai/
ext.rs

1use std::{pin::pin, vec::IntoIter};
2
3use anyhow::Error;
4use futures_util::{
5    stream,
6    stream::{FlatMap, Iter},
7    Stream, StreamExt,
8};
9use tokio_stream::wrappers::ReceiverStream;
10
11trait Decompose<T> {
12    fn decompose(self) -> Vec<T>;
13}
14
15impl Decompose<char> for String {
16    fn decompose(self) -> Vec<char> {
17        self.chars().collect()
18    }
19}
20
21pub type CharStream<T> = FlatMap<
22    T,
23    Iter<IntoIter<Result<char, Error>>>,
24    fn(anyhow::Result<String>) -> Iter<IntoIter<Result<char, Error>>>,
25>;
26pub type LineStream = ReceiverStream<anyhow::Result<String>>;
27
28pub trait OpenAiStreamExt: Stream<Item = anyhow::Result<String>> + Sized {
29    fn chars(self) -> CharStream<Self> {
30        self.flat_map(|elem| {
31            let result = match elem {
32                Err(res) => vec![Err(res)],
33                Ok(res) => res.chars().map(Ok).collect(),
34            };
35            stream::iter(result)
36        })
37    }
38
39    /// Outputs a stream of Strings where each item is a line
40    fn lines(self) -> LineStream
41    where
42        Self: Send + 'static,
43    {
44        let (tx, rx) = tokio::sync::mpsc::channel(1);
45
46        tokio::spawn(async move {
47            let mut chars = pin!(self.chars());
48
49            let mut s = String::new();
50
51            while let Some(char) = chars.next().await {
52                let char = match char {
53                    Ok(char) => char,
54                    Err(e) => {
55                        let _ = tx.send(Err(e)).await;
56                        return;
57                    }
58                };
59
60                if char == '\n' {
61                    let s = core::mem::take(&mut s);
62                    if tx.send(Ok(s)).await.is_err() {
63                        return;
64                    }
65                } else {
66                    s.push(char);
67                }
68            }
69
70            if !s.is_empty() {
71                let _ = tx.send(Ok(s)).await;
72            }
73        });
74
75        ReceiverStream::new(rx)
76    }
77}
78
79impl<T> OpenAiStreamExt for T
80where
81    T: Stream<Item = anyhow::Result<String>>,
82    T: Sized,
83{
84}