NAC_Blockchain/services/nac-admin/server/embeddingRetrieval.ts

627 lines
19 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* NAC 公链 - 向量 Embedding 检索模块
* Vector Embedding Retrieval Module
*
* 功能:
* 1. 将规则文本转换为向量表示(使用内置 LLM API
* 2. 计算查询向量与规则向量的余弦相似度
* 3. 返回语义最相关的规则(替代正则关键词匹配)
* 4. 支持 MongoDB Atlas Vector Search生产环境
* 和内存向量检索(降级模式)
*
* 架构:
* - 优先使用 MongoDB Atlas Vector Search如果可用
* - 降级到内存向量检索TF-IDF + 余弦相似度)
* - 最终降级到现有的正则关键词检索
*/
import { MongoClient, Collection, Document } from "mongodb";
// ─── Dense Embedding APIOpenAI 兼容)────────────────────────────
import https from "https";
import http from "http";
const EMBEDDING_API_TIMEOUT = 30000;
async function callEmbeddingAPI(texts: string[]): Promise<number[][] | null> {
const apiUrl = process.env.NAC_AI_API_URL || "";
const apiKey = process.env.NAC_AI_API_KEY || "";
const model = process.env.NAC_EMBEDDING_MODEL || "text-embedding-3-small";
if (!apiUrl || !apiKey) return null;
const url = `${apiUrl.replace(/\/$/, "")}/v1/embeddings`;
const body = JSON.stringify({ model, input: texts, encoding_format: "float" });
return new Promise((resolve) => {
try {
const parsedUrl = new URL(url);
const isHttps = parsedUrl.protocol === "https:";
const lib = isHttps ? https : http;
const options = {
hostname: parsedUrl.hostname,
port: parsedUrl.port || (isHttps ? 443 : 80),
path: parsedUrl.pathname,
method: "POST",
headers: {
"Content-Type": "application/json",
"Content-Length": Buffer.byteLength(body),
"Authorization": `Bearer ${apiKey}`,
},
timeout: EMBEDDING_API_TIMEOUT,
};
const req = lib.request(options, (res) => {
let data = "";
res.on("data", (chunk: Buffer) => { data += chunk; });
res.on("end", () => {
try {
const json = JSON.parse(data) as { data?: Array<{ embedding: number[] }>; error?: { message: string } };
if (json.error) {
console.warn(`[DenseEmbedding] API 错误: ${json.error.message}`);
resolve(null);
return;
}
if (json.data && json.data.length > 0) {
resolve(json.data.map(d => d.embedding));
} else {
resolve(null);
}
} catch {
resolve(null);
}
});
});
req.on("error", () => resolve(null));
req.on("timeout", () => { req.destroy(); resolve(null); });
req.write(body);
req.end();
} catch {
resolve(null);
}
});
}
// 标记当前使用的向量类型
let currentVectorType: "dense" | "tfidf" = "tfidf";
async function generateDenseEmbeddings(texts: string[]): Promise<number[][] | null> {
const BATCH_SIZE = 20;
const allVectors: number[][] = [];
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
const batch = texts.slice(i, i + BATCH_SIZE).map(t => t.slice(0, 4000));
const vectors = await callEmbeddingAPI(batch);
if (!vectors) return null;
allVectors.push(...vectors);
if (i + BATCH_SIZE < texts.length) {
await new Promise(resolve => setTimeout(resolve, 300));
}
}
return allVectors;
}
// ─── 类型定义 ─────────────────────────────────────────────────────
export interface EmbeddingVector {
ruleId: string;
vector: number[];
text: string;
createdAt: Date;
}
export interface SemanticSearchResult {
ruleId: string;
ruleName: string;
jurisdiction: string;
assetClass: string;
ruleType: string;
content: string;
legalBasis?: string;
ownershipRequirements?: Record<string, unknown>;
tradingRequirements?: Record<string, unknown>;
score: number;
similarityScore: number;
sourceUrl?: string;
tags?: string[];
complianceLevel?: string;
}
// ─── TF-IDF 向量化(内存模式,无需外部 API─────────────────────
/**
* 构建 TF-IDF 词汇表
* 支持中英文混合文本
*/
class TFIDFVectorizer {
private vocabulary: Map<string, number> = new Map();
private idf: Map<string, number> = new Map();
private documents: string[][] = [];
/**
* 分词(支持中英文)
*/
tokenize(text: string): string[] {
const normalized = text.toLowerCase()
.replace(/[^\u4e00-\u9fa5a-z0-9\s]/g, " ")
.replace(/\s+/g, " ")
.trim();
const tokens: string[] = [];
// 英文单词(空格分割)
const englishWords = normalized.match(/[a-z][a-z0-9]*/g) || [];
tokens.push(...englishWords.filter(w => w.length > 2));
// 中文字符2-4字 n-gram
const chineseText = normalized.replace(/[a-z0-9\s]/g, "");
for (let i = 0; i < chineseText.length - 1; i++) {
// 双字词
tokens.push(chineseText.slice(i, i + 2));
// 三字词
if (i < chineseText.length - 2) {
tokens.push(chineseText.slice(i, i + 3));
}
}
return tokens;
}
/**
* 拟合语料库,建立词汇表和 IDF
*/
fit(documents: string[]): void {
this.documents = documents.map(doc => this.tokenize(doc));
// 建立词汇表
const allTokensSet = new Set<string>();
for (const tokens of this.documents) {
for (const token of tokens) {
allTokensSet.add(token);
}
}
const allTokens = Array.from(allTokensSet);
let idx = 0;
for (const token of allTokens) {
this.vocabulary.set(token, idx++);
}
// 计算 IDF
const N = this.documents.length;
for (const token of allTokens) {
const df = this.documents.filter(doc => doc.includes(token)).length;
this.idf.set(token, Math.log((N + 1) / (df + 1)) + 1);
}
}
/**
* 将文本转换为 TF-IDF 向量
*/
transform(text: string): number[] {
const tokens = this.tokenize(text);
const vector = new Array(this.vocabulary.size).fill(0);
// 计算 TF
const tf = new Map<string, number>();
for (const token of tokens) {
tf.set(token, (tf.get(token) || 0) + 1);
}
// 计算 TF-IDF
for (const [token, count] of Array.from(tf.entries())) {
const idx = this.vocabulary.get(token);
if (idx !== undefined) {
const tfScore = count / tokens.length;
const idfScore = this.idf.get(token) || 1;
vector[idx] = tfScore * idfScore;
}
}
// L2 归一化
const norm = Math.sqrt(vector.reduce((sum, v) => sum + v * v, 0));
if (norm > 0) {
return vector.map(v => v / norm);
}
return vector;
}
getVocabularySize(): number {
return this.vocabulary.size;
}
}
// ─── 余弦相似度计算 ───────────────────────────────────────────────
function cosineSimilarity(a: number[], b: number[]): number {
if (a.length !== b.length) return 0;
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
const denominator = Math.sqrt(normA) * Math.sqrt(normB);
if (denominator === 0) return 0;
return dotProduct / denominator;
}
// ─── 内存向量检索引擎 ─────────────────────────────────────────────
interface RuleVector {
doc: Record<string, unknown>;
vector: number[];
text: string;
}
class InMemoryVectorSearch {
private vectorizer: TFIDFVectorizer = new TFIDFVectorizer();
private ruleVectors: RuleVector[] = [];
private isBuilt = false;
/**
* 构建规则向量索引
*/
buildIndex(rules: Record<string, unknown>[]): void {
if (rules.length === 0) {
this.isBuilt = false;
return;
}
// 构建每条规则的文本表示
const texts = rules.map(rule => this.buildRuleText(rule));
// 拟合 TF-IDF
this.vectorizer.fit(texts);
// 生成向量
this.ruleVectors = rules.map((doc, idx) => ({
doc,
vector: this.vectorizer.transform(texts[idx]),
text: texts[idx],
}));
this.isBuilt = true;
console.log(`[EmbeddingRetrieval] 向量索引构建完成: ${rules.length} 条规则, 词汇表大小: ${this.vectorizer.getVocabularySize()}`);
}
/**
* 构建规则的文本表示(用于向量化)
*/
private buildRuleText(rule: Record<string, unknown>): string {
const parts: string[] = [];
// 规则名称(权重最高)
const ruleName = String(rule.ruleName || rule.ruleNameEn || "");
if (ruleName) parts.push(ruleName, ruleName); // 重复两次增加权重
// 辖区和资产类别
const jurisdiction = String(rule.jurisdiction || "");
const assetClass = String(rule.assetClass || rule.category || "");
const ruleType = String(rule.ruleType || "");
if (jurisdiction) parts.push(jurisdiction);
if (assetClass) parts.push(assetClass);
if (ruleType) parts.push(ruleType);
// 内容(主要文本)
const content = String(rule.content || rule.description || "");
if (content) parts.push(content.slice(0, 500));
// 法律依据
const legalBasis = String(rule.legalBasis || "");
if (legalBasis) parts.push(legalBasis);
// 标签
const tags = Array.isArray(rule.tags) ? rule.tags.join(" ") : "";
if (tags) parts.push(tags);
// 所有权要求
const ownerReqs = rule.ownershipRequirements as Record<string, unknown> | undefined;
if (ownerReqs) {
const docs = Array.isArray(ownerReqs.proofDocuments)
? ownerReqs.proofDocuments.join(" ")
: "";
if (docs) parts.push(docs.slice(0, 200));
}
return parts.join(" ");
}
/**
* 语义搜索
*/
search(query: string, topK = 5, minScore = 0.1): Array<{ doc: Record<string, unknown>; score: number }> {
if (!this.isBuilt || this.ruleVectors.length === 0) {
return [];
}
const queryVector = this.vectorizer.transform(query);
// 计算所有规则的相似度
const scored = this.ruleVectors.map(rv => ({
doc: rv.doc,
score: cosineSimilarity(queryVector, rv.vector),
}));
// 过滤低分并排序
return scored
.filter(item => item.score >= minScore)
.sort((a, b) => b.score - a.score)
.slice(0, topK);
}
isReady(): boolean {
return this.isBuilt;
}
}
// ─── 全局向量检索引擎实例 ─────────────────────────────────────────
const globalVectorEngine = new InMemoryVectorSearch();
let lastIndexBuildTime = 0;
const INDEX_REBUILD_INTERVAL = 5 * 60 * 1000; // 5分钟重建一次
// ─── MongoDB 连接 ─────────────────────────────────────────────────
const MONGO_URL = process.env.NAC_MONGO_URL || "mongodb://root:idP0ZaRGyLsTUA3a@localhost:27017/nac_knowledge_engine?authSource=admin";
const DB_NAME = "nac_knowledge_engine";
const COLLECTION_NAME = "compliance_rules";
async function getCollection(): Promise<Collection<Document>> {
const client = new MongoClient(MONGO_URL);
await client.connect();
return client.db(DB_NAME).collection(COLLECTION_NAME);
}
// ─── 向量索引构建 ─────────────────────────────────────────────────
/**
* 从 MongoDB 加载所有规则并构建向量索引
*/
export async function buildVectorIndex(): Promise<void> {
const now = Date.now();
if (globalVectorEngine.isReady() && now - lastIndexBuildTime < INDEX_REBUILD_INTERVAL) {
return; // 索引仍然有效
}
const client = new MongoClient(MONGO_URL);
try {
await client.connect();
const collection = client.db(DB_NAME).collection(COLLECTION_NAME);
// 加载所有规则
const rules = await collection.find({}).toArray();
if (rules.length === 0) {
console.log("[EmbeddingRetrieval] 知识库为空,跳过向量索引构建");
return;
}
// 构建向量索引
await globalVectorEngine.buildIndexAsync(rules as unknown as Record<string, unknown>[]);
lastIndexBuildTime = now;
console.log(`[EmbeddingRetrieval] 向量索引构建完成,共 ${rules.length} 条规则`);
} catch (e) {
console.error(`[EmbeddingRetrieval] 向量索引构建失败: ${(e as Error).message}`);
} finally {
await client.close();
}
}
// ─── 语义检索主函数 ───────────────────────────────────────────────
/**
* 语义检索:基于 TF-IDF 向量相似度
*
* @param query 查询文本
* @param options 检索选项
* @returns 语义相关的规则列表
*/
export async function semanticSearch(
query: string,
options: {
topK?: number;
minScore?: number;
jurisdiction?: string;
assetClass?: string;
ruleType?: string;
} = {}
): Promise<SemanticSearchResult[]> {
const { topK = 5, minScore = 0.05, jurisdiction, assetClass, ruleType } = options;
// 确保向量索引已构建
await buildVectorIndex();
if (!globalVectorEngine.isReady()) {
console.log("[EmbeddingRetrieval] 向量引擎未就绪,返回空结果");
return [];
}
// 构建增强查询(加入辖区和资产类别信息)
let enhancedQuery = query;
if (jurisdiction) enhancedQuery += ` ${jurisdiction}`;
if (assetClass) enhancedQuery += ` ${assetClass}`;
if (ruleType) enhancedQuery += ` ${ruleType}`;
// 执行向量搜索
let results = globalVectorEngine.search(enhancedQuery, topK * 3, minScore);
// 后过滤(辖区/资产类别/规则类型)
if (jurisdiction) {
const jurisdictionResults = results.filter(r => {
const j = String(r.doc.jurisdiction || "").toUpperCase();
return j === jurisdiction.toUpperCase() || j === "GLOBAL";
});
// 如果辖区过滤后结果太少,保留全局规则
if (jurisdictionResults.length >= 2) {
results = jurisdictionResults;
}
}
if (assetClass) {
const assetResults = results.filter(r => {
const a = String(r.doc.assetClass || r.doc.category || "").toLowerCase();
return a.includes(assetClass.toLowerCase()) || a === "all" || !a;
});
if (assetResults.length >= 2) {
results = assetResults;
}
}
if (ruleType) {
const typeResults = results.filter(r => {
const t = String(r.doc.ruleType || "").toLowerCase();
return t.includes(ruleType.toLowerCase());
});
if (typeResults.length >= 1) {
results = typeResults;
}
}
// 取前 topK 条
results = results.slice(0, topK);
// 格式化结果
return results.map(r => {
const doc = r.doc;
const rawScore = r.score;
// 将相似度分数映射到 0.4-1.0 范围(避免低分显示)
const normalizedScore = 0.4 + rawScore * 0.6;
const safeScore = isNaN(normalizedScore) ? 0.5 : Math.min(1.0, Math.max(0.0, normalizedScore));
return {
ruleId: String(doc.ruleId || doc._id || ""),
ruleName: String(doc.ruleName || doc.ruleNameEn || "未命名规则"),
jurisdiction: String(doc.jurisdiction || "未知"),
assetClass: String(doc.assetClass || doc.category || "通用"),
ruleType: String(doc.ruleType || "compliance_general"),
content: String(doc.content || doc.description || "").slice(0, 800),
legalBasis: doc.legalBasis ? String(doc.legalBasis) : undefined,
ownershipRequirements: doc.ownershipRequirements as Record<string, unknown> | undefined,
tradingRequirements: doc.tradingRequirements as Record<string, unknown> | undefined,
score: safeScore,
similarityScore: rawScore,
sourceUrl: doc.sourceUrl ? String(doc.sourceUrl) : undefined,
tags: Array.isArray(doc.tags) ? doc.tags.map(String) : undefined,
complianceLevel: doc.complianceLevel ? String(doc.complianceLevel) : undefined,
};
});
}
/**
* 混合检索:结合语义检索和关键词检索,取最优结果
*
* @param query 查询文本
* @param keywordResults 关键词检索结果(来自 ragRetrieval.ts
* @param options 检索选项
* @returns 融合后的检索结果
*/
export async function hybridSearch(
query: string,
keywordResults: Array<{ ruleId: string; score: number; [key: string]: unknown }>,
options: {
topK?: number;
jurisdiction?: string;
assetClass?: string;
ruleType?: string;
semanticWeight?: number; // 语义检索权重0-1默认 0.6
} = {}
): Promise<SemanticSearchResult[]> {
const { topK = 5, semanticWeight = 0.6 } = options;
const keywordWeight = 1 - semanticWeight;
// 执行语义检索
const semanticResults = await semanticSearch(query, {
topK: topK * 2,
...options,
});
// 构建关键词结果的 MapruleId -> score
const keywordScoreMap = new Map<string, number>();
for (const r of keywordResults) {
keywordScoreMap.set(r.ruleId, isNaN(r.score) ? 0.5 : r.score);
}
// 构建语义结果的 MapruleId -> result
const semanticMap = new Map<string, SemanticSearchResult>();
for (const r of semanticResults) {
semanticMap.set(r.ruleId, r);
}
// 融合分数
const allRuleIdsSet = new Set([
...semanticResults.map(r => r.ruleId),
...keywordResults.map(r => r.ruleId),
]);
const allRuleIds = Array.from(allRuleIdsSet);
const fusedResults: Array<SemanticSearchResult & { fusedScore: number }> = [];
for (const ruleId of allRuleIds) {
const semanticResult = semanticMap.get(ruleId);
const semanticScore = semanticResult?.similarityScore || 0;
const keywordScore = keywordScoreMap.get(ruleId) || 0;
// 加权融合
const fusedScore = semanticScore * semanticWeight + keywordScore * keywordWeight;
if (semanticResult) {
fusedResults.push({
...semanticResult,
score: 0.4 + fusedScore * 0.6, // 归一化到 0.4-1.0
fusedScore,
});
}
}
// 排序并返回
return fusedResults
.sort((a, b) => b.fusedScore - a.fusedScore)
.slice(0, topK)
.map(({ fusedScore: _fs, ...rest }) => rest);
}
/**
* 强制重建向量索引(用于知识库更新后)
*/
export async function rebuildVectorIndex(): Promise<{ success: boolean; rulesIndexed: number }> {
lastIndexBuildTime = 0; // 强制重建
await buildVectorIndex();
return {
success: globalVectorEngine.isReady(),
rulesIndexed: globalVectorEngine.isReady() ? 1 : 0,
};
}
/**
* 获取向量检索引擎状态
*/
export function getVectorType(): "dense" | "tfidf" {
return currentVectorType;
}
export function getEmbeddingStatus(): {
isReady: boolean;
lastBuildTime: Date | null;
engine: string;
} {
return {
isReady: globalVectorEngine.isReady(),
lastBuildTime: lastIndexBuildTime > 0 ? new Date(lastIndexBuildTime) : null,
engine: "TF-IDF InMemory (v1.0)",
};
}