1#![allow(non_snake_case)]
2
3use std::{
49 collections::HashMap,
50 sync::{
51 Arc,
52 atomic::{AtomicBool, AtomicU64, Ordering},
53 },
54};
55
56use anyhow::Result;
57use futures_util::{SinkExt, StreamExt, stream::SplitSink};
58use serde_json::Value;
59use tokio::{
60 net::{TcpListener, TcpStream},
61 sync::{Mutex, oneshot},
62};
63use tokio_tungstenite::{
64 MaybeTlsStream,
65 WebSocketStream,
66 accept_async,
67 connect_async,
68 tungstenite::{Message, Utf8Bytes},
69};
70
71#[derive(Clone)]
73pub struct SharedSecret(pub [u8; 32]);
74
75impl SharedSecret {
76 pub fn random() -> Self {
77 Self(rand::random::<[u8; 32]>())
81 }
82
83 pub fn as_hex(&self) -> String { hex::encode(self.0) }
84
85 pub fn from_hex(Hex:&str) -> Result<Self> {
86 let Bytes = hex::decode(Hex)?;
87 if Bytes.len() != 32 {
88 anyhow::bail!("shared secret must be 32 bytes (got {})", Bytes.len());
89 }
90 let mut Out = [0u8; 32];
91 Out.copy_from_slice(&Bytes);
92 Ok(Self(Out))
93 }
94}
95
96pub type HandlerFn =
99 Arc<dyn Fn(Value) -> futures_util::future::BoxFuture<'static, Result<Value, String>> + Send + Sync>;
100
101#[derive(Default)]
103pub struct HandlerRegistry {
104 Handlers:Mutex<HashMap<String, HandlerFn>>,
105}
106
107impl HandlerRegistry {
108 pub fn new() -> Arc<Self> { Arc::new(Self::default()) }
109
110 pub async fn Register(&self, Method:String, Handler:HandlerFn) {
111 self.Handlers.lock().await.insert(Method, Handler);
112 }
113
114 pub async fn Lookup(&self, Method:&str) -> Option<HandlerFn> { self.Handlers.lock().await.get(Method).cloned() }
115}
116
117pub async fn ServeLocal(Port:u16, Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
124 let Address = format!("127.0.0.1:{}", Port);
125 let Listener = TcpListener::bind(&Address).await?;
126 tracing::info!(target: "Mist::WebSocket", "server listening on {}", Address);
127
128 let PortStr = format!("{}", Port);
132 CommonLibrary::Telemetry::CaptureEvent::Fn(
133 "land:mist:server:start",
134 Some(vec![("address", Address.as_str()), ("port", PortStr.as_str())]),
135 );
136
137 loop {
138 let (Stream, Peer) = match Listener.accept().await {
139 Ok(P) => P,
140 Err(Error) => {
141 tracing::warn!(target: "Mist::WebSocket", "accept error: {}", Error);
142 continue;
143 },
144 };
145 let SecretClone = Secret.clone();
146 let RegistryClone = Registry.clone();
147 tokio::spawn(async move {
148 if let Err(Error) = HandleConnection(Stream, SecretClone, RegistryClone).await {
149 tracing::warn!(target: "Mist::WebSocket", "connection from {} closed with error: {}", Peer, Error);
150 }
151 });
152 }
153}
154
155async fn HandleConnection(Stream:TcpStream, _Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
156 let WebSocketStream = accept_async(Stream).await?;
163 let (mut Sink, mut Source) = WebSocketStream.split();
164
165 while let Some(MessageResult) = Source.next().await {
166 let Message = match MessageResult {
167 Ok(M) => M,
168 Err(Error) => {
169 tracing::debug!(target: "Mist::WebSocket", "frame read error: {}", Error);
170 break;
171 },
172 };
173
174 match Message {
175 Message::Text(Text) => {
176 let Envelope:Value = match serde_json::from_str(&Text) {
177 Ok(V) => V,
178 Err(Error) => {
179 tracing::debug!(target: "Mist::WebSocket", "bad text frame: {}", Error);
180 continue;
181 },
182 };
183 let Method = Envelope.get("method").and_then(|V| V.as_str()).unwrap_or("");
184 let Identifier = Envelope.get("id").cloned().unwrap_or(Value::Null);
185 let Params = Envelope.get("params").cloned().unwrap_or(Value::Array(vec![]));
186
187 if Method.is_empty() {
188 continue;
189 }
190
191 let Handler = Registry.Lookup(Method).await;
192 let Response = match Handler {
193 Some(H) => {
194 match H(Params).await {
195 Ok(Value) => serde_json::json!({ "id": Identifier, "result": Value }),
196 Err(ErrorMessage) => serde_json::json!({ "id": Identifier, "error": ErrorMessage }),
197 }
198 },
199 None => {
200 serde_json::json!({
201 "id": Identifier,
202 "error": format!("Unknown method: {}", Method),
203 })
204 },
205 };
206
207 if Identifier.is_null() {
208 continue;
210 }
211
212 if let Err(Error) = Sink.send(Message::Text(Utf8Bytes::from(Response.to_string()))).await {
213 tracing::debug!(target: "Mist::WebSocket", "send error: {}", Error);
214 break;
215 }
216 },
217 Message::Binary(Bytes) => {
218 tracing::trace!(target: "Mist::WebSocket", "binary frame ({} bytes) ignored - reserved for phase 2", Bytes.len());
219 },
220 Message::Close(_) => break,
221 _ => {},
222 }
223 }
224 Ok(())
225}
226
227type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>>;
229
230pub struct Client {
233 Sink:Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
237 Pending:PendingMap,
238 NextIdentifier:AtomicU64,
239 Closed:AtomicBool,
240}
241
242impl Client {
243 pub async fn connect(Address:&str) -> Result<Arc<Self>> {
248 let (Stream, _Response) = connect_async(Address).await?;
249 let (Sink, mut Source) = Stream.split();
250 let Sink = Arc::new(Mutex::new(Sink));
251 let Pending:PendingMap = Arc::new(Mutex::new(HashMap::new()));
252 let SelfReference = Arc::new(Self {
253 Sink,
254 Pending:Pending.clone(),
255 NextIdentifier:AtomicU64::new(1),
256 Closed:AtomicBool::new(false),
257 });
258
259 let SelfForReader = SelfReference.clone();
262 tokio::spawn(async move {
263 while let Some(MessageResult) = Source.next().await {
264 let Frame = match MessageResult {
265 Ok(M) => M,
266 Err(_) => break,
267 };
268 match Frame {
269 Message::Text(Text) => {
270 if let Ok(Envelope) = serde_json::from_str::<Value>(&Text) {
271 let Identifier = Envelope.get("id").and_then(|V| V.as_u64());
272 if let Some(Identifier) = Identifier {
273 let Sender = SelfForReader.Pending.lock().await.remove(&Identifier);
274 if let Some(Sender) = Sender {
275 let Result = if let Some(ErrorValue) = Envelope.get("error") {
276 Err(ErrorValue.to_string())
277 } else {
278 Ok(Envelope.get("result").cloned().unwrap_or(Value::Null))
279 };
280 let _ = Sender.send(Result);
281 }
282 }
283 }
284 },
285 Message::Close(_) => break,
286 _ => {},
287 }
288 }
289 SelfForReader.Closed.store(true, Ordering::Relaxed);
290 let mut Pending = SelfForReader.Pending.lock().await;
292 for (_, Sender) in Pending.drain() {
293 let _ = Sender.send(Err("connection closed".into()));
294 }
295 });
296
297 Ok(SelfReference)
298 }
299
300 pub async fn invoke(&self, Method:&str, Params:Value) -> Result<Value, String> {
304 if self.Closed.load(Ordering::Relaxed) {
305 return Err("connection closed".into());
306 }
307 let Identifier = self.NextIdentifier.fetch_add(1, Ordering::Relaxed);
308 let (Tx, Rx) = oneshot::channel();
309 self.Pending.lock().await.insert(Identifier, Tx);
310 let Envelope = serde_json::json!({ "id": Identifier, "method": Method, "params": Params });
311 let Text = Envelope.to_string();
312 let SendResult = self.Sink.lock().await.send(Message::Text(Utf8Bytes::from(Text))).await;
313 if SendResult.is_err() {
314 self.Pending.lock().await.remove(&Identifier);
315 return Err("send failed".into());
316 }
317 Rx.await.map_err(|_| "request cancelled".to_string())?
318 }
319
320 pub async fn notify(&self, Method:&str, Params:Value) -> Result<(), String> {
322 if self.Closed.load(Ordering::Relaxed) {
323 return Err("connection closed".into());
324 }
325 let Envelope = serde_json::json!({ "id": Value::Null, "method": Method, "params": Params });
326 let Text = Envelope.to_string();
327 self.Sink
328 .lock()
329 .await
330 .send(Message::Text(Utf8Bytes::from(Text)))
331 .await
332 .map_err(|Error| Error.to_string())
333 }
334
335 pub fn is_closed(&self) -> bool { self.Closed.load(Ordering::Relaxed) }
336}