1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
//! Generate a summary and adds it as metadata
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use derive_builder::Builder;
use swiftide_core::{indexing::Node, prompt::PromptTemplate, SimplePrompt, Transformer};

pub const NAME: &str = "Summary";

/// This module defines the `MetadataSummary` struct and its associated methods,
/// which are used for generating metadata in the form of a summary
/// for a given text. It interacts with a client (e.g., `OpenAI`) to generate
/// the summary based on the text chunk in an `Node`.

/// `MetadataSummary` is responsible for generating a summary
/// for a given text chunk. It uses a templated prompt to interact with a client
/// that implements the `SimplePrompt` trait.
#[derive(Debug, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct MetadataSummary {
    #[builder(setter(custom))]
    client: Arc<dyn SimplePrompt>,
    #[builder(default = "default_prompt()")]
    /// The prompt templated used. Can be overwritten via the builder. Has the `node` available as
    /// context.
    prompt_template: PromptTemplate,
    #[builder(default)]
    concurrency: Option<usize>,
}

impl MetadataSummary {
    pub fn builder() -> MetadataSummaryBuilder {
        MetadataSummaryBuilder::default()
    }

    pub fn from_client(client: impl SimplePrompt + 'static) -> MetadataSummaryBuilder {
        MetadataSummaryBuilder::default().client(client).to_owned()
    }
    /// Creates a new instance of `MetadataSummary`.
    ///
    /// # Arguments
    ///
    /// * `client` - An implementation of the `SimplePrompt` trait.
    ///
    /// # Returns
    ///
    /// A new instance of `MetadataSummary`.
    pub fn new(client: impl SimplePrompt + 'static) -> Self {
        Self {
            client: Arc::new(client),
            prompt_template: default_prompt(),
            concurrency: None,
        }
    }

    #[must_use]
    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
        self.concurrency = Some(concurrency);
        self
    }
}

/// Generates the default prompt template for extracting a summary.
fn default_prompt() -> PromptTemplate {
    include_str!("prompts/metadata_summary.prompt.md").into()
}

impl MetadataSummaryBuilder {
    pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
        self.client = Some(Arc::new(client));
        self
    }
}

#[async_trait]
impl Transformer for MetadataSummary {
    /// Transforms an `Node` by extracting a summary
    /// based on the text chunk within the node.
    ///
    /// # Arguments
    ///
    /// * `node` - The `Node` containing the text chunk to process.
    ///
    /// # Returns
    ///
    /// A `Result` containing the transformed `Node` with added metadata,
    /// or an error if the transformation fails.
    ///
    /// # Errors
    ///
    /// This function will return an error if the client fails to generate
    /// a summary from the provided prompt.
    #[tracing::instrument(skip_all, name = "transformers.metadata_summary")]
    async fn transform_node(&self, mut node: Node) -> Result<Node> {
        let prompt = self.prompt_template.to_prompt().with_node(&node);

        let response = self.client.prompt(prompt).await?;

        node.metadata.insert(NAME, response);

        Ok(node)
    }

    fn concurrency(&self) -> Option<usize> {
        self.concurrency
    }
}

#[cfg(test)]
mod test {
    use swiftide_core::MockSimplePrompt;

    use super::*;

    #[tokio::test]
    async fn test_template() {
        let template = default_prompt();

        let prompt = template.to_prompt().with_node(&Node::new("test"));
        insta::assert_snapshot!(prompt.render().await.unwrap());
    }

    #[tokio::test]
    async fn test_metadata_summary() {
        let mut client = MockSimplePrompt::new();

        client
            .expect_prompt()
            .returning(|_| Ok("A Summary".to_string()));

        let transformer = MetadataSummary::builder().client(client).build().unwrap();
        let node = Node::new("Some text");

        let result = transformer.transform_node(node).await.unwrap();

        assert_eq!(result.metadata.get("Summary").unwrap(), "A Summary");
    }
}