385 lines
13 KiB
Rust
385 lines
13 KiB
Rust
use axum::extract::State;
|
|
use axum::{http::StatusCode, Json};
|
|
use p384::ecdsa::signature::Verifier;
|
|
use p384::ecdsa::{Signature, VerifyingKey};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::str::FromStr;
|
|
use std::sync::{Arc, Mutex, RwLock};
|
|
use tokio::io::{self, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::watch::{self, Receiver, Sender};
|
|
use uuid::Uuid;
|
|
|
|
#[derive(Deserialize, Serialize, Debug)]
|
|
pub struct ProxyCommand {
|
|
#[serde(flatten)]
|
|
command: Command,
|
|
signature: Option<Signature>,
|
|
}
|
|
|
|
impl ProxyCommand {
|
|
fn verify_signature(&self, verifying_key: &Option<VerifyingKey>) -> bool {
|
|
match (verifying_key, &self.signature) {
|
|
(Some(key), Some(signature)) => {
|
|
let message = serde_json::to_string(&self.command).unwrap();
|
|
key.verify(message.as_bytes(), signature).is_ok()
|
|
}
|
|
(Some(_), None) => false,
|
|
(None, _) => true,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize, Serialize, Debug)]
|
|
#[serde(rename_all = "snake_case")]
|
|
enum Command {
|
|
Create {
|
|
incoming_port: u16,
|
|
destination_port: u16,
|
|
destination_ip: IpAddr,
|
|
id: Uuid,
|
|
},
|
|
Modify {
|
|
destination_port: u16,
|
|
destination_ip: IpAddr,
|
|
id: Uuid,
|
|
},
|
|
Delete {
|
|
id: Uuid,
|
|
},
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct ProxyResponse {
|
|
message: String,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct GlobalState {
|
|
proxies: Mutex<HashMap<Uuid, ProxyState>>,
|
|
ports: RwLock<HashSet<u16>>,
|
|
verifying_key: Option<VerifyingKey>,
|
|
}
|
|
|
|
impl GlobalState {
|
|
pub fn new<S: AsRef<str>>(verifying_key: Option<S>) -> Self {
|
|
Self {
|
|
proxies: Mutex::new(HashMap::new()),
|
|
ports: RwLock::new(HashSet::new()),
|
|
verifying_key: verifying_key
|
|
.map(|key| VerifyingKey::from_str(key.as_ref()).ok())
|
|
.flatten(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ProxyState {
|
|
incoming_port: u16,
|
|
destination: SocketAddr,
|
|
control: Sender<ProxyControlMessage>,
|
|
}
|
|
|
|
pub async fn root() -> &'static str {
|
|
"Hello, World!"
|
|
}
|
|
|
|
pub async fn process_command(
|
|
State(state): State<Arc<GlobalState>>,
|
|
Json(payload): Json<ProxyCommand>,
|
|
) -> (StatusCode, Json<ProxyResponse>) {
|
|
tracing::info!("Received payload: {:?}", payload);
|
|
|
|
if !payload.verify_signature(&state.verifying_key) {
|
|
return (
|
|
StatusCode::UNAUTHORIZED,
|
|
Json(ProxyResponse {
|
|
message: "Invalid signature".to_string(),
|
|
}),
|
|
);
|
|
}
|
|
match payload.command {
|
|
Command::Create {
|
|
incoming_port,
|
|
destination_port,
|
|
destination_ip,
|
|
id,
|
|
} => {
|
|
// Check if ID or incoming_port already exists
|
|
if state.proxies.lock().unwrap().get(&id).is_some() {
|
|
return (
|
|
StatusCode::CONFLICT,
|
|
Json(ProxyResponse {
|
|
message: "Id already exists. Use the modify command instead.".to_string(),
|
|
}),
|
|
);
|
|
}
|
|
if !state.ports.write().unwrap().insert(incoming_port) {
|
|
return (
|
|
StatusCode::CONFLICT,
|
|
Json(ProxyResponse {
|
|
message: format!("The `incoming_port` already in use: {incoming_port}"),
|
|
}),
|
|
);
|
|
}
|
|
|
|
let addr = SocketAddr::new(destination_ip, destination_port);
|
|
let (tx, rx) = watch::channel(ProxyControlMessage::Open { destination: addr });
|
|
state.proxies.lock().unwrap().insert(
|
|
id,
|
|
ProxyState {
|
|
incoming_port,
|
|
destination: addr,
|
|
control: tx,
|
|
},
|
|
);
|
|
add_proxy(incoming_port, rx).await.unwrap(); // TODO: error propagation??
|
|
(
|
|
StatusCode::ACCEPTED,
|
|
Json(ProxyResponse {
|
|
message: format!(
|
|
"Created tunnel {id} on port {incoming_port} to use {destination_ip}:{destination_port}"
|
|
),
|
|
}),
|
|
)
|
|
}
|
|
Command::Modify {
|
|
destination_port,
|
|
destination_ip,
|
|
id,
|
|
} => {
|
|
if let Some(proxy) = state.proxies.lock().unwrap().get_mut(&id) {
|
|
proxy.destination.set_port(destination_port);
|
|
proxy.destination.set_ip(destination_ip);
|
|
proxy
|
|
.control
|
|
.send(ProxyControlMessage::Open {
|
|
destination: proxy.destination,
|
|
})
|
|
.unwrap();
|
|
(
|
|
StatusCode::ACCEPTED,
|
|
Json(ProxyResponse {
|
|
message: format!(
|
|
"Changed tunnel {id} to use {destination_ip}:{destination_port}"
|
|
),
|
|
}),
|
|
)
|
|
} else {
|
|
(
|
|
StatusCode::NOT_FOUND,
|
|
Json(ProxyResponse {
|
|
message: format!("Id not found: {id}"),
|
|
}),
|
|
)
|
|
}
|
|
}
|
|
Command::Delete { id } => {
|
|
if let Some(proxy) = state.proxies.lock().unwrap().remove(&id) {
|
|
proxy.control.send(ProxyControlMessage::Close).unwrap();
|
|
state.ports.write().unwrap().remove(&proxy.incoming_port);
|
|
(
|
|
StatusCode::ACCEPTED,
|
|
Json(ProxyResponse {
|
|
message: format!("Deleted tunnel: {id}"),
|
|
}),
|
|
)
|
|
} else {
|
|
(
|
|
StatusCode::NOT_FOUND,
|
|
Json(ProxyResponse {
|
|
message: format!("Id not found: {id}"),
|
|
}),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum ProxyControlMessage {
|
|
Open { destination: SocketAddr },
|
|
Close,
|
|
}
|
|
|
|
async fn add_proxy(in_port: u16, control: Receiver<ProxyControlMessage>) -> anyhow::Result<()> {
|
|
let listener = TcpListener::bind(("127.0.0.1", in_port)).await?;
|
|
|
|
tracing::info!("proxying port {in_port} to {:?}", *control.borrow());
|
|
|
|
tokio::spawn(proxy(listener, control));
|
|
Ok(())
|
|
}
|
|
|
|
async fn proxy(listener: TcpListener, mut control: Receiver<ProxyControlMessage>) {
|
|
loop {
|
|
tokio::select! {
|
|
l = listener.accept()=> {
|
|
if let Ok((inbound, _)) = l {
|
|
let transfer = transfer(inbound, control.clone());
|
|
|
|
tokio::spawn(transfer);
|
|
}
|
|
}
|
|
_ = control.changed() => {
|
|
match *control.borrow() {
|
|
ProxyControlMessage::Open { destination } => {
|
|
tracing::info!("destination for proxy port {} changed to {}", listener.local_addr().unwrap(), destination);
|
|
},
|
|
ProxyControlMessage::Close => {
|
|
tracing::info!("destination for proxy port {} closed", listener.local_addr().unwrap());
|
|
return;
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn transfer(
|
|
mut inbound: TcpStream,
|
|
mut control: Receiver<ProxyControlMessage>,
|
|
) -> anyhow::Result<()> {
|
|
loop {
|
|
let current_destination =
|
|
if let ProxyControlMessage::Open { destination } = *control.borrow() {
|
|
Some(destination)
|
|
} else {
|
|
break Ok(());
|
|
};
|
|
let mut outbound = TcpStream::connect(current_destination.unwrap()).await?;
|
|
|
|
let (mut ri, mut wi) = inbound.split();
|
|
let (mut ro, mut wo) = outbound.split();
|
|
|
|
let client_to_server = async {
|
|
io::copy(&mut ri, &mut wo).await?;
|
|
wo.shutdown().await
|
|
};
|
|
|
|
let server_to_client = async {
|
|
io::copy(&mut ro, &mut wi).await?;
|
|
wi.shutdown().await
|
|
};
|
|
|
|
// Select between the copy tasks and watch channel
|
|
tokio::select! {
|
|
// Join the two copy streams and wait for the connection to clone
|
|
result = async move { tokio::join!(client_to_server, server_to_client) } => {
|
|
match result {
|
|
(Ok(_), Ok(_)) => {
|
|
break Ok(());
|
|
}
|
|
(r1, r2) => {
|
|
if r1.is_err() {
|
|
tracing::error!("error closing client->server of {:?}: {:?}", inbound, &r1);
|
|
}
|
|
if r2.is_err() {
|
|
tracing::error!("error closing server->client of {:?}: {:?}", inbound, &r2);
|
|
}
|
|
r1?;
|
|
r2?;
|
|
},
|
|
}
|
|
}
|
|
_ = control.changed() => {
|
|
match *control.borrow() {
|
|
ProxyControlMessage::Open { destination } => {
|
|
eprintln!("Switching to new destination: {destination}");
|
|
// Disconnect the current outbound connection and restart the loop
|
|
drop(outbound);
|
|
continue;
|
|
},
|
|
ProxyControlMessage::Close => {
|
|
break Ok(());
|
|
},
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::net::{IpAddr, Ipv4Addr};
|
|
|
|
use crate::{Command, ProxyCommand};
|
|
use p384::{
|
|
ecdsa::{signature::Signer, Signature, SigningKey, VerifyingKey},
|
|
elliptic_curve::rand_core::OsRng,
|
|
};
|
|
use uuid::uuid;
|
|
|
|
#[test]
|
|
fn serialize_proxy_command_create() {
|
|
let key = SigningKey::from_slice(&[1; 48]).unwrap();
|
|
let signature = key.sign(&[]); // Not a valid signature
|
|
let proxy_command = ProxyCommand {
|
|
command: Command::Create {
|
|
incoming_port: 5555,
|
|
destination_port: 6666,
|
|
destination_ip: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
|
id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"),
|
|
},
|
|
signature: Some(signature),
|
|
};
|
|
let expected = "{\"create\":{\"incoming_port\":5555,\"destination_port\":6666,\"\
|
|
destination_ip\":\"127.0.0.1\",\"id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\"},\
|
|
\"signature\":\"\
|
|
5C912C4B3BFF2ADB49885DCBDB53D6D3041D0632E498CDFF\
|
|
2114CD2DCAC936AB0901B47C411E5BB57FE77BEF96044940\
|
|
81680ADAD0775CD144E2D2678537F621ED587E13EB430126\
|
|
C7A757AEC99CE08A2D0F3A5C9FB45E9349F36408DFD7BA17\"}";
|
|
|
|
assert_eq!(serde_json::to_string(&proxy_command).unwrap(), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn serialize_proxy_command_delete() {
|
|
let key = SigningKey::from_slice(&[1; 48]).unwrap();
|
|
let signature = key.sign(&[]); // Not a valid signature
|
|
let proxy_command = ProxyCommand {
|
|
command: Command::Delete {
|
|
id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"),
|
|
},
|
|
signature: Some(signature),
|
|
};
|
|
let expected = "{\"delete\":{\"id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\"},\
|
|
\"signature\":\"\
|
|
5C912C4B3BFF2ADB49885DCBDB53D6D3041D0632E498CDFF\
|
|
2114CD2DCAC936AB0901B47C411E5BB57FE77BEF96044940\
|
|
81680ADAD0775CD144E2D2678537F621ED587E13EB430126\
|
|
C7A757AEC99CE08A2D0F3A5C9FB45E9349F36408DFD7BA17\"}";
|
|
|
|
assert_eq!(serde_json::to_string(&proxy_command).unwrap(), expected);
|
|
}
|
|
|
|
#[test]
|
|
fn verify_signature() {
|
|
let command = Command::Create {
|
|
incoming_port: 4567,
|
|
destination_port: 7654,
|
|
destination_ip: IpAddr::V4(Ipv4Addr::new(123, 23, 76, 21)),
|
|
id: uuid::Uuid::new_v4(),
|
|
};
|
|
|
|
// Create signed message
|
|
let signing_key = SigningKey::random(&mut OsRng);
|
|
let message = serde_json::to_string(&command).unwrap();
|
|
let signature: Signature = signing_key.sign(message.as_bytes());
|
|
let bytes = signature.to_bytes();
|
|
assert_eq!(bytes.len(), 96);
|
|
let proxy_command = ProxyCommand {
|
|
command,
|
|
signature: Some(signature),
|
|
};
|
|
|
|
// Verify signed message
|
|
let verifying_key = VerifyingKey::from(&signing_key);
|
|
assert!(proxy_command.verify_signature(&Some(verifying_key)));
|
|
}
|
|
}
|