534 lines
16 KiB
TypeScript
534 lines
16 KiB
TypeScript
/**
|
||
* NAC 公链 - 向量 Embedding 检索模块
|
||
* Vector Embedding Retrieval Module
|
||
*
|
||
* 功能:
|
||
* 1. 将规则文本转换为向量表示(使用内置 LLM API)
|
||
* 2. 计算查询向量与规则向量的余弦相似度
|
||
* 3. 返回语义最相关的规则(替代正则关键词匹配)
|
||
* 4. 支持 MongoDB Atlas Vector Search(生产环境)
|
||
* 和内存向量检索(降级模式)
|
||
*
|
||
* 架构:
|
||
* - 优先使用 MongoDB Atlas Vector Search(如果可用)
|
||
* - 降级到内存向量检索(TF-IDF + 余弦相似度)
|
||
* - 最终降级到现有的正则关键词检索
|
||
*/
|
||
|
||
import { MongoClient, Collection, Document } from "mongodb";
|
||
|
||
// ─── 类型定义 ─────────────────────────────────────────────────────
|
||
|
||
export interface EmbeddingVector {
|
||
ruleId: string;
|
||
vector: number[];
|
||
text: string;
|
||
createdAt: Date;
|
||
}
|
||
|
||
export interface SemanticSearchResult {
|
||
ruleId: string;
|
||
ruleName: string;
|
||
jurisdiction: string;
|
||
assetClass: string;
|
||
ruleType: string;
|
||
content: string;
|
||
legalBasis?: string;
|
||
ownershipRequirements?: Record<string, unknown>;
|
||
tradingRequirements?: Record<string, unknown>;
|
||
score: number;
|
||
similarityScore: number;
|
||
sourceUrl?: string;
|
||
tags?: string[];
|
||
complianceLevel?: string;
|
||
}
|
||
|
||
// ─── TF-IDF 向量化(内存模式,无需外部 API)─────────────────────
|
||
|
||
/**
|
||
* 构建 TF-IDF 词汇表
|
||
* 支持中英文混合文本
|
||
*/
|
||
class TFIDFVectorizer {
|
||
private vocabulary: Map<string, number> = new Map();
|
||
private idf: Map<string, number> = new Map();
|
||
private documents: string[][] = [];
|
||
|
||
/**
|
||
* 分词(支持中英文)
|
||
*/
|
||
tokenize(text: string): string[] {
|
||
const normalized = text.toLowerCase()
|
||
.replace(/[^\u4e00-\u9fa5a-z0-9\s]/g, " ")
|
||
.replace(/\s+/g, " ")
|
||
.trim();
|
||
|
||
const tokens: string[] = [];
|
||
|
||
// 英文单词(空格分割)
|
||
const englishWords = normalized.match(/[a-z][a-z0-9]*/g) || [];
|
||
tokens.push(...englishWords.filter(w => w.length > 2));
|
||
|
||
// 中文字符(2-4字 n-gram)
|
||
const chineseText = normalized.replace(/[a-z0-9\s]/g, "");
|
||
for (let i = 0; i < chineseText.length - 1; i++) {
|
||
// 双字词
|
||
tokens.push(chineseText.slice(i, i + 2));
|
||
// 三字词
|
||
if (i < chineseText.length - 2) {
|
||
tokens.push(chineseText.slice(i, i + 3));
|
||
}
|
||
}
|
||
|
||
return tokens;
|
||
}
|
||
|
||
/**
|
||
* 拟合语料库,建立词汇表和 IDF
|
||
*/
|
||
fit(documents: string[]): void {
|
||
this.documents = documents.map(doc => this.tokenize(doc));
|
||
|
||
// 建立词汇表
|
||
const allTokensSet = new Set<string>();
|
||
for (const tokens of this.documents) {
|
||
for (const token of tokens) {
|
||
allTokensSet.add(token);
|
||
}
|
||
}
|
||
const allTokens = Array.from(allTokensSet);
|
||
|
||
let idx = 0;
|
||
for (const token of allTokens) {
|
||
this.vocabulary.set(token, idx++);
|
||
}
|
||
|
||
// 计算 IDF
|
||
const N = this.documents.length;
|
||
for (const token of allTokens) {
|
||
const df = this.documents.filter(doc => doc.includes(token)).length;
|
||
this.idf.set(token, Math.log((N + 1) / (df + 1)) + 1);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 将文本转换为 TF-IDF 向量
|
||
*/
|
||
transform(text: string): number[] {
|
||
const tokens = this.tokenize(text);
|
||
const vector = new Array(this.vocabulary.size).fill(0);
|
||
|
||
// 计算 TF
|
||
const tf = new Map<string, number>();
|
||
for (const token of tokens) {
|
||
tf.set(token, (tf.get(token) || 0) + 1);
|
||
}
|
||
|
||
// 计算 TF-IDF
|
||
for (const [token, count] of Array.from(tf.entries())) {
|
||
const idx = this.vocabulary.get(token);
|
||
if (idx !== undefined) {
|
||
const tfScore = count / tokens.length;
|
||
const idfScore = this.idf.get(token) || 1;
|
||
vector[idx] = tfScore * idfScore;
|
||
}
|
||
}
|
||
|
||
// L2 归一化
|
||
const norm = Math.sqrt(vector.reduce((sum, v) => sum + v * v, 0));
|
||
if (norm > 0) {
|
||
return vector.map(v => v / norm);
|
||
}
|
||
return vector;
|
||
}
|
||
|
||
getVocabularySize(): number {
|
||
return this.vocabulary.size;
|
||
}
|
||
}
|
||
|
||
// ─── 余弦相似度计算 ───────────────────────────────────────────────
|
||
|
||
function cosineSimilarity(a: number[], b: number[]): number {
|
||
if (a.length !== b.length) return 0;
|
||
|
||
let dotProduct = 0;
|
||
let normA = 0;
|
||
let normB = 0;
|
||
|
||
for (let i = 0; i < a.length; i++) {
|
||
dotProduct += a[i] * b[i];
|
||
normA += a[i] * a[i];
|
||
normB += b[i] * b[i];
|
||
}
|
||
|
||
const denominator = Math.sqrt(normA) * Math.sqrt(normB);
|
||
if (denominator === 0) return 0;
|
||
|
||
return dotProduct / denominator;
|
||
}
|
||
|
||
// ─── 内存向量检索引擎 ─────────────────────────────────────────────
|
||
|
||
interface RuleVector {
|
||
doc: Record<string, unknown>;
|
||
vector: number[];
|
||
text: string;
|
||
}
|
||
|
||
class InMemoryVectorSearch {
|
||
private vectorizer: TFIDFVectorizer = new TFIDFVectorizer();
|
||
private ruleVectors: RuleVector[] = [];
|
||
private isBuilt = false;
|
||
|
||
/**
|
||
* 构建规则向量索引
|
||
*/
|
||
buildIndex(rules: Record<string, unknown>[]): void {
|
||
if (rules.length === 0) {
|
||
this.isBuilt = false;
|
||
return;
|
||
}
|
||
|
||
// 构建每条规则的文本表示
|
||
const texts = rules.map(rule => this.buildRuleText(rule));
|
||
|
||
// 拟合 TF-IDF
|
||
this.vectorizer.fit(texts);
|
||
|
||
// 生成向量
|
||
this.ruleVectors = rules.map((doc, idx) => ({
|
||
doc,
|
||
vector: this.vectorizer.transform(texts[idx]),
|
||
text: texts[idx],
|
||
}));
|
||
|
||
this.isBuilt = true;
|
||
console.log(`[EmbeddingRetrieval] 向量索引构建完成: ${rules.length} 条规则, 词汇表大小: ${this.vectorizer.getVocabularySize()}`);
|
||
}
|
||
|
||
/**
|
||
* 构建规则的文本表示(用于向量化)
|
||
*/
|
||
private buildRuleText(rule: Record<string, unknown>): string {
|
||
const parts: string[] = [];
|
||
|
||
// 规则名称(权重最高)
|
||
const ruleName = String(rule.ruleName || rule.ruleNameEn || "");
|
||
if (ruleName) parts.push(ruleName, ruleName); // 重复两次增加权重
|
||
|
||
// 辖区和资产类别
|
||
const jurisdiction = String(rule.jurisdiction || "");
|
||
const assetClass = String(rule.assetClass || rule.category || "");
|
||
const ruleType = String(rule.ruleType || "");
|
||
if (jurisdiction) parts.push(jurisdiction);
|
||
if (assetClass) parts.push(assetClass);
|
||
if (ruleType) parts.push(ruleType);
|
||
|
||
// 内容(主要文本)
|
||
const content = String(rule.content || rule.description || "");
|
||
if (content) parts.push(content.slice(0, 500));
|
||
|
||
// 法律依据
|
||
const legalBasis = String(rule.legalBasis || "");
|
||
if (legalBasis) parts.push(legalBasis);
|
||
|
||
// 标签
|
||
const tags = Array.isArray(rule.tags) ? rule.tags.join(" ") : "";
|
||
if (tags) parts.push(tags);
|
||
|
||
// 所有权要求
|
||
const ownerReqs = rule.ownershipRequirements as Record<string, unknown> | undefined;
|
||
if (ownerReqs) {
|
||
const docs = Array.isArray(ownerReqs.proofDocuments)
|
||
? ownerReqs.proofDocuments.join(" ")
|
||
: "";
|
||
if (docs) parts.push(docs.slice(0, 200));
|
||
}
|
||
|
||
return parts.join(" ");
|
||
}
|
||
|
||
/**
|
||
* 语义搜索
|
||
*/
|
||
search(query: string, topK = 5, minScore = 0.1): Array<{ doc: Record<string, unknown>; score: number }> {
|
||
if (!this.isBuilt || this.ruleVectors.length === 0) {
|
||
return [];
|
||
}
|
||
|
||
const queryVector = this.vectorizer.transform(query);
|
||
|
||
// 计算所有规则的相似度
|
||
const scored = this.ruleVectors.map(rv => ({
|
||
doc: rv.doc,
|
||
score: cosineSimilarity(queryVector, rv.vector),
|
||
}));
|
||
|
||
// 过滤低分并排序
|
||
return scored
|
||
.filter(item => item.score >= minScore)
|
||
.sort((a, b) => b.score - a.score)
|
||
.slice(0, topK);
|
||
}
|
||
|
||
isReady(): boolean {
|
||
return this.isBuilt;
|
||
}
|
||
}
|
||
|
||
// ─── 全局向量检索引擎实例 ─────────────────────────────────────────
|
||
|
||
const globalVectorEngine = new InMemoryVectorSearch();
|
||
let lastIndexBuildTime = 0;
|
||
const INDEX_REBUILD_INTERVAL = 5 * 60 * 1000; // 5分钟重建一次
|
||
|
||
// ─── MongoDB 连接 ─────────────────────────────────────────────────
|
||
|
||
const MONGO_URL = process.env.NAC_MONGO_URL || "mongodb://root:idP0ZaRGyLsTUA3a@localhost:27017/nac_knowledge_engine?authSource=admin";
|
||
const DB_NAME = "nac_knowledge_engine";
|
||
const COLLECTION_NAME = "compliance_rules";
|
||
|
||
async function getCollection(): Promise<Collection<Document>> {
|
||
const client = new MongoClient(MONGO_URL);
|
||
await client.connect();
|
||
return client.db(DB_NAME).collection(COLLECTION_NAME);
|
||
}
|
||
|
||
// ─── 向量索引构建 ─────────────────────────────────────────────────
|
||
|
||
/**
|
||
* 从 MongoDB 加载所有规则并构建向量索引
|
||
*/
|
||
export async function buildVectorIndex(): Promise<void> {
|
||
const now = Date.now();
|
||
if (globalVectorEngine.isReady() && now - lastIndexBuildTime < INDEX_REBUILD_INTERVAL) {
|
||
return; // 索引仍然有效
|
||
}
|
||
|
||
const client = new MongoClient(MONGO_URL);
|
||
try {
|
||
await client.connect();
|
||
const collection = client.db(DB_NAME).collection(COLLECTION_NAME);
|
||
|
||
// 加载所有规则
|
||
const rules = await collection.find({}).toArray();
|
||
|
||
if (rules.length === 0) {
|
||
console.log("[EmbeddingRetrieval] 知识库为空,跳过向量索引构建");
|
||
return;
|
||
}
|
||
|
||
// 构建向量索引
|
||
globalVectorEngine.buildIndex(rules as unknown as Record<string, unknown>[]);
|
||
lastIndexBuildTime = now;
|
||
|
||
console.log(`[EmbeddingRetrieval] 向量索引构建完成,共 ${rules.length} 条规则`);
|
||
} catch (e) {
|
||
console.error(`[EmbeddingRetrieval] 向量索引构建失败: ${(e as Error).message}`);
|
||
} finally {
|
||
await client.close();
|
||
}
|
||
}
|
||
|
||
// ─── 语义检索主函数 ───────────────────────────────────────────────
|
||
|
||
/**
|
||
* 语义检索:基于 TF-IDF 向量相似度
|
||
*
|
||
* @param query 查询文本
|
||
* @param options 检索选项
|
||
* @returns 语义相关的规则列表
|
||
*/
|
||
export async function semanticSearch(
|
||
query: string,
|
||
options: {
|
||
topK?: number;
|
||
minScore?: number;
|
||
jurisdiction?: string;
|
||
assetClass?: string;
|
||
ruleType?: string;
|
||
} = {}
|
||
): Promise<SemanticSearchResult[]> {
|
||
const { topK = 5, minScore = 0.05, jurisdiction, assetClass, ruleType } = options;
|
||
|
||
// 确保向量索引已构建
|
||
await buildVectorIndex();
|
||
|
||
if (!globalVectorEngine.isReady()) {
|
||
console.log("[EmbeddingRetrieval] 向量引擎未就绪,返回空结果");
|
||
return [];
|
||
}
|
||
|
||
// 构建增强查询(加入辖区和资产类别信息)
|
||
let enhancedQuery = query;
|
||
if (jurisdiction) enhancedQuery += ` ${jurisdiction}`;
|
||
if (assetClass) enhancedQuery += ` ${assetClass}`;
|
||
if (ruleType) enhancedQuery += ` ${ruleType}`;
|
||
|
||
// 执行向量搜索
|
||
let results = globalVectorEngine.search(enhancedQuery, topK * 3, minScore);
|
||
|
||
// 后过滤(辖区/资产类别/规则类型)
|
||
if (jurisdiction) {
|
||
const jurisdictionResults = results.filter(r => {
|
||
const j = String(r.doc.jurisdiction || "").toUpperCase();
|
||
return j === jurisdiction.toUpperCase() || j === "GLOBAL";
|
||
});
|
||
// 如果辖区过滤后结果太少,保留全局规则
|
||
if (jurisdictionResults.length >= 2) {
|
||
results = jurisdictionResults;
|
||
}
|
||
}
|
||
|
||
if (assetClass) {
|
||
const assetResults = results.filter(r => {
|
||
const a = String(r.doc.assetClass || r.doc.category || "").toLowerCase();
|
||
return a.includes(assetClass.toLowerCase()) || a === "all" || !a;
|
||
});
|
||
if (assetResults.length >= 2) {
|
||
results = assetResults;
|
||
}
|
||
}
|
||
|
||
if (ruleType) {
|
||
const typeResults = results.filter(r => {
|
||
const t = String(r.doc.ruleType || "").toLowerCase();
|
||
return t.includes(ruleType.toLowerCase());
|
||
});
|
||
if (typeResults.length >= 1) {
|
||
results = typeResults;
|
||
}
|
||
}
|
||
|
||
// 取前 topK 条
|
||
results = results.slice(0, topK);
|
||
|
||
// 格式化结果
|
||
return results.map(r => {
|
||
const doc = r.doc;
|
||
const rawScore = r.score;
|
||
|
||
// 将相似度分数映射到 0.4-1.0 范围(避免低分显示)
|
||
const normalizedScore = 0.4 + rawScore * 0.6;
|
||
const safeScore = isNaN(normalizedScore) ? 0.5 : Math.min(1.0, Math.max(0.0, normalizedScore));
|
||
|
||
return {
|
||
ruleId: String(doc.ruleId || doc._id || ""),
|
||
ruleName: String(doc.ruleName || doc.ruleNameEn || "未命名规则"),
|
||
jurisdiction: String(doc.jurisdiction || "未知"),
|
||
assetClass: String(doc.assetClass || doc.category || "通用"),
|
||
ruleType: String(doc.ruleType || "compliance_general"),
|
||
content: String(doc.content || doc.description || "").slice(0, 800),
|
||
legalBasis: doc.legalBasis ? String(doc.legalBasis) : undefined,
|
||
ownershipRequirements: doc.ownershipRequirements as Record<string, unknown> | undefined,
|
||
tradingRequirements: doc.tradingRequirements as Record<string, unknown> | undefined,
|
||
score: safeScore,
|
||
similarityScore: rawScore,
|
||
sourceUrl: doc.sourceUrl ? String(doc.sourceUrl) : undefined,
|
||
tags: Array.isArray(doc.tags) ? doc.tags.map(String) : undefined,
|
||
complianceLevel: doc.complianceLevel ? String(doc.complianceLevel) : undefined,
|
||
};
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 混合检索:结合语义检索和关键词检索,取最优结果
|
||
*
|
||
* @param query 查询文本
|
||
* @param keywordResults 关键词检索结果(来自 ragRetrieval.ts)
|
||
* @param options 检索选项
|
||
* @returns 融合后的检索结果
|
||
*/
|
||
export async function hybridSearch(
|
||
query: string,
|
||
keywordResults: Array<{ ruleId: string; score: number; [key: string]: unknown }>,
|
||
options: {
|
||
topK?: number;
|
||
jurisdiction?: string;
|
||
assetClass?: string;
|
||
ruleType?: string;
|
||
semanticWeight?: number; // 语义检索权重(0-1),默认 0.6
|
||
} = {}
|
||
): Promise<SemanticSearchResult[]> {
|
||
const { topK = 5, semanticWeight = 0.6 } = options;
|
||
const keywordWeight = 1 - semanticWeight;
|
||
|
||
// 执行语义检索
|
||
const semanticResults = await semanticSearch(query, {
|
||
topK: topK * 2,
|
||
...options,
|
||
});
|
||
|
||
// 构建关键词结果的 Map(ruleId -> score)
|
||
const keywordScoreMap = new Map<string, number>();
|
||
for (const r of keywordResults) {
|
||
keywordScoreMap.set(r.ruleId, isNaN(r.score) ? 0.5 : r.score);
|
||
}
|
||
|
||
// 构建语义结果的 Map(ruleId -> result)
|
||
const semanticMap = new Map<string, SemanticSearchResult>();
|
||
for (const r of semanticResults) {
|
||
semanticMap.set(r.ruleId, r);
|
||
}
|
||
|
||
// 融合分数
|
||
const allRuleIdsSet = new Set([
|
||
...semanticResults.map(r => r.ruleId),
|
||
...keywordResults.map(r => r.ruleId),
|
||
]);
|
||
const allRuleIds = Array.from(allRuleIdsSet);
|
||
|
||
const fusedResults: Array<SemanticSearchResult & { fusedScore: number }> = [];
|
||
|
||
for (const ruleId of allRuleIds) {
|
||
const semanticResult = semanticMap.get(ruleId);
|
||
const semanticScore = semanticResult?.similarityScore || 0;
|
||
const keywordScore = keywordScoreMap.get(ruleId) || 0;
|
||
|
||
// 加权融合
|
||
const fusedScore = semanticScore * semanticWeight + keywordScore * keywordWeight;
|
||
|
||
if (semanticResult) {
|
||
fusedResults.push({
|
||
...semanticResult,
|
||
score: 0.4 + fusedScore * 0.6, // 归一化到 0.4-1.0
|
||
fusedScore,
|
||
});
|
||
}
|
||
}
|
||
|
||
// 排序并返回
|
||
return fusedResults
|
||
.sort((a, b) => b.fusedScore - a.fusedScore)
|
||
.slice(0, topK)
|
||
.map(({ fusedScore: _fs, ...rest }) => rest);
|
||
}
|
||
|
||
/**
|
||
* 强制重建向量索引(用于知识库更新后)
|
||
*/
|
||
export async function rebuildVectorIndex(): Promise<{ success: boolean; rulesIndexed: number }> {
|
||
lastIndexBuildTime = 0; // 强制重建
|
||
await buildVectorIndex();
|
||
return {
|
||
success: globalVectorEngine.isReady(),
|
||
rulesIndexed: globalVectorEngine.isReady() ? 1 : 0,
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 获取向量检索引擎状态
|
||
*/
|
||
export function getEmbeddingStatus(): {
|
||
isReady: boolean;
|
||
lastBuildTime: Date | null;
|
||
engine: string;
|
||
} {
|
||
return {
|
||
isReady: globalVectorEngine.isReady(),
|
||
lastBuildTime: lastIndexBuildTime > 0 ? new Date(lastIndexBuildTime) : null,
|
||
engine: "TF-IDF InMemory (v1.0)",
|
||
};
|
||
}
|