tfhe_hpu_backend/interface/
variable.rs1use super::*;
5use crate::asm::iop::VarMode;
6use crate::entities::{HpuLweCiphertextOwned, HpuParameters};
7use crate::ffi;
8use std::sync::{mpsc, Arc, Mutex};
9
10#[derive(Debug)]
11enum SyncState {
12 None,
13 CpuSync,
14 HpuSync,
15 BothSync,
16}
17
18pub(crate) struct HpuVar {
19 bundle: memory::CiphertextBundle,
20 state: SyncState,
21 pending: usize,
22}
23
24impl std::fmt::Debug for HpuVar {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(
27 f,
28 "HpuVar<{{state: {:?}, bundle: {:?}}}",
29 self.state, self.bundle
30 )
31 }
32}
33
34impl HpuVar {
36 pub fn try_cpu_sync(&mut self) -> Result<(), HpuInternalError> {
37 if self.pending > 0 {
38 Err(HpuInternalError::OperationPending)
39 } else {
40 match self.state {
41 SyncState::CpuSync | SyncState::BothSync => Ok(()),
42 SyncState::HpuSync => {
43 for slot in self.bundle.iter_mut() {
44 slot.mz
45 .iter_mut()
46 .for_each(|mz| mz.sync(ffi::SyncMode::Device2Host));
47 }
48 self.state = SyncState::BothSync;
49 Ok(())
50 }
51 SyncState::None => Err(HpuInternalError::UninitData),
52 }
53 }
54 }
55
56 pub(crate) fn try_hpu_sync(&mut self) -> Result<(), HpuInternalError> {
57 match self.state {
60 SyncState::None => {
61 if self.pending > 0 {
62 Ok(()) } else {
64 Err(HpuInternalError::UninitData)
65 }
66 }
67 SyncState::HpuSync | SyncState::BothSync => Ok(()),
68 SyncState::CpuSync => {
69 for slot in self.bundle.iter_mut() {
70 slot.mz
71 .iter_mut()
72 .for_each(|mz| mz.sync(ffi::SyncMode::Host2Device));
73 }
74 self.state = if self.pending > 0 {
75 SyncState::HpuSync
76 } else {
77 SyncState::BothSync
78 };
79 Ok(())
80 }
81 }
82 }
83}
84
85impl HpuVar {
86 pub(crate) fn operation_pending(&mut self) {
87 self.pending += 1;
88 }
89 pub(crate) fn operation_done(&mut self) {
90 if self.pending > 0 {
91 self.pending -= 1;
92 self.state = SyncState::HpuSync;
93 } else {
94 panic!("`operation_done` called on variable without pending operations");
95 }
96 }
97}
98
99#[derive(Clone)]
100pub struct HpuVarWrapped {
101 pub(crate) inner: Arc<Mutex<HpuVar>>,
102 pub(crate) id: memory::ciphertext::SlotId,
103 pub(crate) pool: memory::CiphertextMemory,
105 pub(crate) cmd_api: mpsc::Sender<cmd::HpuCmd>,
107 pub(crate) params: Arc<HpuParameters>,
108 pub(crate) width: usize,
109 pub(crate) mode: VarMode,
110}
111
112impl std::fmt::Debug for HpuVarWrapped {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "HpuVarWrapped{{ {:?} }}", self.id)
115 }
116}
117
118impl HpuVarWrapped {
120 fn new_in(
121 pool: memory::CiphertextMemory,
122 cmd_api: mpsc::Sender<cmd::HpuCmd>,
123 params: Arc<HpuParameters>,
124 width: usize,
125 mode: VarMode,
126 ) -> Self {
127 let bundle = pool.get_bundle(width);
128
129 Self {
130 id: *bundle.id(),
131 pool,
132 cmd_api,
133 params,
134 width,
135 mode,
136 inner: Arc::new(Mutex::new(HpuVar {
137 bundle,
138 state: SyncState::None,
139 pending: 0,
140 })),
141 }
142 }
143
144 pub(crate) fn new_from(
145 pool: memory::CiphertextMemory,
146 cmd_api: mpsc::Sender<cmd::HpuCmd>,
147 params: Arc<HpuParameters>,
148 ct: Vec<HpuLweCiphertextOwned<u64>>,
149 mode: VarMode,
150 ) -> Self {
151 let var = Self::new_in(pool, cmd_api, params, ct.len(), mode);
152
153 {
157 let mut inner = var.inner.lock().unwrap();
158
159 for (slot, ct) in std::iter::zip(inner.bundle.iter_mut(), ct.into_iter()) {
160 #[cfg(feature = "io-dump")]
161 let params = ct.params().clone();
162 for (id, cut) in ct.into_container().iter().enumerate() {
163 slot.mz[id].write(0, cut);
164 #[cfg(feature = "io-dump")]
165 io_dump::dump(
166 &cut.as_slice(),
167 ¶ms,
168 io_dump::DumpKind::BlweIn,
169 io_dump::DumpId::Slot(slot.id, id),
170 );
171 }
172 }
173 inner.state = SyncState::CpuSync;
174 }
175 var
176 }
177
178 pub(crate) fn fork(&self, trgt_mode: VarMode) -> Self {
181 let Self {
182 pool,
183 cmd_api,
184 params,
185 width,
186 mode,
187 ..
188 } = self.clone();
189
190 let width = match (&mode, &trgt_mode) {
191 (_, VarMode::Bool) => 1,
192 (VarMode::Native, VarMode::Native) => width,
193 (VarMode::Native, VarMode::Half) => width / 2,
194 (VarMode::Half, VarMode::Native) => 2 * width,
195 (VarMode::Half, VarMode::Half) => width,
196 _ => panic!("Unsupported mode, couldn't use a Boolean to build a bigger variable"),
197 };
198 Self::new_in(pool, cmd_api, params, width, trgt_mode)
199 }
200
201 pub fn try_into(self) -> Result<Vec<HpuLweCiphertextOwned<u64>>, HpuError> {
202 let mut inner = self.inner.lock().unwrap();
204 match inner.try_cpu_sync() {
205 Ok(_) => {}
206 Err(err) => {
207 drop(inner);
208 match err {
209 HpuInternalError::OperationPending => return Err(HpuError::SyncPending(self)),
210 HpuInternalError::UninitData => {
211 panic!("Encounter unrecoverable HpuInternalError: {err:?}")
212 }
213 }
214 }
215 }
216
217 let mut ct = Vec::new();
218
219 for slot in inner.bundle.iter() {
220 let mut hpu_lwe = HpuLweCiphertextOwned::<u64>::new(0, (*self.params).clone());
223 let mut hw_slice = hpu_lwe.as_mut_view().into_container();
224
225 #[allow(unused_variables)]
227 std::iter::zip(slot.mz.iter(), hw_slice.iter_mut())
228 .enumerate()
229 .for_each(|(id, (mz, cut))| {
230 mz.read(0, cut);
231 #[cfg(feature = "io-dump")]
232 io_dump::dump(
233 &cut.as_ref(),
234 &self.params,
235 io_dump::DumpKind::BlweOut,
236 io_dump::DumpId::Slot(slot.id, id),
237 );
238 });
239 ct.push(hpu_lwe);
240 }
241
242 Ok(ct)
243 }
244
245 pub fn into_ct(self) -> Vec<HpuLweCiphertextOwned<u64>> {
248 let mut var = self;
250 loop {
251 var = match var.try_into() {
252 Ok(ct) => break ct,
253 Err(err) => match err {
254 HpuError::SyncPending(v) => v,
255 },
256 }
257 }
258 }
259
260 pub fn wait(&self) {
263 loop {
264 match self.inner.lock().unwrap().try_cpu_sync() {
265 Ok(_) => break,
266 Err(err) => match err {
267 HpuInternalError::OperationPending => {}
268 HpuInternalError::UninitData => {
269 panic!("Encounter unrecoverable HpuInternalError: {err:?}")
270 }
271 },
272 }
273 }
274 }
275
276 pub fn is_boolean(&self) -> bool {
278 self.mode == VarMode::Bool
279 }
280}