From f3dd96e3a64128adb0eead2c025b5d958c6d089f Mon Sep 17 00:00:00 2001 From: Erik Date: Tue, 25 Apr 2023 14:56:49 +0200 Subject: [PATCH] Add checks on incoming command --- src/lib.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index af84172..1f34250 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,10 @@ use axum::{http::StatusCode, Json}; use p384::ecdsa::signature::Verifier; use p384::ecdsa::{Signature, VerifyingKey}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::watch::{self, Receiver, Sender}; @@ -59,6 +59,7 @@ pub struct ProxyResponse { #[derive(Debug)] pub struct GlobalState { proxies: Mutex>, + ports: RwLock>, verifying_key: Option, } @@ -66,6 +67,7 @@ impl GlobalState { pub fn new>(verifying_key: Option) -> Self { Self { proxies: Mutex::new(HashMap::new()), + ports: RwLock::new(HashSet::new()), verifying_key: verifying_key .map(|key| VerifyingKey::from_str(key.as_ref()).ok()) .flatten(), @@ -75,6 +77,7 @@ impl GlobalState { #[derive(Debug)] struct ProxyState { + incoming_port: u16, destination: SocketAddr, control: Sender, } @@ -104,16 +107,43 @@ pub async fn process_command( destination_ip, id, } => { + // Check if ID or incoming_port already exists + if state.proxies.lock().unwrap().get(&id).is_some() { + return ( + StatusCode::CONFLICT, + Json(ProxyResponse { + message: "Id already exists. Use the modify command instead.".to_string(), + }), + ); + } + if !state.ports.write().unwrap().insert(incoming_port) { + return ( + StatusCode::CONFLICT, + Json(ProxyResponse { + message: format!("The `incoming_port` already in use: {incoming_port}"), + }), + ); + } + let addr = SocketAddr::new(destination_ip, destination_port); let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr }); state.proxies.lock().unwrap().insert( id, ProxyState { + incoming_port, destination: addr, control: tx, }, ); add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation?? + ( + StatusCode::ACCEPTED, + Json(ProxyResponse { + message: format!( + "Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}" + ), + }), + ) } Command::Modify { destination_port, @@ -129,20 +159,43 @@ pub async fn process_command( destination: proxy.destination, }) .unwrap(); + ( + StatusCode::ACCEPTED, + Json(ProxyResponse { + message: format!( + "Changed tunnel {id} to use {destination_ip}:{destination_port}" + ), + }), + ) + } else { + ( + StatusCode::NOT_FOUND, + Json(ProxyResponse { + message: format!("Id not found: {id}"), + }), + ) } } Command::Delete { id } => { - if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) { + if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) { proxy.control.send(ProxyControlMessage::Close).unwrap(); + state.ports.write().unwrap().remove(&proxy.incoming_port); + ( + StatusCode::ACCEPTED, + Json(ProxyResponse { + message: format!("Deleted tunnel: {id}"), + }), + ) + } else { + ( + StatusCode::NOT_FOUND, + Json(ProxyResponse { + message: format!("Id not found: {id}"), + }), + ) } } } - ( - StatusCode::ACCEPTED, - Json(ProxyResponse { - message: "success".to_string(), - }), - ) } #[derive(Debug)]