use async_trait::async_trait;
use derive_builder::Builder;
use swiftide_core::{indexing::IndexingStream, indexing::Node, ChunkerTransformer};
use text_splitter::{Characters, MarkdownSplitter};
#[derive(Debug, Builder)]
#[builder(pattern = "owned", setter(strip_option))]
pub struct ChunkMarkdown {
chunker: MarkdownSplitter<Characters>,
#[builder(default)]
concurrency: Option<usize>,
#[builder(default)]
range: Option<std::ops::Range<usize>>,
}
impl ChunkMarkdown {
pub fn from_max_characters(max_characters: usize) -> Self {
Self {
chunker: MarkdownSplitter::new(max_characters),
concurrency: None,
range: None,
}
}
pub fn from_chunk_range(range: std::ops::Range<usize>) -> Self {
Self {
chunker: MarkdownSplitter::new(range.clone()),
concurrency: None,
range: Some(range),
}
}
pub fn builder() -> ChunkMarkdownBuilder {
ChunkMarkdownBuilder::default()
}
#[must_use]
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
fn min_size(&self) -> usize {
self.range.as_ref().map_or(0, |r| r.start)
}
}
#[async_trait]
impl ChunkerTransformer for ChunkMarkdown {
#[tracing::instrument(skip_all, name = "transformers.chunk_markdown")]
async fn transform_node(&self, node: Node) -> IndexingStream {
let chunks = self
.chunker
.chunks(&node.chunk)
.filter_map(|chunk| {
let trim = chunk.trim();
if trim.is_empty() || trim.len() < self.min_size() {
None
} else {
Some(chunk.to_string())
}
})
.collect::<Vec<String>>();
IndexingStream::iter(chunks.into_iter().map(move |chunk| {
Ok(Node {
chunk,
..node.clone()
})
}))
}
fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
#[cfg(test)]
mod test {
use super::*;
use futures_util::stream::TryStreamExt;
const MARKDOWN: &str = r"
# Hello, world!
This is a test markdown document.
## Section 1
This is a paragraph.
## Section 2
This is another paragraph.
";
#[tokio::test]
async fn test_transforming_with_max_characters_and_trimming() {
let chunker = ChunkMarkdown::from_max_characters(40);
let node = Node {
chunk: MARKDOWN.to_string(),
..Node::default()
};
let nodes: Vec<Node> = chunker
.transform_node(node)
.await
.try_collect()
.await
.unwrap();
for line in MARKDOWN.lines().filter(|line| !line.trim().is_empty()) {
assert!(nodes.iter().any(|node| node.chunk == line.trim()));
}
assert_eq!(nodes.len(), 6);
}
#[tokio::test]
async fn test_always_within_range() {
let ranges = vec![(10..15), (20..25), (30..35), (40..45), (50..55)];
for range in ranges {
let chunker = ChunkMarkdown::from_chunk_range(range.clone());
let node = Node {
chunk: MARKDOWN.to_string(),
..Node::default()
};
let nodes: Vec<Node> = chunker
.transform_node(node)
.await
.try_collect()
.await
.unwrap();
assert!(
nodes.iter().all(|node| {
let len = node.chunk.len();
range.contains(&len)
}),
"{:?}, {:?}",
range,
nodes.iter().filter(|node| {
let len = node.chunk.len();
!range.contains(&len)
})
);
}
}
#[test]
fn test_builder() {
ChunkMarkdown::builder()
.chunker(MarkdownSplitter::new(40))
.concurrency(10)
.range(10..20)
.build()
.unwrap();
}
}