Add proxy control

This commit is contained in:
Erik 2023-03-14 17:38:49 +01:00
parent b9d59761bf
commit fcfbd4cbec
2 changed files with 101 additions and 25 deletions

View File

@ -1,8 +1,12 @@
use axum::extract::State;
use axum::{http::StatusCode, Json}; use axum::{http::StatusCode, Json};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::{IpAddr, SocketAddr}; use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::{Arc, Mutex};
use tokio::io::{self, AsyncWriteExt}; use tokio::io::{self, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch::{self, Receiver, Sender};
use uuid::Uuid; use uuid::Uuid;
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
@ -16,11 +20,12 @@ enum Command {
New { New {
incoming_port: u16, incoming_port: u16,
destination_port: u16, destination_port: u16,
destination_ip: IpAddr, destination_ip: Ipv4Addr,
id: Uuid, id: Uuid,
}, },
Modify { Modify {
destionation_ip: IpAddr, destination_port: u16,
destination_ip: Ipv4Addr,
id: Uuid, id: Uuid,
}, },
Delete { Delete {
@ -33,11 +38,31 @@ pub struct ProxyResponse {
message: String, message: String,
} }
#[derive(Debug)]
pub struct GlobalState {
proxies: Mutex<HashMap<Uuid, ProxyState>>,
}
impl GlobalState {
pub fn new() -> Self {
Self {
proxies: Mutex::new(HashMap::new()),
}
}
}
#[derive(Debug)]
struct ProxyState {
destination: SocketAddrV4,
control: Sender<ProxyControlMessage>,
}
pub async fn root() -> &'static str { pub async fn root() -> &'static str {
"Hello, World!" "Hello, World!"
} }
pub async fn process_command( pub async fn process_command(
State(state): State<Arc<GlobalState>>,
Json(payload): Json<ProxyCommand>, Json(payload): Json<ProxyCommand>,
) -> (StatusCode, Json<ProxyResponse>) { ) -> (StatusCode, Json<ProxyResponse>) {
tracing::error!("Received payload: {:?}", payload); tracing::error!("Received payload: {:?}", payload);
@ -49,19 +74,38 @@ pub async fn process_command(
destination_ip, destination_ip,
id, id,
} => { } => {
// TODO: add id to global proxy map let addr = SocketAddrV4::new(destination_ip, destination_port);
add_proxy( let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
incoming_port, state.proxies.lock().unwrap().insert(
SocketAddr::new(destination_ip, destination_port), id,
) ProxyState {
.await destination: addr,
.unwrap(); // TODO: error propagation?? control: tx,
},
);
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
} }
Command::Modify { Command::Modify {
destionation_ip, destination_port,
destination_ip,
id, id,
} => todo!(), } => {
Command::Delete { id } => todo!(), if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
proxy.destination.set_port(destination_port);
proxy.destination.set_ip(destination_ip);
proxy
.control
.send(ProxyControlMessage::Open {
destination: proxy.destination,
})
.unwrap();
}
}
Command::Delete { id } => {
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
proxy.control.send(ProxyControlMessage::Close).unwrap();
}
}
} }
( (
StatusCode::CREATED, StatusCode::CREATED,
@ -71,24 +115,54 @@ pub async fn process_command(
) )
} }
async fn add_proxy(in_port: u16, destination: SocketAddr) -> anyhow::Result<()> { #[derive(Debug)]
enum ProxyControlMessage {
Open { destination: SocketAddrV4 }, // Reroute { new: SocketAddr },
Close,
}
async fn add_proxy(in_port: u16, control: Receiver<ProxyControlMessage>) -> anyhow::Result<()> {
let listener = TcpListener::bind(("127.0.0.1", in_port)).await?; let listener = TcpListener::bind(("127.0.0.1", in_port)).await?;
tracing::info!("proxying port {in_port} to {destination}"); tracing::info!("proxying port {in_port} to {:?}", *control.borrow());
tokio::spawn(proxy(listener, destination)); tokio::spawn(proxy(listener, control));
Ok(()) Ok(())
} }
async fn proxy(listener: TcpListener, destination: SocketAddr) { async fn proxy(listener: TcpListener, mut control: Receiver<ProxyControlMessage>) {
while let Ok((inbound, _)) = listener.accept().await { let mut current_destination =
let transfer = transfer(inbound, destination); if let ProxyControlMessage::Open { destination } = *control.borrow() {
Some(destination)
} else {
None
};
loop {
tokio::select! {
l = listener.accept()=> {
if let Ok((inbound, _)) = l {
let transfer = transfer(inbound, current_destination.unwrap());
tokio::spawn(transfer); tokio::spawn(transfer);
} }
} }
_ = control.changed() => {
match *control.borrow() {
ProxyControlMessage::Open { destination } => {
tracing::info!("destination for proxy port {} changed to {}", listener.local_addr().unwrap(), destination);
current_destination=Some(destination);
},
ProxyControlMessage::Close => {
tracing::info!("destination for proxy port {} closed", listener.local_addr().unwrap());
return;
},
}
}
}
}
}
async fn transfer(mut inbound: TcpStream, destination: SocketAddr) -> anyhow::Result<()> { async fn transfer(mut inbound: TcpStream, destination: SocketAddrV4) -> anyhow::Result<()> {
let mut outbound = TcpStream::connect(destination).await?; let mut outbound = TcpStream::connect(destination).await?;
let (mut ri, mut wi) = inbound.split(); let (mut ri, mut wi) = inbound.split();
@ -123,7 +197,7 @@ mod tests {
command: Command::New { command: Command::New {
incoming_port: 5555, incoming_port: 5555,
destination_port: 6666, destination_port: 6666,
destination_ip: std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), destination_ip: Ipv4Addr::new(127, 0, 0, 1),
id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"), id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"),
}, },
signature: [0u8; 32], signature: [0u8; 32],

View File

@ -2,8 +2,8 @@ use axum::{
routing::{get, post}, routing::{get, post},
Router, Router,
}; };
use proxima_centauri::{process_command, root}; use proxima_centauri::{process_command, root, GlobalState};
use std::net::SocketAddr; use std::{net::SocketAddr, sync::Arc};
use tracing::Level; use tracing::Level;
#[tokio::main] #[tokio::main]
@ -15,12 +15,14 @@ async fn main() {
tracing::subscriber::set_global_default(subscriber).unwrap(); tracing::subscriber::set_global_default(subscriber).unwrap();
let shared_state = Arc::new(GlobalState::new());
// build our application with a route // build our application with a route
let app = Router::new() let app = Router::new()
// `GET /` goes to `root` // `GET /` goes to `root`
.route("/", get(root)) .route("/", get(root))
// `POST /command` goes to `process_command` // `POST /command` goes to `process_command`
.route("/command", post(process_command)); .route("/command", post(process_command))
.with_state(shared_state);
// run our app with hyper // run our app with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000));