1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use std::{pin::pin, vec::IntoIter};

use anyhow::Error;
use futures_util::{
    stream,
    stream::{FlatMap, Iter},
    Stream, StreamExt,
};
use tokio_stream::wrappers::ReceiverStream;

trait Decompose<T> {
    fn decompose(self) -> Vec<T>;
}

impl Decompose<char> for String {
    fn decompose(self) -> Vec<char> {
        self.chars().collect()
    }
}

pub type CharStream<T> = FlatMap<
    T,
    Iter<IntoIter<Result<char, Error>>>,
    fn(anyhow::Result<String>) -> Iter<IntoIter<Result<char, Error>>>,
>;
pub type LineStream = ReceiverStream<anyhow::Result<String>>;

pub trait OpenAiStreamExt: Stream<Item = anyhow::Result<String>> + Sized {
    fn chars(self) -> CharStream<Self> {
        self.flat_map(|elem| {
            let result = match elem {
                Err(res) => vec![Err(res)],
                Ok(res) => res.chars().map(Ok).collect(),
            };
            stream::iter(result)
        })
    }

    /// Outputs a stream of Strings where each item is a line
    fn lines(self) -> LineStream
    where
        Self: Send + 'static,
    {
        let (tx, rx) = tokio::sync::mpsc::channel(1);

        tokio::spawn(async move {
            let mut chars = pin!(self.chars());

            let mut s = String::new();

            while let Some(char) = chars.next().await {
                let char = match char {
                    Ok(char) => char,
                    Err(e) => {
                        let _ = tx.send(Err(e)).await;
                        return;
                    }
                };

                if char == '\n' {
                    let s = core::mem::take(&mut s);
                    if tx.send(Ok(s)).await.is_err() {
                        return;
                    }
                } else {
                    s.push(char);
                }
            }

            if !s.is_empty() {
                let _ = tx.send(Ok(s)).await;
            }
        });

        ReceiverStream::new(rx)
    }
}

impl<T> OpenAiStreamExt for T
where
    T: Stream<Item = anyhow::Result<String>>,
    T: Sized,
{
}