1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
//! Domain separation context helpers.
use std::sync::Mutex;

use once_cell::sync::Lazy;

use oasis_core_runtime::common::{crypto::hash::Hash, namespace::Namespace};

const CHAIN_CONTEXT_SEPARATOR: &[u8] = b" for chain ";

static CHAIN_CONTEXT: Lazy<Mutex<Option<Vec<u8>>>> = Lazy::new(Default::default);

/// Return the globally configured chain domain separation context.
///
/// The returned domain separation context is computed as:
///
/// ```plain
/// <base> || " for chain " || <chain-context>
/// ```
///
/// # Panics
///
/// This function will panic in case the global chain domain separation context was not previously
/// set using `set_chain_context`.
///
pub fn get_chain_context_for(base: &[u8]) -> Vec<u8> {
    let guard = CHAIN_CONTEXT.lock().unwrap();
    let chain_context = match guard.as_ref() {
        Some(cc) => cc,
        None => {
            drop(guard); // Avoid poisioning the global lock.
            panic!("chain domain separation context must be configured");
        }
    };

    let mut ctx = vec![0; base.len() + CHAIN_CONTEXT_SEPARATOR.len() + chain_context.len()];
    ctx[..base.len()].copy_from_slice(base);
    ctx[base.len()..base.len() + CHAIN_CONTEXT_SEPARATOR.len()]
        .copy_from_slice(CHAIN_CONTEXT_SEPARATOR);
    ctx[base.len() + CHAIN_CONTEXT_SEPARATOR.len()..].copy_from_slice(chain_context);
    ctx
}

/// Configure the global chain domain separation context.
///
/// The domain separation context is computed as:
///
/// ```plain
/// Base-16(H(<runtime-id> || <consensus-chain-context>))
/// ```
///
/// # Panics
///
/// This function will panic in case the global chain domain separation context was already set.
///
pub fn set_chain_context(runtime_id: Namespace, consensus_chain_context: &str) {
    let ctx = hex::encode(Hash::digest_bytes_list(&[
        runtime_id.as_ref(),
        consensus_chain_context.as_bytes(),
    ]));
    let mut guard = CHAIN_CONTEXT.lock().unwrap();
    if let Some(ref existing) = *guard {
        if cfg!(any(test, feature = "test")) && existing == ctx.as_bytes() {
            return;
        }
        let ex = String::from_utf8(existing.clone()).unwrap();
        drop(guard); // Avoid poisioning the global lock.
        panic!("chain domain separation context already set: {ex}");
    }
    *guard = Some(ctx.into_bytes());
}

/// Test helper to serialize unit tests using the global chain context. The chain context is reset
/// when this method is called.
///
/// # Example
///
/// ```rust
/// # use oasis_runtime_sdk::crypto::signature::context::test_using_chain_context;
/// let _guard = test_using_chain_context();
/// // ... rest of the test code follows ...
/// ```
#[cfg(any(test, feature = "test"))]
pub fn test_using_chain_context() -> std::sync::MutexGuard<'static, ()> {
    static TEST_USING_CHAIN_CONTEXT: Lazy<Mutex<()>> = Lazy::new(Default::default);
    let guard = TEST_USING_CHAIN_CONTEXT.lock().unwrap();
    *CHAIN_CONTEXT.lock().unwrap() = None;

    guard
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_chain_context() {
        let _guard = test_using_chain_context();
        set_chain_context(
            "8000000000000000000000000000000000000000000000000000000000000000".into(),
            "643fb06848be7e970af3b5b2d772eb8cfb30499c8162bc18ac03df2f5e22520e",
        );

        let ctx = get_chain_context_for(b"oasis-runtime-sdk/tx: v0");
        assert_eq!(&String::from_utf8(ctx).unwrap(), "oasis-runtime-sdk/tx: v0 for chain ca4842870b97a6d5c0d025adce0b6a0dec94d2ba192ede70f96349cfbe3628b9");
    }

    #[test]
    fn test_chain_context_not_configured() {
        let _guard = test_using_chain_context();

        let result = std::panic::catch_unwind(|| get_chain_context_for(b"test"));
        assert!(result.is_err());
    }

    #[test]
    fn test_chain_context_already_configured() {
        let _guard = test_using_chain_context();
        set_chain_context(
            "8000000000000000000000000000000000000000000000000000000000000000".into(),
            "643fb06848be7e970af3b5b2d772eb8cfb30499c8162bc18ac03df2f5e22520e",
        );

        let result = std::panic::catch_unwind(|| {
            set_chain_context(
                "8000000000000000000000000000000000000000000000000000000000000001".into(),
                "643fb06848be7e970af3b5b2d772eb8cfb30499c8162bc18ac03df2f5e22520e",
            )
        });
        assert!(result.is_err());
    }
}