Final Server Code

The final code looks like this:

use std::{ collections::hash_map::{Entry, HashMap}, future::Future, sync::Arc, }; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::{tcp::OwnedWriteHalf, TcpListener, TcpStream, ToSocketAddrs}, sync::{mpsc, oneshot, Notify}, task, }; type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>; type Sender<T> = mpsc::UnboundedSender<T>; type Receiver<T> = mpsc::UnboundedReceiver<T>; #[tokio::main] pub(crate) async fn main() -> Result<()> { accept_loop("127.0.0.1:8080").await } async fn accept_loop(addr: impl ToSocketAddrs) -> Result<()> { let listener = TcpListener::bind(addr).await?; let (broker_sender, broker_receiver) = mpsc::unbounded_channel(); let broker = task::spawn(broker_loop(broker_receiver)); let shutdown_notification = Arc::new(Notify::new()); loop { tokio::select!{ Ok((stream, _socket_addr)) = listener.accept() => { println!("Accepting from: {}", stream.peer_addr()?); spawn_and_log_error(connection_loop(broker_sender.clone(), stream, shutdown_notification.clone())); }, _ = tokio::signal::ctrl_c() => break, } } println!("Shutting down!"); shutdown_notification.notify_waiters(); drop(broker_sender); broker.await?; Ok(()) } async fn connection_loop(broker: Sender<Event>, stream: TcpStream, shutdown: Arc<Notify>) -> Result<()> { let (reader, writer) = stream.into_split(); let reader = BufReader::new(reader); let mut lines = reader.lines(); let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>(); let name = match lines.next_line().await { Ok(Some(line)) => line, Ok(None) => return Err("peer disconnected immediately".into()), Err(e) => return Err(Box::new(e)), }; println!("user {} connected", name); broker .send(Event::NewPeer { name: name.clone(), stream: writer, shutdown: shutdown_receiver, }) .unwrap(); loop { tokio::select! { Ok(Some(line)) = lines.next_line() => { let (dest, msg) = match line.split_once(':') { None => continue, Some((dest, msg)) => (dest, msg.trim()), }; let dest: Vec<String> = dest .split(',') .map(|name| name.trim().to_string()) .collect(); let msg: String = msg.trim().to_string(); broker .send(Event::Message { from: name.clone(), to: dest, msg, }) .unwrap(); }, _ = shutdown.notified() => break, } } println!("Closing connection loop!"); drop(shutdown_sender); Ok(()) } async fn connection_writer_loop( messages: &mut Receiver<String>, stream: &mut OwnedWriteHalf, mut shutdown: oneshot::Receiver<()>, ) -> Result<()> { loop { tokio::select! { msg = messages.recv() => match msg { Some(msg) => stream.write_all(msg.as_bytes()).await?, None => break, }, _ = &mut shutdown => break } } println!("Closing connection_writer loop!"); Ok(()) } #[derive(Debug)] enum Event { NewPeer { name: String, stream: OwnedWriteHalf, shutdown: oneshot::Receiver<()>, }, Message { from: String, to: Vec<String>, msg: String, }, } async fn broker_loop(mut events: Receiver<Event>) { let (disconnect_sender, mut disconnect_receiver) = mpsc::unbounded_channel::<(String, Receiver<String>)>(); let mut peers: HashMap<String, Sender<String>> = HashMap::new(); loop { let event = tokio::select! { event = events.recv() => match event { None => break, Some(event) => event, }, disconnect = disconnect_receiver.recv() => { let (name, _pending_messages) = disconnect.unwrap(); assert!(peers.remove(&name).is_some()); println!("user {} disconnected", name); continue; }, }; match event { Event::Message { from, to, msg } => { for addr in to { if let Some(peer) = peers.get_mut(&addr) { let msg = format!("from {}: {}\n", from, msg); peer.send(msg).unwrap(); } } } Event::NewPeer { name, mut stream, shutdown, } => match peers.entry(name.clone()) { Entry::Occupied(..) => (), Entry::Vacant(entry) => { let (client_sender, mut client_receiver) = mpsc::unbounded_channel(); entry.insert(client_sender); let disconnect_sender = disconnect_sender.clone(); spawn_and_log_error(async move { let res = connection_writer_loop(&mut client_receiver, &mut stream, shutdown) .await; println!("user {} disconnected", name); disconnect_sender.send((name, client_receiver)).unwrap(); res }); } }, } } drop(peers); drop(disconnect_sender); while let Some((_name, _pending_messages)) = disconnect_receiver.recv().await {} } fn spawn_and_log_error<F>(fut: F) -> task::JoinHandle<()> where F: Future<Output = Result<()>> + Send + 'static, { task::spawn(async move { if let Err(e) = fut.await { eprintln!("{}", e) } }) }