允许通过环境变量修改配置;优化格式

This commit is contained in:
yly 2023-11-19 05:24:55 +08:00
parent 01f90d4bf1
commit 27c463bdf1
3 changed files with 55 additions and 40 deletions

1
Cargo.lock generated
View File

@ -124,6 +124,7 @@ dependencies = [
"askama", "askama",
"axum", "axum",
"minijinja", "minijinja",
"once_cell",
"serde", "serde",
"sqlx", "sqlx",
"tokio", "tokio",

View File

@ -18,3 +18,4 @@ tower-http = { version = "0.4.4", features = ["trace"] }
tracing = "0.1.40" tracing = "0.1.40"
askama = "0.10" askama = "0.10"
minijinja = "1.0.9" minijinja = "1.0.9"
once_cell = "1.18.0"

View File

@ -1,40 +1,41 @@
use axum::extract::Query; use axum::extract::Query;
use axum::http::{HeaderMap, HeaderValue, Uri};
use axum::response::{Html, Redirect}; use axum::response::{Html, Redirect};
use axum::http::{Uri, HeaderMap, HeaderValue}; use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Form, Router};
use axum::{ use minijinja::{context, Environment};
extract::State,
http::StatusCode,
response::IntoResponse,
routing::get,
Form, Router,
};
use serde::Deserialize; use serde::Deserialize;
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
use std::{env, borrow::BorrowMut};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{borrow::BorrowMut, env};
use std::{collections::HashMap, str::FromStr}; use std::{collections::HashMap, str::FromStr};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
use uuid::Uuid;
use tower_http::trace::{self, TraceLayer}; use tower_http::trace::{self, TraceLayer};
use tracing::Level; use tracing::Level;
use minijinja::{Environment, context}; use uuid::Uuid;
pub struct ServerState { pub struct ServerState {
pub db: sqlx::Pool<sqlx::Sqlite>, pub db: sqlx::Pool<sqlx::Sqlite>,
pub session: Mutex<HashMap<Uuid, Instant>>, pub session: Mutex<HashMap<Uuid, Instant>>,
// started: Instant, // started: Instant,
} }
#[derive(Deserialize, sqlx::FromRow,Debug)] #[derive(Deserialize, sqlx::FromRow, Debug)]
pub struct UserLoginForm { pub struct UserLoginForm {
#[sqlx(rename = "NAME")] #[sqlx(rename = "NAME")]
pub username: String, pub username: String,
#[sqlx(rename = "KEY")] #[sqlx(rename = "KEY")]
pub otp: String, pub otp: String,
} }
const COOKIE_NAME: &str = "aaron_auth"; use once_cell::sync::Lazy;
const SESSION_ACTIVE_TIME: u64 = 600; static COOKIE_NAME: Lazy<String> =
Lazy::new(|| env::var("COOKIE_NAME").unwrap_or("aaron_auth".to_string()));
const SESSION_ACTIVE_TIME: Lazy<u64> = Lazy::new(|| {
env::var("SESSION_ACTIVE_TIME")
.ok()
.and_then(|value| value.parse().ok())
.unwrap_or(600)
});
const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html"); const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html");
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -69,15 +70,14 @@ async fn main() {
.unwrap(); .unwrap();
} }
async fn gc_task(state: Arc<ServerState>) {
async fn gc_task(state:Arc<ServerState>){ let mut interval = tokio::time::interval(Duration::from_secs(*SESSION_ACTIVE_TIME));
let mut interval = tokio::time::interval(Duration::from_secs(SESSION_ACTIVE_TIME));
loop { loop {
interval.tick().await; interval.tick().await;
let res = gc(state.clone()).await; let res = gc(state.clone()).await;
match res { match res {
Ok(_) => tracing::info!("gc completed"), Ok(_) => tracing::info!("gc completed"),
Err(s) => tracing::error!("gc failed:{}",s), Err(s) => tracing::error!("gc failed:{}", s),
} }
} }
} }
@ -88,15 +88,19 @@ async fn auth(
// Form(frm): Form<UserLoginForm>, // Form(frm): Form<UserLoginForm>,
cookies: Cookies, cookies: Cookies,
) -> StatusCode { ) -> StatusCode {
if let Some(session_token) = cookies.get(COOKIE_NAME) { if let Some(session_token) = cookies.get(&COOKIE_NAME) {
tracing::info!("session:{}",session_token.value()); tracing::info!("session:{}", session_token.value());
let Ok(s) = uuid::Uuid::from_str(session_token.value()) else { let Ok(s) = uuid::Uuid::from_str(session_token.value()) else {
return StatusCode::UNAUTHORIZED; return StatusCode::UNAUTHORIZED;
}; };
let mut locked = state.session.lock().await; let mut locked = state.session.lock().await;
if locked.contains_key(&s) { // FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed if locked.contains_key(&s) {
let Some(v) = locked.insert(s,Instant::now()+Duration::from_secs(SESSION_ACTIVE_TIME)) else { // FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed
tracing::info!("session:{} extended",session_token.value()); let Some(v) = locked.insert(
s,
Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME),
) else {
tracing::info!("session:{} extended", session_token.value());
return StatusCode::UNAUTHORIZED; return StatusCode::UNAUTHORIZED;
}; };
if Instant::now() < v { if Instant::now() < v {
@ -117,7 +121,7 @@ async fn login(
let Ok(mut conn) = conn else { let Ok(mut conn) = conn else {
return Err((StatusCode::BAD_GATEWAY, "db连接错误")); return Err((StatusCode::BAD_GATEWAY, "db连接错误"));
}; };
tracing::info!("{:?}",&frm); tracing::info!("{:?}", &frm);
let target = sqlx::query_as::<_, UserLoginForm>( let target = sqlx::query_as::<_, UserLoginForm>(
r#" r#"
SELECT NAME, KEY FROM USERS WHERE NAME = ? SELECT NAME, KEY FROM USERS WHERE NAME = ?
@ -126,14 +130,17 @@ async fn login(
.bind(frm.username) .bind(frm.username)
.fetch_optional(&mut *conn) .fetch_optional(&mut *conn)
.await; .await;
tracing::info!("{:?}",&target); tracing::info!("{:?}", &target);
if let Ok(Some(target)) = target { if let Ok(Some(target)) = target {
if check_otp(target.otp, frm.otp) { if check_otp(target.otp, frm.otp) {
let s = Uuid::new_v4(); let s = Uuid::new_v4();
let mut locked = state.session.lock().await; let mut locked = state.session.lock().await;
locked.insert(s.clone(), Instant::now() + Duration::from_secs(SESSION_ACTIVE_TIME)); locked.insert(
let mut new_cookie = Cookie::new(COOKIE_NAME, s.to_string()); s.clone(),
Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME),
);
let mut new_cookie = Cookie::new(&*COOKIE_NAME, s.to_string());
new_cookie.set_domain(".aaronhu.cn"); new_cookie.set_domain(".aaronhu.cn");
cookies.add(new_cookie); cookies.add(new_cookie);
if let Some(original_uri) = params.get("original_url") { if let Some(original_uri) = params.get("original_url") {
@ -142,29 +149,35 @@ async fn login(
return Err((StatusCode::ACCEPTED, "ok")); return Err((StatusCode::ACCEPTED, "ok"));
} else { } else {
return Err((StatusCode::UNAUTHORIZED,"wrong password")); return Err((StatusCode::UNAUTHORIZED, "wrong password"));
} }
} }
return Err((StatusCode::BAD_GATEWAY, "unreachable")); return Err((StatusCode::BAD_GATEWAY, "unreachable"));
} }
async fn login_page( async fn login_page(headers: HeaderMap<HeaderValue>) -> impl IntoResponse {
headers: HeaderMap<HeaderValue>, tracing::info!("Headers: {:#?}", headers);
)-> impl IntoResponse{
tracing::info!("Headers: {:#?}",headers);
let mut env = Environment::new(); let mut env = Environment::new();
env.add_template("login.html", LOGIN_PAGE_HTML).unwrap(); env.add_template("login.html", LOGIN_PAGE_HTML).unwrap();
let template = env.get_template("login.html").unwrap(); let template = env.get_template("login.html").unwrap();
if let Some(original_uri) = headers.get("X-Original-URI") { if let Some(original_uri) = headers.get("X-Original-URI") {
if let Ok(uri) = original_uri.to_str() { if let Ok(uri) = original_uri.to_str() {
tracing::info!("redirect to {}",uri); tracing::info!("redirect to {}", uri);
if !uri.is_empty() { if !uri.is_empty() {
let uri = "?original_url=".to_owned()+uri; let uri = "?original_url=".to_owned() + uri;
return Html(template.render(context! { url => uri }).unwrap_or("Error".to_string())) return Html(
template
.render(context! { url => uri })
.unwrap_or("Error".to_string()),
);
} }
} }
} }
return Html(template.render(context! { url => String::new() }).unwrap_or("Error".to_string())) return Html(
template
.render(context! { url => String::new() })
.unwrap_or("Error".to_string()),
);
} }
pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool { pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool {
@ -186,11 +199,11 @@ pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool {
return false; return false;
} }
async fn gc(state:Arc<ServerState>)->Result<(),String>{ async fn gc(state: Arc<ServerState>) -> Result<(), String> {
let mut locked = state.session.lock().await; let mut locked = state.session.lock().await;
let current_time = Instant::now(); let current_time = Instant::now();
tracing::info!("before gc ,active Sessions {:?}",locked); tracing::info!("before gc ,active Sessions {:?}", locked);
locked.borrow_mut().retain(|_,v| *v > current_time); locked.borrow_mut().retain(|_, v| *v > current_time);
tracing::info!("gc fired,active Sessions {:?}",locked); tracing::info!("gc fired,active Sessions {:?}", locked);
Ok(()) Ok(())
} }