oasis_rofl_client/
lib.rs

1// rofl-client/rs/src/lib.rs
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4
5const DEFAULT_SOCKET: &str = "/run/rofl-appd.sock";
6
7#[derive(Clone)]
8pub struct RoflClient {
9    socket_path: String,
10}
11
12impl RoflClient {
13    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
14        Self::with_socket_path(DEFAULT_SOCKET)
15    }
16
17    pub fn with_socket_path<P: AsRef<Path>>(
18        socket_path: P,
19    ) -> Result<Self, Box<dyn std::error::Error>> {
20        let socket_path = socket_path.as_ref().to_string_lossy().to_string();
21        if !Path::new(&socket_path).exists() {
22            return Err(format!("Socket not found at: {socket_path}").into());
23        }
24        Ok(Self { socket_path })
25    }
26
27    // GET /rofl/v1/app/id
28    pub async fn get_app_id(&self) -> Result<String, Box<dyn std::error::Error>> {
29        let sock = self.socket_path.clone();
30        let res = tokio::task::spawn_blocking(move || -> std::io::Result<String> {
31            let body = http_unix_request(&sock, "GET", "/rofl/v1/app/id", None, None)?;
32            let s = String::from_utf8(body)
33                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
34            Ok(s.trim().to_string())
35        })
36        .await
37        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
38        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
39        Ok(res)
40    }
41
42    // POST /rofl/v1/keys/generate
43    pub async fn generate_key(
44        &self,
45        key_id: &str,
46        kind: KeyKind,
47    ) -> Result<String, Box<dyn std::error::Error>> {
48        let sock = self.socket_path.clone();
49        let req = serde_json::to_vec(&KeyGenerationRequest {
50            key_id: key_id.to_string(),
51            kind: kind.to_string(),
52        })?;
53        let res = tokio::task::spawn_blocking(move || -> std::io::Result<String> {
54            let body = http_unix_request(
55                &sock,
56                "POST",
57                "/rofl/v1/keys/generate",
58                Some(&req),
59                Some("application/json"),
60            )?;
61            let resp: KeyGenerationResponse = serde_json::from_slice(&body)
62                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
63            Ok(resp.key)
64        })
65        .await
66        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
67        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
68        Ok(res)
69    }
70
71    // POST /rofl/v1/tx/sign-submit
72    pub async fn sign_submit(
73        &self,
74        tx: Tx,
75        encrypt: Option<bool>,
76    ) -> Result<String, Box<dyn std::error::Error>> {
77        let sock = self.socket_path.clone();
78        let req = serde_json::to_vec(&SignSubmitRequest { tx, encrypt })?;
79        let res = tokio::task::spawn_blocking(move || -> std::io::Result<String> {
80            let body = http_unix_request(
81                &sock,
82                "POST",
83                "/rofl/v1/tx/sign-submit",
84                Some(&req),
85                Some("application/json"),
86            )?;
87            let resp: SignSubmitResponse = serde_json::from_slice(&body)
88                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
89            Ok(resp.data)
90        })
91        .await
92        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
93        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
94        Ok(res)
95    }
96
97    // GET /rofl/v1/metadata
98    pub async fn get_metadata(
99        &self,
100    ) -> Result<std::collections::HashMap<String, String>, Box<dyn std::error::Error>> {
101        let sock = self.socket_path.clone();
102        let res = tokio::task::spawn_blocking(
103            move || -> std::io::Result<std::collections::HashMap<String, String>> {
104                let body = http_unix_request(&sock, "GET", "/rofl/v1/metadata", None, None)?;
105                let resp: std::collections::HashMap<String, String> = serde_json::from_slice(&body)
106                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
107                Ok(resp)
108            },
109        )
110        .await
111        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
112        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
113        Ok(res)
114    }
115
116    // POST /rofl/v1/metadata
117    pub async fn set_metadata(
118        &self,
119        metadata: &std::collections::HashMap<String, String>,
120    ) -> Result<(), Box<dyn std::error::Error>> {
121        let sock = self.socket_path.clone();
122        let req = serde_json::to_vec(metadata)?;
123        tokio::task::spawn_blocking(move || -> std::io::Result<()> {
124            let _body = http_unix_request(
125                &sock,
126                "POST",
127                "/rofl/v1/metadata",
128                Some(&req),
129                Some("application/json"),
130            )?;
131            Ok(())
132        })
133        .await
134        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
135        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
136        Ok(())
137    }
138
139    // POST /rofl/v1/query
140    pub async fn query(
141        &self,
142        method: &str,
143        args: &[u8],
144    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
145        let sock = self.socket_path.clone();
146        let payload = serde_json::json!({
147            "method": method,
148            "args": hex::encode(args),
149        });
150        let req = serde_json::to_vec(&payload)?;
151        let res = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
152            let body = http_unix_request(
153                &sock,
154                "POST",
155                "/rofl/v1/query",
156                Some(&req),
157                Some("application/json"),
158            )?;
159            let resp: serde_json::Value = serde_json::from_slice(&body)
160                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
161            let data_hex = resp.get("data").and_then(|v| v.as_str()).ok_or_else(|| {
162                std::io::Error::new(std::io::ErrorKind::InvalidData, "Missing 'data' field")
163            })?;
164            let data = hex::decode(data_hex)
165                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
166            Ok(data)
167        })
168        .await
169        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?
170        .map_err(|e| -> Box<dyn std::error::Error> { Box::new(e) })?;
171        Ok(res)
172    }
173
174    /// Convenience helper for ETH-style call
175    pub async fn sign_submit_eth(
176        &self,
177        gas_limit: u64,
178        to: &str,
179        value: &str,
180        data_hex: &str,
181        encrypt: Option<bool>,
182    ) -> Result<String, Box<dyn std::error::Error>> {
183        let eth = EthCall {
184            gas_limit,
185            to: to.to_string(),
186            value: value.to_string(),
187            data: data_hex.to_string(),
188        };
189        self.sign_submit(Tx::Eth(eth), encrypt).await
190    }
191}
192
193// Blocking HTTP-over-UDS request using std::os::unix::net::UnixStream
194fn http_unix_request(
195    socket_path: &str,
196    method: &str,
197    path: &str,
198    body: Option<&[u8]>,
199    content_type: Option<&str>,
200) -> std::io::Result<Vec<u8>> {
201    use std::{
202        io::{Error, ErrorKind, Read, Write},
203        os::unix::net::UnixStream,
204    };
205
206    let mut stream = UnixStream::connect(socket_path)?;
207
208    let mut req = Vec::new();
209    req.extend_from_slice(format!("{method} {path} HTTP/1.1\r\n").as_bytes());
210    req.extend_from_slice(b"Host: localhost\r\n");
211    req.extend_from_slice(b"Connection: close\r\n");
212    if let Some(ct) = content_type {
213        req.extend_from_slice(format!("Content-Type: {ct}\r\n").as_bytes());
214    }
215    if let Some(b) = body {
216        req.extend_from_slice(format!("Content-Length: {}\r\n", b.len()).as_bytes());
217    }
218    req.extend_from_slice(b"\r\n");
219    if let Some(b) = body {
220        req.extend_from_slice(b);
221    }
222
223    stream.write_all(&req)?;
224    stream.flush()?;
225
226    let mut resp = Vec::new();
227    let mut buf = [0u8; 8192];
228    loop {
229        let n = stream.read(&mut buf)?;
230        if n == 0 {
231            break;
232        }
233        resp.extend_from_slice(&buf[..n]);
234    }
235
236    let header_end = resp
237        .windows(4)
238        .position(|w| w == b"\r\n\r\n")
239        .ok_or_else(|| {
240            Error::new(
241                ErrorKind::InvalidData,
242                "Invalid HTTP response: no header/body delimiter",
243            )
244        })?;
245    let (head, body_bytes) = resp.split_at(header_end + 4);
246
247    let mut lines = head.split(|&b| b == b'\n');
248    let status_line = lines
249        .next()
250        .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Invalid HTTP response: empty"))?;
251    let status_str = String::from_utf8(status_line.to_vec())
252        .map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
253    let code: u16 = status_str
254        .split_whitespace()
255        .nth(1)
256        .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Invalid HTTP status line"))?
257        .parse()
258        .map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
259
260    if !(200..300).contains(&code) {
261        let msg = String::from_utf8_lossy(body_bytes).to_string();
262        return Err(Error::other(format!("HTTP {code} error: {msg}")));
263    }
264
265    Ok(body_bytes.to_vec())
266}
267
268// See https://github.com/oasisprotocol/oasis-sdk/blob/1ae8882b05d10a44398e52b5b8c56ab2086f81b3/rofl-appd/src/services/kms.rs#L59-L74
269#[derive(Debug, Clone, Serialize, Deserialize)]
270#[serde(rename_all = "kebab-case")]
271pub enum KeyKind {
272    Raw256,
273    Raw384,
274    Ed25519,
275    Secp256k1,
276}
277
278impl std::fmt::Display for KeyKind {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            KeyKind::Raw256 => write!(f, "raw-256"),
282            KeyKind::Raw384 => write!(f, "raw-384"),
283            KeyKind::Ed25519 => write!(f, "ed25519"),
284            KeyKind::Secp256k1 => write!(f, "secp256k1"),
285        }
286    }
287}
288
289#[derive(Debug, Serialize)]
290struct KeyGenerationRequest {
291    key_id: String,
292    kind: String,
293}
294
295#[derive(Debug, Deserialize)]
296struct KeyGenerationResponse {
297    key: String,
298}
299
300// -------------------- sign-submit types --------------------
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303#[serde(tag = "kind", content = "data")]
304pub enum Tx {
305    #[serde(rename = "eth")]
306    Eth(EthCall),
307    #[serde(rename = "std")]
308    Std(String), // CBOR-serialized hex-encoded Transaction
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct EthCall {
313    pub gas_limit: u64,
314    pub to: String,
315    pub value: String,
316    pub data: String, // hex string without 0x prefix
317}
318
319#[derive(Debug, Serialize)]
320struct SignSubmitRequest {
321    pub tx: Tx,
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub encrypt: Option<bool>,
324}
325
326#[derive(Debug, Deserialize)]
327struct SignSubmitResponse {
328    pub data: String, // CBOR-serialized hex-encoded call result
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use std::{
335        io::{Read, Write},
336        os::unix::net::UnixListener,
337        thread,
338    };
339    use tempfile::TempDir;
340
341    fn setup_mock_server(responses: Vec<(String, String)>) -> (TempDir, String) {
342        let temp_dir = TempDir::new().unwrap();
343        let socket_path = temp_dir.path().join("test.sock");
344        let socket_path_str = socket_path.to_string_lossy().to_string();
345
346        let listener = UnixListener::bind(&socket_path).unwrap();
347
348        thread::spawn(move || {
349            for (expected_path, response) in responses {
350                if let Ok((mut stream, _)) = listener.accept() {
351                    let mut buf = vec![0u8; 4096];
352                    let n = stream.read(&mut buf).unwrap();
353                    let request = String::from_utf8_lossy(&buf[..n]);
354
355                    // Check if the request contains the expected path
356                    assert!(request.contains(&expected_path));
357
358                    let http_response = format!(
359                        "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
360                        response.len(),
361                        response
362                    );
363                    stream.write_all(http_response.as_bytes()).unwrap();
364                }
365            }
366        });
367
368        // Give the server time to start
369        thread::sleep(std::time::Duration::from_millis(100));
370
371        (temp_dir, socket_path_str)
372    }
373
374    #[tokio::test]
375    async fn test_get_app_id() {
376        let (_temp_dir, socket_path) = setup_mock_server(vec![(
377            "/rofl/v1/app/id".to_string(),
378            "oasis1qr677rv0dcnh7ys4yanlynysvnjtk9gnsyhvm5wj".to_string(),
379        )]);
380
381        let client = RoflClient::with_socket_path(&socket_path).unwrap();
382        let app_id = client.get_app_id().await.unwrap();
383
384        assert_eq!(app_id, "oasis1qr677rv0dcnh7ys4yanlynysvnjtk9gnsyhvm5wj");
385    }
386
387    #[tokio::test]
388    async fn test_generate_key() {
389        let response = r#"{"key":"0x123456789abcdef"}"#;
390        let (_temp_dir, socket_path) = setup_mock_server(vec![(
391            "/rofl/v1/keys/generate".to_string(),
392            response.to_string(),
393        )]);
394
395        let client = RoflClient::with_socket_path(&socket_path).unwrap();
396        let key = client
397            .generate_key("test-key-id", KeyKind::Secp256k1)
398            .await
399            .unwrap();
400
401        assert_eq!(key, "0x123456789abcdef");
402    }
403
404    #[tokio::test]
405    async fn test_get_metadata() {
406        let response = r#"{"key1":"value1","key2":"value2"}"#;
407        let (_temp_dir, socket_path) = setup_mock_server(vec![(
408            "/rofl/v1/metadata".to_string(),
409            response.to_string(),
410        )]);
411
412        let client = RoflClient::with_socket_path(&socket_path).unwrap();
413        let metadata = client.get_metadata().await.unwrap();
414
415        assert_eq!(metadata.get("key1").unwrap(), "value1");
416        assert_eq!(metadata.get("key2").unwrap(), "value2");
417    }
418
419    #[tokio::test]
420    async fn test_set_metadata() {
421        let (_temp_dir, socket_path) =
422            setup_mock_server(vec![("/rofl/v1/metadata".to_string(), "".to_string())]);
423
424        let client = RoflClient::with_socket_path(&socket_path).unwrap();
425        let mut metadata = std::collections::HashMap::new();
426        metadata.insert("new_key".to_string(), "new_value".to_string());
427
428        client.set_metadata(&metadata).await.unwrap();
429    }
430
431    #[tokio::test]
432    async fn test_query() {
433        let response = r#"{"data":"48656c6c6f"}"#;
434        let (_temp_dir, socket_path) =
435            setup_mock_server(vec![("/rofl/v1/query".to_string(), response.to_string())]);
436
437        let client = RoflClient::with_socket_path(&socket_path).unwrap();
438        let args = b"\xa1\x64test\x65value";
439        let result = client.query("test.Method", args).await.unwrap();
440
441        assert_eq!(result, b"Hello");
442    }
443
444    #[tokio::test]
445    async fn test_bad_socket_path() {
446        let result = RoflClient::with_socket_path("/non/existent/socket.sock");
447        assert!(result.is_err());
448    }
449}