Add checks on incoming command

This commit is contained in:
Erik 2023-04-25 14:56:49 +02:00
parent b684c1dd76
commit f3dd96e3a6

View File

@ -3,10 +3,10 @@ use axum::{http::StatusCode, Json};
use p384::ecdsa::signature::Verifier; use p384::ecdsa::signature::Verifier;
use p384::ecdsa::{Signature, VerifyingKey}; use p384::ecdsa::{Signature, VerifyingKey};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex, RwLock};
use tokio::io::{self, AsyncWriteExt}; use tokio::io::{self, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch::{self, Receiver, Sender}; use tokio::sync::watch::{self, Receiver, Sender};
@ -59,6 +59,7 @@ pub struct ProxyResponse {
#[derive(Debug)] #[derive(Debug)]
pub struct GlobalState { pub struct GlobalState {
proxies: Mutex<HashMap<Uuid, ProxyState>>, proxies: Mutex<HashMap<Uuid, ProxyState>>,
ports: RwLock<HashSet<u16>>,
verifying_key: Option<VerifyingKey>, verifying_key: Option<VerifyingKey>,
} }
@ -66,6 +67,7 @@ impl GlobalState {
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self { pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
Self { Self {
proxies: Mutex::new(HashMap::new()), proxies: Mutex::new(HashMap::new()),
ports: RwLock::new(HashSet::new()),
verifying_key: verifying_key verifying_key: verifying_key
.map(|key| VerifyingKey::from_str(key.as_ref()).ok()) .map(|key| VerifyingKey::from_str(key.as_ref()).ok())
.flatten(), .flatten(),
@ -75,6 +77,7 @@ impl GlobalState {
#[derive(Debug)] #[derive(Debug)]
struct ProxyState { struct ProxyState {
incoming_port: u16,
destination: SocketAddr, destination: SocketAddr,
control: Sender<ProxyControlMessage>, control: Sender<ProxyControlMessage>,
} }
@ -104,16 +107,43 @@ pub async fn process_command(
destination_ip, destination_ip,
id, id,
} => { } => {
// Check if ID or incoming_port already exists
if state.proxies.lock().unwrap().get(&id).is_some() {
return (
StatusCode::CONFLICT,
Json(ProxyResponse {
message: "Id already exists. Use the modify command instead.".to_string(),
}),
);
}
if !state.ports.write().unwrap().insert(incoming_port) {
return (
StatusCode::CONFLICT,
Json(ProxyResponse {
message: format!("The `incoming_port` already in use: {incoming_port}"),
}),
);
}
let addr = SocketAddr::new(destination_ip, destination_port); let addr = SocketAddr::new(destination_ip, destination_port);
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr }); let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
state.proxies.lock().unwrap().insert( state.proxies.lock().unwrap().insert(
id, id,
ProxyState { ProxyState {
incoming_port,
destination: addr, destination: addr,
control: tx, control: tx,
}, },
); );
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation?? add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!(
"Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}"
),
}),
)
} }
Command::Modify { Command::Modify {
destination_port, destination_port,
@ -129,20 +159,43 @@ pub async fn process_command(
destination: proxy.destination, destination: proxy.destination,
}) })
.unwrap(); .unwrap();
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!(
"Changed tunnel {id} to use {destination_ip}:{destination_port}"
),
}),
)
} else {
(
StatusCode::NOT_FOUND,
Json(ProxyResponse {
message: format!("Id not found: {id}"),
}),
)
} }
} }
Command::Delete { id } => { Command::Delete { id } => {
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) { if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) {
proxy.control.send(ProxyControlMessage::Close).unwrap(); proxy.control.send(ProxyControlMessage::Close).unwrap();
state.ports.write().unwrap().remove(&proxy.incoming_port);
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!("Deleted tunnel: {id}"),
}),
)
} else {
(
StatusCode::NOT_FOUND,
Json(ProxyResponse {
message: format!("Id not found: {id}"),
}),
)
} }
} }
} }
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: "success".to_string(),
}),
)
} }
#[derive(Debug)] #[derive(Debug)]