Add checks on incoming command

This commit is contained in:
Erik 2023-04-25 14:56:49 +02:00
parent b684c1dd76
commit f3dd96e3a6

View File

@ -3,10 +3,10 @@ use axum::{http::StatusCode, Json};
use p384::ecdsa::signature::Verifier;
use p384::ecdsa::{Signature, VerifyingKey};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use tokio::io::{self, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch::{self, Receiver, Sender};
@ -59,6 +59,7 @@ pub struct ProxyResponse {
#[derive(Debug)]
pub struct GlobalState {
proxies: Mutex<HashMap<Uuid, ProxyState>>,
ports: RwLock<HashSet<u16>>,
verifying_key: Option<VerifyingKey>,
}
@ -66,6 +67,7 @@ impl GlobalState {
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
Self {
proxies: Mutex::new(HashMap::new()),
ports: RwLock::new(HashSet::new()),
verifying_key: verifying_key
.map(|key| VerifyingKey::from_str(key.as_ref()).ok())
.flatten(),
@ -75,6 +77,7 @@ impl GlobalState {
#[derive(Debug)]
struct ProxyState {
incoming_port: u16,
destination: SocketAddr,
control: Sender<ProxyControlMessage>,
}
@ -104,16 +107,43 @@ pub async fn process_command(
destination_ip,
id,
} => {
// Check if ID or incoming_port already exists
if state.proxies.lock().unwrap().get(&id).is_some() {
return (
StatusCode::CONFLICT,
Json(ProxyResponse {
message: "Id already exists. Use the modify command instead.".to_string(),
}),
);
}
if !state.ports.write().unwrap().insert(incoming_port) {
return (
StatusCode::CONFLICT,
Json(ProxyResponse {
message: format!("The `incoming_port` already in use: {incoming_port}"),
}),
);
}
let addr = SocketAddr::new(destination_ip, destination_port);
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
state.proxies.lock().unwrap().insert(
id,
ProxyState {
incoming_port,
destination: addr,
control: tx,
},
);
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!(
"Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}"
),
}),
)
}
Command::Modify {
destination_port,
@ -129,20 +159,43 @@ pub async fn process_command(
destination: proxy.destination,
})
.unwrap();
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!(
"Changed tunnel {id} to use {destination_ip}:{destination_port}"
),
}),
)
} else {
(
StatusCode::NOT_FOUND,
Json(ProxyResponse {
message: format!("Id not found: {id}"),
}),
)
}
}
Command::Delete { id } => {
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) {
proxy.control.send(ProxyControlMessage::Close).unwrap();
state.ports.write().unwrap().remove(&proxy.incoming_port);
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: format!("Deleted tunnel: {id}"),
}),
)
} else {
(
StatusCode::NOT_FOUND,
Json(ProxyResponse {
message: format!("Id not found: {id}"),
}),
)
}
}
}
(
StatusCode::ACCEPTED,
Json(ProxyResponse {
message: "success".to_string(),
}),
)
}
#[derive(Debug)]