From fcfbd4cbecd64128a52880f6a979b8389b238e70 Mon Sep 17 00:00:00 2001 From: Erik Date: Tue, 14 Mar 2023 17:38:49 +0100 Subject: [PATCH] Add proxy control --- src/lib.rs | 118 ++++++++++++++++++++++++++++++++++++++++++---------- src/main.rs | 8 ++-- 2 files changed, 101 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 09febde..ead51a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,12 @@ +use axum::extract::State; use axum::{http::StatusCode, Json}; use serde::{Deserialize, Serialize}; -use std::net::{IpAddr, SocketAddr}; +use std::collections::HashMap; +use std::net::{Ipv4Addr, SocketAddrV4}; +use std::sync::{Arc, Mutex}; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch::{self, Receiver, Sender}; use uuid::Uuid; #[derive(Deserialize, Serialize, Debug)] @@ -16,11 +20,12 @@ enum Command { New { incoming_port: u16, destination_port: u16, - destination_ip: IpAddr, + destination_ip: Ipv4Addr, id: Uuid, }, Modify { - destionation_ip: IpAddr, + destination_port: u16, + destination_ip: Ipv4Addr, id: Uuid, }, Delete { @@ -33,11 +38,31 @@ pub struct ProxyResponse { message: String, } +#[derive(Debug)] +pub struct GlobalState { + proxies: Mutex>, +} + +impl GlobalState { + pub fn new() -> Self { + Self { + proxies: Mutex::new(HashMap::new()), + } + } +} + +#[derive(Debug)] +struct ProxyState { + destination: SocketAddrV4, + control: Sender, +} + pub async fn root() -> &'static str { "Hello, World!" } pub async fn process_command( + State(state): State>, Json(payload): Json, ) -> (StatusCode, Json) { tracing::error!("Received payload: {:?}", payload); @@ -49,19 +74,38 @@ pub async fn process_command( destination_ip, id, } => { - // TODO: add id to global proxy map - add_proxy( - incoming_port, - SocketAddr::new(destination_ip, destination_port), - ) - .await - .unwrap(); // TODO: error propagation?? + let addr = SocketAddrV4::new(destination_ip, destination_port); + let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr }); + state.proxies.lock().unwrap().insert( + id, + ProxyState { + destination: addr, + control: tx, + }, + ); + add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation?? } Command::Modify { - destionation_ip, + destination_port, + destination_ip, id, - } => todo!(), - Command::Delete { id } => todo!(), + } => { + if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) { + proxy.destination.set_port(destination_port); + proxy.destination.set_ip(destination_ip); + proxy + .control + .send(ProxyControlMessage::Open { + destination: proxy.destination, + }) + .unwrap(); + } + } + Command::Delete { id } => { + if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) { + proxy.control.send(ProxyControlMessage::Close).unwrap(); + } + } } ( StatusCode::CREATED, @@ -71,24 +115,54 @@ pub async fn process_command( ) } -async fn add_proxy(in_port: u16, destination: SocketAddr) -> anyhow::Result<()> { +#[derive(Debug)] +enum ProxyControlMessage { + Open { destination: SocketAddrV4 }, // Reroute { new: SocketAddr }, + Close, +} + +async fn add_proxy(in_port: u16, control: Receiver) -> anyhow::Result<()> { let listener = TcpListener::bind(("127.0.0.1", in_port)).await?; - tracing::info!("proxying port {in_port} to {destination}"); + tracing::info!("proxying port {in_port} to {:?}", *control.borrow()); - tokio::spawn(proxy(listener, destination)); + tokio::spawn(proxy(listener, control)); Ok(()) } -async fn proxy(listener: TcpListener, destination: SocketAddr) { - while let Ok((inbound, _)) = listener.accept().await { - let transfer = transfer(inbound, destination); +async fn proxy(listener: TcpListener, mut control: Receiver) { + let mut current_destination = + if let ProxyControlMessage::Open { destination } = *control.borrow() { + Some(destination) + } else { + None + }; + loop { + tokio::select! { + l = listener.accept()=> { + if let Ok((inbound, _)) = l { + let transfer = transfer(inbound, current_destination.unwrap()); - tokio::spawn(transfer); + tokio::spawn(transfer); + } + } + _ = control.changed() => { + match *control.borrow() { + ProxyControlMessage::Open { destination } => { + tracing::info!("destination for proxy port {} changed to {}", listener.local_addr().unwrap(), destination); + current_destination=Some(destination); + }, + ProxyControlMessage::Close => { + tracing::info!("destination for proxy port {} closed", listener.local_addr().unwrap()); + return; + }, + } + } + } } } -async fn transfer(mut inbound: TcpStream, destination: SocketAddr) -> anyhow::Result<()> { +async fn transfer(mut inbound: TcpStream, destination: SocketAddrV4) -> anyhow::Result<()> { let mut outbound = TcpStream::connect(destination).await?; let (mut ri, mut wi) = inbound.split(); @@ -123,7 +197,7 @@ mod tests { command: Command::New { incoming_port: 5555, destination_port: 6666, - destination_ip: std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + destination_ip: Ipv4Addr::new(127, 0, 0, 1), id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"), }, signature: [0u8; 32], diff --git a/src/main.rs b/src/main.rs index 132cf2f..9bb451f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,8 @@ use axum::{ routing::{get, post}, Router, }; -use proxima_centauri::{process_command, root}; -use std::net::SocketAddr; +use proxima_centauri::{process_command, root, GlobalState}; +use std::{net::SocketAddr, sync::Arc}; use tracing::Level; #[tokio::main] @@ -15,12 +15,14 @@ async fn main() { tracing::subscriber::set_global_default(subscriber).unwrap(); + let shared_state = Arc::new(GlobalState::new()); // build our application with a route let app = Router::new() // `GET /` goes to `root` .route("/", get(root)) // `POST /command` goes to `process_command` - .route("/command", post(process_command)); + .route("/command", post(process_command)) + .with_state(shared_state); // run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000));