diff --git a/Cargo.lock b/Cargo.lock index 131718f..15a7252 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -194,16 +194,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chela" -version = "1.0.0" +version = "1.1.0" dependencies = [ "axum", "color-eyre", "eyre", + "hyper", + "hyper-util", "info_utils", "serde", "sqids", "sqlx", "tokio", + "tower", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 741083f..7f8867f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,17 +1,20 @@ [package] name = "chela" -version = "1.0.0" +version = "1.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -axum = "0.7.5" +axum = { version = "0.7.5", features = ["tokio"] } color-eyre = "0.6.3" eyre = "0.6.12" +hyper = "1.2.0" +hyper-util = { version = "0.1.3", features = ["tokio"] } info_utils = "2.2.3" serde = "1.0.197" sqids = "0.4.1" sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] } tokio = { version = "1.37.0", features = ["full"] } +tower = "0.4.13" url = { version = "2.5.0", features = ["serde"] } diff --git a/README.md b/README.md index 2f24b88..60ba570 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,12 @@ A page that Chela will redirect to when `/` is requested instead of replying wit ##### `CHELA_BEHIND_PROXY` If this variable is set, Chela will use the `X-Real-IP` header as the client IP address rather than the connection address. +##### `CHELA_UNIX_SOCKET` +If you would like Chela to listen for HTTP requests over a Unix socket, set this variable to the socket path that it should use. By default, Chela will listen via a Tcp socket. + +##### `CHELA_ALPHABET` +If this variable is set, Chela will use the characters in `CHELA_ALPHABET` to create IDs for URLs. The default alphabet is `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ`. See [here](https://sqids.org/faq#unique) + ### Manually #### Build ```bash diff --git a/src/get.rs b/src/get.rs index 61c5dd2..efe2711 100644 --- a/src/get.rs +++ b/src/get.rs @@ -9,6 +9,7 @@ use axum::Extension; use info_utils::prelude::*; use crate::ServerState; +use crate::UdsConnectInfo; use crate::UrlRow; pub async fn index(Extension(state): Extension) -> impl IntoResponse { @@ -34,6 +35,16 @@ pub async fn index(Extension(state): Extension) -> impl IntoRespons .into_response() } +pub async fn id_unix( + headers: HeaderMap, + ConnectInfo(addr): ConnectInfo, + Extension(state): Extension, + Path(id): Path, +) -> impl IntoResponse { + let ip = format!("{:?}", addr.peer_addr); + run_id(headers, ip, state, id).await +} + /// # Panics /// Will panic if `parse()` fails pub async fn id( @@ -42,8 +53,17 @@ pub async fn id( Extension(state): Extension, Path(id): Path, ) -> impl IntoResponse { - let mut show_request = false; let ip = get_ip(&headers, addr, &state).unwrap_or_default(); + run_id(headers, ip, state, id).await +} + +async fn run_id( + headers: HeaderMap, + ip: String, + state: ServerState, + id: String, +) -> impl IntoResponse { + let mut show_request = false; log!("Request for '{}' from {}", id.clone(), ip); let mut use_id = id; if use_id.ends_with('+') { @@ -66,7 +86,7 @@ pub async fn id( .into_response(); } log!("Redirecting {} -> {}", it.id, it.url); - save_analytics(headers, it.clone(), addr, state).await; + save_analytics(headers, it.clone(), ip, state).await; let mut response_headers = HeaderMap::new(); response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap()); response_headers.insert("Location", it.url.parse().unwrap()); @@ -92,9 +112,8 @@ pub async fn id( (StatusCode::NOT_FOUND, Html("
Not found.
")).into_response() } -async fn save_analytics(headers: HeaderMap, item: UrlRow, addr: SocketAddr, state: ServerState) { +async fn save_analytics(headers: HeaderMap, item: UrlRow, ip: String, state: ServerState) { let id = item.id; - let ip = get_ip(&headers, addr, &state); let referer = match headers.get("referer") { Some(it) => { if let Ok(i) = it.to_str() { @@ -130,7 +149,7 @@ VALUES ($1,$2,$3,$4) .await; if res.is_ok() { - log!("Saved analytics for '{id}' from {}", ip.unwrap_or_default()); + log!("Saved analytics for '{id}' from {}", ip); } } diff --git a/src/main.rs b/src/main.rs index e307b8e..1dde8a7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,23 @@ -use std::net::SocketAddr; - -use url::Url; - +use axum::extract::connect_info; +use axum::http::Request; use axum::routing::{get, post}; use axum::Router; use sqlx::postgres::PgPoolOptions; use sqlx::{Pool, Postgres}; -use sqids::Sqids; - -use serde::Deserialize; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server; use info_utils::prelude::*; +use serde::Deserialize; +use sqids::Sqids; +use tower::Service; +use url::Url; + +use std::env; +use std::sync::Arc; pub mod get; pub mod post; @@ -39,22 +44,39 @@ pub struct CreateForm { pub url: url::Url, } +#[derive(Clone)] +#[allow(dead_code)] +pub struct UdsConnectInfo { + pub peer_addr: Arc, + pub peer_cred: tokio::net::unix::UCred, +} + +impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo { + fn connect_info(target: &tokio::net::UnixStream) -> Self { + let peer_addr = target.peer_addr().unwrap(); + let peer_cred = target.peer_cred().unwrap(); + + Self { + peer_addr: Arc::new(peer_addr), + peer_cred, + } + } +} + #[tokio::main] async fn main() -> eyre::Result<()> { color_eyre::install()?; let db_pool = init_db().await?; - let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string()); + let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string()); + let alphabet = env::var("CHELA_ALPHABET") + .unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string()); let sqids = Sqids::builder() - .alphabet( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - .chars() - .collect(), - ) + .alphabet(alphabet.chars().collect()) .blocklist(["create".to_string()].into()) .build()?; - let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default(); - let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok(); + let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default(); + let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok(); let server_state = ServerState { db_pool, host, @@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> { behind_proxy, }; - let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); - let port = 3000; + serve(server_state).await?; + Ok(()) +} + +async fn serve(state: ServerState) -> eyre::Result<()> { + let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default(); + if unix_socket.is_empty() { + let router = Router::new() + .route("/", get(get::index)) + .route("/:id", get(get::id)) + .route("/create", get(get::create_id)) + .route("/", post(post::create_link)) + .layer(axum::Extension(state)); + let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); + let port = 3000; + let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?; + log!("Listening at {}:{}", address, port); + axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ) + .await?; + } else { + let router = Router::new() + .route("/", get(get::index)) + .route("/:id", get(get::id_unix)) + .route("/create", get(get::create_id)) + .route("/", post(post::create_link)) + .layer(axum::Extension(state)); + let unix_socket_path = std::path::Path::new(&unix_socket); + if unix_socket_path.exists() { + tokio::fs::remove_file(unix_socket_path).await?; + } + let listener = tokio::net::UnixListener::bind(unix_socket_path)?; + log!("Listening via Unix socket at {}", unix_socket); + tokio::spawn(async move { + let mut service = router.into_make_service_with_connect_info::(); + loop { + let (socket, _remote_addr) = listener.accept().await.unwrap(); + let tower_service = match service.call(&socket).await { + Ok(value) => value, + Err(err) => match err {}, + }; + + tokio::spawn(async move { + let socket = TokioIo::new(socket); + let hyper_service = + hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(socket, hyper_service) + .await + { + warn!("Failed to serve connection: {}", err); + } + }); + } + }) + .await?; + } - let router = init_routes(server_state); - let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?; - log!("Listening at {}:{}", address, port); - axum::serve( - listener, - router.into_make_service_with_connect_info::(), - ) - .await?; Ok(()) } @@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result> { let db_pool = PgPoolOptions::new() .max_connections(15) .connect( - std::env::var("DATABASE_URL") + env::var("DATABASE_URL") .expect("DATABASE_URL must be set") .as_str(), ) @@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking ( fn init_routes(state: ServerState) -> Router { Router::new() .route("/", get(get::index)) - .route("/:id", get(get::id)) .route("/create", get(get::create_id)) .route("/", post(post::create_link)) .layer(axum::Extension(state))