oasis_runtime_sdk_contracts/abi/
gas.rs

1//! Gas metering instrumentation.
2use std::collections::BTreeMap;
3
4use walrus::{ir::*, FunctionBuilder, GlobalId, LocalFunction, Module};
5
6use crate::Error;
7
8/// Name of the exported global that holds the gas limit.
9pub const EXPORT_GAS_LIMIT: &str = "gas_limit";
10/// Name of the exported global that holds the gas limit exhausted flag.
11pub const EXPORT_GAS_LIMIT_EXHAUSTED: &str = "gas_limit_exhausted";
12
13/// Configures the gas limit on the given instance.
14pub fn set_gas_limit<C>(
15    instance: &wasm3::Instance<'_, '_, C>,
16    gas_limit: u64,
17) -> Result<(), Error> {
18    instance
19        .set_global(EXPORT_GAS_LIMIT, gas_limit)
20        .map_err(|err| Error::ExecutionFailed(err.into()))
21}
22
23/// Returns the remaining gas.
24pub fn get_remaining_gas<C>(instance: &wasm3::Instance<'_, '_, C>) -> u64 {
25    instance.get_global(EXPORT_GAS_LIMIT).unwrap_or_default()
26}
27
28/// Returns the amount of gas requested that was over the limit.
29pub fn get_exhausted_amount<C>(instance: &wasm3::Instance<'_, '_, C>) -> u64 {
30    instance
31        .get_global(EXPORT_GAS_LIMIT_EXHAUSTED)
32        .unwrap_or_default()
33}
34
35/// Attempts to use the given amount of gas.
36pub fn use_gas<C>(instance: &wasm3::Instance<'_, '_, C>, amount: u64) -> Result<(), wasm3::Trap> {
37    let gas_limit: u64 = instance
38        .get_global(EXPORT_GAS_LIMIT)
39        .map_err(|_| wasm3::Trap::Abort)?;
40    if gas_limit < amount {
41        let _ = instance.set_global(EXPORT_GAS_LIMIT_EXHAUSTED, amount);
42        return Err(wasm3::Trap::Abort);
43    }
44    instance
45        .set_global(EXPORT_GAS_LIMIT, gas_limit - amount)
46        .map_err(|_| wasm3::Trap::Abort)?;
47    Ok(())
48}
49
50/// Inject gas metering instrumentation into the module.
51pub fn transform(module: &mut Module) {
52    let gas_limit_global = module.globals.add_local(
53        walrus::ValType::I64,
54        true,
55        walrus::InitExpr::Value(Value::I64(0)),
56    );
57    let gas_limit_exhausted_global = module.globals.add_local(
58        walrus::ValType::I64,
59        true,
60        walrus::InitExpr::Value(Value::I64(0)),
61    );
62    module.exports.add(EXPORT_GAS_LIMIT, gas_limit_global);
63    module
64        .exports
65        .add(EXPORT_GAS_LIMIT_EXHAUSTED, gas_limit_exhausted_global);
66
67    for (_, func) in module.funcs.iter_local_mut() {
68        transform_function(func, gas_limit_global, gas_limit_exhausted_global);
69    }
70}
71
72/// Instruction cost function.
73fn instruction_cost(instr: &Instr) -> u64 {
74    match instr {
75        Instr::Loop(_) | Instr::Block(_) => 2,
76
77        Instr::LocalGet(_)
78        | Instr::LocalSet(_)
79        | Instr::LocalTee(_)
80        | Instr::GlobalGet(_)
81        | Instr::GlobalSet(_)
82        | Instr::Const(_) => 1,
83
84        Instr::Call(_) => 15,
85
86        Instr::CallIndirect(_) => 20,
87
88        Instr::Br(_) | Instr::BrIf(_) => 3,
89        Instr::BrTable(_) => 4,
90
91        Instr::Binop(op) => match op.op {
92            BinaryOp::I32Eq
93            | BinaryOp::I32Ne
94            | BinaryOp::I32LtS
95            | BinaryOp::I32LtU
96            | BinaryOp::I32GtS
97            | BinaryOp::I32GtU
98            | BinaryOp::I32LeS
99            | BinaryOp::I32LeU
100            | BinaryOp::I32GeS
101            | BinaryOp::I32GeU
102            | BinaryOp::I64Eq
103            | BinaryOp::I64Ne
104            | BinaryOp::I64LtS
105            | BinaryOp::I64LtU
106            | BinaryOp::I64GtS
107            | BinaryOp::I64GtU
108            | BinaryOp::I64LeS
109            | BinaryOp::I64LeU
110            | BinaryOp::I64GeS
111            | BinaryOp::I64GeU
112            | BinaryOp::I32Add
113            | BinaryOp::I32Sub
114            | BinaryOp::I32Mul
115            | BinaryOp::I32And
116            | BinaryOp::I32Or
117            | BinaryOp::I32Xor
118            | BinaryOp::I32Shl
119            | BinaryOp::I32ShrS
120            | BinaryOp::I32ShrU
121            | BinaryOp::I32Rotl
122            | BinaryOp::I32Rotr
123            | BinaryOp::I64Add
124            | BinaryOp::I64Sub
125            | BinaryOp::I64Mul
126            | BinaryOp::I64And
127            | BinaryOp::I64Or
128            | BinaryOp::I64Xor
129            | BinaryOp::I64Shl
130            | BinaryOp::I64ShrS
131            | BinaryOp::I64ShrU
132            | BinaryOp::I64Rotl
133            | BinaryOp::I64Rotr => 1,
134            BinaryOp::I32DivS
135            | BinaryOp::I32DivU
136            | BinaryOp::I32RemS
137            | BinaryOp::I32RemU
138            | BinaryOp::I64DivS
139            | BinaryOp::I64DivU
140            | BinaryOp::I64RemS
141            | BinaryOp::I64RemU => 4,
142            _ => 3,
143        },
144
145        Instr::Unop(op) => match op.op {
146            UnaryOp::I32Eqz
147            | UnaryOp::I32Clz
148            | UnaryOp::I32Ctz
149            | UnaryOp::I32Popcnt
150            | UnaryOp::I64Eqz
151            | UnaryOp::I64Clz
152            | UnaryOp::I64Ctz
153            | UnaryOp::I64Popcnt => 1,
154
155            _ => 3,
156        },
157
158        _ => 10,
159    }
160}
161
162/// A block of instructions which is metered.
163#[derive(Debug)]
164struct MeteredBlock {
165    /// Instruction sequence where metering code should be injected.
166    seq_id: InstrSeqId,
167    /// Start index of instruction within the instruction sequence before which the metering code
168    /// should be injected.
169    start_index: usize,
170    /// Instruction cost.
171    cost: u64,
172    /// Indication of whether the metered block can be merged in case instruction sequence and start
173    /// index match. In case the block cannot be merged this contains the index
174    merge_index: Option<usize>,
175}
176
177impl MeteredBlock {
178    fn new(seq_id: InstrSeqId, start_index: usize) -> Self {
179        Self {
180            seq_id,
181            start_index,
182            cost: 0,
183            merge_index: None,
184        }
185    }
186
187    /// Create a mergable version of this metered block with the given start index.
188    fn mergable(&self, start_index: usize) -> Self {
189        Self {
190            seq_id: self.seq_id,
191            start_index,
192            cost: 0,
193            merge_index: Some(self.start_index),
194        }
195    }
196}
197
198/// A map of finalized metered blocks.
199#[derive(Default)]
200struct MeteredBlocks {
201    blocks: BTreeMap<InstrSeqId, Vec<MeteredBlock>>,
202}
203
204impl MeteredBlocks {
205    /// Finalize the given metered block. This means that the cost associated with the block cannot
206    /// change anymore.
207    fn finalize(&mut self, block: MeteredBlock) {
208        if block.cost > 0 {
209            self.blocks.entry(block.seq_id).or_default().push(block);
210        }
211    }
212}
213
214fn determine_metered_blocks(func: &LocalFunction) -> BTreeMap<InstrSeqId, Vec<MeteredBlock>> {
215    // NOTE: This is based on walrus::ir::dfs_in_order but we need more information.
216
217    let mut blocks = MeteredBlocks::default();
218    let mut stack: Vec<(InstrSeqId, usize, MeteredBlock)> = vec![(
219        func.entry_block(),                       // Initial instruction sequence to visit.
220        0,                                        // Instruction offset within the sequence.
221        MeteredBlock::new(func.entry_block(), 0), // Initial metered block.
222    )];
223
224    'traversing_blocks: while let Some((seq_id, index, mut metered_block)) = stack.pop() {
225        let seq = func.block(seq_id);
226
227        'traversing_instrs: for (index, (instr, _)) in seq.instrs.iter().enumerate().skip(index) {
228            // NOTE: Current instruction is always included in the current metered block.
229            metered_block.cost += instruction_cost(instr);
230
231            // Determine whether we need to end/start a metered block.
232            match instr {
233                Instr::Block(Block { seq }) => {
234                    // Do not start a new metered block as blocks are unconditional and metered
235                    // blocks can encompass many of them to avoid injecting unnecessary
236                    // instructions.
237                    stack.push((seq_id, index + 1, metered_block.mergable(index + 1)));
238                    stack.push((*seq, 0, metered_block));
239                    continue 'traversing_blocks;
240                }
241
242                Instr::Loop(Loop { seq }) => {
243                    // Finalize current metered block.
244                    blocks.finalize(metered_block);
245                    // Start a new metered block for remainder of block.
246                    stack.push((seq_id, index + 1, MeteredBlock::new(seq_id, index + 1)));
247                    // Start a new metered block for loop body.
248                    stack.push((*seq, 0, MeteredBlock::new(*seq, 0)));
249                    continue 'traversing_blocks;
250                }
251
252                Instr::IfElse(IfElse {
253                    consequent,
254                    alternative,
255                }) => {
256                    // Finalize current metered block.
257                    blocks.finalize(metered_block);
258
259                    // Start a new metered block for remainder of block.
260                    stack.push((seq_id, index + 1, MeteredBlock::new(seq_id, index + 1)));
261                    // Start new metered blocks for alternative and consequent blocks.
262                    stack.push((*alternative, 0, MeteredBlock::new(*alternative, 0)));
263                    stack.push((*consequent, 0, MeteredBlock::new(*consequent, 0)));
264                    continue 'traversing_blocks;
265                }
266
267                Instr::Call(_)
268                | Instr::CallIndirect(_)
269                | Instr::Br(_)
270                | Instr::BrIf(_)
271                | Instr::BrTable(_)
272                | Instr::Return(_) => {
273                    // Finalize current metered block and start a new one for the remainder.
274                    blocks.finalize(std::mem::replace(
275                        &mut metered_block,
276                        MeteredBlock::new(seq_id, index + 1),
277                    ));
278                    continue 'traversing_instrs;
279                }
280
281                _ => continue 'traversing_instrs,
282            }
283        }
284
285        // Check if we can merge the blocks.
286        if let Some((_, _, upper)) = stack.last_mut() {
287            match upper.merge_index {
288                Some(index)
289                    if upper.seq_id == metered_block.seq_id
290                        && index == metered_block.start_index =>
291                {
292                    // Blocks can be merged, so overwrite upper.
293                    *upper = metered_block;
294                    continue 'traversing_blocks;
295                }
296                _ => {
297                    // Blocks cannot be merged so treat as new block.
298                }
299            }
300        }
301
302        blocks.finalize(metered_block);
303    }
304
305    blocks.blocks
306}
307
308fn transform_function(
309    func: &mut LocalFunction,
310    gas_limit_global: GlobalId,
311    gas_limit_exhausted_global: GlobalId,
312) {
313    // First pass: determine where metering instructions should be injected.
314    let blocks = determine_metered_blocks(func);
315
316    // Second pass: actually emit metering instructions in correct positions.
317    let builder = func.builder_mut();
318    for (seq_id, blocks) in blocks {
319        let mut seq = builder.instr_seq(seq_id);
320        let instrs = seq.instrs_mut();
321
322        let original_instrs = std::mem::take(instrs);
323        let new_instrs_len = instrs.len() + METERING_INSTRUCTION_COUNT * blocks.len();
324        let mut new_instrs = Vec::with_capacity(new_instrs_len);
325
326        let mut block_iter = blocks.into_iter().peekable();
327        for (index, (instr, loc)) in original_instrs.into_iter().enumerate() {
328            match block_iter.peek() {
329                Some(block) if block.start_index == index => {
330                    inject_metering(
331                        builder,
332                        &mut new_instrs,
333                        block_iter.next().unwrap(),
334                        gas_limit_global,
335                        gas_limit_exhausted_global,
336                    );
337                }
338                _ => {}
339            }
340
341            // Push original instruction.
342            new_instrs.push((instr, loc));
343        }
344
345        let mut seq = builder.instr_seq(seq_id);
346        let instrs = seq.instrs_mut();
347        *instrs = new_instrs;
348    }
349}
350
351/// Number of injected metering instructions (needed to calculate final instruction size).
352const METERING_INSTRUCTION_COUNT: usize = 8;
353
354fn inject_metering(
355    builder: &mut FunctionBuilder,
356    instrs: &mut Vec<(Instr, InstrLocId)>,
357    block: MeteredBlock,
358    gas_limit_global: GlobalId,
359    gas_limit_exhausted_global: GlobalId,
360) {
361    let mut builder = builder.dangling_instr_seq(None);
362    let seq = builder
363        // if unsigned(globals[gas_limit]) < unsigned(block.cost) { throw(); }
364        .global_get(gas_limit_global)
365        .i64_const(block.cost as i64)
366        .binop(BinaryOp::I64LtU)
367        .if_else(
368            None,
369            |then| {
370                then.i64_const(block.cost as i64)
371                    .global_set(gas_limit_exhausted_global)
372                    .unreachable();
373            },
374            |_else| {},
375        )
376        // globals[gas_limit] -= block.cost;
377        .global_get(gas_limit_global)
378        .i64_const(block.cost as i64)
379        .binop(BinaryOp::I64Sub)
380        .global_set(gas_limit_global);
381
382    instrs.append(seq.instrs_mut());
383}
384
385#[cfg(test)]
386mod test {
387    use pretty_assertions::assert_eq;
388
389    macro_rules! test_transform {
390        (name = $name:ident, source = $src:expr, expected = $expected:expr) => {
391            #[test]
392            fn $name() {
393                let src = wat::parse_str($src).unwrap();
394                let expected = wat::parse_str($expected).unwrap();
395
396                let mut result_module = walrus::ModuleConfig::new()
397                    .generate_producers_section(false)
398                    .parse(&src)
399                    .unwrap();
400
401                super::transform(&mut result_module);
402
403                let mut expected_module = walrus::ModuleConfig::new()
404                    .generate_producers_section(false)
405                    .parse(&expected)
406                    .unwrap();
407
408                let result_wasm = result_module.emit_wasm();
409                let expected_wasm = expected_module.emit_wasm();
410                let result = wasmprinter::print_bytes(&result_wasm).unwrap();
411                let expected = wasmprinter::print_bytes(&expected_wasm).unwrap();
412
413                assert_eq!(result, expected);
414            }
415        };
416    }
417
418    test_transform! {
419        name = simple,
420        source = r#"
421        (module
422            (func (result i32)
423                (i32.const 1)))
424        "#,
425        expected = r#"
426        (module
427            (func (result i32)
428                (if
429                    (i64.lt_u
430                        (global.get 0)
431                        (i64.const 1))
432                    (then
433                        (global.set 1
434                            (i64.const 1))
435                        (unreachable)))
436                (global.set 0
437                    (i64.sub
438                        (global.get 0)
439                        (i64.const 1)))
440                (i32.const 1))
441            (global (;0;) (mut i64) (i64.const 0))
442            (global (;1;) (mut i64) (i64.const 0))
443            (export "gas_limit" (global 0))
444            (export "gas_limit_exhausted" (global 1)))
445        "#
446    }
447
448    test_transform! {
449        name = nested_blocks,
450        source = r#"
451        (module
452            (func (result i32)
453                (block
454                    (block
455                        (block
456                            (i32.const 1)
457                            (drop))))
458                (i32.const 1)))
459        "#,
460        expected = r#"
461        (module
462            (func (result i32)
463                (if
464                    (i64.lt_u
465                        (global.get 0)
466                        (i64.const 18))
467                    (then
468                        (global.set 1
469                            (i64.const 18))
470                        (unreachable)))
471                (global.set 0
472                    (i64.sub
473                        (global.get 0)
474                        (i64.const 18)))
475                (block
476                    (block
477                        (block
478                            (i32.const 1)
479                            (drop))))
480                (i32.const 1))
481            (global (;0;) (mut i64) (i64.const 0))
482            (global (;1;) (mut i64) (i64.const 0))
483            (export "gas_limit" (global 0))
484            (export "gas_limit_exhausted" (global 1)))
485        "#
486    }
487
488    test_transform! {
489        name = nested_blocks_with_loop,
490        source = r#"
491        (module
492            (func (result i32)
493                (block
494                    (block
495                        (block
496                            (i32.const 1)
497                            (drop))
498                        (loop
499                            (i32.const 1)
500                            (drop)
501                            (i32.const 1)
502                            (drop)
503                            (br 0))))
504                (i32.const 1)))
505        "#,
506        expected = r#"
507        (module
508            (func (result i32)
509                (if
510                    (i64.lt_u
511                        (global.get 0)
512                        (i64.const 19))
513                    (then
514                        (global.set 1
515                            (i64.const 19))
516                        (unreachable)))
517                (global.set 0
518                    (i64.sub
519                        (global.get 0)
520                        (i64.const 19)))
521                (block
522                    (block
523                        (block
524                            (i32.const 1)
525                            (drop))
526                        (loop
527                            (if
528                                (i64.lt_u
529                                    (global.get 0)
530                                    (i64.const 25))
531                                (then
532                                    (global.set 1
533                                        (i64.const 25))
534                                    (unreachable)))
535                            (global.set 0
536                                (i64.sub
537                                    (global.get 0)
538                                    (i64.const 25)))
539                            (i32.const 1)
540                            (drop)
541                            (i32.const 1)
542                            (drop)
543                            (br 0))))
544                (if
545                    (i64.lt_u
546                        (global.get 0)
547                        (i64.const 1))
548                    (then
549                        (global.set 1
550                            (i64.const 1))
551                        (unreachable)))
552                (global.set 0
553                    (i64.sub
554                        (global.get 0)
555                        (i64.const 1)))
556                (i32.const 1))
557            (global (;0;) (mut i64) (i64.const 0))
558            (global (;1;) (mut i64) (i64.const 0))
559            (export "gas_limit" (global 0))
560            (export "gas_limit_exhausted" (global 1)))
561        "#
562    }
563
564    test_transform! {
565        name = if_else,
566        source = r#"
567        (module
568            (func (result i32)
569                (i32.const 1)
570                (if
571                    (then
572                        (i32.const 1)
573                        (drop)
574                        (i32.const 1)
575                        (drop))
576                    (else
577                        (i32.const 1)
578                        (drop)))
579                (i32.const 1)))
580        "#,
581        expected = r#"
582        (module
583            (func (result i32)
584                (if
585                    (i64.lt_u
586                        (global.get 0)
587                        (i64.const 11))
588                    (then
589                        (global.set 1
590                            (i64.const 11))
591                        (unreachable)))
592                (global.set 0
593                    (i64.sub
594                        (global.get 0)
595                        (i64.const 11)))
596                (i32.const 1)
597                (if
598                    (then
599                        (if
600                            (i64.lt_u
601                                (global.get 0)
602                                (i64.const 22))
603                            (then
604                                (global.set 1
605                                    (i64.const 22))
606                                (unreachable)))
607                        (global.set 0
608                            (i64.sub
609                                (global.get 0)
610                                (i64.const 22)))
611                        (i32.const 1)
612                        (drop)
613                        (i32.const 1)
614                        (drop)
615                    )
616                    (else
617                        (if
618                            (i64.lt_u
619                                (global.get 0)
620                                (i64.const 11))
621                            (then
622                                (global.set 1
623                                    (i64.const 11))
624                                (unreachable)))
625                        (global.set 0
626                            (i64.sub
627                                (global.get 0)
628                                (i64.const 11)))
629                        (i32.const 1)
630                        (drop)
631                    )
632                )
633                (if
634                    (i64.lt_u
635                        (global.get 0)
636                        (i64.const 1))
637                    (then
638                        (global.set 1
639                            (i64.const 1))
640                        (unreachable)))
641                (global.set 0
642                    (i64.sub
643                        (global.get 0)
644                        (i64.const 1)))
645                (i32.const 1))
646            (global (;0;) (mut i64) (i64.const 0))
647            (global (;1;) (mut i64) (i64.const 0))
648            (export "gas_limit" (global 0))
649            (export "gas_limit_exhausted" (global 1)))
650        "#
651    }
652}