O link shortener da cumperativa.xyz
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

1289 lines
48 KiB

use argh::FromArgs;
use validators::traits::ValidateString;
use warp::{http::Response, hyper::StatusCode, Filter};
#[macro_use]
/// Module containing custom defined macros.
mod macros {
/// Macros useful for debug contexts.
///
/// For example, `ifdbg!(expr)` replaces the $expr with () when the compile
/// profile is set to `RELEASE`.
#[macro_use]
pub mod debug {
#[cfg(debug_assertions)]
/// debuginfo!("debug info", "release info") is functionally equivalent
/// to
///
/// ```
/// if DEBUG {
/// "debug info"
/// } else {
/// "release info"
/// }
/// ```
///
/// An overloaded `debuginfo!(str)` is defined to mean
/// `debuginfo!(str, "Internal error.")`.
macro_rules! debuginfo {
($log:literal) => {
$log
};
($log:literal,$alt:literal) => {
$log
};
}
#[cfg(not(debug_assertions))]
macro_rules! debuginfo {
($log:literal) => {
"Internal error."
};
($log:literal,$alt:literal) => {
$alt
};
}
#[cfg(debug_assertions)]
/// `ifdbg!($expr)` is functionally equivalent to
///
/// ```
/// if DEBUG {
/// $expr
/// } else {
/// ()
/// }
/// ```
///
/// It can be particularly useful in combination with `eprintln!()`,
/// i.e.,
///
/// ```
/// ifdbg!(eprintln!("Debug error information."))
/// ```
macro_rules! ifdbg {
($expr:expr) => {
$expr;
};
}
#[cfg(not(debug_assertions))]
macro_rules! ifdbg {
($expr:expr) => {
()
};
}
}
}
/// Affine to static configuration.
mod conf {
use serde::{Deserialize, Serialize};
use std::{net::IpAddr, path::PathBuf, str::FromStr};
use validators::prelude::*;
use warp::{filters::BoxedFilter, Filter};
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Configuration settings specific to the (Redis) database.
/// See the `Default` implementation for sensible values.
pub struct DbConfig {
/// The URL of the Redis database.
pub address: String,
/// The expiration time of entries, in seconds.
pub expire_seconds: usize,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Rules for constructing (shortened URL) slugs.
pub struct SlugRules {
/// (Exact) length of the slugs.
pub length: usize,
/// Valid characters to include in the slug.
pub chars: String,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Configuration settings for what happens for `GET /`, specifically
/// whether and which file or directory to serve.
pub enum ServeDirRules {
/// Serve the specified file.
File(PathBuf),
/// Serve the specified directory
/// (enumerated if `index.html` is not present).
Dir(PathBuf),
}
#[derive(Serialize, Deserialize, Debug, Validator, Clone)]
#[validator(ip(local(Allow), port(Must)))]
/// Struct specifying where the HTTP server should be served.
///
/// This struct is meant to be parsed from a larger configuration struct.
pub struct ServeAddr {
/// Serve the HTTP server at this IP
pub ip: IpAddr,
/// Serve the HTTP server at this port
pub port: u16,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Configuration for the service of the HTTP server.
///
/// See the definitions of [`ServeDirRules`] and [`ServeAddr`] for more
/// information on the specific configuration.
pub struct ServeRules {
/// Configuration for the contents served.
pub dir: ServeDirRules,
/// Configuration for the serve location.
pub addr: ServeAddr,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Configuration of logging by lonk.
pub struct LogRules {
/// Where to write error logs to. The file will be appended to.
pub error_log_file: PathBuf,
/// Where to write access ogs to. The file will be appended to.
pub access_log_file: PathBuf,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
/// Configuration struct. This struct is a typed representation of the
/// configuration file, with each of the domain-specific configurations
/// defined as their own type (in reflection of a JSON structure, for
/// example). See the definition of each of the member structs for more
/// information.
pub struct Config {
/// The "version" of the configuration, corresponding to the MAJOR in
/// semantic versioning. Should be increased every time the
/// configuration structure suffers breaking changes.
/// This value is optional because sufficiently old configuration files
/// may not have a version field.
pub version: Option<usize>,
/// Configuration regarding the Redis database.
pub db: DbConfig,
/// Configuration regarding logging.
pub log_rules: LogRules,
/// Configuration regarding the types of (URL shorten) slugs produced.
pub slug_rules: SlugRules,
/// Configuration regarding where and how the HTTP server is served.
pub serve_rules: ServeRules,
}
/// Get the configuration version field that this version of lonk expects.
pub fn config_version() -> usize {
usize::from_str(env!("CARGO_PKG_VERSION_MAJOR")).unwrap()
}
pub enum ConfigParseError {
SerdeError(serde_json::error::Error),
OldVersion(usize),
ServeFileNotFile(PathBuf),
ServeFileNotExists(PathBuf),
ServeDirNotDir(PathBuf),
ServeDirNotExists(PathBuf),
AccessLogDirectoryNotExists(PathBuf),
ErrorLogDirectoryNotExists(PathBuf),
}
impl Config {
pub fn from_sync_buffer<R: std::io::Read>(
buffer: std::io::BufReader<R>,
) -> Result<Self, ConfigParseError> {
let parsed: Config =
serde_json::from_reader(buffer).map_err(|err| ConfigParseError::SerdeError(err))?;
parsed.validate()
}
fn validate(self) -> Result<Self, ConfigParseError> {
// Check configuration version
let parsed_version = self.version.unwrap_or(0);
if parsed_version != config_version() {
return Err(ConfigParseError::OldVersion(parsed_version));
}
// Check existence of serve file or directory
match &self.serve_rules.dir {
ServeDirRules::File(file) => {
if !file.exists() {
return Err(ConfigParseError::ServeFileNotExists(file.clone()));
}
if !file.is_file() {
return Err(ConfigParseError::ServeFileNotFile(file.clone()));
}
}
ServeDirRules::Dir(dir) => {
if !dir.exists() {
return Err(ConfigParseError::ServeDirNotExists(dir.clone()));
}
if !dir.is_dir() {
return Err(ConfigParseError::ServeDirNotDir(dir.clone()));
}
}
}
// Check access and error log parent directories
// - Access log file
let weak_canonical = normalize_path(&self.log_rules.access_log_file);
if let Some(parent) = weak_canonical.parent() {
if !parent.exists() {
return Err(ConfigParseError::AccessLogDirectoryNotExists(
parent.to_path_buf(),
));
}
}
// - Error log file
let weak_canonical = normalize_path(&self.log_rules.error_log_file);
if let Some(parent) = weak_canonical.parent() {
if !parent.exists() {
return Err(ConfigParseError::ErrorLogDirectoryNotExists(
parent.to_path_buf(),
));
}
}
Ok(self)
}
}
/// Yanked from the source of cargo. Weaker than canonicalize, because it
/// doesn't require the target file to exist.
fn normalize_path(path: &std::path::Path) -> PathBuf {
use std::path::*;
let mut components = path.components().peekable();
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
components.next();
PathBuf::from(c.as_os_str())
} else {
PathBuf::new()
};
for component in components {
match component {
Component::Prefix(..) => unreachable!(),
Component::RootDir => {
ret.push(component.as_os_str());
}
Component::CurDir => {}
Component::ParentDir => {
ret.pop();
}
Component::Normal(c) => {
ret.push(c);
}
}
}
ret
}
impl ConfigParseError {
pub fn panic_with_message(self, config_file_name: &str) -> ! {
match self {
ConfigParseError::SerdeError(err) => match err.classify() {
serde_json::error::Category::Io => {
eprintln!("IO error when reading configuration file.")
}
serde_json::error::Category::Syntax => eprintln!(
concat!(
"Configuration file is syntactically incorrect.\n",
"See {}:{}:{}."
),
config_file_name,
err.line(),
err.column()
),
serde_json::error::Category::Data => eprintln!(
concat!("Error deserializing configuration file; expected different data type.\n",
"See {}:{}:{}."),
config_file_name,
err.line(),
err.column()
),
serde_json::error::Category::Eof => {
eprintln!("Unexpected end of file when reading configuration file.")
}
},
ConfigParseError::OldVersion(old_version) => {
eprintln!(
concat!("Configuration file has outdated version.\n",
"Expected version field to be {} but got {}."),
old_version,
config_version()
);
}
ConfigParseError::ServeDirNotExists(dir) => {
eprintln!(
"Configuration file indicates directory {} should be served, but it does not exist.",
dir.to_string_lossy()
)
}
ConfigParseError::ServeDirNotDir(dir) => {
eprintln!(
"Configuration file indicates directory {} should be served, but it is not a directory.",
dir.to_string_lossy()
)
}
ConfigParseError::ServeFileNotExists(file) => {
eprintln!(
"Configuration file indicates file {} should be served, but it does not exist.",
file.to_string_lossy()
)
}
ConfigParseError::ServeFileNotFile(file) => {
eprintln!(
"Configuration file indicates file {} should be served, but it is not a file.",
file.to_string_lossy()
)
}
ConfigParseError::AccessLogDirectoryNotExists(dir) => {
eprintln!("Access log file should have parent directory {}, but this directory does not exist.", dir.to_string_lossy())
}
ConfigParseError::ErrorLogDirectoryNotExists(dir) => {
eprintln!("Error log file should have parent directory {}, but this directory does not exist.", dir.to_string_lossy())
}
}
std::process::exit(1);
}
}
// Default implementations
impl Default for DbConfig {
fn default() -> Self {
Self {
address: "redis://127.0.0.1:6379".to_string(),
expire_seconds: 259200, // 3 days
}
}
}
impl Default for SlugRules {
fn default() -> Self {
Self {
length: 5,
chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"
.to_string(),
}
}
}
impl ServeDirRules {
pub fn to_filter(&self) -> BoxedFilter<(warp::fs::File,)> {
match self {
ServeDirRules::File(file) => warp::fs::file(file.clone()).boxed(),
ServeDirRules::Dir(dir) => warp::fs::dir(dir.clone()).boxed(),
}
}
}
impl Default for ServeDirRules {
fn default() -> Self {
ServeDirRules::Dir("/etc/lonk/served".into())
}
}
impl Default for ServeAddr {
fn default() -> Self {
Self {
ip: [127, 0, 0, 1].into(),
port: 8080,
}
}
}
impl Default for ServeRules {
fn default() -> Self {
Self {
dir: Default::default(),
addr: ServeAddr::default(),
}
}
}
impl Default for LogRules {
fn default() -> Self {
Self {
error_log_file: "/etc/lonk/log/error.log".into(),
access_log_file: "/etc/lonk/log/access.log".into(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
version: Some(config_version()),
db: Default::default(),
slug_rules: Default::default(),
serve_rules: Default::default(),
log_rules: Default::default(),
}
}
}
}
/// Affine to live service.
mod service {
use validators::prelude::*;
#[derive(Validator)]
#[validator(http_url(local(NotAllow)))]
#[derive(Clone, Debug)]
#[allow(dead_code)]
/// A struct representing a URL.
pub struct HttpUrl {
url: validators::url::Url,
is_https: bool,
}
#[derive(Validator)]
#[validator(domain(ipv4(Allow), local(NotAllow), at_least_two_labels(Must), port(Allow)))]
#[allow(dead_code)]
pub struct Domain {
domain: String,
port: Option<u16>,
}
impl std::fmt::Display for HttpUrl {
fn fmt(&self, f: &mut validators_prelude::Formatter<'_>) -> std::fmt::Result {
self.url.fmt(f)
}
}
impl HttpUrl {
/// Transform this into an `Err(())` if the url is not a `Domain`.
pub fn strict(self) -> Result<Self, ()> {
match self.url.domain() {
None => return Err(()),
Some(domain) => {
if Domain::parse_string(domain).is_err() {
return Err(());
}
}
}
Ok(self)
}
}
/// Database management, including messaging and work stealing.
pub mod db {
use super::{slug::Slug, HttpUrl};
use async_object_pool::Pool;
use redis::Commands;
use std::sync::Arc;
use tokio::sync;
use validators::prelude::*;
#[derive(Debug)]
/// Struct representing a connection to the Redis database, for
/// management of Slug <-> URL registry.
///
/// Behind the curtains, `SlugDatabase` implements an asynchronous
/// scheme, based on message passing and a continuously running `Tokio`
/// worker. This results in `SlugDatabase` being a thin wrapper around
/// a single `mpsc::UnboundedSender` channel. Because this is the
/// single producer, when `SlugDatabase` is dropped, every related
/// `Tokio` worker is shut down as well.
///
/// See the documentation of [`SlugDbMessage`] for more information on
/// the specific messages to be exchanged with the `SlugDatabase`.
pub struct SlugDatabase {
tx: sync::mpsc::UnboundedSender<SlugDbMessage>,
}
#[derive(Clone, Debug)]
/// Response for a request to add a URL to the database.
pub enum AddResult {
/// The URL was successfully added, and assigned this slug.
Success(Slug),
/// The URL could not be added to the database.
Fail,
}
#[derive(Clone, Debug)]
/// Response for a request to translate a slug to a URL.
pub enum GetResult {
/// The corresponding URL was found, and has this value.
Found(HttpUrl),
/// The given slug does not exist in the database.
NotFound,
/// There was some internal error when trying to translate the slug.
InternalError,
}
/// Request to the slug database for a particular action.
///
/// Since the [`SlugDatabase`] operates asynchronously and on a
/// message-passing basis (see the documentation of [`SlugDatabase`] for
/// more information), actions are performed by sending such a message,
/// and then (asynchronously) listening on the provided `oneshot`
/// channel for a response from the database.
///
/// For example, when inserting a slug:
///
/// ```
/// let requested_slug = "my_Slug";
/// let url = "https://example.com";
/// let (tx, rx) = sync::oneshot::channel();
/// self.tx
/// .send(SlugDbMessage::Add(requested_slug, url, tx))
/// .expect("The SlugDbMessage channel is unexpectedly closed.");
/// let db_response = rx.await;
/// ```
enum SlugDbMessage {
/// Insert a Slug -> URL registry into the database.
Insert(Slug, HttpUrl, sync::oneshot::Sender<AddResult>),
/// Get the URL associated to a slug (if it exists).
Get(Slug, sync::oneshot::Sender<GetResult>),
}
impl core::fmt::Debug for SlugDbMessage {
fn fmt(&self, f: &mut validators_prelude::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Insert(arg0, arg1, _) => f
.debug_tuple("Add")
.field(arg0)
.field(arg1)
.field(&"oneshot::Sender<AddResult>")
.finish(),
SlugDbMessage::Get(arg0, _) => f
.debug_tuple("Get")
.field(arg0)
.field(&"oneshot::Sender<Url>")
.finish(),
}
}
}
impl SlugDatabase {
/// Create a new slug database form a Redis `Client` object.
/// This is the preferred way to create a new `SlugDatabase`.
///
/// Currently, every entry in the database is expected/set to
/// expire. This may be subject to change in the future.
///
/// Example:
///
/// ```
/// let redis_client = redis::Client::open("redis://127.0.0.1:6379")
/// .expect("Error opening Redis database.");
/// let expiration = 1000; // Entries expire after 1000 seconds.
/// SlugDatabase::from_client(redis_client, expiration)
/// ```
pub fn from_client(client: redis::Client, expire_seconds: usize) -> Self {
let (tx, rx) = sync::mpsc::unbounded_channel::<SlugDbMessage>();
tokio::spawn(SlugDatabase::db_dispatch_worker(client, rx, expire_seconds));
SlugDatabase { tx }
}
/// Tokio thread responsible for receiving requests of connection to
/// the SlugDatabase and dispatching them to working threads.
async fn db_dispatch_worker(
client: redis::Client,
mut rx: sync::mpsc::UnboundedReceiver<SlugDbMessage>,
expire_seconds: usize,
) {
// redis::Connection pool.
// Per the documentation of the redis crate, these are not
// pooled internally.
let pool = Arc::new(sync::Mutex::new(Pool::new(100)));
// Receive and dispatch the incoming messages.
while let Some(msg) = { rx.recv().await } {
// Get a connection from the pool
// (or make a new one if needed)
let mut connection = {
(*pool.lock().await)
.take_or_create(|| {
client
.get_connection()
.expect("Could not open connection to Redis server.")
})
.await
};
let pool = pool.clone();
tokio::spawn(async move {
// Dispatch the message
match msg {
SlugDbMessage::Insert(requested_slug, url, response_channel) => {
SlugDatabase::dispatch_insert(
&mut connection,
requested_slug,
url,
response_channel,
expire_seconds,
)
.await
}
SlugDbMessage::Get(slug, response_channel) => {
SlugDatabase::dispatch_get(&mut connection, slug, response_channel)
.await
}
}
// Put the redis connection item back into the pool
(*pool.lock().await).put(connection).await;
});
}
}
/// Dispatch a request to the database to insert a Slug -> URL
/// registry.
///
/// This function is not expected to be called directly, but rather
/// by the `db_dispatch_worker` function, as a result of a
/// `SlugDbMessage::Insert` message.
async fn dispatch_insert(
connection: &mut redis::Connection,
requested_slug: Slug,
url: HttpUrl,
response_channel: sync::oneshot::Sender<AddResult>,
expire_seconds: usize,
) {
let url_str = url.to_string();
let url_key = format!("url:{}", url_str);
// Check that the URL is not already present in the DB
// This is, to some extent, a protection against collision attacks.
match connection.get::<String, Option<String>>(url_key.clone()) {
Ok(Some(slug)) => {
let slug_key = format!("slug:{}", slug);
// The URL was already present.
// Refresh the expiration.
// (If this operation fails it cannot be corrected for.)
connection
.expire::<String, ()>(url_key, expire_seconds)
.ok();
connection
.expire::<String, ()>(slug_key, expire_seconds)
.ok();
// Return the original slug.
response_channel
.send(AddResult::Success(Slug::unchecked_from_str(slug)))
.ok();
return;
}
Err(err) => {
response_channel.send(AddResult::Fail).ok();
ifdbg!(eprintln!("{}", err));
return;
}
Ok(None) => {} // continue with insertion
};
// The URL is not present in the database; insert it.
let slug_key = format!("slug:{}", requested_slug.inner_str());
// Make sure that there's no collision with the slug; if so, we
// are in one of two situations: either we got really unlucky,
// or the slug space has been exhausted.
// In any case, to be safe, fail the operation.
match connection.get::<String, Option<String>>(slug_key.clone()) {
Ok(Some(_)) => {
// Collision!
response_channel.send(AddResult::Fail).ok();
eprintln!(
concat!(
"Collision for slug {}!\n",
"Slug space may have been exhausted.\n",
"If you see this message repeatedly,",
"consider increasing the slug size."
),
slug_key
);
return;
}
Err(err) => {
// Internal error in communication.
response_channel.send(AddResult::Fail).ok();
ifdbg!(eprintln!("{}", err));
return;
}
Ok(None) => {} // continue with insertion
};
let add_result = connection.set_ex::<String, String, ()>(
slug_key,
url_str.clone(),
expire_seconds,
);
if add_result.is_ok() {
connection
.set_ex::<String, String, ()>(
url_key,
requested_slug.inner_str().to_string(),
expire_seconds,
)
.ok(); // If this failed we have no way of correcting for it.
}
response_channel
.send(match add_result {
Ok(_) => AddResult::Success(requested_slug),
Err(err) => {
ifdbg!(eprintln!("{}", err));
AddResult::Fail
}
})
.ok(); // If the receiver has hung up there's nothing we can do.
}
/// Dispatch a request to get from the database the URL associated
/// to a given slug (if it exists).
///
/// This function is not expected to be called directly, but rather
/// by the `db_dispatch_worker` function, as a result of a
/// `SlugDbMessage::Get` message.
async fn dispatch_get(
connection: &mut redis::Connection,
slug: Slug,
response_channel: sync::oneshot::Sender<GetResult>,
) {
let result: Result<Option<String>, _> =
connection.get(format!("slug:{}", slug.inner_str()));
match result {
Ok(Some(url)) => response_channel.send(GetResult::Found(
HttpUrl::parse_string(url).expect("Mismatched URL in the database."),
)),
Ok(None) => response_channel.send(GetResult::NotFound),
Err(err) => {
ifdbg!(eprintln!("{}", err));
response_channel.send(GetResult::InternalError)
}
}
.ok(); // If the receiver has hung up there's nothing we can do.
}
/// Request a slug <-> URL registry to be inserted into the
/// database.
///
/// This is an asynchronous operation; as such, a
/// `oneshot::Receiver` is returned, rather than a result. One
/// should await on this receiver for the result of the operation.
/// (Note that, as a consequence of this, this function is *not*
/// asynchronous.)
///
/// Note that the `requested_slug` argument is just that: a request.
/// If the response is a `Success`, it will include the actual new
/// slug. In particular, if the URL was already present in the
/// database, the already associated slug will be returned.
///
/// Example:
///
/// ```
/// match db.insert_slug("mY_Slug", url).await {
/// Success(slug) => {},// `slug` now points to the URL.
/// Fail => panic!() // There was some problem, and the
/// // registry was not inserted.
/// }
/// ```
pub fn insert_slug(
&self,
requested_slug: Slug,
url: HttpUrl,
) -> sync::oneshot::Receiver<AddResult> {
let (tx, rx) = sync::oneshot::channel();
self.tx
.send(SlugDbMessage::Insert(requested_slug, url, tx))
.expect("The SlugDbMessage channel is unexpectedly closed.");
rx
}
/// Request the URL associated to the provided slug, if it exists.
///
/// This is an asynchronous operation; as such, a
/// `oneshot::Receiver` is returned, rather than a result. One
/// should await on this receiver for the result of the operation.
/// (Note that, as a consequence of this, this function is *not*
/// asynchronous.)
pub fn get_slug(&self, slug: Slug) -> sync::oneshot::Receiver<GetResult> {
let (tx, rx) = sync::oneshot::channel();
self.tx
.send(SlugDbMessage::Get(slug, tx))
.expect("The SlugDbMessage channel is unexpectedly closed.");
rx
}
}
}
/// Affine to slug definition, generation, parsing, etc.
pub mod slug {
use crate::conf::SlugRules;
use rand::prelude::*;
use std::collections::BTreeSet;
/// A struct responsible for constructing random slugs, or validating
/// existing ones.
pub struct SlugFactory {
slug_length: usize,
slug_chars: BTreeSet<char>,
slug_chars_indexable: Vec<char>,
}
#[derive(Clone, Debug)]
/// A slug, as in the sequence of characters in the URL shortener that
/// aliases to a given URL.
///
/// Usually this is a/the argument in the `GET` request to the link
/// shortener.
///
/// `Slug`s are typically produced by [`SlugFactory`]s, or given by the
/// user.
pub struct Slug(String);
impl Slug {
/// Create a `Slug` directly from a `String`. This will **not**
/// check that the given string is compatible with the working
/// [`SlugFactory`], and so should be used with care.
pub fn unchecked_from_str(slug_str: String) -> Slug {
Slug(slug_str)
}
pub fn inner_str<'this>(&'this self) -> &'this str {
&self.0
}
}
/// Why a provided slug is invalid.
pub enum InvalidSlug {
/// The slug has more characters that defined for the [`SlugFactory`].
TooLong,
/// The slug has a character that was not given to the [`SlugFactory`].
BadChar,
}
impl SlugFactory {
/// Create a new `SlugFactory`, according to the provided `SlugRules`.
///
/// This is the preferred way to create a `SlugFactory`.
pub fn from_rules(rules: SlugRules) -> Self {
let mut slug_chars = BTreeSet::<char>::new();
slug_chars.extend(rules.chars.chars());
SlugFactory {
slug_length: rules.length,
slug_chars,
slug_chars_indexable: rules.chars.chars().collect(),
}
}
/// Transform a literal string into a `Slug` according to the rules
/// of this `SlugFactory`.
///
/// In case the provided literal is incompatible with the
/// `SlugFactory`'s rules, an `Err(InvalidSlug)` is provided
/// explaining why.
pub fn parse_str(&self, s: &str) -> Result<Slug, InvalidSlug> {
for (i, char) in s.chars().enumerate() {
if i >= self.slug_length {
return Err(InvalidSlug::TooLong);
}
if !self.slug_chars.contains(&char) {
return Err(InvalidSlug::BadChar);
}
}
Ok(Slug(s.to_string()))
}
/// Generate a random `Slug` according to the rules of this
/// `SlugFactory`.
///
/// Internal randomness is handled by the `rand` crate, and obeys a
/// `Uniform` distribution over the list of valid characters for
/// each character.
pub fn generate(&self) -> Slug {
// Generate indices then map
let distribution =
rand::distributions::Uniform::new(0, self.slug_chars_indexable.len());
let slug_str = distribution
.sample_iter(rand::thread_rng())
.take(self.slug_length)
.map(|i| self.slug_chars_indexable[i])
.collect::<String>();
Slug(slug_str)
}
}
}
/// Affine to logging
pub mod log {
use std::path::PathBuf;
use tokio::{fs::OpenOptions, io::AsyncWriteExt, sync};
/// A struct responsible for logging events, per messages received from
/// other processes.
pub struct Logger {
access_tx: sync::mpsc::UnboundedSender<String>,
error_tx: sync::mpsc::UnboundedSender<String>,
}
impl Logger {
pub fn from_log_rules(config: &crate::conf::LogRules) -> Self {
// Create the communication channels
let (access_tx, access_rx) = sync::mpsc::unbounded_channel::<String>();
let (error_tx, error_rx) = sync::mpsc::unbounded_channel::<String>();
// Start the logging tasks
tokio::spawn(Self::logging_task(
access_rx,
config.access_log_file.clone(),
));
tokio::spawn(Self::logging_task(error_rx, config.error_log_file.clone()));
// Done
Logger {
access_tx,
error_tx,
}
}
/// Log a message into the access log file.
///
/// Returns a Result with empty type; if posting the log message
/// failed for any reason, it's unlikely to recover, and the user
/// should decide either to stop logging, ignore these errors, or
/// halt the program.
pub fn access(&self, msg: String) -> Result<(), ()> {
self.access_tx.send(msg).map_err(|_| ())
}
/// Log a message into the error log file.
///
/// Returns a Result with empty type; if posting the log message
/// failed for any reason, it's unlikely to recover, and the user
/// should decide either to stop logging, ignore these errors, or
/// halt the program.
pub fn error(&self, msg: String) -> Result<(), ()> {
self.error_tx.send(msg).map_err(|_| ())
}
/// The task responsible for receiving the log messages and actually
/// writing them into the corresponding files. One task is created
/// for each target file.
async fn logging_task(mut rx: sync::mpsc::UnboundedReceiver<String>, into: PathBuf) {
// Open the log file in append mode
let file = OpenOptions::new()
.append(true)
.create(true)
.open(into.clone())
.await;
if let Err(e) = file {
eprintln!(
concat!(
"Could not open {} for logging, with error:\n",
"{}\n",
"Future logging may result in errors."
),
into.clone().to_string_lossy(),
e
);
return;
}
let mut file = file.unwrap();
// Listen to the logging message channel
while let Some(log) = rx.recv().await {
let write_result = file
.write_buf(
&mut format!(
"{} ",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Bad system time")
.as_secs()
)
.as_bytes(),
)
.await
.and(file.write_buf(&mut log.as_bytes()).await);
if let Err(e) = write_result {
eprintln!(
concat!(
"Error writing to {}!\n",
"{}\n",
"Continuing, but future logging may error again."
),
into.clone().to_string_lossy(),
e
)
}
}
// All logging tx channels were dropped, close this task
}
}
}
}
use service::*;
/// Shorten a URL.
async fn shorten(
slug_factory: &slug::SlugFactory,
db: &db::SlugDatabase,
b64str: &str,
logger: &log::Logger,
) -> Result<slug::Slug, (StatusCode, String)> {
// Parse the URL given by the user. It should arrive as a Base64 string,
// and anything other than this should cleanly result in an HTTP rejection.
let url = {
let unencoded_bytes = base64::decode_config(b64str, base64::STANDARD).map_err(|_| {
(
warp::http::StatusCode::BAD_REQUEST,
debuginfo!("Could not decode base64 str.", "Invalid Base64.").into(),
)
})?;
let url_str = std::str::from_utf8(&unencoded_bytes[..]).map_err(|_| {
(
warp::http::StatusCode::BAD_REQUEST,
debuginfo!(
"Parsed bytes of base64 str, but could not decode as UTF8.",
"Invalid Base64."
)
.into(),
)
})?;
HttpUrl::parse_string(url_str)
.map_err(|_| (warp::http::StatusCode::BAD_REQUEST, "Invalid URL.".into()))?
.strict()
.map_err(|_| (warp::http::StatusCode::BAD_REQUEST, "Invalid URL.".into()))?
};
// Generate a (candidate) new slug for the incoming URL...
let new_slug = slug_factory.generate();
// ...and attempt to insert it into the database.
// Failure to do so is reported to the user.
let insert_result = db.insert_slug(new_slug, url.clone()).await;
match insert_result {
Ok(result) => match result {
service::db::AddResult::Success(slug) => {
logger
.access(format!("{} -> {}\n", slug.inner_str(), url))
.ok();
Ok(slug)
}
service::db::AddResult::Fail => Err((
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
debuginfo!("Got insertion response, but it was error.").into(),
)),
},
Err(e) => {
ifdbg!(eprintln!("{}", e));
Err((
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
debuginfo!("Response channel for insertion is unexpectedly closed").into(),
))
}
}
}
/// Redirect from a slug.
async fn redirect(
slug_str: &str,
slug_factory: &slug::SlugFactory,
db: &db::SlugDatabase,
) -> Result<HttpUrl, (StatusCode, String)> {
// Check that the slug is valid.
let slug = slug_factory.parse_str(slug_str).map_err(|e| match e {
slug::InvalidSlug::TooLong => (
warp::http::StatusCode::BAD_REQUEST,
debuginfo!("Given slug is too long.", "Invalid URL.").into(),
),
slug::InvalidSlug::BadChar => (
warp::http::StatusCode::BAD_REQUEST,
debuginfo!("Given slug has invalid characters.", "Invalid URL.").into(),
),
})?;
match db.get_slug(slug).await {
Ok(result) => match result {
db::GetResult::Found(url) => Ok(url),
db::GetResult::NotFound => Err((
warp::http::StatusCode::BAD_REQUEST,
debuginfo!("The slug does not exist in the database.", "Invalid URL.").into(),
)),
db::GetResult::InternalError => Err((
warp::http::StatusCode::BAD_REQUEST,
"Internal error.".into(),
)),
},
Err(e) => {
ifdbg!(eprintln!("{}", e));
Err((
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
debuginfo!("Response channel for insertion is unexpectedly closed").into(),
))
}
}
}
#[tokio::main]
async fn serve() {
// Read configuration
let config: conf::Config = {
let config_file_name = std::env::var("LONK_CONFIG").unwrap_or("lonk.json".to_string());
let config_file = std::fs::File::open(config_file_name.clone()).unwrap_or_else(|err| {
match err.kind() {
std::io::ErrorKind::NotFound => {
eprintln!("Configuration file {} does not exist.", config_file_name)
}
std::io::ErrorKind::PermissionDenied => {
eprintln!("Read permission to {} was denied.", config_file_name)
}
_ => eprintln!(
"Error when trying to read configuration file {}: {}",
config_file_name, err
),
};
std::process::exit(1);
});
let parse_result = tokio::task::spawn_blocking(move || {
conf::Config::from_sync_buffer(std::io::BufReader::new(config_file))
})
.await
.expect("Tokio error from blocking task.");
match parse_result {
Err(err) => err.panic_with_message(&config_file_name),
Ok(config) => config,
}
};
// Create logger
let logger = log::Logger::from_log_rules(&config.log_rules);
// Create slug factory
let slug_factory = slug::SlugFactory::from_rules(config.slug_rules);
// Initialize database
let db = {
let client = redis::Client::open(config.db.address).expect("Error opening Redis database.");
db::SlugDatabase::from_client(client, config.db.expire_seconds)
};
// We leak the slug factory, the database, and the logger, because we know
// that these will live forever, and want them to have 'static lifetime so
// that warp is happy.
let slug_factory: &'static slug::SlugFactory = Box::leak(Box::new(slug_factory));
let db: &'static db::SlugDatabase = Box::leak(Box::new(db));
let logger: &'static log::Logger = Box::leak(Box::new(logger));
// Warp logging compatibility layer
let log = warp::log::custom(move |info| {
let log_msg = format!(
"{} ({}) {} {}, replied with status {}\n",
info.remote_addr()
.map(|x| x.to_string())
.unwrap_or_else(|| "<Unknown remote address>".to_string()),
info.user_agent()
.unwrap_or_else(|| "<No user agent provided>"),
info.method(),
info.path(),
info.status().as_u16(),
);
if info.status().is_client_error() || info.status().is_server_error() {
logger.error(log_msg).ok();
} else {
logger.access(log_msg).ok();
}
});
// POST /shorten/ with link in argument
let shorten = warp::post()
.and(warp::path("shorten"))
.and(warp::body::content_length_limit(1024))
.and(warp::body::bytes())
.then(move |body: warp::hyper::body::Bytes| async move {
let b64str = std::str::from_utf8(&body[..]);
if b64str.is_err() {
return Response::builder()
.status(warp::http::StatusCode::BAD_REQUEST)
.body(String::new())
.unwrap();
}
match shorten(&slug_factory, &db, b64str.unwrap(), logger).await {
Ok(slug) => Response::builder()
.body(slug.inner_str().to_string())
.unwrap(),
Err((status, message)) => Response::builder().status(status).body(message).unwrap(),
}
});
// GET /l/:Slug
let link = warp::path("l")
.and(warp::path::param())
.then(move |slug: String| async move {
match redirect(&slug, &slug_factory, &db).await {
Ok(url) => Response::builder()
.status(warp::http::StatusCode::FOUND)
.header("Location", url.to_string())
.body("".to_string())
.unwrap(),
Err((status, message)) => Response::builder().status(status).body(message).unwrap(),
}
});
// GET /
// This should be the last thing matched, so that anything that doesn't
// match another filter will try to match a file.
let homepage = warp::get()
.and(config.serve_rules.dir.to_filter());
let get_routes = warp::get().and(link.or(homepage));
let post_routes = warp::post().and(shorten);
let routes = get_routes.or(post_routes).with(log);
eprintln!(
"Now serving lonk at IP {}, port {}!",
config.serve_rules.addr.ip, config.serve_rules.addr.port
);
warp::serve(routes)
.run((config.serve_rules.addr.ip, config.serve_rules.addr.port))
.await;
unreachable!("The warp server runs forever.")
}
#[derive(FromArgs, PartialEq, Debug)]
/// Start lonk.
struct Run {
/// print the version and quit
#[argh(switch)]
version: bool,
/// write a default configuration to stdout and quit
#[argh(switch)]
print_default_config: bool,
}
const VERSION: &'static str = env!("CARGO_PKG_VERSION");
fn main() {
let run = argh::from_env::<Run>();
if run.version {
println!("lonk v{}", VERSION);
std::process::exit(0);
}
if run.print_default_config {
println!(
"{}",
serde_json::to_string_pretty(&conf::Config::default())
.expect("Default configuration should always be JSON serializable")
);
std::process::exit(0);
}
serve();
}