Add support for Unix sockets and custom alphabet

This commit is contained in:
Shav Kinderlehrer 2024-04-07 20:06:27 -04:00
parent f080854b84
commit 3b7d5454cc
5 changed files with 139 additions and 35 deletions

5
Cargo.lock generated
View File

@ -194,16 +194,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "chela" name = "chela"
version = "1.0.0" version = "1.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"color-eyre", "color-eyre",
"eyre", "eyre",
"hyper",
"hyper-util",
"info_utils", "info_utils",
"serde", "serde",
"sqids", "sqids",
"sqlx", "sqlx",
"tokio", "tokio",
"tower",
"url", "url",
] ]

View File

@ -1,17 +1,20 @@
[package] [package]
name = "chela" name = "chela"
version = "1.0.0" version = "1.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
axum = "0.7.5" axum = { version = "0.7.5", features = ["tokio"] }
color-eyre = "0.6.3" color-eyre = "0.6.3"
eyre = "0.6.12" eyre = "0.6.12"
hyper = "1.2.0"
hyper-util = { version = "0.1.3", features = ["tokio"] }
info_utils = "2.2.3" info_utils = "2.2.3"
serde = "1.0.197" serde = "1.0.197"
sqids = "0.4.1" sqids = "0.4.1"
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] }
tokio = { version = "1.37.0", features = ["full"] } tokio = { version = "1.37.0", features = ["full"] }
tower = "0.4.13"
url = { version = "2.5.0", features = ["serde"] } url = { version = "2.5.0", features = ["serde"] }

View File

@ -63,6 +63,12 @@ A page that Chela will redirect to when `/` is requested instead of replying wit
##### `CHELA_BEHIND_PROXY` ##### `CHELA_BEHIND_PROXY`
If this variable is set, Chela will use the `X-Real-IP` header as the client IP address rather than the connection address. If this variable is set, Chela will use the `X-Real-IP` header as the client IP address rather than the connection address.
##### `CHELA_UNIX_SOCKET`
If you would like Chela to listen for HTTP requests over a Unix socket, set this variable to the socket path that it should use. By default, Chela will listen via a Tcp socket.
##### `CHELA_ALPHABET`
If this variable is set, Chela will use the characters in `CHELA_ALPHABET` to create IDs for URLs. The default alphabet is `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ`. See [here](https://sqids.org/faq#unique)
### Manually ### Manually
#### Build #### Build
```bash ```bash

View File

@ -9,6 +9,7 @@ use axum::Extension;
use info_utils::prelude::*; use info_utils::prelude::*;
use crate::ServerState; use crate::ServerState;
use crate::UdsConnectInfo;
use crate::UrlRow; use crate::UrlRow;
pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoResponse { pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoResponse {
@ -34,6 +35,16 @@ pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoRespons
.into_response() .into_response()
} }
pub async fn id_unix(
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<UdsConnectInfo>,
Extension(state): Extension<ServerState>,
Path(id): Path<String>,
) -> impl IntoResponse {
let ip = format!("{:?}", addr.peer_addr);
run_id(headers, ip, state, id).await
}
/// # Panics /// # Panics
/// Will panic if `parse()` fails /// Will panic if `parse()` fails
pub async fn id( pub async fn id(
@ -42,8 +53,17 @@ pub async fn id(
Extension(state): Extension<ServerState>, Extension(state): Extension<ServerState>,
Path(id): Path<String>, Path(id): Path<String>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let mut show_request = false;
let ip = get_ip(&headers, addr, &state).unwrap_or_default(); let ip = get_ip(&headers, addr, &state).unwrap_or_default();
run_id(headers, ip, state, id).await
}
async fn run_id(
headers: HeaderMap,
ip: String,
state: ServerState,
id: String,
) -> impl IntoResponse {
let mut show_request = false;
log!("Request for '{}' from {}", id.clone(), ip); log!("Request for '{}' from {}", id.clone(), ip);
let mut use_id = id; let mut use_id = id;
if use_id.ends_with('+') { if use_id.ends_with('+') {
@ -66,7 +86,7 @@ pub async fn id(
.into_response(); .into_response();
} }
log!("Redirecting {} -> {}", it.id, it.url); log!("Redirecting {} -> {}", it.id, it.url);
save_analytics(headers, it.clone(), addr, state).await; save_analytics(headers, it.clone(), ip, state).await;
let mut response_headers = HeaderMap::new(); let mut response_headers = HeaderMap::new();
response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap()); response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap());
response_headers.insert("Location", it.url.parse().unwrap()); response_headers.insert("Location", it.url.parse().unwrap());
@ -92,9 +112,8 @@ pub async fn id(
(StatusCode::NOT_FOUND, Html("<pre>Not found.</pre>")).into_response() (StatusCode::NOT_FOUND, Html("<pre>Not found.</pre>")).into_response()
} }
async fn save_analytics(headers: HeaderMap, item: UrlRow, addr: SocketAddr, state: ServerState) { async fn save_analytics(headers: HeaderMap, item: UrlRow, ip: String, state: ServerState) {
let id = item.id; let id = item.id;
let ip = get_ip(&headers, addr, &state);
let referer = match headers.get("referer") { let referer = match headers.get("referer") {
Some(it) => { Some(it) => {
if let Ok(i) = it.to_str() { if let Ok(i) = it.to_str() {
@ -130,7 +149,7 @@ VALUES ($1,$2,$3,$4)
.await; .await;
if res.is_ok() { if res.is_ok() {
log!("Saved analytics for '{id}' from {}", ip.unwrap_or_default()); log!("Saved analytics for '{id}' from {}", ip);
} }
} }

View File

@ -1,18 +1,23 @@
use std::net::SocketAddr; use axum::extract::connect_info;
use axum::http::Request;
use url::Url;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::Router; use axum::Router;
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use sqids::Sqids; use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use serde::Deserialize; use hyper_util::server;
use info_utils::prelude::*; use info_utils::prelude::*;
use serde::Deserialize;
use sqids::Sqids;
use tower::Service;
use url::Url;
use std::env;
use std::sync::Arc;
pub mod get; pub mod get;
pub mod post; pub mod post;
@ -39,22 +44,39 @@ pub struct CreateForm {
pub url: url::Url, pub url: url::Url,
} }
#[derive(Clone)]
#[allow(dead_code)]
pub struct UdsConnectInfo {
pub peer_addr: Arc<tokio::net::unix::SocketAddr>,
pub peer_cred: tokio::net::unix::UCred,
}
impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo {
fn connect_info(target: &tokio::net::UnixStream) -> Self {
let peer_addr = target.peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap();
Self {
peer_addr: Arc::new(peer_addr),
peer_cred,
}
}
}
#[tokio::main] #[tokio::main]
async fn main() -> eyre::Result<()> { async fn main() -> eyre::Result<()> {
color_eyre::install()?; color_eyre::install()?;
let db_pool = init_db().await?; let db_pool = init_db().await?;
let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string()); let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string());
let alphabet = env::var("CHELA_ALPHABET")
.unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string());
let sqids = Sqids::builder() let sqids = Sqids::builder()
.alphabet( .alphabet(alphabet.chars().collect())
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
.chars()
.collect(),
)
.blocklist(["create".to_string()].into()) .blocklist(["create".to_string()].into())
.build()?; .build()?;
let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default(); let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok(); let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok();
let server_state = ServerState { let server_state = ServerState {
db_pool, db_pool,
host, host,
@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> {
behind_proxy, behind_proxy,
}; };
let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); serve(server_state).await?;
let port = 3000; Ok(())
}
async fn serve(state: ServerState) -> eyre::Result<()> {
let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default();
if unix_socket.is_empty() {
let router = Router::new()
.route("/", get(get::index))
.route("/:id", get(get::id))
.route("/create", get(get::create_id))
.route("/", post(post::create_link))
.layer(axum::Extension(state));
let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
let port = 3000;
let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
log!("Listening at {}:{}", address, port);
axum::serve(
listener,
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.await?;
} else {
let router = Router::new()
.route("/", get(get::index))
.route("/:id", get(get::id_unix))
.route("/create", get(get::create_id))
.route("/", post(post::create_link))
.layer(axum::Extension(state));
let unix_socket_path = std::path::Path::new(&unix_socket);
if unix_socket_path.exists() {
tokio::fs::remove_file(unix_socket_path).await?;
}
let listener = tokio::net::UnixListener::bind(unix_socket_path)?;
log!("Listening via Unix socket at {}", unix_socket);
tokio::spawn(async move {
let mut service = router.into_make_service_with_connect_info::<UdsConnectInfo>();
loop {
let (socket, _remote_addr) = listener.accept().await.unwrap();
let tower_service = match service.call(&socket).await {
Ok(value) => value,
Err(err) => match err {},
};
tokio::spawn(async move {
let socket = TokioIo::new(socket);
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(socket, hyper_service)
.await
{
warn!("Failed to serve connection: {}", err);
}
});
}
})
.await?;
}
let router = init_routes(server_state);
let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
log!("Listening at {}:{}", address, port);
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(()) Ok(())
} }
@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result<Pool<Postgres>> {
let db_pool = PgPoolOptions::new() let db_pool = PgPoolOptions::new()
.max_connections(15) .max_connections(15)
.connect( .connect(
std::env::var("DATABASE_URL") env::var("DATABASE_URL")
.expect("DATABASE_URL must be set") .expect("DATABASE_URL must be set")
.as_str(), .as_str(),
) )
@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking (
fn init_routes(state: ServerState) -> Router { fn init_routes(state: ServerState) -> Router {
Router::new() Router::new()
.route("/", get(get::index)) .route("/", get(get::index))
.route("/:id", get(get::id))
.route("/create", get(get::create_id)) .route("/create", get(get::create_id))
.route("/", post(post::create_link)) .route("/", post(post::create_link))
.layer(axum::Extension(state)) .layer(axum::Extension(state))