482 lines
16 KiB
TypeScript
482 lines
16 KiB
TypeScript
/**
|
||
* NAC Dense Embedding 检索模块
|
||
*
|
||
* 使用 OpenAI 兼容的 text-embedding API 实现语义向量检索
|
||
* 支持中英文混合查询,余弦相似度排序
|
||
*
|
||
* 架构:
|
||
* 1. 查询向量化:将用户查询转为 embedding 向量
|
||
* 2. 规则向量化:预计算规则 embedding 并缓存到 MongoDB
|
||
* 3. 余弦相似度:计算查询与规则的语义相似度
|
||
* 4. 混合排序:结合关键词分数和语义分数
|
||
*/
|
||
|
||
import https from "https";
|
||
import http from "http";
|
||
|
||
// ─── 类型定义 ─────────────────────────────────────────────────────
|
||
|
||
export interface EmbeddingVector {
|
||
ruleId: string;
|
||
vector: number[];
|
||
model: string;
|
||
createdAt: Date;
|
||
textHash: string; // 用于检测规则内容变化
|
||
}
|
||
|
||
export interface SemanticSearchResult {
|
||
ruleId: string;
|
||
jurisdiction: string;
|
||
assetClass: string;
|
||
ruleType: string;
|
||
title: string;
|
||
content: string;
|
||
ownershipRequirements?: string[];
|
||
tradingRequirements?: string[];
|
||
legalBasis?: string;
|
||
officialSource?: string;
|
||
semanticScore: number; // 语义相似度 0-1
|
||
keywordScore: number; // 关键词匹配分数 0-1
|
||
combinedScore: number; // 综合分数 0-1
|
||
}
|
||
|
||
// ─── Embedding API 配置 ───────────────────────────────────────────
|
||
|
||
const getEmbeddingConfig = () => {
|
||
const apiUrl = process.env.NAC_AI_API_URL || "";
|
||
const apiKey = process.env.NAC_AI_API_KEY || "";
|
||
const model = process.env.NAC_EMBEDDING_MODEL || "text-embedding-3-small";
|
||
|
||
if (!apiUrl || !apiKey) {
|
||
return null; // 降级到 TF-IDF
|
||
}
|
||
|
||
return {
|
||
url: `${apiUrl.replace(/\/$/, "")}/v1/embeddings`,
|
||
key: apiKey,
|
||
model,
|
||
};
|
||
};
|
||
|
||
// ─── HTTP 请求工具 ────────────────────────────────────────────────
|
||
|
||
async function httpPost(url: string, headers: Record<string, string>, body: object): Promise<unknown> {
|
||
return new Promise((resolve, reject) => {
|
||
const bodyStr = JSON.stringify(body);
|
||
const parsedUrl = new URL(url);
|
||
const isHttps = parsedUrl.protocol === "https:";
|
||
const lib = isHttps ? https : http;
|
||
|
||
const options = {
|
||
hostname: parsedUrl.hostname,
|
||
port: parsedUrl.port || (isHttps ? 443 : 80),
|
||
path: parsedUrl.pathname + parsedUrl.search,
|
||
method: "POST",
|
||
headers: {
|
||
"Content-Type": "application/json",
|
||
"Content-Length": Buffer.byteLength(bodyStr),
|
||
...headers,
|
||
},
|
||
timeout: 30000,
|
||
};
|
||
|
||
const req = lib.request(options, (res) => {
|
||
let data = "";
|
||
res.on("data", (chunk) => { data += chunk; });
|
||
res.on("end", () => {
|
||
try {
|
||
resolve(JSON.parse(data));
|
||
} catch {
|
||
reject(new Error(`JSON 解析失败: ${data.slice(0, 200)}`));
|
||
}
|
||
});
|
||
});
|
||
|
||
req.on("error", reject);
|
||
req.on("timeout", () => {
|
||
req.destroy();
|
||
reject(new Error("Embedding API 请求超时"));
|
||
});
|
||
|
||
req.write(bodyStr);
|
||
req.end();
|
||
});
|
||
}
|
||
|
||
// ─── Embedding 生成 ───────────────────────────────────────────────
|
||
|
||
/**
|
||
* 生成文本的 embedding 向量
|
||
* 失败时返回 null(降级到 TF-IDF)
|
||
*/
|
||
export async function generateEmbedding(text: string): Promise<number[] | null> {
|
||
const config = getEmbeddingConfig();
|
||
if (!config) return null;
|
||
|
||
// 截断过长文本(embedding 模型通常限制 8192 tokens)
|
||
const truncatedText = text.slice(0, 4000);
|
||
|
||
try {
|
||
const response = await httpPost(
|
||
config.url,
|
||
{ "Authorization": `Bearer ${config.key}` },
|
||
{
|
||
model: config.model,
|
||
input: truncatedText,
|
||
encoding_format: "float",
|
||
}
|
||
) as { data?: Array<{ embedding: number[] }>; error?: { message: string } };
|
||
|
||
if (response.error) {
|
||
console.warn(`[Embedding] API 错误: ${response.error.message}`);
|
||
return null;
|
||
}
|
||
|
||
if (response.data && response.data[0]?.embedding) {
|
||
return response.data[0].embedding;
|
||
}
|
||
|
||
return null;
|
||
} catch (err) {
|
||
console.warn(`[Embedding] 生成失败: ${(err as Error).message}`);
|
||
return null;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 批量生成 embedding(带速率限制)
|
||
*/
|
||
export async function generateEmbeddingsBatch(
|
||
texts: string[],
|
||
batchSize = 20
|
||
): Promise<Array<number[] | null>> {
|
||
const results: Array<number[] | null> = [];
|
||
|
||
for (let i = 0; i < texts.length; i += batchSize) {
|
||
const batch = texts.slice(i, i + batchSize);
|
||
const batchResults = await Promise.all(
|
||
batch.map(text => generateEmbedding(text))
|
||
);
|
||
results.push(...batchResults);
|
||
|
||
// 速率限制:每批次间隔 500ms
|
||
if (i + batchSize < texts.length) {
|
||
await new Promise(resolve => setTimeout(resolve, 500));
|
||
}
|
||
}
|
||
|
||
return results;
|
||
}
|
||
|
||
// ─── 余弦相似度计算 ───────────────────────────────────────────────
|
||
|
||
/**
|
||
* 计算两个向量的余弦相似度
|
||
*/
|
||
export function cosineSimilarity(vecA: number[], vecB: number[]): number {
|
||
if (vecA.length !== vecB.length || vecA.length === 0) return 0;
|
||
|
||
let dotProduct = 0;
|
||
let normA = 0;
|
||
let normB = 0;
|
||
|
||
for (let i = 0; i < vecA.length; i++) {
|
||
dotProduct += vecA[i] * vecB[i];
|
||
normA += vecA[i] * vecA[i];
|
||
normB += vecB[i] * vecB[i];
|
||
}
|
||
|
||
const denominator = Math.sqrt(normA) * Math.sqrt(normB);
|
||
if (denominator === 0) return 0;
|
||
|
||
// 余弦相似度范围 [-1, 1],归一化到 [0, 1]
|
||
return (dotProduct / denominator + 1) / 2;
|
||
}
|
||
|
||
// ─── 规则文本构建 ─────────────────────────────────────────────────
|
||
|
||
/**
|
||
* 将规则对象转为用于 embedding 的文本
|
||
* 包含所有关键字段,提升语义检索质量
|
||
*/
|
||
export function buildRuleEmbeddingText(rule: Record<string, unknown>): string {
|
||
const parts: string[] = [];
|
||
|
||
// 标题和基本信息
|
||
if (rule.title) parts.push(`标题: ${rule.title}`);
|
||
if (rule.jurisdiction) parts.push(`司法辖区: ${rule.jurisdiction}`);
|
||
if (rule.assetClass) parts.push(`资产类别: ${rule.assetClass}`);
|
||
if (rule.ruleType) parts.push(`规则类型: ${rule.ruleType}`);
|
||
|
||
// 主要内容
|
||
if (rule.content) parts.push(`内容: ${String(rule.content).slice(0, 1000)}`);
|
||
|
||
// 描述(旧格式兼容)
|
||
if (rule.description) parts.push(`描述: ${rule.description}`);
|
||
if (rule.descriptionI18n) {
|
||
const i18n = rule.descriptionI18n as Record<string, string>;
|
||
if (i18n.zh) parts.push(`中文描述: ${i18n.zh}`);
|
||
if (i18n.en) parts.push(`English: ${i18n.en}`);
|
||
}
|
||
|
||
// 所有权要求
|
||
if (Array.isArray(rule.ownershipRequirements) && rule.ownershipRequirements.length > 0) {
|
||
parts.push(`所有权要求: ${(rule.ownershipRequirements as string[]).join("; ")}`);
|
||
}
|
||
|
||
// 交易要求
|
||
if (Array.isArray(rule.tradingRequirements) && rule.tradingRequirements.length > 0) {
|
||
parts.push(`交易规则: ${(rule.tradingRequirements as string[]).join("; ")}`);
|
||
}
|
||
|
||
// 法律依据
|
||
if (rule.legalBasis) parts.push(`法律依据: ${rule.legalBasis}`);
|
||
|
||
// 标签
|
||
if (Array.isArray(rule.tags) && rule.tags.length > 0) {
|
||
parts.push(`标签: ${(rule.tags as string[]).join(", ")}`);
|
||
}
|
||
|
||
return parts.join("\n");
|
||
}
|
||
|
||
// ─── 简单哈希(用于检测内容变化)────────────────────────────────
|
||
|
||
function simpleHash(text: string): string {
|
||
let hash = 0;
|
||
for (let i = 0; i < text.length; i++) {
|
||
const char = text.charCodeAt(i);
|
||
hash = ((hash << 5) - hash) + char;
|
||
hash = hash & hash; // 转为 32 位整数
|
||
}
|
||
return Math.abs(hash).toString(16);
|
||
}
|
||
|
||
// ─── MongoDB 向量缓存 ─────────────────────────────────────────────
|
||
|
||
/**
|
||
* 从 MongoDB 加载所有规则的 embedding 缓存
|
||
*/
|
||
export async function loadEmbeddingCache(
|
||
db: { collection: (name: string) => { find: (q: object) => { toArray: () => Promise<unknown[]> } } }
|
||
): Promise<Map<string, EmbeddingVector>> {
|
||
const cache = new Map<string, EmbeddingVector>();
|
||
|
||
try {
|
||
const vectors = await db.collection("rule_embeddings").find({}).toArray() as EmbeddingVector[];
|
||
for (const v of vectors) {
|
||
cache.set(v.ruleId, v);
|
||
}
|
||
console.log(`[Embedding] 已加载 ${cache.size} 条 embedding 缓存`);
|
||
} catch (err) {
|
||
console.warn(`[Embedding] 加载缓存失败: ${(err as Error).message}`);
|
||
}
|
||
|
||
return cache;
|
||
}
|
||
|
||
/**
|
||
* 保存 embedding 到 MongoDB 缓存
|
||
*/
|
||
export async function saveEmbeddingCache(
|
||
db: { collection: (name: string) => { updateOne: (filter: object, update: object, options: object) => Promise<unknown> } },
|
||
ruleId: string,
|
||
vector: number[],
|
||
model: string,
|
||
textHash: string
|
||
): Promise<void> {
|
||
try {
|
||
await db.collection("rule_embeddings").updateOne(
|
||
{ ruleId },
|
||
{
|
||
$set: {
|
||
ruleId,
|
||
vector,
|
||
model,
|
||
textHash,
|
||
createdAt: new Date(),
|
||
}
|
||
},
|
||
{ upsert: true }
|
||
);
|
||
} catch (err) {
|
||
console.warn(`[Embedding] 保存缓存失败: ${(err as Error).message}`);
|
||
}
|
||
}
|
||
|
||
// ─── 主检索函数 ───────────────────────────────────────────────────
|
||
|
||
/**
|
||
* Dense Embedding 语义检索
|
||
*
|
||
* @param query 用户查询文本
|
||
* @param rules 候选规则列表
|
||
* @param db MongoDB 实例(用于 embedding 缓存)
|
||
* @param topK 返回最相关的 K 条规则
|
||
* @param keywordScores 关键词预匹配分数(可选,用于混合排序)
|
||
*/
|
||
export async function semanticSearch(
|
||
query: string,
|
||
rules: Record<string, unknown>[],
|
||
db: {
|
||
collection: (name: string) => {
|
||
find: (q: object) => { toArray: () => Promise<unknown[]> };
|
||
updateOne: (filter: object, update: object, options: object) => Promise<unknown>;
|
||
}
|
||
} | null,
|
||
topK = 5,
|
||
keywordScores: Map<string, number> = new Map()
|
||
): Promise<SemanticSearchResult[]> {
|
||
|
||
// 1. 生成查询向量
|
||
const queryVector = await generateEmbedding(query);
|
||
if (!queryVector) {
|
||
console.warn("[Embedding] 查询向量生成失败,降级到关键词检索");
|
||
return [];
|
||
}
|
||
|
||
// 2. 加载 embedding 缓存
|
||
const cache = db ? await loadEmbeddingCache(db) : new Map<string, EmbeddingVector>();
|
||
|
||
// 3. 计算每条规则的语义相似度
|
||
const results: SemanticSearchResult[] = [];
|
||
const toCompute: Array<{ rule: Record<string, unknown>; text: string; ruleId: string }> = [];
|
||
|
||
for (const rule of rules) {
|
||
const ruleId = String(rule.ruleId || rule._id || "");
|
||
const ruleText = buildRuleEmbeddingText(rule);
|
||
const textHash = simpleHash(ruleText);
|
||
|
||
const cached = cache.get(ruleId);
|
||
|
||
if (cached && cached.textHash === textHash) {
|
||
// 使用缓存的向量
|
||
const semanticScore = cosineSimilarity(queryVector, cached.vector);
|
||
const keywordScore = keywordScores.get(ruleId) || 0;
|
||
const combinedScore = semanticScore * 0.7 + keywordScore * 0.3;
|
||
|
||
results.push({
|
||
ruleId,
|
||
jurisdiction: String(rule.jurisdiction || ""),
|
||
assetClass: String(rule.assetClass || rule.category || ""),
|
||
ruleType: String(rule.ruleType || ""),
|
||
title: String(rule.title || rule.ruleName || ""),
|
||
content: String(rule.content || rule.description || ""),
|
||
ownershipRequirements: Array.isArray(rule.ownershipRequirements)
|
||
? rule.ownershipRequirements as string[] : undefined,
|
||
tradingRequirements: Array.isArray(rule.tradingRequirements)
|
||
? rule.tradingRequirements as string[] : undefined,
|
||
legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined,
|
||
officialSource: rule.officialSource ? String(rule.officialSource) : undefined,
|
||
semanticScore,
|
||
keywordScore,
|
||
combinedScore,
|
||
});
|
||
} else {
|
||
// 需要重新计算
|
||
toCompute.push({ rule, text: ruleText, ruleId });
|
||
}
|
||
}
|
||
|
||
// 4. 批量计算未缓存规则的向量
|
||
if (toCompute.length > 0) {
|
||
console.log(`[Embedding] 计算 ${toCompute.length} 条规则的向量...`);
|
||
const vectors = await generateEmbeddingsBatch(toCompute.map(r => r.text));
|
||
|
||
for (let i = 0; i < toCompute.length; i++) {
|
||
const { rule, text, ruleId } = toCompute[i];
|
||
const vector = vectors[i];
|
||
|
||
if (vector) {
|
||
const textHash = simpleHash(text);
|
||
|
||
// 异步保存到缓存(不阻塞检索)
|
||
if (db) {
|
||
saveEmbeddingCache(db, ruleId, vector, "text-embedding-3-small", textHash)
|
||
.catch(err => console.warn(`[Embedding] 缓存保存失败: ${err.message}`));
|
||
}
|
||
|
||
const semanticScore = cosineSimilarity(queryVector, vector);
|
||
const keywordScore = keywordScores.get(ruleId) || 0;
|
||
const combinedScore = semanticScore * 0.7 + keywordScore * 0.3;
|
||
|
||
results.push({
|
||
ruleId,
|
||
jurisdiction: String(rule.jurisdiction || ""),
|
||
assetClass: String(rule.assetClass || rule.category || ""),
|
||
ruleType: String(rule.ruleType || ""),
|
||
title: String(rule.title || rule.ruleName || ""),
|
||
content: String(rule.content || rule.description || ""),
|
||
ownershipRequirements: Array.isArray(rule.ownershipRequirements)
|
||
? rule.ownershipRequirements as string[] : undefined,
|
||
tradingRequirements: Array.isArray(rule.tradingRequirements)
|
||
? rule.tradingRequirements as string[] : undefined,
|
||
legalBasis: rule.legalBasis ? String(rule.legalBasis) : undefined,
|
||
officialSource: rule.officialSource ? String(rule.officialSource) : undefined,
|
||
semanticScore,
|
||
keywordScore,
|
||
combinedScore,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
// 5. 按综合分数排序,返回 topK
|
||
results.sort((a, b) => b.combinedScore - a.combinedScore);
|
||
return results.slice(0, topK);
|
||
}
|
||
|
||
// ─── 预计算所有规则 embedding ─────────────────────────────────────
|
||
|
||
/**
|
||
* 批量预计算所有规则的 embedding 并存入 MongoDB
|
||
* 建议在爬虫完成后调用,或定时执行
|
||
*/
|
||
export async function precomputeAllEmbeddings(
|
||
rules: Record<string, unknown>[],
|
||
db: {
|
||
collection: (name: string) => {
|
||
find: (q: object) => { toArray: () => Promise<unknown[]> };
|
||
updateOne: (filter: object, update: object, options: object) => Promise<unknown>;
|
||
}
|
||
}
|
||
): Promise<{ success: number; failed: number; skipped: number }> {
|
||
|
||
const config = getEmbeddingConfig();
|
||
if (!config) {
|
||
console.warn("[Embedding] API 未配置,跳过预计算");
|
||
return { success: 0, failed: 0, skipped: rules.length };
|
||
}
|
||
|
||
const cache = await loadEmbeddingCache(db);
|
||
let success = 0, failed = 0, skipped = 0;
|
||
|
||
for (const rule of rules) {
|
||
const ruleId = String(rule.ruleId || rule._id || "");
|
||
const ruleText = buildRuleEmbeddingText(rule);
|
||
const textHash = simpleHash(ruleText);
|
||
|
||
// 检查是否需要更新
|
||
const cached = cache.get(ruleId);
|
||
if (cached && cached.textHash === textHash) {
|
||
skipped++;
|
||
continue;
|
||
}
|
||
|
||
const vector = await generateEmbedding(ruleText);
|
||
if (vector) {
|
||
await saveEmbeddingCache(db, ruleId, vector, config.model, textHash);
|
||
success++;
|
||
console.log(`[Embedding] ✅ ${ruleId}`);
|
||
} else {
|
||
failed++;
|
||
console.warn(`[Embedding] ❌ ${ruleId}`);
|
||
}
|
||
|
||
// 速率限制
|
||
await new Promise(resolve => setTimeout(resolve, 200));
|
||
}
|
||
|
||
return { success, failed, skipped };
|
||
}
|