Skip to main content

rust_genai_types/
grounding.rs

1use crate::base64_serde;
2use serde::{Deserialize, Serialize};
3
4/// 引用日期。
5#[derive(Debug, Clone, Serialize, Deserialize)]
6#[serde(rename_all = "camelCase")]
7pub struct GoogleTypeDate {
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub day: Option<i32>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub month: Option<i32>,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub year: Option<i32>,
14}
15
16/// 引用信息。
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "camelCase")]
19pub struct Citation {
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub end_index: Option<i32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub license: Option<String>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub publication_date: Option<GoogleTypeDate>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub start_index: Option<i32>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub title: Option<String>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub uri: Option<String>,
32}
33
34/// 引用元数据。
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(rename_all = "camelCase")]
37pub struct CitationMetadata {
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub citations: Option<Vec<Citation>>,
40}
41
42/// Author attribution for Maps.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct PlaceAnswerSourcesAuthorAttribution {
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub display_name: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub photo_uri: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub uri: Option<String>,
52}
53
54/// Review snippet.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct PlaceAnswerSourcesReviewSnippet {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub author_attribution: Option<PlaceAnswerSourcesAuthorAttribution>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub flag_content_uri: Option<String>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub google_maps_uri: Option<String>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub relative_publish_time_description: Option<String>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub review: Option<String>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub review_id: Option<String>,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub title: Option<String>,
72}
73
74/// Sources used to generate the place answer.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76#[serde(rename_all = "camelCase")]
77pub struct PlaceAnswerSources {
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub flag_content_uri: Option<String>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub review_snippets: Option<Vec<PlaceAnswerSourcesReviewSnippet>>,
82}
83
84/// Grounding chunk from Google Maps.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(rename_all = "camelCase")]
87pub struct MapsChunk {
88    pub uri: String,
89    pub title: String,
90    pub text: String,
91    pub place_id: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub place_answer_sources: Option<PlaceAnswerSources>,
94}
95
96/// Rag chunk page span.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(rename_all = "camelCase")]
99pub struct RagChunkPageSpan {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub first_page: Option<i32>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub last_page: Option<i32>,
104}
105
106/// RAG chunk.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108#[serde(rename_all = "camelCase")]
109pub struct RagChunk {
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub page_span: Option<RagChunkPageSpan>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub text: Option<String>,
114}
115
116/// Grounding chunk from retrieved context.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct RetrievedContextChunk {
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub document_name: Option<String>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub rag_chunk: Option<RagChunk>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub text: Option<String>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub title: Option<String>,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub uri: Option<String>,
130}
131
132/// Grounding chunk from web.
133#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct WebChunk {
136    pub uri: String,
137    pub title: String,
138}
139
140/// Grounding chunk union.
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase", untagged)]
143pub enum GroundingChunk {
144    Web {
145        web: WebChunk,
146    },
147    RetrievedContext {
148        retrieved_context: RetrievedContextChunk,
149    },
150    Maps {
151        maps: MapsChunk,
152    },
153}
154
155impl GroundingChunk {
156    /// 获取来源 URI(如果存在)。
157    #[must_use]
158    pub fn uri(&self) -> Option<&str> {
159        match self {
160            Self::Web { web } => Some(web.uri.as_str()),
161            Self::Maps { maps } => Some(maps.uri.as_str()),
162            Self::RetrievedContext { retrieved_context } => retrieved_context.uri.as_deref(),
163        }
164    }
165
166    /// 获取标题(如果存在)。
167    #[must_use]
168    pub fn title(&self) -> Option<&str> {
169        match self {
170            Self::Web { web } => Some(web.title.as_str()),
171            Self::Maps { maps } => Some(maps.title.as_str()),
172            Self::RetrievedContext { retrieved_context } => retrieved_context.title.as_deref(),
173        }
174    }
175}
176
177/// Segment of the content.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct Segment {
181    pub part_index: i32,
182    pub start_index: i32,
183    pub end_index: i32,
184    pub text: String,
185}
186
187/// Grounding support.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct GroundingSupport {
191    #[serde(default)]
192    pub grounding_chunk_indices: Vec<i32>,
193    #[serde(default)]
194    pub confidence_scores: Vec<f64>,
195    pub segment: Segment,
196}
197
198/// Retrieval metadata.
199#[derive(Debug, Clone, Serialize, Deserialize)]
200#[serde(rename_all = "camelCase")]
201pub struct RetrievalMetadata {
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub google_search_dynamic_retrieval_score: Option<f32>,
204}
205
206/// Google search entry point.
207#[derive(Debug, Clone, Serialize, Deserialize)]
208#[serde(rename_all = "camelCase")]
209pub struct SearchEntryPoint {
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub rendered_content: Option<String>,
212    #[serde(
213        default,
214        skip_serializing_if = "Option::is_none",
215        with = "base64_serde::option"
216    )]
217    pub sdk_blob: Option<Vec<u8>>,
218}
219
220/// Source flagging URI.
221#[derive(Debug, Clone, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GroundingMetadataSourceFlaggingUri {
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub flag_content_uri: Option<String>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub source_id: Option<String>,
228}
229
230/// Grounding 元数据。
231#[derive(Debug, Clone, Serialize, Deserialize, Default)]
232#[serde(rename_all = "camelCase")]
233pub struct GroundingMetadata {
234    #[serde(default)]
235    pub grounding_chunks: Vec<GroundingChunk>,
236    #[serde(default)]
237    pub grounding_supports: Vec<GroundingSupport>,
238    #[serde(default)]
239    pub web_search_queries: Vec<String>,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub search_entry_point: Option<SearchEntryPoint>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub retrieval_metadata: Option<RetrievalMetadata>,
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub google_maps_widget_context_token: Option<String>,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub retrieval_queries: Option<Vec<String>>,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub source_flagging_uris: Option<Vec<GroundingMetadataSourceFlaggingUri>>,
250}
251
252impl GroundingMetadata {
253    /// 生成带内联引用的文本(使用 `grounding_supports` 的 `segment.end_index` 位置插入引用序号)。
254    #[must_use]
255    pub fn add_citations(&self, text: &str) -> String {
256        if self.grounding_supports.is_empty() {
257            return text.to_string();
258        }
259
260        let mut positions =
261            std::collections::BTreeMap::<usize, std::collections::BTreeSet<i32>>::new();
262
263        for support in &self.grounding_supports {
264            let Ok(end_index) = usize::try_from(support.segment.end_index) else {
265                continue;
266            };
267            let Some(byte_end) = char_index_to_byte(text, end_index) else {
268                continue;
269            };
270
271            let entry = positions.entry(byte_end).or_default();
272            for idx in &support.grounding_chunk_indices {
273                if let Some(one_based) = idx.checked_add(1) {
274                    if one_based > 0 {
275                        entry.insert(one_based);
276                    }
277                }
278            }
279        }
280
281        if positions.is_empty() {
282            return text.to_string();
283        }
284
285        let mut output = text.to_string();
286        for (pos, indices) in positions.into_iter().rev() {
287            if pos > output.len() {
288                continue;
289            }
290            let label = indices
291                .into_iter()
292                .map(|value| value.to_string())
293                .collect::<Vec<_>>()
294                .join(",");
295            output.insert_str(pos, &format!(" [{label}]"));
296        }
297
298        output
299    }
300
301    /// 提取引用链接(按照 `grounding_chunks` 顺序去重)。
302    #[must_use]
303    pub fn citation_uris(&self) -> Vec<String> {
304        let mut seen = std::collections::HashSet::new();
305        let mut uris = Vec::new();
306
307        for chunk in &self.grounding_chunks {
308            if let Some(uri) = chunk.uri() {
309                if seen.insert(uri.to_string()) {
310                    uris.push(uri.to_string());
311                }
312            }
313        }
314
315        uris
316    }
317}
318
319fn char_index_to_byte(text: &str, index: usize) -> Option<usize> {
320    if index == 0 {
321        return Some(0);
322    }
323    let mut count = 0usize;
324    for (byte_idx, _) in text.char_indices() {
325        if count == index {
326            return Some(byte_idx);
327        }
328        count += 1;
329    }
330    if count == index {
331        Some(text.len())
332    } else {
333        None
334    }
335}
336
337// 兼容旧名称(避免外部依赖受影响)。
338pub type GroundingChunkMapsPlaceAnswerSources = PlaceAnswerSources;
339pub type GroundingChunkMapsPlaceAnswerSourcesReviewSnippet = PlaceAnswerSourcesReviewSnippet;
340pub type GroundingChunkMapsPlaceAnswerSourcesAuthorAttribution =
341    PlaceAnswerSourcesAuthorAttribution;
342pub type GroundingChunkMaps = MapsChunk;
343pub type GroundingChunkRetrievedContext = RetrievedContextChunk;
344pub type GroundingChunkWeb = WebChunk;
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use serde_json::json;
350
351    #[test]
352    fn grounding_chunk_uri_and_title() {
353        let web = GroundingChunk::Web {
354            web: WebChunk {
355                uri: "https://example.com".to_string(),
356                title: "Example".to_string(),
357            },
358        };
359        let maps = GroundingChunk::Maps {
360            maps: MapsChunk {
361                uri: "https://maps.example.com".to_string(),
362                title: "Map".to_string(),
363                text: "info".to_string(),
364                place_id: "place-1".to_string(),
365                place_answer_sources: None,
366            },
367        };
368        let retrieved = GroundingChunk::RetrievedContext {
369            retrieved_context: RetrievedContextChunk {
370                document_name: None,
371                rag_chunk: None,
372                text: None,
373                title: Some("Doc".to_string()),
374                uri: Some("https://doc.example.com".to_string()),
375            },
376        };
377
378        assert_eq!(web.uri(), Some("https://example.com"));
379        assert_eq!(maps.title(), Some("Map"));
380        assert_eq!(retrieved.uri(), Some("https://doc.example.com"));
381        assert_eq!(retrieved.title(), Some("Doc"));
382    }
383
384    #[test]
385    fn search_entry_point_base64_roundtrip() {
386        let entry = SearchEntryPoint {
387            rendered_content: Some("rendered".to_string()),
388            sdk_blob: Some(vec![1, 2, 3]),
389        };
390        let value = serde_json::to_value(&entry).unwrap();
391        assert_eq!(
392            value,
393            json!({
394                "renderedContent": "rendered",
395                "sdkBlob": "AQID"
396            })
397        );
398
399        let decoded: SearchEntryPoint = serde_json::from_value(value).unwrap();
400        assert_eq!(decoded.sdk_blob, Some(vec![1, 2, 3]));
401    }
402
403    #[test]
404    fn grounding_metadata_add_citations_and_uris() {
405        let metadata = GroundingMetadata {
406            grounding_chunks: vec![
407                GroundingChunk::Web {
408                    web: WebChunk {
409                        uri: "https://a.example".to_string(),
410                        title: "A".to_string(),
411                    },
412                },
413                GroundingChunk::Web {
414                    web: WebChunk {
415                        uri: "https://b.example".to_string(),
416                        title: "B".to_string(),
417                    },
418                },
419            ],
420            grounding_supports: vec![GroundingSupport {
421                grounding_chunk_indices: vec![0, 1],
422                confidence_scores: vec![0.9],
423                segment: Segment {
424                    part_index: 0,
425                    start_index: 0,
426                    end_index: 2,
427                    text: "hi".to_string(),
428                },
429            }],
430            ..Default::default()
431        };
432
433        let cited = metadata.add_citations("hi!");
434        assert_eq!(cited, "hi [1,2]!");
435        let uris = metadata.citation_uris();
436        assert_eq!(uris, vec!["https://a.example", "https://b.example"]);
437    }
438}