ssr_algorithms/fsrs/
mod.rs

1use fsrs::{FSRSItem, FSRS};
2use s_text_input_f::{BlocksWithAnswer, ParagraphItem};
3use serde::{Deserialize, Serialize};
4use ssr_core::task::level::TaskLevel;
5use std::time::SystemTime;
6
7use s_text_input_f as stif;
8
9mod level;
10use level::{Level, Quality, RepetitionContext};
11
12#[derive(Serialize, Deserialize, Debug, Clone)]
13pub struct Task {
14    level: Option<Level>,
15    input_blocks: s_text_input_f::Blocks,
16    correct_answer: s_text_input_f::Response,
17    #[serde(default)]
18    other_answers: Vec<s_text_input_f::Response>,
19}
20
21#[derive(Serialize, Deserialize, Debug)]
22pub struct Shared {
23    weights: [f32; 19],
24}
25impl Default for Shared {
26    fn default() -> Self {
27        Self {
28            weights: fsrs::DEFAULT_PARAMETERS,
29        }
30    }
31}
32impl ssr_core::task::SharedState<'_> for Shared {}
33use itertools::Itertools;
34fn extract_first_long_term_reviews<'a>(
35    items: impl IntoIterator<Item = &'a FSRSItem>,
36) -> Vec<FSRSItem> {
37    items
38        .into_iter()
39        .filter_map(|i| {
40            let a = i
41                .reviews
42                .iter()
43                .take_while_inclusive(|r| r.delta_t < 1)
44                .copied()
45                .collect_vec();
46            if a.last()?.delta_t < 1 || a.len() == i.reviews.len() {
47                return None;
48            }
49            Some(FSRSItem { reviews: a })
50        })
51        .collect()
52}
53
54impl ssr_core::task::SharedStateExt<'_, Task> for Shared {
55    fn optimize<'b>(
56        &mut self,
57        tasks: impl IntoIterator<Item = &'b Task>,
58    ) -> Result<(), Box<dyn std::error::Error>>
59    where
60        Task: 'b,
61    {
62        let mut tasks = tasks
63            .into_iter()
64            .filter_map(|t| t.level.as_ref())
65            .map(|x| x.history.clone())
66            .filter(|x| x.reviews.iter().any(|r| r.delta_t != 0))
67            .collect::<Vec<_>>();
68        tasks.extend(extract_first_long_term_reviews(&tasks));
69        let fsrs = FSRS::new(None)?;
70        let best_params: [f32; 19] = fsrs
71            .compute_parameters(tasks, None)?
72            .try_into()
73            .expect("fsrs library should return exactly '19' weights");
74        self.weights = best_params;
75        Ok(())
76    }
77}
78
79impl ssr_core::task::Task<'_> for Task {
80    type SharedState = Shared;
81
82    fn next_repetition(
83        &self,
84        shared_state: &Self::SharedState,
85        retrievability_goal: f64,
86    ) -> SystemTime {
87        if let Some(ref level) = self.level {
88            level.next_repetition(shared_state, retrievability_goal)
89        } else {
90            SystemTime::UNIX_EPOCH
91        }
92    }
93
94    fn complete(
95        &mut self,
96        shared_state: &mut Self::SharedState,
97        desired_retention: f64,
98        interaction: &mut impl FnMut(
99            s_text_input_f::Blocks,
100        ) -> std::io::Result<s_text_input_f::Response>,
101    ) -> std::io::Result<()> {
102        let review_time = chrono::Local::now();
103        let user_answer = interaction(self.input_blocks.clone())?;
104        let quality =
105            self.complete_inner(user_answer, shared_state, desired_retention, interaction)?;
106        if let Some(ref mut level) = self.level {
107            level.update(
108                shared_state,
109                RepetitionContext {
110                    quality,
111                    review_time,
112                },
113            );
114        } else {
115            self.level = Some(Level::new(quality, review_time));
116        }
117        Ok(())
118    }
119
120    fn new(input: s_text_input_f::BlocksWithAnswer) -> Self {
121        Self {
122            level: None,
123            input_blocks: input.blocks,
124            correct_answer: input.answer,
125            other_answers: Vec::new(),
126        }
127    }
128
129    fn get_blocks(&self) -> s_text_input_f::BlocksWithAnswer {
130        BlocksWithAnswer {
131            blocks: self.input_blocks.clone(),
132            answer: self.correct_answer.clone(),
133        }
134    }
135}
136
137pub enum Correctness {
138    Wrong,
139    DefaultCorrect,
140    OtherCorrect { index: usize },
141}
142impl Correctness {
143    pub fn is_correct(&self) -> bool {
144        match self {
145            Correctness::Wrong => false,
146            Correctness::DefaultCorrect => true,
147            Correctness::OtherCorrect { index: _ } => true,
148        }
149    }
150}
151
152impl Task {
153    pub fn new(
154        input_blocks: s_text_input_f::Blocks,
155        correct_answer: s_text_input_f::Response,
156    ) -> Self {
157        Self {
158            level: Default::default(),
159            input_blocks,
160            correct_answer,
161            other_answers: Vec::new(),
162        }
163    }
164    fn gen_feedback_form(
165        &mut self,
166        user_answer: Vec<Vec<String>>,
167        directive: String,
168        qualities_strings: Vec<String>,
169    ) -> Vec<s_text_input_f::Block> {
170        let correct_answer = match self.correctness(&user_answer) {
171            Correctness::Wrong | Correctness::DefaultCorrect => self.correct_answer.clone(),
172            Correctness::OtherCorrect { index } => self.other_answers[index].clone(),
173        };
174        let mut feedback =
175            s_text_input_f::to_answered(self.input_blocks.clone(), user_answer, correct_answer)
176                .into_iter()
177                .map(s_text_input_f::Block::Answered)
178                .collect::<Vec<_>>();
179        feedback.push(s_text_input_f::Block::Paragraph(vec![]));
180        feedback.push(s_text_input_f::Block::Paragraph(vec![ParagraphItem::Text(
181            directive,
182        )]));
183        feedback.push(s_text_input_f::Block::OneOf(qualities_strings));
184        feedback
185    }
186
187    fn get_feedback<T: Copy>(
188        &mut self,
189        user_answer: Vec<Vec<String>>,
190        directive: String,
191        qualities_strings: Vec<String>,
192        interaction: &mut impl FnMut(
193            Vec<s_text_input_f::Block>,
194        ) -> Result<Vec<Vec<String>>, std::io::Error>,
195        qualities: Vec<T>,
196    ) -> Result<T, std::io::Error> {
197        let feedback = self.gen_feedback_form(user_answer, directive, qualities_strings);
198        let user_feedback = interaction(feedback)?;
199        let i = s_text_input_f::response_as_one_of(user_feedback.last().unwrap().to_owned())
200            .unwrap()
201            .unwrap();
202        let quality = qualities[i];
203        Ok(quality)
204    }
205
206    fn complete_inner(
207        &mut self,
208        user_answer: Vec<Vec<String>>,
209        shared_state: &Shared,
210        retrievability_goal: f64,
211        interaction: &mut impl FnMut(s_text_input_f::Blocks) -> std::io::Result<Vec<Vec<String>>>,
212    ) -> std::io::Result<Quality> {
213        let next_states = self.next_states(shared_state, retrievability_goal);
214        Ok(match self.correctness(&user_answer).is_correct() {
215            true => self.feedback_correct(user_answer, next_states, interaction)?,
216            false => self.feedback_wrong(user_answer, next_states, interaction)?,
217        })
218    }
219    fn correctness(&mut self, user_answer: &Vec<Vec<String>>) -> Correctness {
220        if stif::eq_response(&self.correct_answer, user_answer, true, false) {
221            return Correctness::DefaultCorrect;
222        }
223        for (index, ans) in self.other_answers.iter().enumerate() {
224            if stif::eq_response(ans, user_answer, true, false) {
225                return Correctness::OtherCorrect { index };
226            }
227        }
228        Correctness::Wrong
229    }
230
231    fn next_states(&self, shared: &Shared, retrievability_goal: f64) -> fsrs::NextStates {
232        let fsrs = level::fsrs(shared);
233        let now = chrono::Local::now();
234        fsrs.next_states(
235            self.level.as_ref().map(|l| l.memory_state(&fsrs)),
236            retrievability_goal as f32,
237            level::sleeps_between(self.level.as_ref().map_or(now, |l| l.last_review), now)
238                .try_into()
239                .unwrap(),
240        )
241        .unwrap()
242    }
243
244    fn feedback_correct(
245        &mut self,
246        user_answer: Vec<Vec<String>>,
247        next_states: fsrs::NextStates,
248        interaction: &mut impl FnMut(s_text_input_f::Blocks) -> std::io::Result<Vec<Vec<String>>>,
249    ) -> std::io::Result<Quality> {
250        let qualities = vec![Quality::Hard, Quality::Good, Quality::Easy];
251        let qualities_strings = vec![
252            format!("Hard {}d", next_states.hard.interval),
253            format!("Good {}d", next_states.good.interval),
254            format!("Easy {}d", next_states.easy.interval),
255        ];
256        let directive = "All answers correct! Choose difficulty:".to_string();
257        self.get_feedback(
258            user_answer,
259            directive,
260            qualities_strings,
261            interaction,
262            qualities,
263        )
264    }
265
266    fn feedback_wrong(
267        &mut self,
268        user_answer: Vec<Vec<String>>,
269        next_states: fsrs::NextStates,
270        interaction: &mut impl FnMut(s_text_input_f::Blocks) -> std::io::Result<Vec<Vec<String>>>,
271    ) -> std::io::Result<Quality> {
272        #[derive(Clone, Copy)]
273        enum Feedback {
274            Wrong,
275            ActuallyCorrect,
276        }
277        let result = self.get_feedback(
278            user_answer.clone(),
279            "Your answer is wrong.".into(),
280            vec![
281                format!("OK {}h", next_states.again.interval * 24.),
282                "It is actually correct".into(),
283            ],
284            interaction,
285            vec![Feedback::Wrong, Feedback::ActuallyCorrect],
286        )?;
287        match result {
288            Feedback::Wrong => Ok(Quality::Again),
289            Feedback::ActuallyCorrect => {
290                self.other_answers.push(user_answer.clone());
291                self.feedback_correct(user_answer, next_states, interaction)
292            }
293        }
294    }
295}