/** * NAC Dense Embedding 检索模块 * * 使用 OpenAI 兼容的 text-embedding API 实现语义向量检索 * 支持中英文混合查询,余弦相似度排序 * * 架构: * 1. 查询向量化:将用户查询转为 embedding 向量 * 2. 规则向量化:预计算规则 embedding 并缓存到 MongoDB * 3. 余弦相似度:计算查询与规则的语义相似度 * 4. 混合排序:结合关键词分数和语义分数 */ import https from "https"; import http from "http"; // ─── 类型定义 ───────────────────────────────────────────────────── export interface EmbeddingVector { ruleId: string; vector: number[]; model: string; createdAt: Date; textHash: string; // 用于检测规则内容变化 } export interface SemanticSearchResult { ruleId: string; jurisdiction: string; assetClass: string; ruleType: string; title: string; content: string; ownershipRequirements?: string[]; tradingRequirements?: string[]; legalBasis?: string; officialSource?: string; semanticScore: number; // 语义相似度 0-1 keywordScore: number; // 关键词匹配分数 0-1 combinedScore: number; // 综合分数 0-1 } // ─── Embedding API 配置 ─────────────────────────────────────────── const getEmbeddingConfig = () => { const apiUrl = process.env.NAC_AI_API_URL || ""; const apiKey = process.env.NAC_AI_API_KEY || ""; const model = process.env.NAC_EMBEDDING_MODEL || "text-embedding-3-small"; if (!apiUrl || !apiKey) { return null; // 降级到 TF-IDF } return { url: `${apiUrl.replace(/\/$/, "")}/v1/embeddings`, key: apiKey, model, }; }; // ─── HTTP 请求工具 ──────────────────────────────────────────────── async function httpPost(url: string, headers: Record, body: object): Promise { return new Promise((resolve, reject) => { const bodyStr = JSON.stringify(body); const parsedUrl = new URL(url); const isHttps = parsedUrl.protocol === "https:"; const lib = isHttps ? https : http; const options = { hostname: parsedUrl.hostname, port: parsedUrl.port || (isHttps ? 443 : 80), path: parsedUrl.pathname + parsedUrl.search, method: "POST", headers: { "Content-Type": "application/json", "Content-Length": Buffer.byteLength(bodyStr), ...headers, }, timeout: 30000, }; const req = lib.request(options, (res) => { let data = ""; res.on("data", (chunk) => { data += chunk; }); res.on("end", () => { try { resolve(JSON.parse(data)); } catch { reject(new Error(`JSON 解析失败: ${data.slice(0, 200)}`)); } }); }); req.on("error", reject); req.on("timeout", () => { req.destroy(); reject(new Error("Embedding API 请求超时")); }); req.write(bodyStr); req.end(); }); } // ─── Embedding 生成 ─────────────────────────────────────────────── /** * 生成文本的 embedding 向量 * 失败时返回 null(降级到 TF-IDF) */ export async function generateEmbedding(text: string): Promise { const config = getEmbeddingConfig(); if (!config) return null; // 截断过长文本(embedding 模型通常限制 8192 tokens) const truncatedText = text.slice(0, 4000); try { const response = await httpPost( config.url, { "Authorization": `Bearer ${config.key}` }, { model: config.model, input: truncatedText, encoding_format: "float", } ) as { data?: Array<{ embedding: number[] }>; error?: { message: string } }; if (response.error) { console.warn(`[Embedding] API 错误: ${response.error.message}`); return null; } if (response.data && response.data[0]?.embedding) { return response.data[0].embedding; } return null; } catch (err) { console.warn(`[Embedding] 生成失败: ${(err as Error).message}`); return null; } } /** * 批量生成 embedding(带速率限制) */ export async function generateEmbeddingsBatch( texts: string[], batchSize = 20 ): Promise> { const results: Array = []; for (let i = 0; i < texts.length; i += batchSize) { const batch = texts.slice(i, i + batchSize); const batchResults = await Promise.all( batch.map(text => generateEmbedding(text)) ); results.push(...batchResults); // 速率限制:每批次间隔 500ms if (i + batchSize < texts.length) { await new Promise(resolve => setTimeout(resolve, 500)); } } return results; } // ─── 余弦相似度计算 ─────────────────────────────────────────────── /** * 计算两个向量的余弦相似度 */ export function cosineSimilarity(vecA: number[], vecB: number[]): number { if (vecA.length !== vecB.length || vecA.length === 0) return 0; let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < vecA.length; i++) { dotProduct += vecA[i] * vecB[i]; normA += vecA[i] * vecA[i]; normB += vecB[i] * vecB[i]; } const denominator = Math.sqrt(normA) * Math.sqrt(normB); if (denominator === 0) return 0; // 余弦相似度范围 [-1, 1],归一化到 [0, 1] return (dotProduct / denominator + 1) / 2; } // ─── 规则文本构建 ───────────────────────────────────────────────── /** * 将规则对象转为用于 embedding 的文本 * 包含所有关键字段,提升语义检索质量 */ export function buildRuleEmbeddingText(rule: Record): string { const parts: string[] = []; // 标题和基本信息 if (rule.title) parts.push(`标题: ${rule.title}`); if (rule.jurisdiction) parts.push(`司法辖区: ${rule.jurisdiction}`); if (rule.assetClass) parts.push(`资产类别: ${rule.assetClass}`); if (rule.ruleType) parts.push(`规则类型: ${rule.ruleType}`); // 主要内容 if (rule.content) parts.push(`内容: ${String(rule.content).slice(0, 1000)}`); // 描述(旧格式兼容) if (rule.description) parts.push(`描述: ${rule.description}`); if (rule.descriptionI18n) { const i18n = rule.descriptionI18n as Record; if (i18n.zh) parts.push(`中文描述: ${i18n.zh}`); if (i18n.en) parts.push(`English: ${i18n.en}`); } // 所有权要求 if (Array.isArray(rule.ownershipRequirements) && rule.ownershipRequirements.length > 0) { parts.push(`所有权要求: ${(rule.ownershipRequirements as string[]).join("; ")}`); } // 交易要求 if (Array.isArray(rule.tradingRequirements) && rule.tradingRequirements.length > 0) { parts.push(`交易规则: ${(rule.tradingRequirements as string[]).join("; ")}`); } // 法律依据 if (rule.legalBasis) parts.push(`法律依据: ${rule.legalBasis}`); // 标签 if (Array.isArray(rule.tags) && rule.tags.length > 0) { parts.push(`标签: ${(rule.tags as string[]).join(", ")}`); } return parts.join("\n"); } // ─── 简单哈希(用于检测内容变化)──────────────────────────────── function simpleHash(text: string): string { let hash = 0; for (let i = 0; i < text.length; i++) { const char = text.charCodeAt(i); hash = ((hash << 5) - hash) + char; hash = hash & hash; // 转为 32 位整数 } return Math.abs(hash).toString(16); } // ─── MongoDB 向量缓存 ───────────────────────────────────────────── /** * 从 MongoDB 加载所有规则的 embedding 缓存 */ export async function loadEmbeddingCache( db: { collection: (name: string) => { find: (q: object) => { toArray: () => Promise } } } ): Promise> { const cache = new Map(); try { const vectors = await db.collection("rule_embeddings").find({}).toArray() as EmbeddingVector[]; for (const v of vectors) { cache.set(v.ruleId, v); } console.log(`[Embedding] 已加载 ${cache.size} 条 embedding 缓存`); } catch (err) { console.warn(`[Embedding] 加载缓存失败: ${(err as Error).message}`); } return cache; } /** * 保存 embedding 到 MongoDB 缓存 */ export async function saveEmbeddingCache( db: { collection: (name: string) => { updateOne: (filter: object, update: object, options: object) => Promise } }, ruleId: string, vector: number[], model: string, textHash: string ): Promise { try { await db.collection("rule_embeddings").updateOne( { ruleId }, { $set: { ruleId, vector, model, textHash, createdAt: new Date(), } }, { upsert: true } ); } catch (err) { console.warn(`[Embedding] 保存缓存失败: ${(err as Error).message}`); } } // ─── 主检索函数 ─────────────────────────────────────────────────── /** * Dense Embedding 语义检索 * * @param query 用户查询文本 * @param rules 候选规则列表 * @param db MongoDB 实例(用于 embedding 缓存) * @param topK 返回最相关的 K 条规则 * @param keywordScores 关键词预匹配分数(可选,用于混合排序) */ export async function semanticSearch( query: string, rules: Record[], db: { collection: (name: string) => { find: (q: object) => { toArray: () => Promise }; updateOne: (filter: object, update: object, options: object) => Promise; } } | null, topK = 5, keywordScores: Map = new Map() ): Promise { // 1. 生成查询向量 const queryVector = await generateEmbedding(query); if (!queryVector) { console.warn("[Embedding] 查询向量生成失败,降级到关键词检索"); return []; } // 2. 加载 embedding 缓存 const cache = db ? await loadEmbeddingCache(db) : new Map(); // 3. 计算每条规则的语义相似度 const results: SemanticSearchResult[] = []; const toCompute: Array<{ rule: Record; text: string; ruleId: string }> = []; for (const rule of rules) { const ruleId = String(rule.ruleId || rule._id || ""); const ruleText = buildRuleEmbeddingText(rule); const textHash = simpleHash(ruleText); const cached = cache.get(ruleId); if (cached && cached.textHash === textHash) { // 使用缓存的向量 const semanticScore = cosineSimilarity(queryVector, cached.vector); const keywordScore = keywordScores.get(ruleId) || 0; const combinedScore = semanticScore * 0.7 + keywordScore * 0.3; results.push({ ruleId, jurisdiction: String(rule.jurisdiction || ""), assetClass: String(rule.assetClass || rule.category || ""), ruleType: String(rule.ruleType || ""), title: String(rule.title || rule.ruleName || ""), content: String(rule.content || rule.description || ""), ownershipRequirements: Array.isArray(rule.ownershipRequirements) ? rule.ownershipRequirements as string[] : undefined, tradingRequirements: Array.isArray(rule.tradingRequirements) ? rule.tradingRequirements as string[] : undefined, legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined, officialSource: rule.officialSource ? String(rule.officialSource) : undefined, semanticScore, keywordScore, combinedScore, }); } else { // 需要重新计算 toCompute.push({ rule, text: ruleText, ruleId }); } } // 4. 批量计算未缓存规则的向量 if (toCompute.length > 0) { console.log(`[Embedding] 计算 ${toCompute.length} 条规则的向量...`); const vectors = await generateEmbeddingsBatch(toCompute.map(r => r.text)); for (let i = 0; i < toCompute.length; i++) { const { rule, text, ruleId } = toCompute[i]; const vector = vectors[i]; if (vector) { const textHash = simpleHash(text); // 异步保存到缓存(不阻塞检索) if (db) { saveEmbeddingCache(db, ruleId, vector, "text-embedding-3-small", textHash) .catch(err => console.warn(`[Embedding] 缓存保存失败: ${err.message}`)); } const semanticScore = cosineSimilarity(queryVector, vector); const keywordScore = keywordScores.get(ruleId) || 0; const combinedScore = semanticScore * 0.7 + keywordScore * 0.3; results.push({ ruleId, jurisdiction: String(rule.jurisdiction || ""), assetClass: String(rule.assetClass || rule.category || ""), ruleType: String(rule.ruleType || ""), title: String(rule.title || rule.ruleName || ""), content: String(rule.content || rule.description || ""), ownershipRequirements: Array.isArray(rule.ownershipRequirements) ? rule.ownershipRequirements as string[] : undefined, tradingRequirements: Array.isArray(rule.tradingRequirements) ? rule.tradingRequirements as string[] : undefined, legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined, officialSource: rule.officialSource ? String(rule.officialSource) : undefined, semanticScore, keywordScore, combinedScore, }); } } } // 5. 按综合分数排序,返回 topK results.sort((a, b) => b.combinedScore - a.combinedScore); return results.slice(0, topK); } // ─── 预计算所有规则 embedding ───────────────────────────────────── /** * 批量预计算所有规则的 embedding 并存入 MongoDB * 建议在爬虫完成后调用,或定时执行 */ export async function precomputeAllEmbeddings( rules: Record[], db: { collection: (name: string) => { find: (q: object) => { toArray: () => Promise }; updateOne: (filter: object, update: object, options: object) => Promise; } } ): Promise<{ success: number; failed: number; skipped: number }> { const config = getEmbeddingConfig(); if (!config) { console.warn("[Embedding] API 未配置,跳过预计算"); return { success: 0, failed: 0, skipped: rules.length }; } const cache = await loadEmbeddingCache(db); let success = 0, failed = 0, skipped = 0; for (const rule of rules) { const ruleId = String(rule.ruleId || rule._id || ""); const ruleText = buildRuleEmbeddingText(rule); const textHash = simpleHash(ruleText); // 检查是否需要更新 const cached = cache.get(ruleId); if (cached && cached.textHash === textHash) { skipped++; continue; } const vector = await generateEmbedding(ruleText); if (vector) { await saveEmbeddingCache(db, ruleId, vector, config.model, textHash); success++; console.log(`[Embedding] ✅ ${ruleId}`); } else { failed++; console.warn(`[Embedding] ❌ ${ruleId}`); } // 速率限制 await new Promise(resolve => setTimeout(resolve, 200)); } return { success, failed, skipped }; }