oasis_runtime_sdk_contracts/abi/
gas.rs1use std::collections::BTreeMap;
3
4use walrus::{ir::*, FunctionBuilder, GlobalId, LocalFunction, Module};
5
6use crate::Error;
7
8pub const EXPORT_GAS_LIMIT: &str = "gas_limit";
10pub const EXPORT_GAS_LIMIT_EXHAUSTED: &str = "gas_limit_exhausted";
12
13pub fn set_gas_limit<C>(
15 instance: &wasm3::Instance<'_, '_, C>,
16 gas_limit: u64,
17) -> Result<(), Error> {
18 instance
19 .set_global(EXPORT_GAS_LIMIT, gas_limit)
20 .map_err(|err| Error::ExecutionFailed(err.into()))
21}
22
23pub fn get_remaining_gas<C>(instance: &wasm3::Instance<'_, '_, C>) -> u64 {
25 instance.get_global(EXPORT_GAS_LIMIT).unwrap_or_default()
26}
27
28pub fn get_exhausted_amount<C>(instance: &wasm3::Instance<'_, '_, C>) -> u64 {
30 instance
31 .get_global(EXPORT_GAS_LIMIT_EXHAUSTED)
32 .unwrap_or_default()
33}
34
35pub fn use_gas<C>(instance: &wasm3::Instance<'_, '_, C>, amount: u64) -> Result<(), wasm3::Trap> {
37 let gas_limit: u64 = instance
38 .get_global(EXPORT_GAS_LIMIT)
39 .map_err(|_| wasm3::Trap::Abort)?;
40 if gas_limit < amount {
41 let _ = instance.set_global(EXPORT_GAS_LIMIT_EXHAUSTED, amount);
42 return Err(wasm3::Trap::Abort);
43 }
44 instance
45 .set_global(EXPORT_GAS_LIMIT, gas_limit - amount)
46 .map_err(|_| wasm3::Trap::Abort)?;
47 Ok(())
48}
49
50pub fn transform(module: &mut Module) {
52 let gas_limit_global = module.globals.add_local(
53 walrus::ValType::I64,
54 true,
55 walrus::InitExpr::Value(Value::I64(0)),
56 );
57 let gas_limit_exhausted_global = module.globals.add_local(
58 walrus::ValType::I64,
59 true,
60 walrus::InitExpr::Value(Value::I64(0)),
61 );
62 module.exports.add(EXPORT_GAS_LIMIT, gas_limit_global);
63 module
64 .exports
65 .add(EXPORT_GAS_LIMIT_EXHAUSTED, gas_limit_exhausted_global);
66
67 for (_, func) in module.funcs.iter_local_mut() {
68 transform_function(func, gas_limit_global, gas_limit_exhausted_global);
69 }
70}
71
72fn instruction_cost(instr: &Instr) -> u64 {
74 match instr {
75 Instr::Loop(_) | Instr::Block(_) => 2,
76
77 Instr::LocalGet(_)
78 | Instr::LocalSet(_)
79 | Instr::LocalTee(_)
80 | Instr::GlobalGet(_)
81 | Instr::GlobalSet(_)
82 | Instr::Const(_) => 1,
83
84 Instr::Call(_) => 15,
85
86 Instr::CallIndirect(_) => 20,
87
88 Instr::Br(_) | Instr::BrIf(_) => 3,
89 Instr::BrTable(_) => 4,
90
91 Instr::Binop(op) => match op.op {
92 BinaryOp::I32Eq
93 | BinaryOp::I32Ne
94 | BinaryOp::I32LtS
95 | BinaryOp::I32LtU
96 | BinaryOp::I32GtS
97 | BinaryOp::I32GtU
98 | BinaryOp::I32LeS
99 | BinaryOp::I32LeU
100 | BinaryOp::I32GeS
101 | BinaryOp::I32GeU
102 | BinaryOp::I64Eq
103 | BinaryOp::I64Ne
104 | BinaryOp::I64LtS
105 | BinaryOp::I64LtU
106 | BinaryOp::I64GtS
107 | BinaryOp::I64GtU
108 | BinaryOp::I64LeS
109 | BinaryOp::I64LeU
110 | BinaryOp::I64GeS
111 | BinaryOp::I64GeU
112 | BinaryOp::I32Add
113 | BinaryOp::I32Sub
114 | BinaryOp::I32Mul
115 | BinaryOp::I32And
116 | BinaryOp::I32Or
117 | BinaryOp::I32Xor
118 | BinaryOp::I32Shl
119 | BinaryOp::I32ShrS
120 | BinaryOp::I32ShrU
121 | BinaryOp::I32Rotl
122 | BinaryOp::I32Rotr
123 | BinaryOp::I64Add
124 | BinaryOp::I64Sub
125 | BinaryOp::I64Mul
126 | BinaryOp::I64And
127 | BinaryOp::I64Or
128 | BinaryOp::I64Xor
129 | BinaryOp::I64Shl
130 | BinaryOp::I64ShrS
131 | BinaryOp::I64ShrU
132 | BinaryOp::I64Rotl
133 | BinaryOp::I64Rotr => 1,
134 BinaryOp::I32DivS
135 | BinaryOp::I32DivU
136 | BinaryOp::I32RemS
137 | BinaryOp::I32RemU
138 | BinaryOp::I64DivS
139 | BinaryOp::I64DivU
140 | BinaryOp::I64RemS
141 | BinaryOp::I64RemU => 4,
142 _ => 3,
143 },
144
145 Instr::Unop(op) => match op.op {
146 UnaryOp::I32Eqz
147 | UnaryOp::I32Clz
148 | UnaryOp::I32Ctz
149 | UnaryOp::I32Popcnt
150 | UnaryOp::I64Eqz
151 | UnaryOp::I64Clz
152 | UnaryOp::I64Ctz
153 | UnaryOp::I64Popcnt => 1,
154
155 _ => 3,
156 },
157
158 _ => 10,
159 }
160}
161
162#[derive(Debug)]
164struct MeteredBlock {
165 seq_id: InstrSeqId,
167 start_index: usize,
170 cost: u64,
172 merge_index: Option<usize>,
175}
176
177impl MeteredBlock {
178 fn new(seq_id: InstrSeqId, start_index: usize) -> Self {
179 Self {
180 seq_id,
181 start_index,
182 cost: 0,
183 merge_index: None,
184 }
185 }
186
187 fn mergable(&self, start_index: usize) -> Self {
189 Self {
190 seq_id: self.seq_id,
191 start_index,
192 cost: 0,
193 merge_index: Some(self.start_index),
194 }
195 }
196}
197
198#[derive(Default)]
200struct MeteredBlocks {
201 blocks: BTreeMap<InstrSeqId, Vec<MeteredBlock>>,
202}
203
204impl MeteredBlocks {
205 fn finalize(&mut self, block: MeteredBlock) {
208 if block.cost > 0 {
209 self.blocks.entry(block.seq_id).or_default().push(block);
210 }
211 }
212}
213
214fn determine_metered_blocks(func: &LocalFunction) -> BTreeMap<InstrSeqId, Vec<MeteredBlock>> {
215 let mut blocks = MeteredBlocks::default();
218 let mut stack: Vec<(InstrSeqId, usize, MeteredBlock)> = vec![(
219 func.entry_block(), 0, MeteredBlock::new(func.entry_block(), 0), )];
223
224 'traversing_blocks: while let Some((seq_id, index, mut metered_block)) = stack.pop() {
225 let seq = func.block(seq_id);
226
227 'traversing_instrs: for (index, (instr, _)) in seq.instrs.iter().enumerate().skip(index) {
228 metered_block.cost += instruction_cost(instr);
230
231 match instr {
233 Instr::Block(Block { seq }) => {
234 stack.push((seq_id, index + 1, metered_block.mergable(index + 1)));
238 stack.push((*seq, 0, metered_block));
239 continue 'traversing_blocks;
240 }
241
242 Instr::Loop(Loop { seq }) => {
243 blocks.finalize(metered_block);
245 stack.push((seq_id, index + 1, MeteredBlock::new(seq_id, index + 1)));
247 stack.push((*seq, 0, MeteredBlock::new(*seq, 0)));
249 continue 'traversing_blocks;
250 }
251
252 Instr::IfElse(IfElse {
253 consequent,
254 alternative,
255 }) => {
256 blocks.finalize(metered_block);
258
259 stack.push((seq_id, index + 1, MeteredBlock::new(seq_id, index + 1)));
261 stack.push((*alternative, 0, MeteredBlock::new(*alternative, 0)));
263 stack.push((*consequent, 0, MeteredBlock::new(*consequent, 0)));
264 continue 'traversing_blocks;
265 }
266
267 Instr::Call(_)
268 | Instr::CallIndirect(_)
269 | Instr::Br(_)
270 | Instr::BrIf(_)
271 | Instr::BrTable(_)
272 | Instr::Return(_) => {
273 blocks.finalize(std::mem::replace(
275 &mut metered_block,
276 MeteredBlock::new(seq_id, index + 1),
277 ));
278 continue 'traversing_instrs;
279 }
280
281 _ => continue 'traversing_instrs,
282 }
283 }
284
285 if let Some((_, _, upper)) = stack.last_mut() {
287 match upper.merge_index {
288 Some(index)
289 if upper.seq_id == metered_block.seq_id
290 && index == metered_block.start_index =>
291 {
292 *upper = metered_block;
294 continue 'traversing_blocks;
295 }
296 _ => {
297 }
299 }
300 }
301
302 blocks.finalize(metered_block);
303 }
304
305 blocks.blocks
306}
307
308fn transform_function(
309 func: &mut LocalFunction,
310 gas_limit_global: GlobalId,
311 gas_limit_exhausted_global: GlobalId,
312) {
313 let blocks = determine_metered_blocks(func);
315
316 let builder = func.builder_mut();
318 for (seq_id, blocks) in blocks {
319 let mut seq = builder.instr_seq(seq_id);
320 let instrs = seq.instrs_mut();
321
322 let original_instrs = std::mem::take(instrs);
323 let new_instrs_len = instrs.len() + METERING_INSTRUCTION_COUNT * blocks.len();
324 let mut new_instrs = Vec::with_capacity(new_instrs_len);
325
326 let mut block_iter = blocks.into_iter().peekable();
327 for (index, (instr, loc)) in original_instrs.into_iter().enumerate() {
328 match block_iter.peek() {
329 Some(block) if block.start_index == index => {
330 inject_metering(
331 builder,
332 &mut new_instrs,
333 block_iter.next().unwrap(),
334 gas_limit_global,
335 gas_limit_exhausted_global,
336 );
337 }
338 _ => {}
339 }
340
341 new_instrs.push((instr, loc));
343 }
344
345 let mut seq = builder.instr_seq(seq_id);
346 let instrs = seq.instrs_mut();
347 *instrs = new_instrs;
348 }
349}
350
351const METERING_INSTRUCTION_COUNT: usize = 8;
353
354fn inject_metering(
355 builder: &mut FunctionBuilder,
356 instrs: &mut Vec<(Instr, InstrLocId)>,
357 block: MeteredBlock,
358 gas_limit_global: GlobalId,
359 gas_limit_exhausted_global: GlobalId,
360) {
361 let mut builder = builder.dangling_instr_seq(None);
362 let seq = builder
363 .global_get(gas_limit_global)
365 .i64_const(block.cost as i64)
366 .binop(BinaryOp::I64LtU)
367 .if_else(
368 None,
369 |then| {
370 then.i64_const(block.cost as i64)
371 .global_set(gas_limit_exhausted_global)
372 .unreachable();
373 },
374 |_else| {},
375 )
376 .global_get(gas_limit_global)
378 .i64_const(block.cost as i64)
379 .binop(BinaryOp::I64Sub)
380 .global_set(gas_limit_global);
381
382 instrs.append(seq.instrs_mut());
383}
384
385#[cfg(test)]
386mod test {
387 use pretty_assertions::assert_eq;
388
389 macro_rules! test_transform {
390 (name = $name:ident, source = $src:expr, expected = $expected:expr) => {
391 #[test]
392 fn $name() {
393 let src = wat::parse_str($src).unwrap();
394 let expected = wat::parse_str($expected).unwrap();
395
396 let mut result_module = walrus::ModuleConfig::new()
397 .generate_producers_section(false)
398 .parse(&src)
399 .unwrap();
400
401 super::transform(&mut result_module);
402
403 let mut expected_module = walrus::ModuleConfig::new()
404 .generate_producers_section(false)
405 .parse(&expected)
406 .unwrap();
407
408 let result_wasm = result_module.emit_wasm();
409 let expected_wasm = expected_module.emit_wasm();
410 let result = wasmprinter::print_bytes(&result_wasm).unwrap();
411 let expected = wasmprinter::print_bytes(&expected_wasm).unwrap();
412
413 assert_eq!(result, expected);
414 }
415 };
416 }
417
418 test_transform! {
419 name = simple,
420 source = r#"
421 (module
422 (func (result i32)
423 (i32.const 1)))
424 "#,
425 expected = r#"
426 (module
427 (func (result i32)
428 (if
429 (i64.lt_u
430 (global.get 0)
431 (i64.const 1))
432 (then
433 (global.set 1
434 (i64.const 1))
435 (unreachable)))
436 (global.set 0
437 (i64.sub
438 (global.get 0)
439 (i64.const 1)))
440 (i32.const 1))
441 (global (;0;) (mut i64) (i64.const 0))
442 (global (;1;) (mut i64) (i64.const 0))
443 (export "gas_limit" (global 0))
444 (export "gas_limit_exhausted" (global 1)))
445 "#
446 }
447
448 test_transform! {
449 name = nested_blocks,
450 source = r#"
451 (module
452 (func (result i32)
453 (block
454 (block
455 (block
456 (i32.const 1)
457 (drop))))
458 (i32.const 1)))
459 "#,
460 expected = r#"
461 (module
462 (func (result i32)
463 (if
464 (i64.lt_u
465 (global.get 0)
466 (i64.const 18))
467 (then
468 (global.set 1
469 (i64.const 18))
470 (unreachable)))
471 (global.set 0
472 (i64.sub
473 (global.get 0)
474 (i64.const 18)))
475 (block
476 (block
477 (block
478 (i32.const 1)
479 (drop))))
480 (i32.const 1))
481 (global (;0;) (mut i64) (i64.const 0))
482 (global (;1;) (mut i64) (i64.const 0))
483 (export "gas_limit" (global 0))
484 (export "gas_limit_exhausted" (global 1)))
485 "#
486 }
487
488 test_transform! {
489 name = nested_blocks_with_loop,
490 source = r#"
491 (module
492 (func (result i32)
493 (block
494 (block
495 (block
496 (i32.const 1)
497 (drop))
498 (loop
499 (i32.const 1)
500 (drop)
501 (i32.const 1)
502 (drop)
503 (br 0))))
504 (i32.const 1)))
505 "#,
506 expected = r#"
507 (module
508 (func (result i32)
509 (if
510 (i64.lt_u
511 (global.get 0)
512 (i64.const 19))
513 (then
514 (global.set 1
515 (i64.const 19))
516 (unreachable)))
517 (global.set 0
518 (i64.sub
519 (global.get 0)
520 (i64.const 19)))
521 (block
522 (block
523 (block
524 (i32.const 1)
525 (drop))
526 (loop
527 (if
528 (i64.lt_u
529 (global.get 0)
530 (i64.const 25))
531 (then
532 (global.set 1
533 (i64.const 25))
534 (unreachable)))
535 (global.set 0
536 (i64.sub
537 (global.get 0)
538 (i64.const 25)))
539 (i32.const 1)
540 (drop)
541 (i32.const 1)
542 (drop)
543 (br 0))))
544 (if
545 (i64.lt_u
546 (global.get 0)
547 (i64.const 1))
548 (then
549 (global.set 1
550 (i64.const 1))
551 (unreachable)))
552 (global.set 0
553 (i64.sub
554 (global.get 0)
555 (i64.const 1)))
556 (i32.const 1))
557 (global (;0;) (mut i64) (i64.const 0))
558 (global (;1;) (mut i64) (i64.const 0))
559 (export "gas_limit" (global 0))
560 (export "gas_limit_exhausted" (global 1)))
561 "#
562 }
563
564 test_transform! {
565 name = if_else,
566 source = r#"
567 (module
568 (func (result i32)
569 (i32.const 1)
570 (if
571 (then
572 (i32.const 1)
573 (drop)
574 (i32.const 1)
575 (drop))
576 (else
577 (i32.const 1)
578 (drop)))
579 (i32.const 1)))
580 "#,
581 expected = r#"
582 (module
583 (func (result i32)
584 (if
585 (i64.lt_u
586 (global.get 0)
587 (i64.const 11))
588 (then
589 (global.set 1
590 (i64.const 11))
591 (unreachable)))
592 (global.set 0
593 (i64.sub
594 (global.get 0)
595 (i64.const 11)))
596 (i32.const 1)
597 (if
598 (then
599 (if
600 (i64.lt_u
601 (global.get 0)
602 (i64.const 22))
603 (then
604 (global.set 1
605 (i64.const 22))
606 (unreachable)))
607 (global.set 0
608 (i64.sub
609 (global.get 0)
610 (i64.const 22)))
611 (i32.const 1)
612 (drop)
613 (i32.const 1)
614 (drop)
615 )
616 (else
617 (if
618 (i64.lt_u
619 (global.get 0)
620 (i64.const 11))
621 (then
622 (global.set 1
623 (i64.const 11))
624 (unreachable)))
625 (global.set 0
626 (i64.sub
627 (global.get 0)
628 (i64.const 11)))
629 (i32.const 1)
630 (drop)
631 )
632 )
633 (if
634 (i64.lt_u
635 (global.get 0)
636 (i64.const 1))
637 (then
638 (global.set 1
639 (i64.const 1))
640 (unreachable)))
641 (global.set 0
642 (i64.sub
643 (global.get 0)
644 (i64.const 1)))
645 (i32.const 1))
646 (global (;0;) (mut i64) (i64.const 0))
647 (global (;1;) (mut i64) (i64.const 0))
648 (export "gas_limit" (global 0))
649 (export "gas_limit_exhausted" (global 1)))
650 "#
651 }
652}