oasis_runtime_sdk/modules/rewards/
types.rs

1//! Rewards module types.
2use std::collections::BTreeMap;
3
4use thiserror::Error;
5
6use crate::{
7    core::consensus::beacon,
8    types::{address::Address, token},
9};
10
11/// One of the time periods in the reward schedule.
12#[derive(Clone, Debug, Default, cbor::Encode, cbor::Decode)]
13pub struct RewardStep {
14    pub until: beacon::EpochTime,
15    pub amount: token::BaseUnits,
16}
17
18/// A reward schedule.
19#[derive(Clone, Debug, Default, cbor::Encode, cbor::Decode)]
20pub struct RewardSchedule {
21    pub steps: Vec<RewardStep>,
22}
23
24/// Errors emitted during reward schedule validation.
25#[derive(Error, Debug)]
26pub enum RewardScheduleError {
27    #[error("steps not sorted correctly")]
28    StepsNotSorted,
29}
30
31impl RewardSchedule {
32    /// Perform basic reward schedule validation.
33    pub fn validate_basic(&self) -> Result<(), RewardScheduleError> {
34        let mut last_epoch = Default::default();
35        for step in &self.steps {
36            if step.until <= last_epoch {
37                return Err(RewardScheduleError::StepsNotSorted);
38            }
39            last_epoch = step.until;
40        }
41        Ok(())
42    }
43
44    /// Compute the per-entity reward amount for the given epoch based on the schedule.
45    pub fn for_epoch(&self, epoch: beacon::EpochTime) -> token::BaseUnits {
46        for step in &self.steps {
47            if epoch < step.until {
48                return step.amount.clone();
49            }
50        }
51
52        // End of the schedule, default to no rewards.
53        Default::default()
54    }
55}
56
57/// Action that should be taken for a given address when disbursing rewards.
58#[derive(Clone, Debug, PartialEq, Eq)]
59pub enum RewardAction {
60    Reward(u64),
61    NoReward,
62}
63
64impl RewardAction {
65    /// Increment the reward counter associated with the reward.
66    ///
67    /// In case the action is `NoReward` nothing is changed.
68    pub fn increment(&mut self) {
69        match self {
70            RewardAction::Reward(ref mut v) => *v += 1,
71            RewardAction::NoReward => {
72                // Do not change state as the entity has been penalized for the epoch.
73            }
74        }
75    }
76
77    /// Forbids any rewards from accumulating.
78    pub fn forbid(&mut self) {
79        *self = RewardAction::NoReward;
80    }
81
82    /// Value of the reward counter.
83    pub fn value(&self) -> u64 {
84        match self {
85            RewardAction::Reward(v) => *v,
86            RewardAction::NoReward => 0,
87        }
88    }
89}
90
91impl Default for RewardAction {
92    fn default() -> Self {
93        RewardAction::Reward(0)
94    }
95}
96
97impl cbor::Encode for RewardAction {
98    fn into_cbor_value(self) -> cbor::Value {
99        match self {
100            Self::Reward(r) => cbor::Value::Unsigned(r),
101            Self::NoReward => cbor::Value::Simple(cbor::SimpleValue::NullValue),
102        }
103    }
104}
105
106impl cbor::Decode for RewardAction {
107    fn try_default() -> Result<Self, cbor::DecodeError> {
108        Ok(Self::NoReward)
109    }
110
111    fn try_from_cbor_value(value: cbor::Value) -> Result<Self, cbor::DecodeError> {
112        match value {
113            cbor::Value::Unsigned(v) => Ok(Self::Reward(v)),
114            cbor::Value::Simple(cbor::SimpleValue::NullValue) => Ok(Self::NoReward),
115            _ => Err(cbor::DecodeError::UnexpectedType),
116        }
117    }
118}
119
120/// Rewards for the epoch.
121#[derive(Clone, Debug, Default, cbor::Encode, cbor::Decode)]
122pub struct EpochRewards {
123    pub pending: BTreeMap<Address, RewardAction>,
124}
125
126impl EpochRewards {
127    /// Returns an iterator over addresses that should be rewarded.
128    pub fn for_disbursement(
129        &self,
130        threshold_numerator: u64,
131        threshold_denominator: u64,
132    ) -> impl Iterator<Item = Address> + '_ {
133        let max_v = self
134            .pending
135            .iter()
136            .fold(0, |acc, (_, action)| std::cmp::max(acc, action.value()));
137
138        let (_, overflow) = threshold_numerator.overflowing_mul(max_v);
139        let threshold = if overflow {
140            max_v
141                .checked_div(threshold_denominator)
142                .unwrap_or(0)
143                .saturating_mul(threshold_numerator)
144        } else {
145            threshold_numerator
146                .saturating_mul(max_v)
147                .checked_div(threshold_denominator)
148                .unwrap_or(0)
149        };
150
151        self.pending
152            .iter()
153            .filter_map(move |(address, action)| match action {
154                RewardAction::Reward(v) => {
155                    if *v < threshold {
156                        None
157                    } else {
158                        Some(*address)
159                    }
160                }
161                RewardAction::NoReward => None,
162            })
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use crate::testing::keys;
169
170    use super::*;
171
172    #[test]
173    fn test_reward_action() {
174        let mut act = RewardAction::default();
175        act.increment();
176        act.increment();
177        act.increment();
178
179        assert!(matches!(act, RewardAction::Reward(3)));
180
181        act.forbid();
182
183        act.increment();
184        act.increment();
185
186        assert!(matches!(act, RewardAction::NoReward));
187    }
188
189    #[test]
190    fn test_reward_action_serialization() {
191        let actions = vec![
192            RewardAction::Reward(0),
193            RewardAction::Reward(42),
194            RewardAction::NoReward,
195        ];
196        for act in actions {
197            let encoded = &cbor::to_vec(act.clone());
198            let round_trip: RewardAction =
199                cbor::from_slice(encoded).expect("round-trip should succeed");
200            assert_eq!(round_trip, act, "reward actions should round-trip");
201        }
202    }
203
204    #[test]
205    fn test_reward_schedule_validation_fail_1() {
206        let schedule = RewardSchedule {
207            steps: vec![
208                RewardStep {
209                    until: 10,
210                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
211                },
212                RewardStep {
213                    until: 10,
214                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
215                },
216                RewardStep {
217                    until: 15,
218                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
219                },
220            ],
221        };
222        schedule
223            .validate_basic()
224            .expect_err("validation with duplicate steps should fail");
225    }
226
227    #[test]
228    fn test_reward_schedule_validation_fail_2() {
229        let schedule = RewardSchedule {
230            steps: vec![
231                RewardStep {
232                    until: 10,
233                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
234                },
235                RewardStep {
236                    until: 5,
237                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
238                },
239                RewardStep {
240                    until: 15,
241                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
242                },
243            ],
244        };
245        schedule
246            .validate_basic()
247            .expect_err("validation with unsorted steps should fail");
248    }
249
250    #[test]
251    fn test_reward_schedule_validation_ok() {
252        let schedule = RewardSchedule {
253            steps: vec![
254                RewardStep {
255                    until: 5,
256                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
257                },
258                RewardStep {
259                    until: 10,
260                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
261                },
262                RewardStep {
263                    until: 15,
264                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
265                },
266            ],
267        };
268        schedule
269            .validate_basic()
270            .expect("validation of correct schedule should not fail");
271    }
272
273    #[test]
274    fn test_reward_schedule() {
275        let schedule = RewardSchedule {
276            steps: vec![
277                RewardStep {
278                    until: 5,
279                    amount: token::BaseUnits::new(3000, token::Denomination::NATIVE),
280                },
281                RewardStep {
282                    until: 10,
283                    amount: token::BaseUnits::new(2000, token::Denomination::NATIVE),
284                },
285                RewardStep {
286                    until: 15,
287                    amount: token::BaseUnits::new(1000, token::Denomination::NATIVE),
288                },
289            ],
290        };
291
292        assert_eq!(schedule.for_epoch(1).amount(), 3000);
293        assert_eq!(schedule.for_epoch(3).amount(), 3000);
294        assert_eq!(schedule.for_epoch(5).amount(), 2000);
295        assert_eq!(schedule.for_epoch(6).amount(), 2000);
296        assert_eq!(schedule.for_epoch(9).amount(), 2000);
297        assert_eq!(schedule.for_epoch(10).amount(), 1000);
298        assert_eq!(schedule.for_epoch(14).amount(), 1000);
299        assert_eq!(schedule.for_epoch(15).amount(), 0);
300        assert_eq!(schedule.for_epoch(20).amount(), 0);
301        assert_eq!(schedule.for_epoch(100).amount(), 0);
302    }
303
304    #[test]
305    fn test_epoch_rewards() {
306        let epoch_rewards = EpochRewards {
307            pending: {
308                let mut pending = BTreeMap::new();
309                pending.insert(keys::alice::address(), RewardAction::Reward(10));
310                pending.insert(keys::bob::address(), RewardAction::NoReward);
311                pending.insert(keys::charlie::address(), RewardAction::Reward(5));
312                pending
313            },
314        };
315
316        // Alice and Charlie have >= 0.
317        let rewards: Vec<_> = epoch_rewards.for_disbursement(0, 0).collect();
318        assert_eq!(
319            rewards,
320            vec![keys::charlie::address(), keys::alice::address()]
321        );
322        // Alice and Charlie have >= 0.
323        let rewards: Vec<_> = epoch_rewards.for_disbursement(0, 0).collect();
324        assert_eq!(
325            rewards,
326            vec![keys::charlie::address(), keys::alice::address()]
327        );
328        // Only Alice has >= 7.5.
329        let rewards: Vec<_> = epoch_rewards.for_disbursement(3, 4).collect();
330        assert_eq!(rewards, vec![keys::alice::address()]);
331    }
332
333    #[test]
334    fn test_epoch_rewards_overflow() {
335        let epoch_rewards = EpochRewards {
336            pending: {
337                let mut pending = BTreeMap::new();
338                pending.insert(keys::alice::address(), RewardAction::Reward(u64::MAX));
339                pending.insert(keys::charlie::address(), RewardAction::Reward(u64::MAX / 2));
340                pending
341            },
342        };
343
344        // Alice and Charlie have >= 0.
345        let rewards: Vec<_> = epoch_rewards.for_disbursement(0, 0).collect();
346        assert_eq!(
347            rewards,
348            vec![keys::charlie::address(), keys::alice::address()]
349        );
350        // Alice and Charlie have >= 1/2.
351        let rewards: Vec<_> = epoch_rewards.for_disbursement(1, 2).collect();
352        assert_eq!(
353            rewards,
354            vec![keys::charlie::address(), keys::alice::address()]
355        );
356        // Only Alice has >= 3/4, but due to overflow both will be counted.
357        let rewards: Vec<_> = epoch_rewards.for_disbursement(3, 4).collect();
358        assert_eq!(rewards, vec![keys::alice::address()]);
359    }
360}