278 lines
9.6 KiB
Plaintext
278 lines
9.6 KiB
Plaintext
/**
|
||
* NAC Knowledge Engine - RAG检索增强模块
|
||
*
|
||
* 功能:从MongoDB知识库中检索与用户问题最相关的合规规则条文,
|
||
* 作为上下文注入到AI Agent的提示词中,提升回答的准确性和可溯源性。
|
||
*
|
||
* 检索策略(三层递进):
|
||
* 1. MongoDB全文检索($text index)- 关键词精确匹配
|
||
* 2. 正则关键词匹配 - 覆盖全文索引未命中的情况
|
||
* 3. 随机采样 - 兜底策略,确保始终有上下文
|
||
*
|
||
* 无向量数据库依赖,无Manus依赖,纯MongoDB原生实现。
|
||
*/
|
||
|
||
import { getMongoDb, COLLECTIONS } from "./mongodb";
|
||
|
||
// ─── 类型定义 ─────────────────────────────────────────────────────
|
||
|
||
export interface RetrievedRule {
|
||
ruleId: string;
|
||
ruleName: string;
|
||
jurisdiction: string; // 管辖区:CN/HK/SG/US/EU等
|
||
category: string; // 分类:RWA/AML/KYC/证券/基金等
|
||
content: string; // 规则内容(截断到500字)
|
||
description?: string; // 简短描述
|
||
score: number; // 相关性评分 0-1
|
||
source: string; // 来源标识(用于前端引用展示)
|
||
}
|
||
|
||
export interface RAGContext {
|
||
rules: RetrievedRule[];
|
||
totalFound: number;
|
||
retrievalMethod: "fulltext" | "regex" | "sample" | "none";
|
||
queryKeywords: string[];
|
||
}
|
||
|
||
// ─── 关键词提取 ───────────────────────────────────────────────────
|
||
|
||
/**
|
||
* 从用户问题中提取检索关键词
|
||
* 策略:去除停用词,保留实体词和专业术语
|
||
*/
|
||
function extractKeywords(query: string): string[] {
|
||
// NAC/RWA领域停用词
|
||
const STOP_WORDS = new Set([
|
||
"的", "了", "是", "在", "我", "有", "和", "就", "不", "人", "都", "一", "一个",
|
||
"上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好",
|
||
"自己", "这", "那", "什么", "如何", "怎么", "请问", "帮我", "告诉", "介绍",
|
||
"关于", "对于", "针对", "需要", "可以", "应该", "必须", "规定", "要求",
|
||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
|
||
"have", "has", "had", "do", "does", "did", "will", "would", "could", "should",
|
||
"what", "how", "when", "where", "why", "which", "who",
|
||
]);
|
||
|
||
// 提取中文词组(2-8字)和英文单词(3字以上)
|
||
const chineseTerms = query.match(/[\u4e00-\u9fa5]{2,8}/g) || [];
|
||
const englishTerms = query.match(/[a-zA-Z]{3,}/g) || [];
|
||
const numbers = query.match(/\d+/g) || [];
|
||
|
||
const allTerms = [...chineseTerms, ...englishTerms, ...numbers];
|
||
const filtered = allTerms.filter(t => !STOP_WORDS.has(t.toLowerCase()));
|
||
|
||
// 去重,最多取8个关键词
|
||
return Array.from(new Set(filtered)).slice(0, 8);
|
||
}
|
||
|
||
// ─── 主检索函数 ───────────────────────────────────────────────────
|
||
|
||
/**
|
||
* 从MongoDB知识库检索相关规则(RAG核心函数)
|
||
*
|
||
* @param query 用户问题
|
||
* @param options 检索选项
|
||
* @returns RAGContext 包含检索到的规则和元信息
|
||
*/
|
||
export async function retrieveRelevantRules(
|
||
query: string,
|
||
options: {
|
||
maxResults?: number;
|
||
jurisdictions?: string[]; // 限定管辖区
|
||
categories?: string[]; // 限定分类
|
||
language?: string; // 优先返回的语言版本
|
||
} = {}
|
||
): Promise<RAGContext> {
|
||
const { maxResults = 5, jurisdictions, categories, language = "zh" } = options;
|
||
const db = await getMongoDb();
|
||
|
||
if (!db) {
|
||
return { rules: [], totalFound: 0, retrievalMethod: "none", queryKeywords: [] };
|
||
}
|
||
|
||
const keywords = extractKeywords(query);
|
||
const collection = db.collection(COLLECTIONS.COMPLIANCE_RULES);
|
||
|
||
// 构建基础过滤条件
|
||
const baseFilter: Record<string, unknown> = {};
|
||
if (jurisdictions && jurisdictions.length > 0) {
|
||
baseFilter.jurisdiction = { $in: jurisdictions };
|
||
}
|
||
if (categories && categories.length > 0) {
|
||
baseFilter.category = { $in: categories };
|
||
}
|
||
|
||
let rules: RetrievedRule[] = [];
|
||
let retrievalMethod: RAGContext["retrievalMethod"] = "none";
|
||
|
||
// ── 策略1:MongoDB全文检索 ────────────────────────────────────
|
||
if (keywords.length > 0) {
|
||
try {
|
||
const searchText = keywords.join(" ");
|
||
const textFilter = {
|
||
...baseFilter,
|
||
$text: { $search: searchText },
|
||
};
|
||
|
||
const textResults = await collection
|
||
.find(textFilter, {
|
||
projection: {
|
||
score: { $meta: "textScore" },
|
||
ruleId: 1, ruleName: 1, jurisdiction: 1, category: 1,
|
||
content: 1, description: 1,
|
||
// 多语言字段
|
||
"translations.zh": 1, "translations.en": 1,
|
||
},
|
||
})
|
||
.sort({ score: { $meta: "textScore" } })
|
||
.limit(maxResults)
|
||
.toArray();
|
||
|
||
if (textResults.length > 0) {
|
||
rules = textResults.map((doc, idx) => formatRule(doc, language, idx, textResults.length));
|
||
retrievalMethod = "fulltext";
|
||
}
|
||
} catch (e) {
|
||
// 全文索引未建立时降级到正则检索
|
||
console.warn("[RAG] 全文检索失败,降级到正则检索:", (e as Error).message);
|
||
}
|
||
}
|
||
|
||
// ── 策略2:正则关键词匹配(全文检索未命中时)─────────────────
|
||
if (rules.length === 0 && keywords.length > 0) {
|
||
try {
|
||
const regexConditions = keywords.slice(0, 4).map(kw => ({
|
||
$or: [
|
||
{ ruleName: { $regex: kw, $options: "i" } },
|
||
{ description: { $regex: kw, $options: "i" } },
|
||
{ content: { $regex: kw, $options: "i" } },
|
||
{ "translations.zh": { $regex: kw, $options: "i" } },
|
||
],
|
||
}));
|
||
|
||
const regexFilter = {
|
||
...baseFilter,
|
||
$and: regexConditions,
|
||
};
|
||
|
||
const regexResults = await collection
|
||
.find(regexFilter)
|
||
.limit(maxResults)
|
||
.toArray();
|
||
|
||
if (regexResults.length > 0) {
|
||
rules = regexResults.map((doc, idx) => formatRule(doc, language, idx, regexResults.length));
|
||
retrievalMethod = "regex";
|
||
}
|
||
} catch (e) {
|
||
console.warn("[RAG] 正则检索失败:", (e as Error).message);
|
||
}
|
||
}
|
||
|
||
// ── 策略3:随机采样(兜底策略)──────────────────────────────
|
||
if (rules.length === 0) {
|
||
try {
|
||
const sampleResults = await collection
|
||
.aggregate([
|
||
{ $match: baseFilter },
|
||
{ $sample: { size: maxResults } },
|
||
])
|
||
.toArray();
|
||
|
||
if (sampleResults.length > 0) {
|
||
rules = sampleResults.map((doc, idx) => formatRule(doc, language, idx, sampleResults.length, 0.3));
|
||
retrievalMethod = "sample";
|
||
}
|
||
} catch (e) {
|
||
console.warn("[RAG] 随机采样失败:", (e as Error).message);
|
||
}
|
||
}
|
||
|
||
return {
|
||
rules,
|
||
totalFound: rules.length,
|
||
retrievalMethod,
|
||
queryKeywords: keywords,
|
||
};
|
||
}
|
||
|
||
// ─── 格式化工具函数 ───────────────────────────────────────────────
|
||
|
||
function formatRule(
|
||
doc: Record<string, unknown>,
|
||
language: string,
|
||
idx: number,
|
||
total: number,
|
||
baseScore?: number
|
||
): RetrievedRule {
|
||
// 计算相关性评分(全文检索结果按排名递减)
|
||
const score = baseScore !== undefined
|
||
? baseScore
|
||
: Math.max(0.4, 1.0 - (idx / total) * 0.5);
|
||
|
||
// 优先使用对应语言的翻译版本
|
||
const translations = doc.translations as Record<string, string> | undefined;
|
||
let content = "";
|
||
if (translations && translations[language]) {
|
||
content = translations[language];
|
||
} else if (typeof doc.content === "string") {
|
||
content = doc.content;
|
||
} else if (translations?.zh) {
|
||
content = translations.zh;
|
||
} else if (translations?.en) {
|
||
content = translations.en;
|
||
}
|
||
|
||
// 截断内容到500字,避免超出LLM上下文
|
||
const truncatedContent = content.length > 500
|
||
? content.slice(0, 500) + "..."
|
||
: content;
|
||
|
||
const ruleId = String(doc.ruleId || doc._id || "");
|
||
const ruleName = String(doc.ruleName || "未命名规则");
|
||
const jurisdiction = String(doc.jurisdiction || "未知");
|
||
const category = String(doc.category || "通用");
|
||
const description = doc.description ? String(doc.description) : undefined;
|
||
|
||
return {
|
||
ruleId,
|
||
ruleName,
|
||
jurisdiction,
|
||
category,
|
||
content: truncatedContent,
|
||
description,
|
||
score,
|
||
source: `${jurisdiction}·${category}·${ruleName.slice(0, 20)}`,
|
||
};
|
||
}
|
||
|
||
// ─── 构建RAG提示词上下文 ─────────────────────────────────────────
|
||
|
||
/**
|
||
* 将检索到的规则格式化为AI提示词中的上下文段落
|
||
*/
|
||
export function buildRAGPromptContext(ragCtx: RAGContext): string {
|
||
if (ragCtx.rules.length === 0) {
|
||
return "";
|
||
}
|
||
|
||
const lines: string[] = [
|
||
"【知识库检索结果】",
|
||
`(共检索到 ${ragCtx.totalFound} 条相关规则,检索方式:${ragCtx.retrievalMethod})`,
|
||
"",
|
||
];
|
||
|
||
ragCtx.rules.forEach((rule, idx) => {
|
||
lines.push(`【规则 ${idx + 1}】${rule.ruleName}`);
|
||
lines.push(` 管辖区:${rule.jurisdiction} | 分类:${rule.category} | 相关度:${Math.round(rule.score * 100)}%`);
|
||
if (rule.description) {
|
||
lines.push(` 摘要:${rule.description}`);
|
||
}
|
||
lines.push(` 内容:${rule.content}`);
|
||
lines.push("");
|
||
});
|
||
|
||
lines.push("请基于以上知识库内容回答用户问题,并在回答中注明引用的规则来源。");
|
||
|
||
return lines.join("\n");
|
||
}
|