Add checks on incoming command
This commit is contained in:
parent
b684c1dd76
commit
f3dd96e3a6
71
src/lib.rs
71
src/lib.rs
@ -3,10 +3,10 @@ use axum::{http::StatusCode, Json};
|
||||
use p384::ecdsa::signature::Verifier;
|
||||
use p384::ecdsa::{Signature, VerifyingKey};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use tokio::io::{self, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::watch::{self, Receiver, Sender};
|
||||
@ -59,6 +59,7 @@ pub struct ProxyResponse {
|
||||
#[derive(Debug)]
|
||||
pub struct GlobalState {
|
||||
proxies: Mutex<HashMap<Uuid, ProxyState>>,
|
||||
ports: RwLock<HashSet<u16>>,
|
||||
verifying_key: Option<VerifyingKey>,
|
||||
}
|
||||
|
||||
@ -66,6 +67,7 @@ impl GlobalState {
|
||||
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
|
||||
Self {
|
||||
proxies: Mutex::new(HashMap::new()),
|
||||
ports: RwLock::new(HashSet::new()),
|
||||
verifying_key: verifying_key
|
||||
.map(|key| VerifyingKey::from_str(key.as_ref()).ok())
|
||||
.flatten(),
|
||||
@ -75,6 +77,7 @@ impl GlobalState {
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ProxyState {
|
||||
incoming_port: u16,
|
||||
destination: SocketAddr,
|
||||
control: Sender<ProxyControlMessage>,
|
||||
}
|
||||
@ -104,16 +107,43 @@ pub async fn process_command(
|
||||
destination_ip,
|
||||
id,
|
||||
} => {
|
||||
// Check if ID or incoming_port already exists
|
||||
if state.proxies.lock().unwrap().get(&id).is_some() {
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
Json(ProxyResponse {
|
||||
message: "Id already exists. Use the modify command instead.".to_string(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
if !state.ports.write().unwrap().insert(incoming_port) {
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
Json(ProxyResponse {
|
||||
message: format!("The `incoming_port` already in use: {incoming_port}"),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
let addr = SocketAddr::new(destination_ip, destination_port);
|
||||
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
|
||||
state.proxies.lock().unwrap().insert(
|
||||
id,
|
||||
ProxyState {
|
||||
incoming_port,
|
||||
destination: addr,
|
||||
control: tx,
|
||||
},
|
||||
);
|
||||
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
|
||||
(
|
||||
StatusCode::ACCEPTED,
|
||||
Json(ProxyResponse {
|
||||
message: format!(
|
||||
"Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}"
|
||||
),
|
||||
}),
|
||||
)
|
||||
}
|
||||
Command::Modify {
|
||||
destination_port,
|
||||
@ -129,20 +159,43 @@ pub async fn process_command(
|
||||
destination: proxy.destination,
|
||||
})
|
||||
.unwrap();
|
||||
(
|
||||
StatusCode::ACCEPTED,
|
||||
Json(ProxyResponse {
|
||||
message: format!(
|
||||
"Changed tunnel {id} to use {destination_ip}:{destination_port}"
|
||||
),
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ProxyResponse {
|
||||
message: format!("Id not found: {id}"),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
Command::Delete { id } => {
|
||||
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
|
||||
if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) {
|
||||
proxy.control.send(ProxyControlMessage::Close).unwrap();
|
||||
state.ports.write().unwrap().remove(&proxy.incoming_port);
|
||||
(
|
||||
StatusCode::ACCEPTED,
|
||||
Json(ProxyResponse {
|
||||
message: format!("Deleted tunnel: {id}"),
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ProxyResponse {
|
||||
message: format!("Id not found: {id}"),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
(
|
||||
StatusCode::ACCEPTED,
|
||||
Json(ProxyResponse {
|
||||
message: "success".to_string(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user