rofl_utils/
https.rs

1//! A very simple HTTPS client that can be used inside ROFL apps.
2//!
3//! This simple client is needed because Fortanix EDP does not yet have support for mio/Tokio
4//! networking and so the usual `hyper` and `reqwest` cannot be used without patches.
5use std::{
6    fmt,
7    io::{Read, Write},
8    net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream},
9    sync::{Arc, OnceLock},
10};
11
12use rustls::{ClientConfig, ClientConnection, StreamOwned};
13use rustls_pki_types::ServerName;
14use ureq::{
15    http::Uri,
16    resolver::Resolver,
17    transport::{
18        time::NextTimeout, Buffers, ChainedConnector, ConnectionDetails, Connector, LazyBuffers,
19        Transport, TransportAdapter,
20    },
21    Agent, AgentConfig,
22};
23
24/// An `ureq::Agent` that can be used to perform blocking HTTPS requests.
25///
26/// Note that this forbids non-HTTPS requests. If you need to perform plain HTTP requests consider
27/// using `agent_with_config` and pass a suitable config.
28pub fn agent() -> Agent {
29    let cfg = AgentConfig {
30        https_only: true, // Not using HTTPS is unsafe unless careful.
31        user_agent: "rofl-utils/0.1.0".to_string(),
32        ..Default::default()
33    };
34    agent_with_config(cfg)
35}
36
37/// An `ureq::Agent` with given configuration that can be used to perform blocking HTTPS requests.
38pub fn agent_with_config(cfg: AgentConfig) -> Agent {
39    Agent::with_parts(
40        cfg,
41        ChainedConnector::new([SgxConnector.boxed(), RustlsConnector::default().boxed()]),
42        SgxResolver,
43    )
44}
45
46#[derive(Debug)]
47struct SgxConnector;
48
49impl Connector for SgxConnector {
50    fn connect(
51        &self,
52        details: &ConnectionDetails,
53        _chained: Option<Box<dyn Transport>>,
54    ) -> Result<Option<Box<dyn Transport>>, ureq::Error> {
55        let config = &details.config;
56        // Has already been validated.
57        let scheme = details.uri.scheme().unwrap();
58        let authority = details.uri.authority().unwrap();
59
60        let host_port = ureq::resolver::DefaultResolver::host_and_port(scheme, authority)
61            .ok_or(ureq::Error::HostNotFound)?;
62        let stream = TcpStream::connect(host_port)?;
63
64        let buffers = LazyBuffers::new(config.input_buffer_size, config.output_buffer_size);
65        let transport = TcpTransport::new(stream, buffers);
66
67        Ok(Some(Box::new(transport)))
68    }
69}
70
71struct TcpTransport {
72    stream: TcpStream,
73    buffers: LazyBuffers,
74}
75
76impl TcpTransport {
77    fn new(stream: TcpStream, buffers: LazyBuffers) -> TcpTransport {
78        TcpTransport { stream, buffers }
79    }
80}
81
82impl Transport for TcpTransport {
83    fn buffers(&mut self) -> &mut dyn Buffers {
84        &mut self.buffers
85    }
86
87    fn transmit_output(&mut self, amount: usize, _timeout: NextTimeout) -> Result<(), ureq::Error> {
88        let output = &self.buffers.output()[..amount];
89        self.stream.write_all(output)?;
90
91        Ok(())
92    }
93
94    fn await_input(&mut self, _timeout: NextTimeout) -> Result<bool, ureq::Error> {
95        if self.buffers.can_use_input() {
96            return Ok(true);
97        }
98
99        let input = self.buffers.input_mut();
100        let amount = self.stream.read(input)?;
101        self.buffers.add_filled(amount);
102
103        Ok(amount > 0)
104    }
105
106    fn is_open(&mut self) -> bool {
107        // No way to detect on SGX.
108        true
109    }
110}
111
112impl fmt::Debug for TcpTransport {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        f.debug_struct("TcpTransport")
115            .field("addr", &self.stream.peer_addr().ok())
116            .finish()
117    }
118}
119
120#[derive(Debug)]
121struct SgxResolver;
122
123impl Resolver for SgxResolver {
124    fn resolve(
125        &self,
126        _uri: &Uri,
127        _config: &AgentConfig,
128        _timeout: NextTimeout,
129    ) -> Result<ureq::resolver::ResolvedSocketAddrs, ureq::Error> {
130        // Do not resolve anything as SGX does not support resolution and the endpoint must be
131        // passed as a string. We need to return a dummy address.
132        Ok(vec![SocketAddr::V4(SocketAddrV4::new(
133            Ipv4Addr::new(0, 0, 0, 0),
134            0,
135        ))]
136        .into())
137    }
138}
139
140#[derive(Default)]
141struct RustlsConnector {
142    config: OnceLock<Arc<ClientConfig>>,
143}
144
145impl Connector for RustlsConnector {
146    fn connect(
147        &self,
148        details: &ConnectionDetails,
149        chained: Option<Box<dyn Transport>>,
150    ) -> Result<Option<Box<dyn Transport>>, ureq::Error> {
151        let Some(transport) = chained else {
152            panic!("RustlsConnector requires a chained transport");
153        };
154
155        // Only add TLS if we are connecting via HTTPS and the transport isn't TLS
156        // already, otherwise use chained transport as is.
157        if !details.needs_tls() || transport.is_tls() {
158            return Ok(Some(transport));
159        }
160
161        // Initialize the config on first run.
162        let config_ref = self.config.get_or_init(build_config);
163        let config = config_ref.clone();
164
165        let name_borrowed: ServerName<'_> = details
166            .uri
167            .authority()
168            .ok_or(ureq::Error::HostNotFound)?
169            .host()
170            .try_into()
171            .map_err(|_| ureq::Error::HostNotFound)?;
172
173        let name = name_borrowed.to_owned();
174
175        let conn =
176            ClientConnection::new(config, name).map_err(|_| ureq::Error::ConnectionFailed)?;
177        let stream = StreamOwned {
178            conn,
179            sock: TransportAdapter::new(transport),
180        };
181
182        let buffers = LazyBuffers::new(
183            details.config.input_buffer_size,
184            details.config.output_buffer_size,
185        );
186
187        let transport = Box::new(RustlsTransport { buffers, stream });
188
189        Ok(Some(transport))
190    }
191}
192
193fn build_config() -> Arc<ClientConfig> {
194    let provider = Arc::new(rustls_mbedcrypto_provider::mbedtls_crypto_provider());
195
196    let builder = ClientConfig::builder_with_provider(provider)
197        .with_safe_default_protocol_versions()
198        .unwrap();
199
200    let builder = builder
201        .dangerous()
202        .with_custom_certificate_verifier(Arc::new(
203            rustls_mbedpki_provider::MbedTlsServerCertVerifier::new(
204                webpki_root_certs::TLS_SERVER_ROOT_CERTS,
205            )
206            .unwrap(),
207        ));
208
209    let config = builder.with_no_client_auth();
210
211    Arc::new(config)
212}
213
214struct RustlsTransport {
215    buffers: LazyBuffers,
216    stream: StreamOwned<ClientConnection, TransportAdapter>,
217}
218
219impl Transport for RustlsTransport {
220    fn buffers(&mut self) -> &mut dyn Buffers {
221        &mut self.buffers
222    }
223
224    fn transmit_output(&mut self, amount: usize, _timeout: NextTimeout) -> Result<(), ureq::Error> {
225        let output = &self.buffers.output()[..amount];
226        self.stream.write_all(output)?;
227
228        Ok(())
229    }
230
231    fn await_input(&mut self, _timeout: NextTimeout) -> Result<bool, ureq::Error> {
232        if self.buffers.can_use_input() {
233            return Ok(true);
234        }
235
236        let input = self.buffers.input_mut();
237        let amount = self.stream.read(input)?;
238        self.buffers.add_filled(amount);
239
240        Ok(amount > 0)
241    }
242
243    fn is_open(&mut self) -> bool {
244        self.stream.get_mut().get_mut().is_open()
245    }
246
247    fn is_tls(&self) -> bool {
248        true
249    }
250}
251
252impl fmt::Debug for RustlsConnector {
253    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254        f.debug_struct("RustlsConnector").finish()
255    }
256}
257
258impl fmt::Debug for RustlsTransport {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        f.debug_struct("RustlsTransport").finish()
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use mockito::{mock, server_url};
267
268    use super::{agent, agent_with_config};
269
270    #[test]
271    fn test_get_request() {
272        // Set up a mock server
273        let _mock = mock("GET", "/test")
274            .with_status(200)
275            .with_header("content-type", "text/plain")
276            .with_body("Hello, world!")
277            .create();
278
279        // Create an agent
280        let agent = agent_with_config(Default::default());
281
282        // Make a GET request to the mock server
283        let url = format!("{}/test", server_url());
284        let mut response = agent.get(&url).call().unwrap();
285
286        // Verify the response
287        assert_eq!(response.status(), 200);
288        assert_eq!(
289            response.body_mut().read_to_string().unwrap(),
290            "Hello, world!"
291        );
292    }
293
294    #[test]
295    fn test_post_request() {
296        // Set up a mock server for POST request
297        let _mock = mock("POST", "/submit")
298            .with_status(201)
299            .with_header("content-type", "application/json")
300            .with_body(r#"{"success":true}"#)
301            .create();
302
303        // Create an agent
304        let agent = agent_with_config(Default::default());
305
306        // Make a POST request to the mock server
307        let url = format!("{}/submit", server_url());
308        let mut response = agent
309            .post(&url)
310            .content_type("application/json")
311            .send(r#"{"key":"value"}"#)
312            .unwrap();
313
314        // Verify the response
315        assert_eq!(response.status(), 201);
316        assert_eq!(
317            response.body_mut().read_to_string().unwrap(),
318            r#"{"success":true}"#
319        );
320    }
321
322    #[test]
323    fn test_get_remote_https() {
324        let response = agent().get("https://www.google.com/").call().unwrap();
325
326        // Verify the response
327        assert_eq!(
328            "text/html;charset=ISO-8859-1",
329            response
330                .headers()
331                .get("content-type")
332                .unwrap()
333                .to_str()
334                .unwrap()
335                .replace("; ", ";")
336        );
337        assert_eq!(response.body().mime_type(), Some("text/html"));
338    }
339}