NAC_Blockchain/ops/nac-admin/server/denseEmbeddingRetrieval.ts

482 lines
16 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* NAC Dense Embedding 检索模块
*
* 使用 OpenAI 兼容的 text-embedding API 实现语义向量检索
* 支持中英文混合查询,余弦相似度排序
*
* 架构:
* 1. 查询向量化:将用户查询转为 embedding 向量
* 2. 规则向量化:预计算规则 embedding 并缓存到 MongoDB
* 3. 余弦相似度:计算查询与规则的语义相似度
* 4. 混合排序:结合关键词分数和语义分数
*/
import https from "https";
import http from "http";
// ─── 类型定义 ─────────────────────────────────────────────────────
export interface EmbeddingVector {
ruleId: string;
vector: number[];
model: string;
createdAt: Date;
textHash: string; // 用于检测规则内容变化
}
export interface SemanticSearchResult {
ruleId: string;
jurisdiction: string;
assetClass: string;
ruleType: string;
title: string;
content: string;
ownershipRequirements?: string[];
tradingRequirements?: string[];
legalBasis?: string;
officialSource?: string;
semanticScore: number; // 语义相似度 0-1
keywordScore: number; // 关键词匹配分数 0-1
combinedScore: number; // 综合分数 0-1
}
// ─── Embedding API 配置 ───────────────────────────────────────────
const getEmbeddingConfig = () => {
const apiUrl = process.env.NAC_AI_API_URL || "";
const apiKey = process.env.NAC_AI_API_KEY || "";
const model = process.env.NAC_EMBEDDING_MODEL || "text-embedding-3-small";
if (!apiUrl || !apiKey) {
return null; // 降级到 TF-IDF
}
return {
url: `${apiUrl.replace(/\/$/, "")}/v1/embeddings`,
key: apiKey,
model,
};
};
// ─── HTTP 请求工具 ────────────────────────────────────────────────
async function httpPost(url: string, headers: Record<string, string>, body: object): Promise<unknown> {
return new Promise((resolve, reject) => {
const bodyStr = JSON.stringify(body);
const parsedUrl = new URL(url);
const isHttps = parsedUrl.protocol === "https:";
const lib = isHttps ? https : http;
const options = {
hostname: parsedUrl.hostname,
port: parsedUrl.port || (isHttps ? 443 : 80),
path: parsedUrl.pathname + parsedUrl.search,
method: "POST",
headers: {
"Content-Type": "application/json",
"Content-Length": Buffer.byteLength(bodyStr),
...headers,
},
timeout: 30000,
};
const req = lib.request(options, (res) => {
let data = "";
res.on("data", (chunk) => { data += chunk; });
res.on("end", () => {
try {
resolve(JSON.parse(data));
} catch {
reject(new Error(`JSON 解析失败: ${data.slice(0, 200)}`));
}
});
});
req.on("error", reject);
req.on("timeout", () => {
req.destroy();
reject(new Error("Embedding API 请求超时"));
});
req.write(bodyStr);
req.end();
});
}
// ─── Embedding 生成 ───────────────────────────────────────────────
/**
* 生成文本的 embedding 向量
* 失败时返回 null降级到 TF-IDF
*/
export async function generateEmbedding(text: string): Promise<number[] | null> {
const config = getEmbeddingConfig();
if (!config) return null;
// 截断过长文本embedding 模型通常限制 8192 tokens
const truncatedText = text.slice(0, 4000);
try {
const response = await httpPost(
config.url,
{ "Authorization": `Bearer ${config.key}` },
{
model: config.model,
input: truncatedText,
encoding_format: "float",
}
) as { data?: Array<{ embedding: number[] }>; error?: { message: string } };
if (response.error) {
console.warn(`[Embedding] API 错误: ${response.error.message}`);
return null;
}
if (response.data && response.data[0]?.embedding) {
return response.data[0].embedding;
}
return null;
} catch (err) {
console.warn(`[Embedding] 生成失败: ${(err as Error).message}`);
return null;
}
}
/**
* 批量生成 embedding带速率限制
*/
export async function generateEmbeddingsBatch(
texts: string[],
batchSize = 20
): Promise<Array<number[] | null>> {
const results: Array<number[] | null> = [];
for (let i = 0; i < texts.length; i += batchSize) {
const batch = texts.slice(i, i + batchSize);
const batchResults = await Promise.all(
batch.map(text => generateEmbedding(text))
);
results.push(...batchResults);
// 速率限制:每批次间隔 500ms
if (i + batchSize < texts.length) {
await new Promise(resolve => setTimeout(resolve, 500));
}
}
return results;
}
// ─── 余弦相似度计算 ───────────────────────────────────────────────
/**
* 计算两个向量的余弦相似度
*/
export function cosineSimilarity(vecA: number[], vecB: number[]): number {
if (vecA.length !== vecB.length || vecA.length === 0) return 0;
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < vecA.length; i++) {
dotProduct += vecA[i] * vecB[i];
normA += vecA[i] * vecA[i];
normB += vecB[i] * vecB[i];
}
const denominator = Math.sqrt(normA) * Math.sqrt(normB);
if (denominator === 0) return 0;
// 余弦相似度范围 [-1, 1],归一化到 [0, 1]
return (dotProduct / denominator + 1) / 2;
}
// ─── 规则文本构建 ─────────────────────────────────────────────────
/**
* 将规则对象转为用于 embedding 的文本
* 包含所有关键字段,提升语义检索质量
*/
export function buildRuleEmbeddingText(rule: Record<string, unknown>): string {
const parts: string[] = [];
// 标题和基本信息
if (rule.title) parts.push(`标题: ${rule.title}`);
if (rule.jurisdiction) parts.push(`司法辖区: ${rule.jurisdiction}`);
if (rule.assetClass) parts.push(`资产类别: ${rule.assetClass}`);
if (rule.ruleType) parts.push(`规则类型: ${rule.ruleType}`);
// 主要内容
if (rule.content) parts.push(`内容: ${String(rule.content).slice(0, 1000)}`);
// 描述(旧格式兼容)
if (rule.description) parts.push(`描述: ${rule.description}`);
if (rule.descriptionI18n) {
const i18n = rule.descriptionI18n as Record<string, string>;
if (i18n.zh) parts.push(`中文描述: ${i18n.zh}`);
if (i18n.en) parts.push(`English: ${i18n.en}`);
}
// 所有权要求
if (Array.isArray(rule.ownershipRequirements) && rule.ownershipRequirements.length > 0) {
parts.push(`所有权要求: ${(rule.ownershipRequirements as string[]).join("; ")}`);
}
// 交易要求
if (Array.isArray(rule.tradingRequirements) && rule.tradingRequirements.length > 0) {
parts.push(`交易规则: ${(rule.tradingRequirements as string[]).join("; ")}`);
}
// 法律依据
if (rule.legalBasis) parts.push(`法律依据: ${rule.legalBasis}`);
// 标签
if (Array.isArray(rule.tags) && rule.tags.length > 0) {
parts.push(`标签: ${(rule.tags as string[]).join(", ")}`);
}
return parts.join("\n");
}
// ─── 简单哈希(用于检测内容变化)────────────────────────────────
function simpleHash(text: string): string {
let hash = 0;
for (let i = 0; i < text.length; i++) {
const char = text.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash; // 转为 32 位整数
}
return Math.abs(hash).toString(16);
}
// ─── MongoDB 向量缓存 ─────────────────────────────────────────────
/**
* 从 MongoDB 加载所有规则的 embedding 缓存
*/
export async function loadEmbeddingCache(
db: { collection: (name: string) => { find: (q: object) => { toArray: () => Promise<unknown[]> } } }
): Promise<Map<string, EmbeddingVector>> {
const cache = new Map<string, EmbeddingVector>();
try {
const vectors = await db.collection("rule_embeddings").find({}).toArray() as EmbeddingVector[];
for (const v of vectors) {
cache.set(v.ruleId, v);
}
console.log(`[Embedding] 已加载 ${cache.size} 条 embedding 缓存`);
} catch (err) {
console.warn(`[Embedding] 加载缓存失败: ${(err as Error).message}`);
}
return cache;
}
/**
* 保存 embedding 到 MongoDB 缓存
*/
export async function saveEmbeddingCache(
db: { collection: (name: string) => { updateOne: (filter: object, update: object, options: object) => Promise<unknown> } },
ruleId: string,
vector: number[],
model: string,
textHash: string
): Promise<void> {
try {
await db.collection("rule_embeddings").updateOne(
{ ruleId },
{
$set: {
ruleId,
vector,
model,
textHash,
createdAt: new Date(),
}
},
{ upsert: true }
);
} catch (err) {
console.warn(`[Embedding] 保存缓存失败: ${(err as Error).message}`);
}
}
// ─── 主检索函数 ───────────────────────────────────────────────────
/**
* Dense Embedding 语义检索
*
* @param query 用户查询文本
* @param rules 候选规则列表
* @param db MongoDB 实例(用于 embedding 缓存)
* @param topK 返回最相关的 K 条规则
* @param keywordScores 关键词预匹配分数(可选,用于混合排序)
*/
export async function semanticSearch(
query: string,
rules: Record<string, unknown>[],
db: {
collection: (name: string) => {
find: (q: object) => { toArray: () => Promise<unknown[]> };
updateOne: (filter: object, update: object, options: object) => Promise<unknown>;
}
} | null,
topK = 5,
keywordScores: Map<string, number> = new Map()
): Promise<SemanticSearchResult[]> {
// 1. 生成查询向量
const queryVector = await generateEmbedding(query);
if (!queryVector) {
console.warn("[Embedding] 查询向量生成失败,降级到关键词检索");
return [];
}
// 2. 加载 embedding 缓存
const cache = db ? await loadEmbeddingCache(db) : new Map<string, EmbeddingVector>();
// 3. 计算每条规则的语义相似度
const results: SemanticSearchResult[] = [];
const toCompute: Array<{ rule: Record<string, unknown>; text: string; ruleId: string }> = [];
for (const rule of rules) {
const ruleId = String(rule.ruleId || rule._id || "");
const ruleText = buildRuleEmbeddingText(rule);
const textHash = simpleHash(ruleText);
const cached = cache.get(ruleId);
if (cached && cached.textHash === textHash) {
// 使用缓存的向量
const semanticScore = cosineSimilarity(queryVector, cached.vector);
const keywordScore = keywordScores.get(ruleId) || 0;
const combinedScore = semanticScore * 0.7 + keywordScore * 0.3;
results.push({
ruleId,
jurisdiction: String(rule.jurisdiction || ""),
assetClass: String(rule.assetClass || rule.category || ""),
ruleType: String(rule.ruleType || ""),
title: String(rule.title || rule.ruleName || ""),
content: String(rule.content || rule.description || ""),
ownershipRequirements: Array.isArray(rule.ownershipRequirements)
? rule.ownershipRequirements as string[] : undefined,
tradingRequirements: Array.isArray(rule.tradingRequirements)
? rule.tradingRequirements as string[] : undefined,
legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined,
officialSource: rule.officialSource ? String(rule.officialSource) : undefined,
semanticScore,
keywordScore,
combinedScore,
});
} else {
// 需要重新计算
toCompute.push({ rule, text: ruleText, ruleId });
}
}
// 4. 批量计算未缓存规则的向量
if (toCompute.length > 0) {
console.log(`[Embedding] 计算 ${toCompute.length} 条规则的向量...`);
const vectors = await generateEmbeddingsBatch(toCompute.map(r => r.text));
for (let i = 0; i < toCompute.length; i++) {
const { rule, text, ruleId } = toCompute[i];
const vector = vectors[i];
if (vector) {
const textHash = simpleHash(text);
// 异步保存到缓存(不阻塞检索)
if (db) {
saveEmbeddingCache(db, ruleId, vector, "text-embedding-3-small", textHash)
.catch(err => console.warn(`[Embedding] 缓存保存失败: ${err.message}`));
}
const semanticScore = cosineSimilarity(queryVector, vector);
const keywordScore = keywordScores.get(ruleId) || 0;
const combinedScore = semanticScore * 0.7 + keywordScore * 0.3;
results.push({
ruleId,
jurisdiction: String(rule.jurisdiction || ""),
assetClass: String(rule.assetClass || rule.category || ""),
ruleType: String(rule.ruleType || ""),
title: String(rule.title || rule.ruleName || ""),
content: String(rule.content || rule.description || ""),
ownershipRequirements: Array.isArray(rule.ownershipRequirements)
? rule.ownershipRequirements as string[] : undefined,
tradingRequirements: Array.isArray(rule.tradingRequirements)
? rule.tradingRequirements as string[] : undefined,
legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined,
officialSource: rule.officialSource ? String(rule.officialSource) : undefined,
semanticScore,
keywordScore,
combinedScore,
});
}
}
}
// 5. 按综合分数排序,返回 topK
results.sort((a, b) => b.combinedScore - a.combinedScore);
return results.slice(0, topK);
}
// ─── 预计算所有规则 embedding ─────────────────────────────────────
/**
* 批量预计算所有规则的 embedding 并存入 MongoDB
* 建议在爬虫完成后调用,或定时执行
*/
export async function precomputeAllEmbeddings(
rules: Record<string, unknown>[],
db: {
collection: (name: string) => {
find: (q: object) => { toArray: () => Promise<unknown[]> };
updateOne: (filter: object, update: object, options: object) => Promise<unknown>;
}
}
): Promise<{ success: number; failed: number; skipped: number }> {
const config = getEmbeddingConfig();
if (!config) {
console.warn("[Embedding] API 未配置,跳过预计算");
return { success: 0, failed: 0, skipped: rules.length };
}
const cache = await loadEmbeddingCache(db);
let success = 0, failed = 0, skipped = 0;
for (const rule of rules) {
const ruleId = String(rule.ruleId || rule._id || "");
const ruleText = buildRuleEmbeddingText(rule);
const textHash = simpleHash(ruleText);
// 检查是否需要更新
const cached = cache.get(ruleId);
if (cached && cached.textHash === textHash) {
skipped++;
continue;
}
const vector = await generateEmbedding(ruleText);
if (vector) {
await saveEmbeddingCache(db, ruleId, vector, config.model, textHash);
success++;
console.log(`[Embedding] ✅ ${ruleId}`);
} else {
failed++;
console.warn(`[Embedding] ❌ ${ruleId}`);
}
// 速率限制
await new Promise(resolve => setTimeout(resolve, 200));
}
return { success, failed, skipped };
}