From 4c6b2c204283bd086c422e3fb66cd7ee8894df9f Mon Sep 17 00:00:00 2001 From: Shav Kinderlehrer Date: Sat, 6 Apr 2024 09:19:28 -0400 Subject: [PATCH] Implement post --- Cargo.lock | 118 +++++++++++++++++++++++++++++++++++----- Cargo.toml | 5 +- src/get.rs | 91 +++++++++++++++++++++++-------- src/main.rs | 55 +++++++++++++------ src/post.rs | 154 +++++++++++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 368 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ce79ab..1beafb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -192,6 +192,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chela" +version = "0.1.0" +dependencies = [ + "axum", + "color-eyre", + "eyre", + "info_utils", + "serde", + "sqids", + "sqlx", + "tokio", + "url", +] + [[package]] name = "color-eyre" version = "0.6.3" @@ -274,6 +289,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.58", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.58", +] + [[package]] name = "der" version = "0.7.9" @@ -285,6 +335,37 @@ dependencies = [ "zeroize", ] +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.58", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core", + "syn 2.0.58", +] + [[package]] name = "digest" version = "0.10.7" @@ -640,6 +721,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1314,6 +1401,18 @@ dependencies = [ "der", ] +[[package]] +name = "sqids" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f328f10ae594f0da04e5b2f82c089232697312661bca22d5d015a680c84639d" +dependencies = [ + "derive_builder", + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "sqlformat" version = "0.2.3" @@ -1533,6 +1632,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -1807,20 +1912,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", -] - -[[package]] -name = "url_shortener" -version = "0.1.0" -dependencies = [ - "axum", - "color-eyre", - "eyre", - "info_utils", "serde", - "sqlx", - "tokio", - "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index f7d1a5e..3358738 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "url_shortener" +name = "chela" version = "0.1.0" edition = "2021" @@ -11,6 +11,7 @@ color-eyre = "0.6.3" eyre = "0.6.12" info_utils = "2.2.3" serde = "1.0.197" +sqids = "0.4.1" sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "macros", "migrate", "tls-rustls"] } tokio = { version = "1.37.0", features = ["full"] } -url = "2.5.0" +url = { version = "2.5.0", features = ["serde"] } diff --git a/src/get.rs b/src/get.rs index eb2eb05..d79705e 100644 --- a/src/get.rs +++ b/src/get.rs @@ -3,7 +3,7 @@ use std::net::SocketAddr; use axum::extract::{ConnectInfo, Path}; use axum::http::HeaderMap; use axum::http::StatusCode; -use axum::response::{Html, IntoResponse, Redirect}; +use axum::response::{Html, IntoResponse}; use axum::Extension; use info_utils::prelude::*; @@ -11,11 +11,13 @@ use info_utils::prelude::*; use crate::ServerState; use crate::UrlRow; -pub async fn get_index() -> Html<&'static str> { +pub async fn index() -> Html<&'static str> { Html("hello, world!") } -pub async fn get_id( +/// # Panics +/// Will panic if `parse()` fails +pub async fn id( headers: HeaderMap, ConnectInfo(addr): ConnectInfo, Extension(state): Extension, @@ -29,34 +31,48 @@ pub async fn get_id( use_id.pop(); } - let item = sqlx::query_as!(UrlRow, "SELECT * FROM chela.urls WHERE id = $1", use_id) - .fetch_one(&state.db_pool) - .await; + let item: Result = + sqlx::query_as("SELECT * FROM chela.urls WHERE id = $1") + .bind(use_id) + .fetch_one(&state.db_pool) + .await; if let Ok(it) = item { if url::Url::parse(&it.url).is_ok() { if show_request { return Html(format!( - "
http://{}/{} -> {}
", + r#"
http://{}/{} -> {}
"#, state.host, it.id, it.url, it.url )) .into_response(); - } else { - log!("Redirecting {} -> {}", it.id, it.url); - save_analytics(headers, it.clone(), addr, state).await; - return Redirect::temporary(it.url.as_str()).into_response(); } + log!("Redirecting {} -> {}", it.id, it.url); + save_analytics(headers, it.clone(), addr, state).await; + let mut response_headers = HeaderMap::new(); + response_headers.insert("Cache-Control", "private, max-age=90".parse().unwrap()); + response_headers.insert("Location", it.url.parse().unwrap()); + return ( + StatusCode::MOVED_PERMANENTLY, + response_headers, + Html(format!( + r#"Redirecting to {}"#, + it.url, it.url + )), + ) + .into_response(); } + } else if let Err(err) = item { + warn!("{}", err); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Html(format!("
Internal error: {err}.
")), + ) + .into_response(); } - return (StatusCode::NOT_FOUND, Html("
404
")).into_response(); + (StatusCode::NOT_FOUND, Html("
Not found.
")).into_response() } -pub async fn save_analytics( - headers: HeaderMap, - item: UrlRow, - addr: SocketAddr, - state: ServerState, -) { +async fn save_analytics(headers: HeaderMap, item: UrlRow, addr: SocketAddr, state: ServerState) { let id = item.id; let ip = addr.ip().to_string(); let referer = match headers.get("referer") { @@ -80,16 +96,16 @@ pub async fn save_analytics( None => None, }; - let res = sqlx::query!( + let res = sqlx::query( " INSERT INTO chela.tracking (id,ip,referrer,user_agent) VALUES ($1,$2,$3,$4) ", - id, - ip, - referer, - user_agent ) + .bind(id.clone()) + .bind(ip.clone()) + .bind(referer) + .bind(user_agent) .execute(&state.db_pool) .await; @@ -97,3 +113,32 @@ VALUES ($1,$2,$3,$4) log!("Saved analytics for '{id}' from {ip}"); } } + +pub async fn create_id(Extension(state): Extension) -> Html { + Html(format!( + r#" + + + + {} URL Shortener + + +
+ + + + + +
+ + + "#, + state.host + )) +} diff --git a/src/main.rs b/src/main.rs index 57c6430..c775aa5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,10 @@ use axum::Router; use sqlx::postgres::PgPoolOptions; use sqlx::{Pool, Postgres}; +use sqids::Sqids; + +use serde::Deserialize; + use info_utils::prelude::*; pub mod get; @@ -15,15 +19,22 @@ pub mod post; pub struct ServerState { pub db_pool: Pool, pub host: String, + pub sqids: Sqids, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, sqlx::FromRow, PartialEq, Eq)] pub struct UrlRow { - pub index: i32, + pub index: i64, pub id: String, pub url: String, } +#[derive(Deserialize, Debug, Clone)] +pub struct CreateForm { + pub id: String, + pub url: url::Url, +} + #[tokio::main] async fn main() -> eyre::Result<()> { color_eyre::install()?; @@ -31,13 +42,26 @@ async fn main() -> eyre::Result<()> { let db_pool = init_db().await?; let host = std::env::var("CHELA_HOST").unwrap_or("localhost".to_string()); - let server_state = ServerState { db_pool, host }; + + let sqids = Sqids::builder() + .alphabet( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + .chars() + .collect(), + ) + .blocklist(["create".to_string()].into()) + .build()?; + let server_state = ServerState { + db_pool, + host, + sqids, + }; let address = std::env::var("LISTEN_ADDRESS").unwrap_or("0.0.0.0".to_string()); let port = std::env::var("LISTEN_PORT").unwrap_or("3000".to_string()); - let router = init_routes(server_state)?; - let listener = tokio::net::TcpListener::bind(format!("{}:{}", address, port)).await?; + let router = init_routes(server_state); + let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?; log!("Listening at {}:{}", address, port); axum::serve( listener, @@ -54,15 +78,15 @@ async fn init_db() -> eyre::Result> { .await?; log!("Successfully connected to database"); - sqlx::query!("CREATE SCHEMA IF NOT EXISTS chela") + sqlx::query("CREATE SCHEMA IF NOT EXISTS chela") .execute(&db_pool) .await?; log!("Created schema chela"); - sqlx::query!( + sqlx::query( " CREATE TABLE IF NOT EXISTS chela.urls ( - index SERIAL PRIMARY KEY, + index BIGSERIAL PRIMARY KEY, id TEXT NOT NULL UNIQUE, url TEXT NOT NULL ) @@ -72,7 +96,7 @@ CREATE TABLE IF NOT EXISTS chela.urls ( .await?; log!("Created table chela.urls"); - sqlx::query!( + sqlx::query( " CREATE TABLE IF NOT EXISTS chela.tracking ( timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, @@ -90,12 +114,11 @@ CREATE TABLE IF NOT EXISTS chela.tracking ( Ok(db_pool) } -fn init_routes(state: ServerState) -> eyre::Result { - let router = Router::new() - .route("/", get(get::get_index)) - .route("/:id", get(get::get_id)) +fn init_routes(state: ServerState) -> Router { + Router::new() + .route("/", get(get::index)) + .route("/:id", get(get::id)) + .route("/create", get(get::create_id)) .route("/", post(post::create_link)) - .layer(axum::Extension(state)); - - Ok(router) + .layer(axum::Extension(state)) } diff --git a/src/post.rs b/src/post.rs index d3d1257..b597f99 100644 --- a/src/post.rs +++ b/src/post.rs @@ -1,3 +1,155 @@ -pub async fn create_link() { +use axum::extract::Form; +use axum::http::StatusCode; +use axum::response::{Html, IntoResponse}; +use axum::Extension; +use info_utils::prelude::*; + +use crate::CreateForm; +use crate::ServerState; +use crate::UrlRow; + +#[derive(Debug, Clone, sqlx::FromRow, PartialEq, Eq)] +struct NextId { + id: String, + index: Option, + exists: bool, +} + +#[derive(Debug, Clone, sqlx::FromRow, PartialEq, Eq)] +struct NextIndex { + new_index: Option, +} + +pub async fn create_link( + Extension(state): Extension, + Form(form): Form, +) -> impl IntoResponse { + log!("Request to create '{}' -> {}", form.id, form.url.as_str()); + + let try_id = generate_id(form.clone(), state.clone()).await; + if let Ok(id) = try_id { + if id.exists { + log!("Serving cached id {} -> {}", id.id, form.url.as_str()); + return Html(format!( + r#"
http://{}/{} -> {}
"#, + state.host, + id.id, + form.url.as_str(), + form.url.as_str(), + )) + .into_response(); + } + let res; + if let Some(index) = id.index { + res = sqlx::query( + " +INSERT INTO chela.urls (index,id,url) +VALUES ($1,$2,$3) + ", + ) + .bind(index) + .bind(id.id.clone()) + .bind(form.url.as_str()) + .execute(&state.db_pool) + .await; + } else { + res = sqlx::query( + " +INSERT INTO chela.urls (id,url) +VALUES ($1,$2) + ", + ) + .bind(id.id.clone()) + .bind(form.url.as_str()) + .execute(&state.db_pool) + .await; + } + + match res { + Ok(_) => { + log!("Created new id {} -> {}", id.id, form.url.as_str()); + return ( + StatusCode::OK, + Html(format!( + r#"
http://{}/{} -> {}
"#, + state.host, + id.id, + form.url.as_str(), + form.url.as_str(), + )), + ) + .into_response(); + } + Err(err) => { + warn!("{}", err); + return (StatusCode::INTERNAL_SERVER_ERROR, Html("Internal error.")) + .into_response(); + } + } + } else if let Err(err) = try_id { + warn!("{}", err); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Html(format!("Internal error: {err}")), + ) + .into_response(); + } + + (StatusCode::INTERNAL_SERVER_ERROR, Html("Internal error.")).into_response() +} + +async fn generate_id(form: CreateForm, state: ServerState) -> eyre::Result { + if form.id.is_empty() { + let existing_row: Result = + sqlx::query_as("SELECT * FROM chela.urls WHERE url = $1") + .bind(form.url.as_str()) + .fetch_one(&state.db_pool) + .await; + if let Ok(row) = existing_row { + return Ok(NextId { + id: row.id, + index: None, + exists: true, + }); + } + + let next_index: NextIndex = sqlx::query_as( + "SELECT nextval(pg_get_serial_sequence('chela.urls', 'index')) as new_index", + ) + .fetch_one(&state.db_pool) + .await?; + + if let Some(index) = next_index.new_index { + let new_id = state.sqids.encode(&[index.try_into()?])?; + return Ok(NextId { + id: new_id, + index: Some(index), + exists: false, + }); + } + } else { + let existing_row: Result = + sqlx::query_as("SELECT * FROM chela.urls WHERE id = $1") + .bind(form.id.clone()) + .fetch_one(&state.db_pool) + .await; + if let Ok(row) = existing_row { + if row.url == form.url.as_str() { + return Ok(NextId { + id: row.id, + index: None, + exists: true, + }); + } + return Err(eyre::eyre!("id '{}' is already taken", row.id)); + } + return Ok(NextId { + id: form.id, + index: None, + exists: false, + }); + } + + Err(eyre::eyre!("Internal error")) }