1pub mod structs;
58pub mod tutorials;
59
60use crate::structs::*;
61use anyhow::{Error, Result};
62use dotenvy::dotenv;
63use fxhash::FxHashMap;
64use indicatif::ProgressBar;
65use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC as NON_ALNUM};
66use reqwest::{self as request, header};
67
68#[cfg(test)]
69mod tests;
70
71fn encode(s: &str) -> String {
72 utf8_percent_encode(s, NON_ALNUM).to_string()
73}
74
75#[derive(Clone, Debug, Default)]
76pub struct QueryParams {
77 pub paper_id: String,
78 pub query_text: Option<String>,
79 pub fields: Option<Vec<PaperField>>,
80 pub publication_types: Option<Vec<PublicationTypes>>,
81 pub open_access_pdf: Option<bool>,
82 pub min_citation_count: Option<u32>,
83 pub publication_date_or_year: Option<String>,
84 pub year: Option<String>,
85 pub venue: Option<Vec<String>>,
86 pub fields_of_study: Option<Vec<FieldsOfStudy>>,
87 pub offset: Option<u64>,
88 pub limit: Option<u64>,
89 pub token: Option<String>,
90 pub sort: Option<String>,
91}
92
93impl QueryParams {
94 pub fn paper_id(&mut self, paper_id: &str) -> &mut Self {
95 self.paper_id = paper_id.to_string();
96 self
97 }
98 pub fn query_text(&mut self, query_text: &str) -> &mut Self {
99 self.query_text = Some(query_text.to_string());
100 self
101 }
102 pub fn fields(&mut self, fields: Vec<PaperField>) -> &mut Self {
103 self.fields = Some(fields);
104 self
105 }
106 pub fn publication_types(&mut self, publication_types: Vec<PublicationTypes>) -> &mut Self {
107 self.publication_types = Some(publication_types);
108 self
109 }
110 pub fn open_access_pdf(&mut self, open_access_pdf: bool) -> &mut Self {
111 self.open_access_pdf = Some(open_access_pdf);
112 self
113 }
114 pub fn min_citation_count(&mut self, min_citation_count: u32) -> &mut Self {
115 self.min_citation_count = Some(min_citation_count);
116 self
117 }
118 pub fn publication_date_or_year(&mut self, publication_date_or_year: &str) -> &mut Self {
119 self.publication_date_or_year = Some(publication_date_or_year.to_string());
120 self
121 }
122 pub fn year(&mut self, year: &str) -> &mut Self {
123 self.year = Some(year.to_string());
124 self
125 }
126 pub fn venue(&mut self, venue: Vec<&str>) -> &mut Self {
127 let venue: Vec<String> = venue.iter().map(|v| v.to_string()).collect();
128 self.venue = Some(venue);
129 self
130 }
131 pub fn fields_of_study(&mut self, fields_of_study: Vec<FieldsOfStudy>) -> &mut Self {
132 self.fields_of_study = Some(fields_of_study);
133 self
134 }
135 pub fn offset(&mut self, offset: u64) -> &mut Self {
136 self.offset = Some(offset);
137 self
138 }
139 pub fn limit(&mut self, limit: u64) -> &mut Self {
140 self.limit = Some(limit);
141 self
142 }
143 pub fn token(&mut self, token: &str) -> &mut Self {
144 self.token = Some(token.to_string());
145 self
146 }
147 pub fn sort(&mut self, sort: &str) -> &mut Self {
148 self.sort = Some(sort.to_string());
149 self
150 }
151
152 fn fields2string(&self, fields: Vec<PaperField>) -> String {
153 fields
154 .iter()
155 .map(|field| encode(&field.to_string()))
156 .collect::<Vec<String>>()
157 .join(",")
158 }
159
160 fn publication_types2string(&self, publication_types: Vec<PublicationTypes>) -> String {
161 publication_types
162 .iter()
163 .map(|publication_type| encode(&publication_type.to_string()))
164 .collect::<Vec<String>>()
165 .join(",")
166 }
167
168 fn fields_of_study2string(&self, fields_of_study: Vec<FieldsOfStudy>) -> String {
169 fields_of_study
170 .iter()
171 .map(|field| encode(&field.to_string()))
172 .collect::<Vec<String>>()
173 .join(",")
174 }
175
176 pub fn build(&self) -> String {
177 let mut query_params = Vec::new();
178
179 if let Some(query_text) = &self.query_text {
180 query_params.push(format!("query={}", encode(query_text)));
181 }
182 if let Some(fields) = &self.fields {
183 let fields = self.fields2string(fields.clone());
184 query_params.push(format!("fields={}", fields));
185 }
186 if let Some(publication_types) = &self.publication_types {
187 let publication_types = self.publication_types2string(publication_types.clone());
188 query_params.push(format!("publicationTypes={}", publication_types));
189 }
190 if self.open_access_pdf.is_some() {
191 query_params.push("openAccessPdf".to_string());
192 }
193 if let Some(min_citation_count) = &self.min_citation_count {
194 query_params.push(format!("minCitationCount={}", min_citation_count));
195 }
196 if let Some(publication_date_or_year) = &self.publication_date_or_year {
197 query_params.push(format!(
198 "publicationDateOrYear={}",
199 publication_date_or_year
200 ));
201 }
202 if let Some(year) = &self.year {
203 query_params.push(format!("year={}", year));
204 }
205 if let Some(venue) = &self.venue {
206 let venue = venue
207 .iter()
208 .map(|v| encode(v))
209 .collect::<Vec<String>>()
210 .join(",");
211 query_params.push(format!("venue={}", venue));
212 }
213 if let Some(fields_of_study) = &self.fields_of_study {
214 let fields_of_study = self.fields_of_study2string(fields_of_study.clone());
215 query_params.push(format!("fieldsOfStudy={}", fields_of_study));
216 }
217 if let Some(offset) = &self.offset {
218 query_params.push(format!("offset={}", offset));
219 }
220 if let Some(limit) = &self.limit {
221 query_params.push(format!("limit={}", limit));
222 }
223 if let Some(token) = &self.token {
224 query_params.push(format!("token={}", token));
225 }
226 if let Some(sort) = &self.sort {
227 query_params.push(format!("sort={}", sort));
228 }
229
230 if query_params.is_empty() {
231 return "".to_string();
232 } else {
233 let query_params = query_params.join("&");
234 return format!("?{}", query_params);
235 }
236 }
237}
238
239#[derive(Clone, Debug, Default)]
240pub struct SemanticScholar {
241 pub api_key: String,
242}
243
244impl SemanticScholar {
245 pub fn new() -> Self {
246 dotenv().ok();
247 let vars = FxHashMap::from_iter(std::env::vars().into_iter().map(|(k, v)| (k, v)));
248 let api_key = vars
249 .get("SEMANTIC_SCHOLAR_API_KEY")
250 .unwrap_or(&"".to_string())
251 .to_string();
252 Self { api_key: api_key }
253 }
254
255 fn get_url(&self, endpoint: Endpoint, query_params: &mut QueryParams) -> String {
256 let paper_id = query_params.paper_id.clone();
257 let query_params = query_params.build();
258 match endpoint {
259 Endpoint::GetMultiplePpaerDetails => {
260 return format!(
261 "https://api.semanticscholar.org/graph/v1/paper/batch{}",
262 query_params
263 );
264 }
265 Endpoint::GetAPaperByTitle => {
266 let url = format!(
267 "https://api.semanticscholar.org/graph/v1/paper/search/match{}",
268 query_params
269 );
270 return url;
271 }
272 Endpoint::GetPapersByTitle => {
273 let url = format!(
274 "https://api.semanticscholar.org/graph/v1/paper/search{}",
275 query_params
276 );
277 return url;
278 }
279 Endpoint::GetPaperDetails => {
280 let url = format!(
281 "https://api.semanticscholar.org/graph/v1/paper/{}{}",
282 paper_id, query_params
283 );
284 return url;
285 }
286 Endpoint::GetAuthorDetails => {
287 let url = format!(
288 "https://api.semanticscholar.org/graph/v1/author/{}{}",
289 paper_id, query_params
290 );
291 return url;
292 }
293 Endpoint::GetReferencesOfAPaper => {
294 let url = format!(
295 "https://api.semanticscholar.org/graph/v1/paper/{}/references{}",
296 paper_id, query_params
297 );
298 return url;
299 }
300 Endpoint::GetCitationsOfAPaper => {
301 let url = format!(
302 "https://api.semanticscholar.org/graph/v1/paper/{}/citations{}",
303 paper_id, query_params
304 );
305 return url;
306 }
307 }
308 }
309
310 fn sleep(&self, seconds: u64, message: &str) {
311 let pb = ProgressBar::new(seconds);
312 pb.set_style(
313 indicatif::ProgressStyle::default_bar()
314 .template(
315 "{spinner:.green} [{elapsed_precise}] [{bar:40.green/cyan}] {pos}s/{len}s {msg}",
316 )
317 .unwrap()
318 .progress_chars("█▓▒░"),
319 );
320 if message.is_empty() {
321 pb.set_message("Waiting for the next request...");
322 } else {
323 pb.set_message(message.to_string());
324 }
325 for _ in 0..seconds {
326 pb.inc(1);
327 std::thread::sleep(std::time::Duration::from_secs(1));
328 }
329 pb.finish_and_clear();
330 }
331
332 pub async fn bulk_query_by_ids(
366 &mut self,
367 paper_ids: Vec<&str>,
368 fields: Vec<PaperField>,
369 max_retry_count: u64,
370 wait_time: u64,
371 ) -> Result<Vec<Paper>> {
372 let mut max_retry_count = max_retry_count.clone();
373
374 let mut headers = header::HeaderMap::new();
375 headers.insert("Content-Type", "application/json".parse().unwrap());
376 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
377 if !self.api_key.is_empty() {
378 headers.insert("x-api-key", self.api_key.parse().unwrap());
379 }
380 let client = request::Client::builder()
381 .default_headers(headers)
382 .build()
383 .unwrap();
384
385 let mut query_params = QueryParams::default();
386 query_params.fields(fields.clone());
387 let url = self.get_url(Endpoint::GetMultiplePpaerDetails, &mut query_params);
388 let body = format!(
389 "{{\"ids\":[{}]}}",
390 paper_ids
391 .iter()
392 .map(|v| format!("\"{}\"", v))
393 .collect::<Vec<String>>()
394 .join(",")
395 );
396
397 loop {
398 if max_retry_count == 0 {
399 return Err(Error::msg("Failed to get papers"));
400 }
401 let body = client
402 .post(url.clone())
403 .body(body.clone())
404 .send()
405 .await?
406 .text()
407 .await?;
408
409 match serde_json::from_str::<Vec<Paper>>(&body) {
410 Ok(response) => {
411 return Ok(response);
412 }
413 Err(e) => {
414 max_retry_count -= 1;
415 self.sleep(
416 wait_time,
417 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
418 );
419 continue;
420 }
421 }
422 }
423 }
424
425 pub async fn query_papers_by_title(
455 &mut self,
456 query_params: QueryParams,
457 max_retry_count: u64,
458 wait_time: u64,
459 ) -> Result<Vec<Paper>> {
460 let mut query_params = query_params.clone();
461 let mut max_retry_count = max_retry_count.clone();
462
463 let mut headers = header::HeaderMap::new();
464 headers.insert("Content-Type", "application/json".parse().unwrap());
465 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
466 if !self.api_key.is_empty() {
467 headers.insert("x-api-key", self.api_key.parse().unwrap());
468 }
469 let client = request::Client::builder()
470 .default_headers(headers)
471 .build()
472 .unwrap();
473
474 let url = self.get_url(Endpoint::GetPapersByTitle, &mut query_params);
475
476 loop {
477 if max_retry_count == 0 {
478 return Err(Error::msg(format!(
479 "Failed to get paper id for: {}",
480 query_params.query_text.unwrap().clone()
481 )));
482 }
483
484 let body = client.get(url.clone()).send().await?.text().await?;
485 match serde_json::from_str::<PaperIds>(&body) {
486 Ok(response) => {
487 if response.data.is_empty() || response.total == 0 {
488 max_retry_count -= 1;
489 self.sleep(
490 wait_time,
491 format!("Error: Response is empty. Body: {}", &body).as_str(),
492 );
493 continue;
494 }
495 return Ok(response.data);
496 }
497 Err(e) => {
498 max_retry_count -= 1;
499 self.sleep(
500 wait_time,
501 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
502 );
503 continue;
504 }
505 }
506 }
507 }
508
509 pub async fn query_a_paper_by_title(
538 &mut self,
539 query_params: QueryParams,
540 max_retry_count: u64,
541 wait_time: u64,
542 ) -> Result<Paper> {
543 let mut query_params = query_params.clone();
544 let mut max_retry_count = max_retry_count.clone();
545
546 let mut headers = header::HeaderMap::new();
547 headers.insert("Content-Type", "application/json".parse().unwrap());
548 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
549 if !self.api_key.is_empty() {
550 headers.insert("x-api-key", self.api_key.parse()?);
551 }
552 let client = request::Client::builder()
553 .default_headers(headers)
554 .build()?;
555
556 let url = self.get_url(Endpoint::GetAPaperByTitle, &mut query_params);
557 loop {
558 if max_retry_count == 0 {
559 return Err(Error::msg(format!(
560 "Failed to get paper id for: {}",
561 query_params.query_text.unwrap()
562 )));
563 }
564
565 let body = client.get(url.clone()).send().await?.text().await?;
566 match serde_json::from_str::<PaperIds>(&body) {
567 Ok(response) => {
568 if response.data.len() < 1 {
569 max_retry_count -= 1;
570 self.sleep(
571 wait_time,
572 format!("Error: Response is empty. Body: {}", &body).as_str(),
573 );
574 continue;
575 }
576 let paper = response.data.first().unwrap().clone();
577 return Ok(paper);
578 }
579 Err(e) => {
580 max_retry_count -= 1;
581 self.sleep(
582 wait_time,
583 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
584 );
585 continue;
586 }
587 }
588 }
589 }
590
591 pub async fn query_paper_details(
621 &mut self,
622 query_params: QueryParams,
623 max_retry_count: u64,
624 wait_time: u64,
625 ) -> Result<Paper> {
626 let mut query_params = query_params.clone();
627 let mut max_retry_count = max_retry_count.clone();
628
629 let mut fields = query_params.fields.clone().unwrap_or_default();
630 if !fields.contains(&PaperField::PaperId) {
631 fields.push(PaperField::PaperId);
632 query_params.fields = Some(fields);
633 }
634
635 let mut headers = header::HeaderMap::new();
636 headers.insert("Content-Type", "application/json".parse().unwrap());
637 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
638 if !self.api_key.is_empty() {
639 headers.insert("x-api-key", self.api_key.parse().unwrap());
640 }
641 let client = request::Client::builder()
642 .default_headers(headers)
643 .build()
644 .unwrap();
645
646 let url = self.get_url(Endpoint::GetPaperDetails, &mut query_params);
647 loop {
648 if max_retry_count == 0 {
649 return Err(Error::msg(format!(
650 "Failed to get paper details: {}",
651 query_params.paper_id
652 )));
653 }
654 let body = client.get(url.clone()).send().await?.text().await?;
655 match serde_json::from_str::<Paper>(&body) {
656 Ok(response) => {
657 return Ok(response);
658 }
659 Err(e) => {
660 max_retry_count -= 1;
661 self.sleep(
662 wait_time,
663 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
664 );
665 continue;
666 }
667 }
668 }
669 }
670
671 pub async fn query_paper_citations(
672 &mut self,
673 query_params: QueryParams,
674 max_retry_count: u64,
675 wait_time: u64,
676 ) -> Result<ResponsePapers> {
677 let mut query_params = query_params.clone();
678 let mut max_retry_count = max_retry_count.clone();
679
680 let mut fields = query_params.fields.clone().unwrap_or_default();
681 if !fields.contains(&PaperField::PaperId) {
682 fields.push(PaperField::PaperId);
683 query_params.fields = Some(fields);
684 }
685
686 let mut headers = header::HeaderMap::new();
687 headers.insert("Content-Type", "application/json".parse().unwrap());
688 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
689 if !self.api_key.is_empty() {
690 headers.insert("x-api-key", self.api_key.parse().unwrap());
691 }
692 let client = request::Client::builder()
693 .default_headers(headers)
694 .build()
695 .unwrap();
696
697 let url = self.get_url(Endpoint::GetCitationsOfAPaper, &mut query_params);
698
699 loop {
700 if max_retry_count == 0 {
701 return Err(Error::msg(format!(
702 "Failed to get paper citations: {}",
703 query_params.paper_id
704 )));
705 }
706 match client.get(url.clone()).send().await {
707 Ok(response) => {
708 let body = response.text().await?;
709 match serde_json::from_str::<ResponsePapers>(&body) {
710 Ok(response) => {
711 return Ok(response);
712 }
713 Err(e) => {
714 max_retry_count -= 1;
715 self.sleep(
716 wait_time,
717 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
718 );
719 continue;
720 }
721 }
722 }
723 Err(e) => {
724 max_retry_count -= 1;
725 self.sleep(wait_time, &e.to_string());
726 continue;
727 }
728 }
729 }
730 }
731
732 pub async fn query_paper_references(
733 &mut self,
734 query_params: QueryParams,
735 max_retry_count: u64,
736 wait_time: u64,
737 ) -> Result<ResponsePapers> {
738 let mut query_params = query_params.clone();
739 let mut max_retry_count = max_retry_count.clone();
740
741 let mut fields = query_params.fields.clone().unwrap_or_default();
742 if !fields.contains(&PaperField::PaperId) {
743 fields.push(PaperField::PaperId);
744 query_params.fields = Some(fields);
745 }
746
747 let mut headers = header::HeaderMap::new();
748 headers.insert("Content-Type", "application/json".parse().unwrap());
749 headers.insert("user-agent", "ss-tools/0.1".parse().unwrap());
750 if !self.api_key.is_empty() {
751 headers.insert("x-api-key", self.api_key.parse().unwrap());
752 }
753 let client = request::Client::builder()
754 .default_headers(headers)
755 .build()
756 .unwrap();
757
758 let url = self.get_url(Endpoint::GetReferencesOfAPaper, &mut query_params);
759 loop {
760 if max_retry_count == 0 {
761 return Err(Error::msg(format!(
762 "Failed to get paper references: {}",
763 query_params.paper_id
764 )));
765 }
766
767 match client.get(url.clone()).send().await {
768 Ok(response) => {
769 let body = response.text().await?;
770 match serde_json::from_str::<ResponsePapers>(&body) {
771 Ok(response) => {
772 return Ok(response);
773 }
774 Err(e) => {
775 max_retry_count -= 1;
776 self.sleep(
777 wait_time,
778 format!("Error: {} Body: {}", &e.to_string(), &body).as_str(),
779 );
780 continue;
781 }
782 }
783 }
784 Err(e) => {
785 max_retry_count -= 1;
786 self.sleep(wait_time, &e.to_string());
787 continue;
788 }
789 }
790 }
791 }
792}