1use std::{
6 fmt,
7 io::{Read, Write},
8 net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream},
9 sync::{Arc, OnceLock},
10};
11
12use rustls::{ClientConfig, ClientConnection, StreamOwned};
13use rustls_pki_types::ServerName;
14use ureq::{
15 http::Uri,
16 resolver::Resolver,
17 transport::{
18 time::NextTimeout, Buffers, ChainedConnector, ConnectionDetails, Connector, LazyBuffers,
19 Transport, TransportAdapter,
20 },
21 Agent, AgentConfig,
22};
23
24pub fn agent() -> Agent {
29 let cfg = AgentConfig {
30 https_only: true, user_agent: "rofl-utils/0.1.0".to_string(),
32 ..Default::default()
33 };
34 agent_with_config(cfg)
35}
36
37pub fn agent_with_config(cfg: AgentConfig) -> Agent {
39 Agent::with_parts(
40 cfg,
41 ChainedConnector::new([SgxConnector.boxed(), RustlsConnector::default().boxed()]),
42 SgxResolver,
43 )
44}
45
46#[derive(Debug)]
47struct SgxConnector;
48
49impl Connector for SgxConnector {
50 fn connect(
51 &self,
52 details: &ConnectionDetails,
53 _chained: Option<Box<dyn Transport>>,
54 ) -> Result<Option<Box<dyn Transport>>, ureq::Error> {
55 let config = &details.config;
56 let scheme = details.uri.scheme().unwrap();
58 let authority = details.uri.authority().unwrap();
59
60 let host_port = ureq::resolver::DefaultResolver::host_and_port(scheme, authority)
61 .ok_or(ureq::Error::HostNotFound)?;
62 let stream = TcpStream::connect(host_port)?;
63
64 let buffers = LazyBuffers::new(config.input_buffer_size, config.output_buffer_size);
65 let transport = TcpTransport::new(stream, buffers);
66
67 Ok(Some(Box::new(transport)))
68 }
69}
70
71struct TcpTransport {
72 stream: TcpStream,
73 buffers: LazyBuffers,
74}
75
76impl TcpTransport {
77 fn new(stream: TcpStream, buffers: LazyBuffers) -> TcpTransport {
78 TcpTransport { stream, buffers }
79 }
80}
81
82impl Transport for TcpTransport {
83 fn buffers(&mut self) -> &mut dyn Buffers {
84 &mut self.buffers
85 }
86
87 fn transmit_output(&mut self, amount: usize, _timeout: NextTimeout) -> Result<(), ureq::Error> {
88 let output = &self.buffers.output()[..amount];
89 self.stream.write_all(output)?;
90
91 Ok(())
92 }
93
94 fn await_input(&mut self, _timeout: NextTimeout) -> Result<bool, ureq::Error> {
95 if self.buffers.can_use_input() {
96 return Ok(true);
97 }
98
99 let input = self.buffers.input_mut();
100 let amount = self.stream.read(input)?;
101 self.buffers.add_filled(amount);
102
103 Ok(amount > 0)
104 }
105
106 fn is_open(&mut self) -> bool {
107 true
109 }
110}
111
112impl fmt::Debug for TcpTransport {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 f.debug_struct("TcpTransport")
115 .field("addr", &self.stream.peer_addr().ok())
116 .finish()
117 }
118}
119
120#[derive(Debug)]
121struct SgxResolver;
122
123impl Resolver for SgxResolver {
124 fn resolve(
125 &self,
126 _uri: &Uri,
127 _config: &AgentConfig,
128 _timeout: NextTimeout,
129 ) -> Result<ureq::resolver::ResolvedSocketAddrs, ureq::Error> {
130 Ok(vec![SocketAddr::V4(SocketAddrV4::new(
133 Ipv4Addr::new(0, 0, 0, 0),
134 0,
135 ))]
136 .into())
137 }
138}
139
140#[derive(Default)]
141struct RustlsConnector {
142 config: OnceLock<Arc<ClientConfig>>,
143}
144
145impl Connector for RustlsConnector {
146 fn connect(
147 &self,
148 details: &ConnectionDetails,
149 chained: Option<Box<dyn Transport>>,
150 ) -> Result<Option<Box<dyn Transport>>, ureq::Error> {
151 let Some(transport) = chained else {
152 panic!("RustlsConnector requires a chained transport");
153 };
154
155 if !details.needs_tls() || transport.is_tls() {
158 return Ok(Some(transport));
159 }
160
161 let config_ref = self.config.get_or_init(build_config);
163 let config = config_ref.clone();
164
165 let name_borrowed: ServerName<'_> = details
166 .uri
167 .authority()
168 .ok_or(ureq::Error::HostNotFound)?
169 .host()
170 .try_into()
171 .map_err(|_| ureq::Error::HostNotFound)?;
172
173 let name = name_borrowed.to_owned();
174
175 let conn =
176 ClientConnection::new(config, name).map_err(|_| ureq::Error::ConnectionFailed)?;
177 let stream = StreamOwned {
178 conn,
179 sock: TransportAdapter::new(transport),
180 };
181
182 let buffers = LazyBuffers::new(
183 details.config.input_buffer_size,
184 details.config.output_buffer_size,
185 );
186
187 let transport = Box::new(RustlsTransport { buffers, stream });
188
189 Ok(Some(transport))
190 }
191}
192
193fn build_config() -> Arc<ClientConfig> {
194 let provider = Arc::new(rustls_mbedcrypto_provider::mbedtls_crypto_provider());
195
196 let builder = ClientConfig::builder_with_provider(provider)
197 .with_safe_default_protocol_versions()
198 .unwrap();
199
200 let builder = builder
201 .dangerous()
202 .with_custom_certificate_verifier(Arc::new(
203 rustls_mbedpki_provider::MbedTlsServerCertVerifier::new(
204 webpki_root_certs::TLS_SERVER_ROOT_CERTS,
205 )
206 .unwrap(),
207 ));
208
209 let config = builder.with_no_client_auth();
210
211 Arc::new(config)
212}
213
214struct RustlsTransport {
215 buffers: LazyBuffers,
216 stream: StreamOwned<ClientConnection, TransportAdapter>,
217}
218
219impl Transport for RustlsTransport {
220 fn buffers(&mut self) -> &mut dyn Buffers {
221 &mut self.buffers
222 }
223
224 fn transmit_output(&mut self, amount: usize, _timeout: NextTimeout) -> Result<(), ureq::Error> {
225 let output = &self.buffers.output()[..amount];
226 self.stream.write_all(output)?;
227
228 Ok(())
229 }
230
231 fn await_input(&mut self, _timeout: NextTimeout) -> Result<bool, ureq::Error> {
232 if self.buffers.can_use_input() {
233 return Ok(true);
234 }
235
236 let input = self.buffers.input_mut();
237 let amount = self.stream.read(input)?;
238 self.buffers.add_filled(amount);
239
240 Ok(amount > 0)
241 }
242
243 fn is_open(&mut self) -> bool {
244 self.stream.get_mut().get_mut().is_open()
245 }
246
247 fn is_tls(&self) -> bool {
248 true
249 }
250}
251
252impl fmt::Debug for RustlsConnector {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 f.debug_struct("RustlsConnector").finish()
255 }
256}
257
258impl fmt::Debug for RustlsTransport {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.debug_struct("RustlsTransport").finish()
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use mockito::{mock, server_url};
267
268 use super::{agent, agent_with_config};
269
270 #[test]
271 fn test_get_request() {
272 let _mock = mock("GET", "/test")
274 .with_status(200)
275 .with_header("content-type", "text/plain")
276 .with_body("Hello, world!")
277 .create();
278
279 let agent = agent_with_config(Default::default());
281
282 let url = format!("{}/test", server_url());
284 let mut response = agent.get(&url).call().unwrap();
285
286 assert_eq!(response.status(), 200);
288 assert_eq!(
289 response.body_mut().read_to_string().unwrap(),
290 "Hello, world!"
291 );
292 }
293
294 #[test]
295 fn test_post_request() {
296 let _mock = mock("POST", "/submit")
298 .with_status(201)
299 .with_header("content-type", "application/json")
300 .with_body(r#"{"success":true}"#)
301 .create();
302
303 let agent = agent_with_config(Default::default());
305
306 let url = format!("{}/submit", server_url());
308 let mut response = agent
309 .post(&url)
310 .content_type("application/json")
311 .send(r#"{"key":"value"}"#)
312 .unwrap();
313
314 assert_eq!(response.status(), 201);
316 assert_eq!(
317 response.body_mut().read_to_string().unwrap(),
318 r#"{"success":true}"#
319 );
320 }
321
322 #[test]
323 fn test_get_remote_https() {
324 let response = agent().get("https://www.google.com/").call().unwrap();
325
326 assert_eq!(
328 "text/html;charset=ISO-8859-1",
329 response
330 .headers()
331 .get("content-type")
332 .unwrap()
333 .to_str()
334 .unwrap()
335 .replace("; ", ";")
336 );
337 assert_eq!(response.body().mime_type(), Some("text/html"));
338 }
339}