1use std::{
2 collections::{btree_map, BTreeMap, HashSet},
3 iter::Peekable,
4};
5
6use anyhow::{Error, Result};
7
8use crate::{
9 common::{crypto::hash::Hash, namespace::Namespace},
10 storage::mkvs::{self, tree::Key, Proof},
11};
12
13pub struct OverlayTree<T: mkvs::FallibleMKVS> {
19 inner: T,
20 overlay: BTreeMap<Vec<u8>, Vec<u8>>,
21 dirty: HashSet<Vec<u8>>,
22}
23
24impl<T: mkvs::FallibleMKVS> OverlayTree<T> {
25 pub fn new(inner: T) -> Self {
27 Self {
28 inner,
29 overlay: BTreeMap::new(),
30 dirty: HashSet::new(),
31 }
32 }
33
34 pub fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
36 if self.dirty.contains(key) {
38 return Ok(self.overlay.get(key).cloned());
39 }
40
41 self.inner.get(key)
43 }
44
45 pub fn get_proof(&self, key: &[u8]) -> Result<Option<Proof>> {
46 if !self.dirty.is_empty() {
47 Err(Error::msg(
48 "overlay tree proofs are not supported when there are dirty values",
49 ))?;
50 }
51
52 self.inner.get_proof(key)
53 }
54
55 pub fn insert(&mut self, key: &[u8], value: &[u8]) -> Result<Option<Vec<u8>>> {
57 let previous = self.get(key)?;
58
59 self.overlay.insert(key.to_owned(), value.to_owned());
60 self.dirty.insert(key.to_owned());
61
62 Ok(previous)
63 }
64
65 pub fn remove(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
68 if self.dirty.contains(key) {
70 return Ok(self.overlay.remove(key));
71 }
72
73 let value = self.inner.get(key)?;
74
75 if value.is_some() {
77 self.dirty.insert(key.to_owned());
78 }
79 Ok(value)
80 }
81
82 pub fn iter(&self) -> OverlayTreeIterator<T> {
84 OverlayTreeIterator::new(self)
85 }
86
87 pub fn commit(&mut self) -> Result<mkvs::WriteLog> {
89 let mut log: mkvs::WriteLog = Vec::new();
90
91 for (key, value) in &self.overlay {
93 self.inner.insert(key, value)?;
94 self.dirty.remove(key);
95
96 log.push(mkvs::LogEntry {
97 key: key.clone(),
98 value: Some(value.clone()),
99 });
100 }
101 self.overlay.clear();
102
103 for key in &self.dirty {
105 self.inner.remove(key)?;
106
107 log.push(mkvs::LogEntry {
108 key: key.clone(),
109 value: None,
110 });
111 }
112 self.dirty.clear();
113
114 Ok(log)
115 }
116
117 pub fn commit_both(
120 &mut self,
121 namespace: Namespace,
122 version: u64,
123 ) -> Result<(mkvs::WriteLog, Hash)> {
124 let write_log = self.commit()?;
126 let root_hash = self.inner.commit(namespace, version)?;
128
129 Ok((write_log, root_hash))
130 }
131}
132
133pub struct OverlayTreeIterator<'tree, T: mkvs::FallibleMKVS> {
135 tree: &'tree OverlayTree<T>,
136
137 inner: Box<dyn mkvs::Iterator + 'tree>,
138 overlay: Peekable<btree_map::Range<'tree, Vec<u8>, Vec<u8>>>,
139 overlay_valid: bool,
140
141 key: Option<Vec<u8>>,
142 value: Option<Vec<u8>>,
143}
144
145impl<'tree, T: mkvs::FallibleMKVS> OverlayTreeIterator<'tree, T> {
146 fn new(tree: &'tree OverlayTree<T>) -> Self {
147 Self {
148 tree,
149 inner: tree.inner.iter(),
150 overlay: tree.overlay.range(vec![]..).peekable(),
151 overlay_valid: true,
152 key: None,
153 value: None,
154 }
155 }
156
157 fn update_iterator_position(&mut self) {
158 loop {
160 if !self.inner.is_valid()
161 || !self
162 .tree
163 .dirty
164 .contains(self.inner.get_key().as_ref().expect("inner.is_valid"))
165 {
166 break;
167 }
168 self.inner.next();
169 }
170
171 let i_key = self.inner.get_key();
172 let o_item = self.overlay.peek();
173 self.overlay_valid = o_item.is_some();
174
175 if self.inner.is_valid()
176 && (!self.overlay_valid
177 || i_key.as_ref().expect("inner.is_valid") < o_item.expect("overlay_valid").0)
178 {
179 self.key = i_key.clone();
181 self.value = self.inner.get_value().clone();
182 } else if self.overlay_valid {
183 let (o_key, o_value) = o_item.expect("overlay_valid");
185 self.key = Some(o_key.to_vec());
186 self.value = Some(o_value.to_vec());
187 } else {
188 self.key = None;
190 self.value = None;
191 }
192 }
193
194 fn next(&mut self) {
195 if !self.overlay_valid
196 || (self.inner.is_valid()
197 && self.inner.get_key().as_ref().expect("inner.is_valid")
198 <= self.overlay.peek().expect("overlay_valid").0)
199 {
200 self.inner.next();
202 } else {
203 self.overlay.next();
205 }
206
207 self.update_iterator_position();
208 }
209}
210
211impl<'tree, T: mkvs::FallibleMKVS> Iterator for OverlayTreeIterator<'tree, T> {
212 type Item = (Vec<u8>, Vec<u8>);
213
214 fn next(&mut self) -> Option<Self::Item> {
215 use mkvs::Iterator;
216
217 if !self.is_valid() {
218 return None;
219 }
220
221 let key = self.key.as_ref().expect("iterator is valid").clone();
222 let value = self.value.as_ref().expect("iterator is valid").clone();
223 OverlayTreeIterator::next(self);
224
225 Some((key, value))
226 }
227}
228
229impl<'tree, T: mkvs::FallibleMKVS> mkvs::Iterator for OverlayTreeIterator<'tree, T> {
230 fn set_prefetch(&mut self, prefetch: usize) {
231 self.inner.set_prefetch(prefetch)
232 }
233
234 fn is_valid(&self) -> bool {
235 self.inner.is_valid() || self.overlay_valid
237 }
238
239 fn error(&self) -> &Option<Error> {
240 self.inner.error()
241 }
242
243 fn rewind(&mut self) {
244 self.seek(&[]);
245 }
246
247 fn seek(&mut self, key: &[u8]) {
248 self.inner.seek(key);
249 self.overlay = self.tree.overlay.range(key.to_vec()..).peekable();
250
251 self.update_iterator_position();
252 }
253
254 fn get_key(&self) -> &Option<Key> {
255 &self.key
256 }
257
258 fn get_value(&self) -> &Option<Vec<u8>> {
259 &self.value
260 }
261
262 fn next(&mut self) {
263 OverlayTreeIterator::next(self)
264 }
265}
266
267impl<T: mkvs::FallibleMKVS> mkvs::MKVS for OverlayTree<T> {
268 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
269 self.get(key).unwrap()
270 }
271
272 fn get_proof(&self, key: &[u8]) -> Option<Proof> {
273 self.get_proof(key).unwrap()
274 }
275
276 fn cache_contains_key(&self, key: &[u8]) -> bool {
277 if self.dirty.contains(key) {
279 return self.overlay.contains_key(key);
280 }
281 self.inner.cache_contains_key(key)
282 }
283
284 fn insert(&mut self, key: &[u8], value: &[u8]) -> Option<Vec<u8>> {
285 self.insert(key, value).unwrap()
286 }
287
288 fn remove(&mut self, key: &[u8]) -> Option<Vec<u8>> {
289 self.remove(key).unwrap()
290 }
291
292 fn prefetch_prefixes(&self, prefixes: &[mkvs::Prefix], limit: u16) {
293 self.inner.prefetch_prefixes(prefixes, limit).unwrap()
294 }
295
296 fn iter(&self) -> Box<dyn mkvs::Iterator + '_> {
297 Box::new(self.iter())
298 }
299
300 fn commit(&mut self, namespace: Namespace, version: u64) -> Result<(mkvs::WriteLog, Hash)> {
301 self.commit_both(namespace, version)
302 }
303}
304
305#[cfg(test)]
306mod test {
307 use super::*;
308 use crate::storage::mkvs::{
309 sync::NoopReadSyncer, tree::iterator::test::test_iterator_with, RootType, Tree,
310 };
311
312 #[test]
313 fn test_overlay() {
314 let mut tree = Tree::builder()
315 .with_root_type(RootType::State)
316 .build(Box::new(NoopReadSyncer));
317
318 let items = vec![
320 (b"key".to_vec(), b"first".to_vec()),
321 (b"key 1".to_vec(), b"one".to_vec()),
322 (b"key 2".to_vec(), b"two".to_vec()),
323 (b"key 5".to_vec(), b"five".to_vec()),
324 (b"key 8".to_vec(), b"eight".to_vec()),
325 (b"key 9".to_vec(), b"nine".to_vec()),
326 ];
327
328 let tests = vec![
329 (b"k".to_vec(), 0),
330 (b"key 1".to_vec(), 1),
331 (b"key 3".to_vec(), 3),
332 (b"key 4".to_vec(), 3),
333 (b"key 5".to_vec(), 3),
334 (b"key 6".to_vec(), 4),
335 (b"key 7".to_vec(), 4),
336 (b"key 8".to_vec(), 4),
337 (b"key 9".to_vec(), 5),
338 (b"key A".to_vec(), -1),
339 ];
340
341 let mut overlay = OverlayTree::new(&mut tree);
343 for (key, value) in items.iter() {
344 overlay.insert(key, value).unwrap();
345 }
346
347 let it = overlay.iter();
349 test_iterator_with(&items, it, &tests);
350
351 for (key, value) in items.iter() {
353 tree.insert(key, value).unwrap();
354 }
355
356 let tree_ref = &tree as *const Tree;
358
359 let mut overlay = OverlayTree::new(&mut tree);
361
362 for (k, expected_v) in &items {
364 let v = overlay.get(&k).unwrap();
365 assert_eq!(v.as_ref(), Some(expected_v));
366 }
367
368 let it = overlay.iter();
371 test_iterator_with(&items, it, &tests);
372
373 overlay.remove(b"key 2").unwrap();
375 overlay.insert(b"key 7", b"seven").unwrap();
376 overlay.remove(b"key 5").unwrap();
377 overlay.insert(b"key 5", b"fivey").unwrap();
378
379 unsafe {
382 let tree_ref = &*tree_ref;
383
384 let value = tree_ref.get(b"key 2").unwrap();
385 assert_eq!(
386 value,
387 Some(b"two".to_vec()),
388 "value in inner tree should be unchanged"
389 );
390 let value = tree_ref.get(b"key 7").unwrap();
391 assert_eq!(value, None, "value should not exist in inner tree");
392 }
393
394 let items = vec![
396 (b"key".to_vec(), b"first".to_vec()),
397 (b"key 1".to_vec(), b"one".to_vec()),
398 (b"key 5".to_vec(), b"fivey".to_vec()),
399 (b"key 7".to_vec(), b"seven".to_vec()),
400 (b"key 8".to_vec(), b"eight".to_vec()),
401 (b"key 9".to_vec(), b"nine".to_vec()),
402 ];
403
404 let tests = vec![
405 (b"k".to_vec(), 0),
406 (b"key 1".to_vec(), 1),
407 (b"key 3".to_vec(), 2),
408 (b"key 4".to_vec(), 2),
409 (b"key 5".to_vec(), 2),
410 (b"key 6".to_vec(), 3),
411 (b"key 7".to_vec(), 3),
412 (b"key 8".to_vec(), 4),
413 (b"key 9".to_vec(), 5),
414 (b"key A".to_vec(), -1),
415 ];
416
417 for (k, expected_v) in &items {
419 let v = overlay.get(&k).unwrap();
420 assert_eq!(v.as_ref(), Some(expected_v));
421 }
422
423 let it = overlay.iter();
425 test_iterator_with(&items, it, &tests);
426
427 overlay.commit().unwrap();
429
430 for (k, expected_v) in &items {
432 let v = tree.get(&k).unwrap();
433 assert_eq!(v.as_ref(), Some(expected_v));
434 }
435
436 let it = tree.iter();
438 test_iterator_with(&items, it, &tests);
439 }
440}