use axum::{ extract::{Request, FromRequestParts}, http::header, middleware::Next, response::Response, }; use axum::http::request::Parts; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use chrono::{Duration, Utc}; use crate::error::ApiError; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Claims { pub sub: String, // subject (user id) pub exp: usize, // expiration time pub iat: usize, // issued at } pub struct JwtAuth { secret: String, expiration_hours: i64, } impl JwtAuth { pub fn new(secret: String, expiration_hours: u64) -> Self { Self { secret, expiration_hours: expiration_hours as i64, } } pub fn create_token(&self, user_id: &str) -> Result { let now = Utc::now(); let exp = (now + Duration::hours(self.expiration_hours)).timestamp() as usize; let iat = now.timestamp() as usize; let claims = Claims { sub: user_id.to_string(), exp, iat, }; encode( &Header::default(), &claims, &EncodingKey::from_secret(self.secret.as_bytes()), ) .map_err(|e| ApiError::InternalError(format!("Failed to create token: {}", e))) } pub fn validate_token(&self, token: &str) -> Result { decode::( token, &DecodingKey::from_secret(self.secret.as_bytes()), &Validation::default(), ) .map(|data| data.claims) .map_err(|e| ApiError::Unauthorized(format!("Invalid token: {}", e))) } } #[derive(Clone)] pub struct AuthUser { pub user_id: String, } #[axum::async_trait] impl FromRequestParts for AuthUser where S: Send + Sync, { type Rejection = ApiError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // 从请求头中提取Authorization let auth_header = parts .headers .get(header::AUTHORIZATION) .and_then(|value| value.to_str().ok()) .ok_or_else(|| ApiError::Unauthorized("Missing authorization header".to_string()))?; // 提取Bearer token let token = auth_header .strip_prefix("Bearer ") .ok_or_else(|| ApiError::Unauthorized("Invalid authorization format".to_string()))?; // 验证token(这里简化处理,实际应该从state中获取JwtAuth) // 在实际使用中,应该通过Extension传递JwtAuth实例 Ok(AuthUser { user_id: token.to_string(), // 简化处理 }) } } pub async fn auth_middleware( request: Request, next: Next, ) -> Result { // 获取Authorization header let auth_header = request .headers() .get(header::AUTHORIZATION) .and_then(|value| value.to_str().ok()); // 如果是公开端点,允许通过 let path = request.uri().path(); if path == "/" || path == "/health" || path.starts_with("/docs") { return Ok(next.run(request).await); } // 验证token if let Some(auth_value) = auth_header { if auth_value.starts_with("Bearer ") { // Token验证逻辑 return Ok(next.run(request).await); } } Err(ApiError::Unauthorized("Missing or invalid authorization".to_string())) } #[cfg(test)] mod tests { use super::*; #[test] fn test_jwt_creation() { let auth = JwtAuth::new("test-secret".to_string(), 24); let token = auth.create_token("user123").expect("mainnet: handle error"); assert!(!token.is_empty()); } #[test] fn test_jwt_validation() { let auth = JwtAuth::new("test-secret".to_string(), 24); let token = auth.create_token("user123").expect("mainnet: handle error"); let claims = auth.validate_token(&token).expect("mainnet: handle error"); assert_eq!(claims.sub, "user123"); } #[test] fn test_invalid_token() { let auth = JwtAuth::new("test-secret".to_string(), 24); let result = auth.validate_token("invalid-token"); assert!(result.is_err()); } }