use std::collections::BTreeMap;
use crate::types::{address::Address, token};
#[derive(Clone, Default, Debug)]
pub struct FeeManager {
tx_fee: Option<TransactionFee>,
block_fees: BTreeMap<token::Denomination, u128>,
}
#[derive(Clone, Default, Debug)]
pub struct TransactionFee {
payer: Address,
denomination: token::Denomination,
charged: u128,
refunded: u128,
}
impl TransactionFee {
pub fn denomination(&self) -> token::Denomination {
self.denomination.clone()
}
pub fn amount(&self) -> u128 {
self.charged.saturating_sub(self.refunded)
}
pub fn payer(&self) -> Address {
self.payer
}
}
impl FeeManager {
pub fn new() -> Self {
Self::default()
}
pub fn tx_fee(&self) -> Option<&TransactionFee> {
self.tx_fee.as_ref()
}
pub fn record_fee(&mut self, payer: Address, amount: &token::BaseUnits) {
let tx_fee = self.tx_fee.get_or_insert_with(|| TransactionFee {
payer,
denomination: amount.denomination().clone(),
..Default::default()
});
assert!(payer == tx_fee.payer, "transaction fee payer cannot change");
assert!(
amount.denomination() == &tx_fee.denomination,
"transaction fee denomination cannot change"
);
tx_fee.charged = tx_fee
.charged
.checked_add(amount.amount())
.expect("should never overflow");
}
pub fn record_refund(&mut self, amount: u128) {
if amount == 0 || self.tx_fee.is_none() {
return;
}
let tx_fee = self.tx_fee.as_mut().unwrap();
tx_fee.refunded = std::cmp::min(tx_fee.refunded.saturating_add(amount), tx_fee.charged);
}
#[must_use = "fee updates should be applied after calling commit"]
pub fn commit_tx(&mut self) -> FeeUpdates {
let tx_fee = self.tx_fee.take().unwrap_or_default();
if tx_fee.amount() > 0 {
let block_fees = self
.block_fees
.entry(tx_fee.denomination.clone())
.or_default();
*block_fees = block_fees
.checked_add(tx_fee.amount())
.expect("should never overflow");
}
FeeUpdates {
payer: tx_fee.payer,
refund: token::BaseUnits::new(tx_fee.refunded, tx_fee.denomination),
}
}
#[must_use = "accumulated fees should be applied after calling commit"]
pub fn commit_block(self) -> BTreeMap<token::Denomination, u128> {
self.block_fees
}
}
pub struct FeeUpdates {
pub payer: Address,
pub refund: token::BaseUnits,
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
testing::keys,
types::token::{self, Denomination},
};
#[test]
fn test_basic_refund() {
let mut mgr = FeeManager::new();
assert!(mgr.tx_fee().is_none());
let fee = token::BaseUnits::new(1_000_000, Denomination::NATIVE);
mgr.record_fee(keys::alice::address(), &fee);
let tx_fee = mgr.tx_fee().expect("tx_fee should be set");
assert_eq!(tx_fee.payer(), keys::alice::address());
assert_eq!(&tx_fee.denomination(), fee.denomination());
assert_eq!(tx_fee.amount(), fee.amount());
mgr.record_refund(400_000);
let tx_fee = mgr.tx_fee().expect("tx_fee should be set");
assert_eq!(tx_fee.payer(), keys::alice::address());
assert_eq!(&tx_fee.denomination(), fee.denomination());
assert_eq!(tx_fee.amount(), 600_000, "should take refund into account");
let fee_updates = mgr.commit_tx();
assert_eq!(fee_updates.payer, keys::alice::address());
assert_eq!(
fee_updates.refund,
token::BaseUnits::new(400_000, Denomination::NATIVE)
);
assert!(mgr.tx_fee().is_none());
mgr.record_fee(
keys::bob::address(),
&token::BaseUnits::new(50_000, Denomination::NATIVE),
);
let fee_updates = mgr.commit_tx();
assert_eq!(fee_updates.payer, keys::bob::address());
assert_eq!(
fee_updates.refund,
token::BaseUnits::new(0, Denomination::NATIVE)
);
mgr.record_fee(
keys::dave::address(),
&token::BaseUnits::new(25_000, "TEST".parse().unwrap()),
);
mgr.record_fee(
keys::dave::address(),
&token::BaseUnits::new(5_000, "TEST".parse().unwrap()),
);
let fee_updates = mgr.commit_tx();
assert_eq!(fee_updates.payer, keys::dave::address());
assert_eq!(
fee_updates.refund,
token::BaseUnits::new(0, "TEST".parse().unwrap())
);
let block_fees = mgr.commit_block();
assert_eq!(block_fees.len(), 2);
assert_eq!(block_fees[&Denomination::NATIVE], 650_000);
assert_eq!(block_fees[&"TEST".parse().unwrap()], 30_000);
}
#[test]
fn test_refund_without_charge() {
let mut mgr = FeeManager::new();
mgr.record_refund(1_000);
assert!(
mgr.tx_fee().is_none(),
"refund should not be recorded if no charge"
);
let fee_updates = mgr.commit_tx();
assert_eq!(fee_updates.payer, Default::default());
assert_eq!(
fee_updates.refund,
token::BaseUnits::new(0, Default::default())
);
let block_fees = mgr.commit_block();
assert!(block_fees.is_empty(), "there should be no recorded fees");
}
#[test]
#[should_panic(expected = "transaction fee payer cannot change")]
fn test_fail_payer_change() {
let mut mgr = FeeManager::new();
let fee = token::BaseUnits::new(1_000_000, Denomination::NATIVE);
mgr.record_fee(keys::alice::address(), &fee);
mgr.record_fee(keys::bob::address(), &fee); }
#[test]
#[should_panic(expected = "transaction fee denomination cannot change")]
fn test_fail_denomination_change() {
let mut mgr = FeeManager::new();
let fee = token::BaseUnits::new(1_000_000, Denomination::NATIVE);
mgr.record_fee(keys::alice::address(), &fee);
let fee = token::BaseUnits::new(1_000_000, "TEST".parse().unwrap());
mgr.record_fee(keys::alice::address(), &fee); }
}