Add proxy control
This commit is contained in:
		
							
								
								
									
										118
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										118
									
								
								src/lib.rs
									
									
									
									
									
								
							@@ -1,8 +1,12 @@
 | 
			
		||||
use axum::extract::State;
 | 
			
		||||
use axum::{http::StatusCode, Json};
 | 
			
		||||
use serde::{Deserialize, Serialize};
 | 
			
		||||
use std::net::{IpAddr, SocketAddr};
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::net::{Ipv4Addr, SocketAddrV4};
 | 
			
		||||
use std::sync::{Arc, Mutex};
 | 
			
		||||
use tokio::io::{self, AsyncWriteExt};
 | 
			
		||||
use tokio::net::{TcpListener, TcpStream};
 | 
			
		||||
use tokio::sync::watch::{self, Receiver, Sender};
 | 
			
		||||
use uuid::Uuid;
 | 
			
		||||
 | 
			
		||||
#[derive(Deserialize, Serialize, Debug)]
 | 
			
		||||
@@ -16,11 +20,12 @@ enum Command {
 | 
			
		||||
    New {
 | 
			
		||||
        incoming_port: u16,
 | 
			
		||||
        destination_port: u16,
 | 
			
		||||
        destination_ip: IpAddr,
 | 
			
		||||
        destination_ip: Ipv4Addr,
 | 
			
		||||
        id: Uuid,
 | 
			
		||||
    },
 | 
			
		||||
    Modify {
 | 
			
		||||
        destionation_ip: IpAddr,
 | 
			
		||||
        destination_port: u16,
 | 
			
		||||
        destination_ip: Ipv4Addr,
 | 
			
		||||
        id: Uuid,
 | 
			
		||||
    },
 | 
			
		||||
    Delete {
 | 
			
		||||
@@ -33,11 +38,31 @@ pub struct ProxyResponse {
 | 
			
		||||
    message: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct GlobalState {
 | 
			
		||||
    proxies: Mutex<HashMap<Uuid, ProxyState>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl GlobalState {
 | 
			
		||||
    pub fn new() -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            proxies: Mutex::new(HashMap::new()),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
struct ProxyState {
 | 
			
		||||
    destination: SocketAddrV4,
 | 
			
		||||
    control: Sender<ProxyControlMessage>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn root() -> &'static str {
 | 
			
		||||
    "Hello, World!"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub async fn process_command(
 | 
			
		||||
    State(state): State<Arc<GlobalState>>,
 | 
			
		||||
    Json(payload): Json<ProxyCommand>,
 | 
			
		||||
) -> (StatusCode, Json<ProxyResponse>) {
 | 
			
		||||
    tracing::error!("Received payload: {:?}", payload);
 | 
			
		||||
@@ -49,19 +74,38 @@ pub async fn process_command(
 | 
			
		||||
            destination_ip,
 | 
			
		||||
            id,
 | 
			
		||||
        } => {
 | 
			
		||||
            // TODO: add id to global proxy map
 | 
			
		||||
            add_proxy(
 | 
			
		||||
                incoming_port,
 | 
			
		||||
                SocketAddr::new(destination_ip, destination_port),
 | 
			
		||||
            )
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap(); // TODO: error propagation??
 | 
			
		||||
            let addr = SocketAddrV4::new(destination_ip, destination_port);
 | 
			
		||||
            let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
 | 
			
		||||
            state.proxies.lock().unwrap().insert(
 | 
			
		||||
                id,
 | 
			
		||||
                ProxyState {
 | 
			
		||||
                    destination: addr,
 | 
			
		||||
                    control: tx,
 | 
			
		||||
                },
 | 
			
		||||
            );
 | 
			
		||||
            add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
 | 
			
		||||
        }
 | 
			
		||||
        Command::Modify {
 | 
			
		||||
            destionation_ip,
 | 
			
		||||
            destination_port,
 | 
			
		||||
            destination_ip,
 | 
			
		||||
            id,
 | 
			
		||||
        } => todo!(),
 | 
			
		||||
        Command::Delete { id } => todo!(),
 | 
			
		||||
        } => {
 | 
			
		||||
            if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
 | 
			
		||||
                proxy.destination.set_port(destination_port);
 | 
			
		||||
                proxy.destination.set_ip(destination_ip);
 | 
			
		||||
                proxy
 | 
			
		||||
                    .control
 | 
			
		||||
                    .send(ProxyControlMessage::Open {
 | 
			
		||||
                        destination: proxy.destination,
 | 
			
		||||
                    })
 | 
			
		||||
                    .unwrap();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        Command::Delete { id } => {
 | 
			
		||||
            if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
 | 
			
		||||
                proxy.control.send(ProxyControlMessage::Close).unwrap();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    (
 | 
			
		||||
        StatusCode::CREATED,
 | 
			
		||||
@@ -71,24 +115,54 @@ pub async fn process_command(
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn add_proxy(in_port: u16, destination: SocketAddr) -> anyhow::Result<()> {
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
enum ProxyControlMessage {
 | 
			
		||||
    Open { destination: SocketAddrV4 }, // Reroute { new: SocketAddr },
 | 
			
		||||
    Close,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn add_proxy(in_port: u16, control: Receiver<ProxyControlMessage>) -> anyhow::Result<()> {
 | 
			
		||||
    let listener = TcpListener::bind(("127.0.0.1", in_port)).await?;
 | 
			
		||||
 | 
			
		||||
    tracing::info!("proxying port {in_port} to {destination}");
 | 
			
		||||
    tracing::info!("proxying port {in_port} to {:?}", *control.borrow());
 | 
			
		||||
 | 
			
		||||
    tokio::spawn(proxy(listener, destination));
 | 
			
		||||
    tokio::spawn(proxy(listener, control));
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn proxy(listener: TcpListener, destination: SocketAddr) {
 | 
			
		||||
    while let Ok((inbound, _)) = listener.accept().await {
 | 
			
		||||
        let transfer = transfer(inbound, destination);
 | 
			
		||||
async fn proxy(listener: TcpListener, mut control: Receiver<ProxyControlMessage>) {
 | 
			
		||||
    let mut current_destination =
 | 
			
		||||
        if let ProxyControlMessage::Open { destination } = *control.borrow() {
 | 
			
		||||
            Some(destination)
 | 
			
		||||
        } else {
 | 
			
		||||
            None
 | 
			
		||||
        };
 | 
			
		||||
    loop {
 | 
			
		||||
        tokio::select! {
 | 
			
		||||
            l = listener.accept()=> {
 | 
			
		||||
                if let Ok((inbound, _)) = l {
 | 
			
		||||
                    let transfer = transfer(inbound, current_destination.unwrap());
 | 
			
		||||
 | 
			
		||||
        tokio::spawn(transfer);
 | 
			
		||||
                    tokio::spawn(transfer);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            _ = control.changed() => {
 | 
			
		||||
                match *control.borrow() {
 | 
			
		||||
                    ProxyControlMessage::Open { destination } => {
 | 
			
		||||
                        tracing::info!("destination for proxy port {} changed to {}", listener.local_addr().unwrap(), destination);
 | 
			
		||||
                        current_destination=Some(destination);
 | 
			
		||||
                    },
 | 
			
		||||
                    ProxyControlMessage::Close => {
 | 
			
		||||
                        tracing::info!("destination for proxy port {} closed", listener.local_addr().unwrap());
 | 
			
		||||
                        return;
 | 
			
		||||
                    },
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn transfer(mut inbound: TcpStream, destination: SocketAddr) -> anyhow::Result<()> {
 | 
			
		||||
async fn transfer(mut inbound: TcpStream, destination: SocketAddrV4) -> anyhow::Result<()> {
 | 
			
		||||
    let mut outbound = TcpStream::connect(destination).await?;
 | 
			
		||||
 | 
			
		||||
    let (mut ri, mut wi) = inbound.split();
 | 
			
		||||
@@ -123,7 +197,7 @@ mod tests {
 | 
			
		||||
            command: Command::New {
 | 
			
		||||
                incoming_port: 5555,
 | 
			
		||||
                destination_port: 6666,
 | 
			
		||||
                destination_ip: std::net::IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
 | 
			
		||||
                destination_ip: Ipv4Addr::new(127, 0, 0, 1),
 | 
			
		||||
                id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"),
 | 
			
		||||
            },
 | 
			
		||||
            signature: [0u8; 32],
 | 
			
		||||
 
 | 
			
		||||
@@ -2,8 +2,8 @@ use axum::{
 | 
			
		||||
    routing::{get, post},
 | 
			
		||||
    Router,
 | 
			
		||||
};
 | 
			
		||||
use proxima_centauri::{process_command, root};
 | 
			
		||||
use std::net::SocketAddr;
 | 
			
		||||
use proxima_centauri::{process_command, root, GlobalState};
 | 
			
		||||
use std::{net::SocketAddr, sync::Arc};
 | 
			
		||||
use tracing::Level;
 | 
			
		||||
 | 
			
		||||
#[tokio::main]
 | 
			
		||||
@@ -15,12 +15,14 @@ async fn main() {
 | 
			
		||||
 | 
			
		||||
    tracing::subscriber::set_global_default(subscriber).unwrap();
 | 
			
		||||
 | 
			
		||||
    let shared_state = Arc::new(GlobalState::new());
 | 
			
		||||
    // build our application with a route
 | 
			
		||||
    let app = Router::new()
 | 
			
		||||
        // `GET /` goes to `root`
 | 
			
		||||
        .route("/", get(root))
 | 
			
		||||
        // `POST /command` goes to `process_command`
 | 
			
		||||
        .route("/command", post(process_command));
 | 
			
		||||
        .route("/command", post(process_command))
 | 
			
		||||
        .with_state(shared_state);
 | 
			
		||||
 | 
			
		||||
    // run our app with hyper
 | 
			
		||||
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user