1use std::future::Future;
2use std::io::{Error as IoError, ErrorKind};
3use std::time::Duration;
4
5use bytes::Bytes;
6use futures::stream::SplitSink;
7use futures::{SinkExt, StreamExt};
8use tokio::net::TcpStream;
9use tokio_util::codec::Framed;
10
11use futu_codec::FutuCodec;
12use futu_codec::frame::FutuFrame;
13use futu_core::error::FutuError;
14
15const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
21
22pub struct Connection {
27 sink: SplitSink<Framed<TcpStream, FutuCodec>, FutuFrame>,
28 stream: futures::stream::SplitStream<Framed<TcpStream, FutuCodec>>,
29}
30
31impl Connection {
32 pub async fn connect(addr: &str) -> Result<Self, FutuError> {
34 let stream = connect_with_timeout(addr, CONNECT_TIMEOUT, TcpStream::connect(addr)).await?;
35 configure_connected_stream(&stream)?;
36 tracing::info!(addr = addr, "TCP connected");
37
38 let framed = Framed::new(stream, FutuCodec);
39 let (sink, stream) = framed.split();
40
41 Ok(Self { sink, stream })
42 }
43
44 pub async fn send(&mut self, frame: FutuFrame) -> Result<(), FutuError> {
46 self.sink.send(frame).await
47 }
48
49 pub async fn recv(&mut self) -> Result<Option<FutuFrame>, FutuError> {
53 match self.stream.next().await {
54 Some(Ok(frame)) => Ok(Some(frame)),
55 Some(Err(e)) => Err(e),
56 None => Ok(None),
57 }
58 }
59
60 pub fn build_frame(proto_id: u32, serial_no: u32, body: Vec<u8>) -> FutuFrame {
62 FutuFrame::new(proto_id, serial_no, Bytes::from(body))
63 }
64}
65
66fn configure_connected_stream(stream: &TcpStream) -> Result<(), FutuError> {
67 stream.set_nodelay(true)?;
68 socket2::SockRef::from(stream).set_keepalive(true)?;
69 Ok(())
70}
71
72async fn connect_with_timeout<T, F>(
73 addr: &str,
74 timeout: Duration,
75 connect: F,
76) -> Result<T, FutuError>
77where
78 F: Future<Output = std::io::Result<T>>,
79{
80 match tokio::time::timeout(timeout, connect).await {
81 Ok(Ok(stream)) => Ok(stream),
82 Ok(Err(err)) => Err(FutuError::Network(err)),
83 Err(_) => Err(FutuError::Network(IoError::new(
84 ErrorKind::TimedOut,
85 format!("connect to {addr} timed out after {}s", timeout.as_secs()),
86 ))),
87 }
88}
89
90#[cfg(test)]
91mod tests;