From 49a35cc8fef66196b432579e4ad382fac6999555 Mon Sep 17 00:00:00 2001 From: Erik Date: Wed, 15 Mar 2023 15:55:02 +0100 Subject: [PATCH] Add redirection for existing connections --- examples/listen_server.rs | 76 +++++++++++++++++++++++++++++++++++ src/lib.rs | 83 ++++++++++++++++++++++++++++----------- 2 files changed, 135 insertions(+), 24 deletions(-) create mode 100644 examples/listen_server.rs diff --git a/examples/listen_server.rs b/examples/listen_server.rs new file mode 100644 index 0000000..5547543 --- /dev/null +++ b/examples/listen_server.rs @@ -0,0 +1,76 @@ +//! A "hello world" echo server with Tokio +//! +//! This server will create a TCP listener, accept connections in a loop, and +//! write back everything that's read off of each TCP connection. +//! +//! Because the Tokio runtime uses a thread pool, each TCP connection is +//! processed concurrently with all other TCP connections across multiple +//! threads. +//! +//! To see this server in action, you can run this in one terminal: +//! +//! cargo run --example echo +//! +//! and in another terminal you can run: +//! +//! cargo run --example connect 127.0.0.1:8080 +//! +//! Each line you type in to the `connect` terminal should be echo'd back to +//! you! If you open up multiple terminals running the `connect` example you +//! should be able to see them all make progress simultaneously. + +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncReadExt}; +use tokio::net::TcpListener; + +use std::env; +use std::error::Error; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Allow passing an address to listen on as the first argument of this + // program, but otherwise we'll just set up our TCP listener on + // 127.0.0.1:8080 for connections. + let addr = env::args() + .nth(1) + .unwrap_or_else(|| "127.0.0.1:8080".to_string()); + + // Next up we create a TCP listener which will listen for incoming + // connections. This TCP listener is bound to the address we determined + // above and must be associated with an event loop. + let listener = TcpListener::bind(&addr).await?; + println!("Listening on: {addr}"); + + loop { + // Asynchronously wait for an inbound socket. + let (mut socket, _) = listener.accept().await?; + + // And this is where much of the magic of this server happens. We + // crucially want all clients to make progress concurrently, rather than + // blocking one on completion of another. To achieve this we use the + // `tokio::spawn` function to execute the work in the background. + // + // Essentially here we're executing a new task to run concurrently, + // which will allow all of our clients to be processed concurrently. + + tokio::spawn(async move { + let mut buf = vec![0; 1024]; + + // In a loop, read data from the socket and write the data back. + loop { + let n = socket + .read(&mut buf) + .await + .expect("failed to read data from socket"); + + if n == 0 { + return; + } + + use std::str; + println!("{}", str::from_utf8(&buf[0..n]).unwrap()); + } + }); + } +} diff --git a/src/lib.rs b/src/lib.rs index ead51a3..0b653fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ pub async fn process_command( State(state): State>, Json(payload): Json, ) -> (StatusCode, Json) { - tracing::error!("Received payload: {:?}", payload); + tracing::info!("Received payload: {:?}", payload); // TODO: verify signature match payload.command { Command::New { @@ -131,17 +131,11 @@ async fn add_proxy(in_port: u16, control: Receiver) -> anyh } async fn proxy(listener: TcpListener, mut control: Receiver) { - let mut current_destination = - if let ProxyControlMessage::Open { destination } = *control.borrow() { - Some(destination) - } else { - None - }; loop { tokio::select! { l = listener.accept()=> { if let Ok((inbound, _)) = l { - let transfer = transfer(inbound, current_destination.unwrap()); + let transfer = transfer(inbound, control.clone()); tokio::spawn(transfer); } @@ -150,7 +144,6 @@ async fn proxy(listener: TcpListener, mut control: Receiver match *control.borrow() { ProxyControlMessage::Open { destination } => { tracing::info!("destination for proxy port {} changed to {}", listener.local_addr().unwrap(), destination); - current_destination=Some(destination); }, ProxyControlMessage::Close => { tracing::info!("destination for proxy port {} closed", listener.local_addr().unwrap()); @@ -162,25 +155,68 @@ async fn proxy(listener: TcpListener, mut control: Receiver } } -async fn transfer(mut inbound: TcpStream, destination: SocketAddrV4) -> anyhow::Result<()> { - let mut outbound = TcpStream::connect(destination).await?; +async fn transfer( + mut inbound: TcpStream, + mut control: Receiver, +) -> anyhow::Result<()> { + loop { + let current_destination = + if let ProxyControlMessage::Open { destination } = *control.borrow() { + Some(destination) + } else { + break Ok(()); + }; + let mut outbound = TcpStream::connect(current_destination.unwrap()).await?; - let (mut ri, mut wi) = inbound.split(); - let (mut ro, mut wo) = outbound.split(); + let (mut ri, mut wi) = inbound.split(); + let (mut ro, mut wo) = outbound.split(); - let client_to_server = async { - io::copy(&mut ri, &mut wo).await?; - wo.shutdown().await - }; + let client_to_server = async { + io::copy(&mut ri, &mut wo).await?; + wo.shutdown().await + }; - let server_to_client = async { - io::copy(&mut ro, &mut wi).await?; - wi.shutdown().await - }; + let server_to_client = async { + io::copy(&mut ro, &mut wi).await?; + wi.shutdown().await + }; - tokio::try_join!(client_to_server, server_to_client)?; + // Select between the copy tasks and watch channel + tokio::select! { + // Join the two copy streams and wait for the connection to clone + result = async move { tokio::join!(client_to_server, server_to_client) } => { + match result { + (Ok(_), Ok(_)) => { + break Ok(()); + } + (r1, r2) => { + if r1.is_err() { + tracing::error!("error closing client->server of {:?}: {:?}", inbound, &r1); + } + if r2.is_err() { + tracing::error!("error closing server->client of {:?}: {:?}", inbound, &r2); + } + r1?; + r2?; + }, + } + } + _ = control.changed() => { + match *control.borrow() { + ProxyControlMessage::Open { destination } => { + eprintln!("Switching to new destination: {destination}"); + // Disconnect the current outbound connection and restart the loop + drop(outbound); + continue; + }, + ProxyControlMessage::Close => { + break Ok(()); + }, + } - Ok(()) + } + } + } } #[cfg(test)] @@ -188,7 +224,6 @@ mod tests { use std::net::Ipv4Addr; use crate::{Command, ProxyCommand}; - use serde::{Deserialize, Serialize}; use uuid::uuid; #[test]