Add support for Unix sockets and custom alphabet
This commit is contained in:
parent
f080854b84
commit
3b7d5454cc
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -194,16 +194,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chela"
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"color-eyre",
|
||||
"eyre",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"info_utils",
|
||||
"serde",
|
||||
"sqids",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tower",
|
||||
"url",
|
||||
]
|
||||
|
||||
|
@ -1,17 +1,20 @@
|
||||
[package]
|
||||
name = "chela"
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = "0.7.5"
|
||||
axum = { version = "0.7.5", features = ["tokio"] }
|
||||
color-eyre = "0.6.3"
|
||||
eyre = "0.6.12"
|
||||
hyper = "1.2.0"
|
||||
hyper-util = { version = "0.1.3", features = ["tokio"] }
|
||||
info_utils = "2.2.3"
|
||||
serde = "1.0.197"
|
||||
sqids = "0.4.1"
|
||||
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] }
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
tower = "0.4.13"
|
||||
url = { version = "2.5.0", features = ["serde"] }
|
||||
|
@ -63,6 +63,12 @@ A page that Chela will redirect to when `/` is requested instead of replying wit
|
||||
##### `CHELA_BEHIND_PROXY`
|
||||
If this variable is set, Chela will use the `X-Real-IP` header as the client IP address rather than the connection address.
|
||||
|
||||
##### `CHELA_UNIX_SOCKET`
|
||||
If you would like Chela to listen for HTTP requests over a Unix socket, set this variable to the socket path that it should use. By default, Chela will listen via a Tcp socket.
|
||||
|
||||
##### `CHELA_ALPHABET`
|
||||
If this variable is set, Chela will use the characters in `CHELA_ALPHABET` to create IDs for URLs. The default alphabet is `abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ`. See [here](https://sqids.org/faq#unique)
|
||||
|
||||
### Manually
|
||||
#### Build
|
||||
```bash
|
||||
|
29
src/get.rs
29
src/get.rs
@ -9,6 +9,7 @@ use axum::Extension;
|
||||
use info_utils::prelude::*;
|
||||
|
||||
use crate::ServerState;
|
||||
use crate::UdsConnectInfo;
|
||||
use crate::UrlRow;
|
||||
|
||||
pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoResponse {
|
||||
@ -34,6 +35,16 @@ pub async fn index(Extension(state): Extension<ServerState>) -> impl IntoRespons
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn id_unix(
|
||||
headers: HeaderMap,
|
||||
ConnectInfo(addr): ConnectInfo<UdsConnectInfo>,
|
||||
Extension(state): Extension<ServerState>,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let ip = format!("{:?}", addr.peer_addr);
|
||||
run_id(headers, ip, state, id).await
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Will panic if `parse()` fails
|
||||
pub async fn id(
|
||||
@ -42,8 +53,17 @@ pub async fn id(
|
||||
Extension(state): Extension<ServerState>,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
let mut show_request = false;
|
||||
let ip = get_ip(&headers, addr, &state).unwrap_or_default();
|
||||
run_id(headers, ip, state, id).await
|
||||
}
|
||||
|
||||
async fn run_id(
|
||||
headers: HeaderMap,
|
||||
ip: String,
|
||||
state: ServerState,
|
||||
id: String,
|
||||
) -> impl IntoResponse {
|
||||
let mut show_request = false;
|
||||
log!("Request for '{}' from {}", id.clone(), ip);
|
||||
let mut use_id = id;
|
||||
if use_id.ends_with('+') {
|
||||
@ -66,7 +86,7 @@ pub async fn id(
|
||||
.into_response();
|
||||
}
|
||||
log!("Redirecting {} -> {}", it.id, it.url);
|
||||
save_analytics(headers, it.clone(), addr, state).await;
|
||||
save_analytics(headers, it.clone(), ip, state).await;
|
||||
let mut response_headers = HeaderMap::new();
|
||||
response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap());
|
||||
response_headers.insert("Location", it.url.parse().unwrap());
|
||||
@ -92,9 +112,8 @@ pub async fn id(
|
||||
(StatusCode::NOT_FOUND, Html("<pre>Not found.</pre>")).into_response()
|
||||
}
|
||||
|
||||
async fn save_analytics(headers: HeaderMap, item: UrlRow, addr: SocketAddr, state: ServerState) {
|
||||
async fn save_analytics(headers: HeaderMap, item: UrlRow, ip: String, state: ServerState) {
|
||||
let id = item.id;
|
||||
let ip = get_ip(&headers, addr, &state);
|
||||
let referer = match headers.get("referer") {
|
||||
Some(it) => {
|
||||
if let Ok(i) = it.to_str() {
|
||||
@ -130,7 +149,7 @@ VALUES ($1,$2,$3,$4)
|
||||
.await;
|
||||
|
||||
if res.is_ok() {
|
||||
log!("Saved analytics for '{id}' from {}", ip.unwrap_or_default());
|
||||
log!("Saved analytics for '{id}' from {}", ip);
|
||||
}
|
||||
}
|
||||
|
||||
|
115
src/main.rs
115
src/main.rs
@ -1,18 +1,23 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use url::Url;
|
||||
|
||||
use axum::extract::connect_info;
|
||||
use axum::http::Request;
|
||||
use axum::routing::{get, post};
|
||||
use axum::Router;
|
||||
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::{Pool, Postgres};
|
||||
|
||||
use sqids::Sqids;
|
||||
|
||||
use serde::Deserialize;
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use hyper_util::server;
|
||||
|
||||
use info_utils::prelude::*;
|
||||
use serde::Deserialize;
|
||||
use sqids::Sqids;
|
||||
use tower::Service;
|
||||
use url::Url;
|
||||
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod get;
|
||||
pub mod post;
|
||||
@ -39,22 +44,39 @@ pub struct CreateForm {
|
||||
pub url: url::Url,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct UdsConnectInfo {
|
||||
pub peer_addr: Arc<tokio::net::unix::SocketAddr>,
|
||||
pub peer_cred: tokio::net::unix::UCred,
|
||||
}
|
||||
|
||||
impl connect_info::Connected<&tokio::net::UnixStream> for UdsConnectInfo {
|
||||
fn connect_info(target: &tokio::net::UnixStream) -> Self {
|
||||
let peer_addr = target.peer_addr().unwrap();
|
||||
let peer_cred = target.peer_cred().unwrap();
|
||||
|
||||
Self {
|
||||
peer_addr: Arc::new(peer_addr),
|
||||
peer_cred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
let db_pool = init_db().await?;
|
||||
let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string());
|
||||
let host = env::var("CHELA_HOST").unwrap_or("localhost".to_string());
|
||||
let alphabet = env::var("CHELA_ALPHABET")
|
||||
.unwrap_or("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".to_string());
|
||||
let sqids = Sqids::builder()
|
||||
.alphabet(
|
||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
.chars()
|
||||
.collect(),
|
||||
)
|
||||
.alphabet(alphabet.chars().collect())
|
||||
.blocklist(["create".to_string()].into())
|
||||
.build()?;
|
||||
let main_page_redirect = std::env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
|
||||
let behind_proxy = std::env::var("CHELA_BEHIND_PROXY").is_ok();
|
||||
let main_page_redirect = env::var("CHELA_MAIN_PAGE_REDIRECT").unwrap_or_default();
|
||||
let behind_proxy = env::var("CHELA_BEHIND_PROXY").is_ok();
|
||||
let server_state = ServerState {
|
||||
db_pool,
|
||||
host,
|
||||
@ -63,17 +85,69 @@ async fn main() -> eyre::Result<()> {
|
||||
behind_proxy,
|
||||
};
|
||||
|
||||
let address = std::env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
|
||||
let port = 3000;
|
||||
serve(server_state).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let router = init_routes(server_state);
|
||||
async fn serve(state: ServerState) -> eyre::Result<()> {
|
||||
let unix_socket = env::var("CHELA_UNIX_SOCKET").unwrap_or_default();
|
||||
if unix_socket.is_empty() {
|
||||
let router = Router::new()
|
||||
.route("/", get(get::index))
|
||||
.route("/:id", get(get::id))
|
||||
.route("/create", get(get::create_id))
|
||||
.route("/", post(post::create_link))
|
||||
.layer(axum::Extension(state));
|
||||
let address = env::var("CHELA_LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string());
|
||||
let port = 3000;
|
||||
let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
|
||||
log!("Listening at {}:{}", address, port);
|
||||
axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
router.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
let router = Router::new()
|
||||
.route("/", get(get::index))
|
||||
.route("/:id", get(get::id_unix))
|
||||
.route("/create", get(get::create_id))
|
||||
.route("/", post(post::create_link))
|
||||
.layer(axum::Extension(state));
|
||||
let unix_socket_path = std::path::Path::new(&unix_socket);
|
||||
if unix_socket_path.exists() {
|
||||
tokio::fs::remove_file(unix_socket_path).await?;
|
||||
}
|
||||
let listener = tokio::net::UnixListener::bind(unix_socket_path)?;
|
||||
log!("Listening via Unix socket at {}", unix_socket);
|
||||
tokio::spawn(async move {
|
||||
let mut service = router.into_make_service_with_connect_info::<UdsConnectInfo>();
|
||||
loop {
|
||||
let (socket, _remote_addr) = listener.accept().await.unwrap();
|
||||
let tower_service = match service.call(&socket).await {
|
||||
Ok(value) => value,
|
||||
Err(err) => match err {},
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let socket = TokioIo::new(socket);
|
||||
let hyper_service =
|
||||
hyper::service::service_fn(move |request: Request<Incoming>| {
|
||||
tower_service.clone().call(request)
|
||||
});
|
||||
|
||||
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(socket, hyper_service)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to serve connection: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -81,7 +155,7 @@ async fn init_db() -> eyre::Result<Pool<Postgres>> {
|
||||
let db_pool = PgPoolOptions::new()
|
||||
.max_connections(15)
|
||||
.connect(
|
||||
std::env::var("DATABASE_URL")
|
||||
env::var("DATABASE_URL")
|
||||
.expect("DATABASE_URL must be set")
|
||||
.as_str(),
|
||||
)
|
||||
@ -128,7 +202,6 @@ CREATE TABLE IF NOT EXISTS chela.tracking (
|
||||
fn init_routes(state: ServerState) -> Router {
|
||||
Router::new()
|
||||
.route("/", get(get::index))
|
||||
.route("/:id", get(get::id))
|
||||
.route("/create", get(get::create_id))
|
||||
.route("/", post(post::create_link))
|
||||
.layer(axum::Extension(state))
|
||||
|
Loading…
Reference in New Issue
Block a user