1use std::collections::HashMap;
4use std::future::Future;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9
10use crate::conn::IncomingRequest;
11
12#[async_trait]
14pub trait RequestHandler: Send + Sync + 'static {
15 async fn handle(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>>;
18}
19
20pub struct FnHandler<F>(pub F);
22
23#[async_trait]
24impl<F, Fut> RequestHandler for FnHandler<F>
25where
26 F: Fn(u64, bytes::Bytes) -> Fut + Send + Sync + 'static,
27 Fut: Future<Output = Option<Vec<u8>>> + Send + 'static,
28{
29 async fn handle(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>> {
30 (self.0)(conn_id, request.body.clone()).await
31 }
32}
33
34pub struct RequestRouter {
36 handlers: RwLock<HashMap<u32, Arc<dyn RequestHandler>>>,
37}
38
39impl RequestRouter {
40 pub fn new() -> Self {
41 Self {
42 handlers: RwLock::new(HashMap::new()),
43 }
44 }
45
46 pub fn register(&self, proto_id: u32, handler: Arc<dyn RequestHandler>) {
48 self.handlers.write().insert(proto_id, handler);
49 }
50
51 pub async fn dispatch(&self, conn_id: u64, request: &IncomingRequest) -> Option<Vec<u8>> {
53 let handler = {
54 let handlers = self.handlers.read();
55 handlers.get(&request.proto_id).cloned()
56 };
57
58 match handler {
59 Some(h) => h.handle(conn_id, request).await,
60 None => {
61 tracing::warn!(
62 proto_id = request.proto_id,
63 conn_id = conn_id,
64 "no handler registered"
65 );
66 Some(make_error_response(-1, "unknown protocol"))
68 }
69 }
70 }
71}
72
73impl Default for RequestRouter {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79fn make_error_response(ret_type: i32, msg: &str) -> Vec<u8> {
81 let resp = futu_proto::init_connect::Response {
82 ret_type,
83 ret_msg: Some(msg.to_string()),
84 err_code: None,
85 s2c: None,
86 };
87 prost::Message::encode_to_vec(&resp)
88}