406 lines
12 KiB
Rust
406 lines
12 KiB
Rust
//! AI模型调用和管理
|
||
//!
|
||
//! 集成三大AI模型:ChatGPT-4.1、DeepSeek-V3、豆包AI-Pro
|
||
|
||
use serde::{Deserialize, Serialize};
|
||
use rust_decimal::Decimal;
|
||
use std::collections::HashMap;
|
||
use reqwest::Client;
|
||
use anyhow::{Result, Context};
|
||
|
||
use crate::{AIProvider, AIValuationResult, Asset, Jurisdiction, InternationalAgreement};
|
||
|
||
/// AI模型配置
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct AIModelConfig {
|
||
/// API端点
|
||
pub endpoint: String,
|
||
/// API密钥
|
||
pub api_key: String,
|
||
/// 模型名称
|
||
pub model_name: String,
|
||
/// 超时时间(秒)
|
||
pub timeout_secs: u64,
|
||
/// 最大重试次数
|
||
pub max_retries: u32,
|
||
}
|
||
|
||
impl AIModelConfig {
|
||
/// 创建ChatGPT配置
|
||
pub fn chatgpt(api_key: String) -> Self {
|
||
Self {
|
||
endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
|
||
api_key,
|
||
model_name: "gpt-4.1".to_string(),
|
||
timeout_secs: 30,
|
||
max_retries: 3,
|
||
}
|
||
}
|
||
|
||
/// 创建DeepSeek配置
|
||
pub fn deepseek(api_key: String) -> Self {
|
||
Self {
|
||
endpoint: "https://api.deepseek.com/v1/chat/completions".to_string(),
|
||
api_key,
|
||
model_name: "deepseek-v3".to_string(),
|
||
timeout_secs: 30,
|
||
max_retries: 3,
|
||
}
|
||
}
|
||
|
||
/// 创建豆包AI配置
|
||
pub fn doubao(api_key: String) -> Self {
|
||
Self {
|
||
endpoint: "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string(),
|
||
api_key,
|
||
model_name: "doubao-pro-32k".to_string(),
|
||
timeout_secs: 30,
|
||
max_retries: 3,
|
||
}
|
||
}
|
||
}
|
||
|
||
/// AI模型客户端
|
||
pub struct AIModelClient {
|
||
provider: AIProvider,
|
||
config: AIModelConfig,
|
||
client: Client,
|
||
}
|
||
|
||
impl AIModelClient {
|
||
/// 创建新的AI模型客户端
|
||
pub fn new(provider: AIProvider, config: AIModelConfig) -> Result<Self> {
|
||
let client = Client::builder()
|
||
.timeout(std::time::Duration::from_secs(config.timeout_secs))
|
||
.build()
|
||
.context("Failed to create HTTP client")?;
|
||
|
||
Ok(Self {
|
||
provider,
|
||
config,
|
||
client,
|
||
})
|
||
}
|
||
|
||
/// 调用AI模型进行资产估值
|
||
pub async fn appraise(
|
||
&self,
|
||
asset: &Asset,
|
||
jurisdiction: Jurisdiction,
|
||
agreement: InternationalAgreement,
|
||
xtzh_price_usd: Decimal,
|
||
) -> Result<AIValuationResult> {
|
||
let prompt = self.build_prompt(asset, jurisdiction, agreement, xtzh_price_usd);
|
||
|
||
let mut retries = 0;
|
||
loop {
|
||
match self.call_api(&prompt).await {
|
||
Ok(response) => {
|
||
return self.parse_response(response);
|
||
}
|
||
Err(e) if retries < self.config.max_retries => {
|
||
retries += 1;
|
||
log::warn!(
|
||
"AI模型 {:?} 调用失败 (尝试 {}/{}): {}",
|
||
self.provider,
|
||
retries,
|
||
self.config.max_retries,
|
||
e
|
||
);
|
||
tokio::time::sleep(std::time::Duration::from_secs(2_u64.pow(retries))).await;
|
||
}
|
||
Err(e) => {
|
||
return Err(e).context(format!(
|
||
"AI模型 {:?} 调用失败,已重试 {} 次",
|
||
self.provider, self.config.max_retries
|
||
));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 构建估值提示词
|
||
fn build_prompt(
|
||
&self,
|
||
asset: &Asset,
|
||
jurisdiction: Jurisdiction,
|
||
agreement: InternationalAgreement,
|
||
xtzh_price_usd: Decimal,
|
||
) -> String {
|
||
let jurisdiction_info = jurisdiction.info();
|
||
let agreement_info = agreement.info();
|
||
|
||
format!(
|
||
r#"你是一位资深的资产估值专家,请对以下资产进行专业估值分析。
|
||
|
||
# 资产信息
|
||
- 资产ID: {}
|
||
- 资产类型: {:?}
|
||
- GNACS编码: {}
|
||
- 资产名称: {}
|
||
- 资产描述: {}
|
||
- 基础估值: {} {}
|
||
|
||
# 司法辖区信息
|
||
- 辖区: {:?}
|
||
- 法律体系: {:?}
|
||
- 会计准则: {:?}
|
||
- 税收政策: {}
|
||
- 监管环境: {}
|
||
|
||
# 国际协定
|
||
- 协定: {}
|
||
- 关税调整: {}
|
||
- 市场准入: {}
|
||
- 投资保护: {}
|
||
|
||
# XTZH价格
|
||
- 当前XTZH价格: {} USD
|
||
|
||
# 估值要求
|
||
请综合考虑以下因素进行估值:
|
||
1. 资产的内在价值和市场价值
|
||
2. 司法辖区的法律、税收、监管影响
|
||
3. 国际贸易协定的影响
|
||
4. 市场流动性和风险因素
|
||
5. ESG因素(如适用)
|
||
|
||
请以JSON格式返回估值结果:
|
||
{{
|
||
"valuation_xtzh": "估值金额(XTZH)",
|
||
"confidence": 0.85,
|
||
"reasoning": "详细的估值推理过程,包括关键假设、调整因素、风险分析等"
|
||
}}
|
||
|
||
注意:
|
||
- valuation_xtzh必须是数字字符串
|
||
- confidence必须在0-1之间
|
||
- reasoning必须详细说明估值逻辑"#,
|
||
asset.id,
|
||
asset.asset_type,
|
||
asset.gnacs_code,
|
||
asset.name,
|
||
asset.description,
|
||
asset.base_valuation_local,
|
||
asset.local_currency,
|
||
jurisdiction,
|
||
jurisdiction_info.legal_system,
|
||
jurisdiction_info.accounting_standard,
|
||
format!("企业税率{:.1}%, 资本利得税{:.1}%, 增值税{:.1}%",
|
||
jurisdiction_info.corporate_tax_rate * 100.0,
|
||
jurisdiction_info.capital_gains_tax_rate * 100.0,
|
||
jurisdiction_info.vat_rate * 100.0),
|
||
format!("监管成本率{:.1}%, 流动性折扣{:.1}%",
|
||
jurisdiction_info.regulatory_cost_rate * 100.0,
|
||
jurisdiction_info.base_liquidity_discount * 100.0),
|
||
agreement_info.name,
|
||
agreement_info.tariff_adjustment,
|
||
agreement_info.market_access_discount,
|
||
agreement_info.investment_protection,
|
||
xtzh_price_usd,
|
||
)
|
||
}
|
||
|
||
/// 调用API
|
||
async fn call_api(&self, prompt: &str) -> Result<String> {
|
||
let request_body = serde_json::json!({
|
||
"model": self.config.model_name,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一位专业的资产估值专家,精通全球资产估值标准和方法。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": 0.3,
|
||
"max_tokens": 2000,
|
||
});
|
||
|
||
let response = self.client
|
||
.post(&self.config.endpoint)
|
||
.header("Authorization", format!("Bearer {}", self.config.api_key))
|
||
.header("Content-Type", "application/json")
|
||
.json(&request_body)
|
||
.send()
|
||
.await
|
||
.context("Failed to send API request")?;
|
||
|
||
if !response.status().is_success() {
|
||
let status = response.status();
|
||
let error_text = response.text().await.unwrap_or_default();
|
||
anyhow::bail!("API request failed with status {}: {}", status, error_text);
|
||
}
|
||
|
||
let response_json: serde_json::Value = response.json().await
|
||
.context("Failed to parse API response")?;
|
||
|
||
let content = response_json["choices"][0]["message"]["content"]
|
||
.as_str()
|
||
.context("Missing content in API response")?
|
||
.to_string();
|
||
|
||
Ok(content)
|
||
}
|
||
|
||
/// 解析AI响应
|
||
fn parse_response(&self, response: String) -> Result<AIValuationResult> {
|
||
// 尝试从响应中提取JSON
|
||
let json_str = if let Some(start) = response.find('{') {
|
||
if let Some(end) = response.rfind('}') {
|
||
&response[start..=end]
|
||
} else {
|
||
&response
|
||
}
|
||
} else {
|
||
&response
|
||
};
|
||
|
||
#[derive(Deserialize)]
|
||
struct ResponseData {
|
||
valuation_xtzh: String,
|
||
confidence: f64,
|
||
reasoning: String,
|
||
}
|
||
|
||
let data: ResponseData = serde_json::from_str(json_str)
|
||
.context("Failed to parse AI response JSON")?;
|
||
|
||
let valuation_xtzh = data.valuation_xtzh.parse::<Decimal>()
|
||
.context("Failed to parse valuation_xtzh as Decimal")?;
|
||
|
||
if !(0.0..=1.0).contains(&data.confidence) {
|
||
anyhow::bail!("Confidence must be between 0 and 1, got {}", data.confidence);
|
||
}
|
||
|
||
Ok(AIValuationResult {
|
||
provider: self.provider,
|
||
valuation_xtzh,
|
||
confidence: data.confidence,
|
||
reasoning: data.reasoning,
|
||
timestamp: chrono::Utc::now(),
|
||
})
|
||
}
|
||
}
|
||
|
||
/// AI模型管理器
|
||
pub struct AIModelManager {
|
||
clients: HashMap<AIProvider, AIModelClient>,
|
||
}
|
||
|
||
impl AIModelManager {
|
||
/// 创建新的AI模型管理器
|
||
pub fn new(
|
||
chatgpt_api_key: String,
|
||
deepseek_api_key: String,
|
||
doubao_api_key: String,
|
||
) -> Result<Self> {
|
||
let mut clients = HashMap::new();
|
||
|
||
clients.insert(
|
||
AIProvider::ChatGPT,
|
||
AIModelClient::new(
|
||
AIProvider::ChatGPT,
|
||
AIModelConfig::chatgpt(chatgpt_api_key),
|
||
)?,
|
||
);
|
||
|
||
clients.insert(
|
||
AIProvider::DeepSeek,
|
||
AIModelClient::new(
|
||
AIProvider::DeepSeek,
|
||
AIModelConfig::deepseek(deepseek_api_key),
|
||
)?,
|
||
);
|
||
|
||
clients.insert(
|
||
AIProvider::DouBao,
|
||
AIModelClient::new(
|
||
AIProvider::DouBao,
|
||
AIModelConfig::doubao(doubao_api_key),
|
||
)?,
|
||
);
|
||
|
||
Ok(Self { clients })
|
||
}
|
||
|
||
/// 并行调用所有AI模型
|
||
pub async fn appraise_all(
|
||
&self,
|
||
asset: &Asset,
|
||
jurisdiction: Jurisdiction,
|
||
agreement: InternationalAgreement,
|
||
xtzh_price_usd: Decimal,
|
||
) -> Result<Vec<AIValuationResult>> {
|
||
let mut tasks = Vec::new();
|
||
|
||
for (provider, client) in &self.clients {
|
||
let asset = asset.clone();
|
||
let client_provider = *provider;
|
||
let task = async move {
|
||
client.appraise(&asset, jurisdiction, agreement, xtzh_price_usd).await
|
||
};
|
||
tasks.push((client_provider, task));
|
||
}
|
||
|
||
let mut results = Vec::new();
|
||
for (provider, task) in tasks {
|
||
match task.await {
|
||
Ok(result) => results.push(result),
|
||
Err(e) => {
|
||
log::error!("AI模型 {:?} 估值失败: {}", provider, e);
|
||
// 继续执行,不中断其他模型
|
||
}
|
||
}
|
||
}
|
||
|
||
if results.is_empty() {
|
||
anyhow::bail!("所有AI模型调用均失败");
|
||
}
|
||
|
||
Ok(results)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_ai_model_config() {
|
||
let config = AIModelConfig::chatgpt("test_key".to_string());
|
||
assert_eq!(config.model_name, "gpt-4.1");
|
||
assert_eq!(config.timeout_secs, 30);
|
||
}
|
||
|
||
#[test]
|
||
fn test_build_prompt() {
|
||
let asset = Asset::new(
|
||
"test_asset".to_string(),
|
||
crate::AssetType::RealEstate,
|
||
"GNACS-001".to_string(),
|
||
"Test Property".to_string(),
|
||
Decimal::new(1000000, 0),
|
||
"USD".to_string(),
|
||
);
|
||
|
||
let client = AIModelClient::new(
|
||
AIProvider::ChatGPT,
|
||
AIModelConfig::chatgpt("test_key".to_string()),
|
||
).unwrap();
|
||
|
||
let prompt = client.build_prompt(
|
||
&asset,
|
||
Jurisdiction::US,
|
||
InternationalAgreement::WTO,
|
||
Decimal::new(100, 0),
|
||
);
|
||
|
||
assert!(prompt.contains("test_asset"));
|
||
assert!(prompt.contains("GNACS-001"));
|
||
assert!(prompt.contains("Test Property"));
|
||
}
|
||
}
|