Add checks on incoming command
This commit is contained in:
parent
b684c1dd76
commit
f3dd96e3a6
75
src/lib.rs
75
src/lib.rs
@ -3,10 +3,10 @@ use axum::{http::StatusCode, Json};
|
|||||||
use p384::ecdsa::signature::Verifier;
|
use p384::ecdsa::signature::Verifier;
|
||||||
use p384::ecdsa::{Signature, VerifyingKey};
|
use p384::ecdsa::{Signature, VerifyingKey};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
use tokio::io::{self, AsyncWriteExt};
|
use tokio::io::{self, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::watch::{self, Receiver, Sender};
|
use tokio::sync::watch::{self, Receiver, Sender};
|
||||||
@ -59,6 +59,7 @@ pub struct ProxyResponse {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GlobalState {
|
pub struct GlobalState {
|
||||||
proxies: Mutex<HashMap<Uuid, ProxyState>>,
|
proxies: Mutex<HashMap<Uuid, ProxyState>>,
|
||||||
|
ports: RwLock<HashSet<u16>>,
|
||||||
verifying_key: Option<VerifyingKey>,
|
verifying_key: Option<VerifyingKey>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,6 +67,7 @@ impl GlobalState {
|
|||||||
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
|
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
proxies: Mutex::new(HashMap::new()),
|
proxies: Mutex::new(HashMap::new()),
|
||||||
|
ports: RwLock::new(HashSet::new()),
|
||||||
verifying_key: verifying_key
|
verifying_key: verifying_key
|
||||||
.map(|key| VerifyingKey::from_str(key.as_ref()).ok())
|
.map(|key| VerifyingKey::from_str(key.as_ref()).ok())
|
||||||
.flatten(),
|
.flatten(),
|
||||||
@ -75,6 +77,7 @@ impl GlobalState {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct ProxyState {
|
struct ProxyState {
|
||||||
|
incoming_port: u16,
|
||||||
destination: SocketAddr,
|
destination: SocketAddr,
|
||||||
control: Sender<ProxyControlMessage>,
|
control: Sender<ProxyControlMessage>,
|
||||||
}
|
}
|
||||||
@ -104,16 +107,43 @@ pub async fn process_command(
|
|||||||
destination_ip,
|
destination_ip,
|
||||||
id,
|
id,
|
||||||
} => {
|
} => {
|
||||||
|
// Check if ID or incoming_port already exists
|
||||||
|
if state.proxies.lock().unwrap().get(&id).is_some() {
|
||||||
|
return (
|
||||||
|
StatusCode::CONFLICT,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: "Id already exists. Use the modify command instead.".to_string(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !state.ports.write().unwrap().insert(incoming_port) {
|
||||||
|
return (
|
||||||
|
StatusCode::CONFLICT,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: format!("The `incoming_port` already in use: {incoming_port}"),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let addr = SocketAddr::new(destination_ip, destination_port);
|
let addr = SocketAddr::new(destination_ip, destination_port);
|
||||||
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
|
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
|
||||||
state.proxies.lock().unwrap().insert(
|
state.proxies.lock().unwrap().insert(
|
||||||
id,
|
id,
|
||||||
ProxyState {
|
ProxyState {
|
||||||
|
incoming_port,
|
||||||
destination: addr,
|
destination: addr,
|
||||||
control: tx,
|
control: tx,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
|
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
|
||||||
|
(
|
||||||
|
StatusCode::ACCEPTED,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: format!(
|
||||||
|
"Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}"
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
Command::Modify {
|
Command::Modify {
|
||||||
destination_port,
|
destination_port,
|
||||||
@ -129,20 +159,43 @@ pub async fn process_command(
|
|||||||
destination: proxy.destination,
|
destination: proxy.destination,
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
|
||||||
}
|
|
||||||
Command::Delete { id } => {
|
|
||||||
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
|
|
||||||
proxy.control.send(ProxyControlMessage::Close).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(
|
(
|
||||||
StatusCode::ACCEPTED,
|
StatusCode::ACCEPTED,
|
||||||
Json(ProxyResponse {
|
Json(ProxyResponse {
|
||||||
message: "success".to_string(),
|
message: format!(
|
||||||
|
"Changed tunnel {id} to use {destination_ip}:{destination_port}"
|
||||||
|
),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: format!("Id not found: {id}"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Command::Delete { id } => {
|
||||||
|
if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) {
|
||||||
|
proxy.control.send(ProxyControlMessage::Close).unwrap();
|
||||||
|
state.ports.write().unwrap().remove(&proxy.incoming_port);
|
||||||
|
(
|
||||||
|
StatusCode::ACCEPTED,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: format!("Deleted tunnel: {id}"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ProxyResponse {
|
||||||
|
message: format!("Id not found: {id}"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
Loading…
Reference in New Issue
Block a user