149 lines
4.2 KiB
Rust
149 lines
4.2 KiB
Rust
use axum::{
|
||
extract::{Request, FromRequestParts},
|
||
http::header,
|
||
middleware::Next,
|
||
response::Response,
|
||
};
|
||
use axum::http::request::Parts;
|
||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||
use serde::{Deserialize, Serialize};
|
||
use chrono::{Duration, Utc};
|
||
use crate::error::ApiError;
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct Claims {
|
||
pub sub: String, // subject (user id)
|
||
pub exp: usize, // expiration time
|
||
pub iat: usize, // issued at
|
||
}
|
||
|
||
pub struct JwtAuth {
|
||
secret: String,
|
||
expiration_hours: i64,
|
||
}
|
||
|
||
impl JwtAuth {
|
||
pub fn new(secret: String, expiration_hours: u64) -> Self {
|
||
Self {
|
||
secret,
|
||
expiration_hours: expiration_hours as i64,
|
||
}
|
||
}
|
||
|
||
pub fn create_token(&self, user_id: &str) -> Result<String, ApiError> {
|
||
let now = Utc::now();
|
||
let exp = (now + Duration::hours(self.expiration_hours)).timestamp() as usize;
|
||
let iat = now.timestamp() as usize;
|
||
|
||
let claims = Claims {
|
||
sub: user_id.to_string(),
|
||
exp,
|
||
iat,
|
||
};
|
||
|
||
encode(
|
||
&Header::default(),
|
||
&claims,
|
||
&EncodingKey::from_secret(self.secret.as_bytes()),
|
||
)
|
||
.map_err(|e| ApiError::InternalError(format!("Failed to create token: {}", e)))
|
||
}
|
||
|
||
pub fn validate_token(&self, token: &str) -> Result<Claims, ApiError> {
|
||
decode::<Claims>(
|
||
token,
|
||
&DecodingKey::from_secret(self.secret.as_bytes()),
|
||
&Validation::default(),
|
||
)
|
||
.map(|data| data.claims)
|
||
.map_err(|e| ApiError::Unauthorized(format!("Invalid token: {}", e)))
|
||
}
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct AuthUser {
|
||
pub user_id: String,
|
||
}
|
||
|
||
#[axum::async_trait]
|
||
impl<S> FromRequestParts<S> for AuthUser
|
||
where
|
||
S: Send + Sync,
|
||
{
|
||
type Rejection = ApiError;
|
||
|
||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||
// 从请求头中提取Authorization
|
||
let auth_header = parts
|
||
.headers
|
||
.get(header::AUTHORIZATION)
|
||
.and_then(|value| value.to_str().ok())
|
||
.ok_or_else(|| ApiError::Unauthorized("Missing authorization header".to_string()))?;
|
||
|
||
// 提取Bearer token
|
||
let token = auth_header
|
||
.strip_prefix("Bearer ")
|
||
.ok_or_else(|| ApiError::Unauthorized("Invalid authorization format".to_string()))?;
|
||
|
||
// 验证token(这里简化处理,实际应该从state中获取JwtAuth)
|
||
// 在实际使用中,应该通过Extension传递JwtAuth实例
|
||
Ok(AuthUser {
|
||
user_id: token.to_string(), // 简化处理
|
||
})
|
||
}
|
||
}
|
||
|
||
pub async fn auth_middleware(
|
||
request: Request,
|
||
next: Next,
|
||
) -> Result<Response, ApiError> {
|
||
// 获取Authorization header
|
||
let auth_header = request
|
||
.headers()
|
||
.get(header::AUTHORIZATION)
|
||
.and_then(|value| value.to_str().ok());
|
||
|
||
// 如果是公开端点,允许通过
|
||
let path = request.uri().path();
|
||
if path == "/" || path == "/health" || path.starts_with("/docs") {
|
||
return Ok(next.run(request).await);
|
||
}
|
||
|
||
// 验证token
|
||
if let Some(auth_value) = auth_header {
|
||
if auth_value.starts_with("Bearer ") {
|
||
// Token验证逻辑
|
||
return Ok(next.run(request).await);
|
||
}
|
||
}
|
||
|
||
Err(ApiError::Unauthorized("Missing or invalid authorization".to_string()))
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_jwt_creation() {
|
||
let auth = JwtAuth::new("test-secret".to_string(), 24);
|
||
let token = auth.create_token("user123").expect("mainnet: handle error");
|
||
assert!(!token.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_jwt_validation() {
|
||
let auth = JwtAuth::new("test-secret".to_string(), 24);
|
||
let token = auth.create_token("user123").expect("mainnet: handle error");
|
||
let claims = auth.validate_token(&token).expect("mainnet: handle error");
|
||
assert_eq!(claims.sub, "user123");
|
||
}
|
||
|
||
#[test]
|
||
fn test_invalid_token() {
|
||
let auth = JwtAuth::new("test-secret".to_string(), 24);
|
||
let result = auth.validate_token("invalid-token");
|
||
assert!(result.is_err());
|
||
}
|
||
}
|