Skip to main content

systemprompt_generator/rss/
default_provider.rs

1use anyhow::{Context, Result, anyhow};
2use async_trait::async_trait;
3use systemprompt_content::ContentRepository;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::SourceId;
6use systemprompt_models::{AppPaths, Config, ContentConfigRaw, WebConfig};
7use systemprompt_provider_contracts::{
8    RssFeedContext, RssFeedItem, RssFeedMetadata, RssFeedProvider, RssFeedSpec,
9};
10use tokio::fs;
11
12use crate::templates::load_web_config;
13
14const DEFAULT_MAX_ITEMS: i64 = 20;
15
16pub struct DefaultRssFeedProvider {
17    db_pool: DbPool,
18    content_config: ContentConfigRaw,
19    web_config: WebConfig,
20}
21
22impl std::fmt::Debug for DefaultRssFeedProvider {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("DefaultRssFeedProvider")
25            .field(
26                "content_sources",
27                &self.content_config.content_sources.keys(),
28            )
29            .finish_non_exhaustive()
30    }
31}
32
33impl DefaultRssFeedProvider {
34    pub async fn new(db_pool: DbPool) -> Result<Self> {
35        let content_config = load_content_config().await?;
36        let web_config = load_web_config()
37            .await
38            .map_err(|e| anyhow!("Failed to load web config: {}", e))?;
39
40        Ok(Self {
41            db_pool,
42            content_config,
43            web_config,
44        })
45    }
46
47    fn get_source_branding(&self, source_name: &str) -> (String, String) {
48        let default_title = &self.web_config.branding.title;
49        let default_description = &self.web_config.branding.description;
50
51        self.content_config
52            .content_sources
53            .get(source_name)
54            .and_then(|source| source.branding.as_ref())
55            .map_or_else(
56                || (default_title.clone(), default_description.clone()),
57                |branding| {
58                    (
59                        branding
60                            .name
61                            .clone()
62                            .unwrap_or_else(|| default_title.clone()),
63                        branding
64                            .description
65                            .clone()
66                            .unwrap_or_else(|| default_description.clone()),
67                    )
68                },
69            )
70    }
71}
72
73async fn load_content_config() -> Result<ContentConfigRaw> {
74    let paths = AppPaths::get().map_err(|e| anyhow!("{}", e))?;
75    let config_path = paths.system().content_config();
76
77    let yaml_content = fs::read_to_string(&config_path)
78        .await
79        .context("Failed to read content config")?;
80
81    serde_yaml::from_str(&yaml_content).context("Failed to parse content config")
82}
83
84#[async_trait]
85impl RssFeedProvider for DefaultRssFeedProvider {
86    fn provider_id(&self) -> &'static str {
87        "default-rss"
88    }
89
90    fn feed_specs(&self) -> Vec<RssFeedSpec> {
91        self.content_config
92            .content_sources
93            .iter()
94            .filter(|(_, source)| source.enabled)
95            .filter(|(_, source)| source.sitemap.as_ref().is_some_and(|s| s.enabled))
96            .map(|(name, source)| RssFeedSpec {
97                source_id: source.source_id.clone(),
98                max_items: DEFAULT_MAX_ITEMS,
99                output_filename: format!("{}.xml", name),
100            })
101            .collect()
102    }
103
104    async fn feed_metadata(&self, ctx: &RssFeedContext<'_>) -> Result<RssFeedMetadata> {
105        let (title, description) = self.get_source_branding(ctx.source_name);
106        let global_config = Config::get()?;
107
108        let language = if self.content_config.metadata.language.is_empty() {
109            "en".to_string()
110        } else {
111            self.content_config.metadata.language.clone()
112        };
113
114        Ok(RssFeedMetadata {
115            title,
116            link: global_config.api_external_url.clone(),
117            description,
118            language: Some(language),
119        })
120    }
121
122    async fn fetch_items(&self, ctx: &RssFeedContext<'_>, limit: i64) -> Result<Vec<RssFeedItem>> {
123        let source_config = self
124            .content_config
125            .content_sources
126            .values()
127            .find(|s| s.source_id.as_str() == ctx.source_name)
128            .ok_or_else(|| anyhow!("Source not found: {}", ctx.source_name))?;
129
130        let url_pattern = source_config
131            .sitemap
132            .as_ref()
133            .map_or("/{slug}", |s| s.url_pattern.as_str());
134
135        let repo = ContentRepository::new(&self.db_pool)
136            .map_err(|e| anyhow!("Failed to create content repository: {}", e))?;
137
138        let source_id = SourceId::new(ctx.source_name);
139        let content_items = repo
140            .list_by_source_limited(&source_id, limit)
141            .await
142            .context("Failed to fetch content for RSS feed")?;
143
144        let items = content_items
145            .into_iter()
146            .map(|content| {
147                let relative_url = url_pattern.replace("{slug}", &content.slug);
148                let link = format!("{}{}", ctx.base_url, relative_url);
149                RssFeedItem {
150                    title: content.title,
151                    link: link.clone(),
152                    description: content.description,
153                    pub_date: content.published_at,
154                    guid: link,
155                    author: Some(content.author),
156                }
157            })
158            .collect();
159
160        Ok(items)
161    }
162}