309 lines
12 KiB
TypeScript
309 lines
12 KiB
TypeScript
/**
|
||
* NAC知识引擎 — 语义检索模块 (v14)
|
||
*
|
||
* 实现自然语言查询的语义相似度检索:
|
||
* 1. 调用AI Embedding API将文本向量化
|
||
* 2. 计算查询向量与规则向量的余弦相似度
|
||
* 3. 返回按相关度排序的结果(含匹配理由)
|
||
*
|
||
* 降级策略:AI不可用时自动退回关键词全文检索
|
||
*/
|
||
|
||
import { getMongoDb, COLLECTIONS } from "./mongodb";
|
||
|
||
// ─── 类型定义 ──────────────────────────────────────────────────────
|
||
export interface SemanticSearchResult {
|
||
_id: string;
|
||
jurisdiction: string;
|
||
assetType: string;
|
||
ruleName: string;
|
||
description: string;
|
||
ruleNameI18n?: Record<string, string>;
|
||
descriptionI18n?: Record<string, string>;
|
||
status: string;
|
||
required: boolean;
|
||
tags: string[];
|
||
score: number; // 相似度得分 0-1
|
||
matchReason: string; // 匹配理由(中文说明)
|
||
searchMethod: "semantic" | "fulltext" | "keyword";
|
||
}
|
||
|
||
// ─── AI Embedding 调用 ─────────────────────────────────────────────
|
||
async function getEmbedding(text: string): Promise<number[] | null> {
|
||
const apiUrl = process.env.NAC_AI_API_URL;
|
||
const apiKey = process.env.NAC_AI_API_KEY;
|
||
|
||
if (!apiUrl || !apiKey) return null;
|
||
|
||
// 构建embedding请求URL(兼容OpenAI格式)
|
||
const baseUrl = apiUrl.replace(/\/chat\/completions\/?$/, "").replace(/\/+$/, "");
|
||
const embeddingUrl = `${baseUrl}/embeddings`;
|
||
|
||
try {
|
||
const resp = await fetch(embeddingUrl, {
|
||
method: "POST",
|
||
headers: {
|
||
"Content-Type": "application/json",
|
||
"Authorization": `Bearer ${apiKey}`,
|
||
},
|
||
body: JSON.stringify({
|
||
model: process.env.NAC_AI_EMBEDDING_MODEL || "text-embedding-v1",
|
||
input: text.slice(0, 2000), // 限制长度
|
||
}),
|
||
signal: AbortSignal.timeout(10000),
|
||
});
|
||
|
||
if (!resp.ok) return null;
|
||
const data = await resp.json() as any;
|
||
return data?.data?.[0]?.embedding || null;
|
||
} catch {
|
||
return null;
|
||
}
|
||
}
|
||
|
||
// ─── 余弦相似度计算 ────────────────────────────────────────────────
|
||
function cosineSimilarity(a: number[], b: number[]): number {
|
||
if (a.length !== b.length || a.length === 0) return 0;
|
||
let dot = 0, normA = 0, normB = 0;
|
||
for (let i = 0; i < a.length; i++) {
|
||
dot += a[i] * b[i];
|
||
normA += a[i] * a[i];
|
||
normB += b[i] * b[i];
|
||
}
|
||
const denom = Math.sqrt(normA) * Math.sqrt(normB);
|
||
return denom === 0 ? 0 : dot / denom;
|
||
}
|
||
|
||
// ─── 生成匹配理由 ──────────────────────────────────────────────────
|
||
function generateMatchReason(
|
||
query: string,
|
||
rule: any,
|
||
score: number,
|
||
method: string
|
||
): string {
|
||
const jurisdiction = rule.jurisdiction;
|
||
const assetType = rule.assetType;
|
||
|
||
if (method === "semantic") {
|
||
if (score >= 0.85) return `与查询高度匹配(${jurisdiction}辖区${assetType}类资产,语义相似度${(score * 100).toFixed(0)}%)`;
|
||
if (score >= 0.70) return `与查询较为相关(${jurisdiction}辖区,语义相似度${(score * 100).toFixed(0)}%)`;
|
||
return `与查询有一定关联(${jurisdiction}辖区,相似度${(score * 100).toFixed(0)}%)`;
|
||
}
|
||
if (method === "fulltext") return `关键词全文匹配(${jurisdiction}辖区${assetType}类资产)`;
|
||
return `关键词匹配(${jurisdiction}辖区)`;
|
||
}
|
||
|
||
// ─── 关键词提取(降级用) ──────────────────────────────────────────
|
||
function extractKeywords(query: string): string[] {
|
||
const stopWords = new Set(["的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一", "一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这", "那", "什么", "如何", "哪些", "需要", "可以", "应该", "必须", "the", "a", "an", "is", "are", "in", "of", "for", "to", "and", "or", "with", "that", "this"]);
|
||
const words = query.split(/[\s,。?!、;:""''()【】\[\],.?!;:()\s]+/).filter(w => w.length >= 2 && !stopWords.has(w));
|
||
return Array.from(new Set(words)).slice(0, 6);
|
||
}
|
||
|
||
// ─── 主语义检索函数 ────────────────────────────────────────────────
|
||
export async function semanticSearch(
|
||
query: string,
|
||
options: {
|
||
jurisdiction?: string;
|
||
assetType?: string;
|
||
limit?: number;
|
||
lang?: string;
|
||
minScore?: number;
|
||
} = {}
|
||
): Promise<{
|
||
results: SemanticSearchResult[];
|
||
searchMethod: "semantic" | "fulltext" | "keyword";
|
||
queryEmbeddingGenerated: boolean;
|
||
totalCandidates: number;
|
||
}> {
|
||
const db = await getMongoDb();
|
||
if (!db) return { results: [], searchMethod: "keyword", queryEmbeddingGenerated: false, totalCandidates: 0 };
|
||
|
||
const limit = options.limit || 10;
|
||
const minScore = options.minScore || 0.45;
|
||
const lang = options.lang || "zh";
|
||
|
||
// 构建基础过滤条件
|
||
const baseFilter: Record<string, unknown> = { status: "active" };
|
||
if (options.jurisdiction) baseFilter.jurisdiction = options.jurisdiction;
|
||
if (options.assetType) baseFilter.assetType = options.assetType;
|
||
|
||
// ── 第一步:尝试语义向量检索 ──────────────────────────────────
|
||
const queryEmbedding = await getEmbedding(query);
|
||
|
||
if (queryEmbedding) {
|
||
// 获取所有候选规则(有向量缓存的优先)
|
||
const allRules = await db.collection(COLLECTIONS.COMPLIANCE_RULES)
|
||
.find(baseFilter)
|
||
.toArray() as any[];
|
||
|
||
const scored: Array<{ rule: any; score: number }> = [];
|
||
|
||
for (const rule of allRules) {
|
||
// 构建规则文本(用于向量化比较)
|
||
const ruleText = [
|
||
rule.ruleNameI18n?.[lang] || rule.ruleName,
|
||
rule.descriptionI18n?.[lang] || rule.description,
|
||
rule.jurisdiction,
|
||
rule.assetType,
|
||
...(rule.tags || []),
|
||
].join(" ");
|
||
|
||
// 检查是否有缓存的向量
|
||
let ruleEmbedding: number[] | null = rule._embedding || null;
|
||
|
||
// 若无缓存,实时计算(限制数量避免超时)
|
||
if (!ruleEmbedding && allRules.length <= 50) {
|
||
ruleEmbedding = await getEmbedding(ruleText);
|
||
if (ruleEmbedding) {
|
||
// 异步缓存向量(不阻塞响应)
|
||
db.collection(COLLECTIONS.COMPLIANCE_RULES).updateOne(
|
||
{ _id: rule._id },
|
||
{ $set: { _embedding: ruleEmbedding, _embeddingUpdatedAt: new Date() } }
|
||
).catch(() => {});
|
||
}
|
||
}
|
||
|
||
if (ruleEmbedding) {
|
||
const score = cosineSimilarity(queryEmbedding, ruleEmbedding);
|
||
if (score >= minScore) {
|
||
scored.push({ rule, score });
|
||
}
|
||
}
|
||
}
|
||
|
||
if (scored.length > 0) {
|
||
scored.sort((a, b) => b.score - a.score);
|
||
const results = scored.slice(0, limit).map(({ rule, score }) => ({
|
||
_id: rule._id.toString(),
|
||
jurisdiction: rule.jurisdiction,
|
||
assetType: rule.assetType,
|
||
ruleName: rule.ruleNameI18n?.[lang] || rule.ruleName,
|
||
description: rule.descriptionI18n?.[lang] || rule.description,
|
||
ruleNameI18n: rule.ruleNameI18n,
|
||
descriptionI18n: rule.descriptionI18n,
|
||
status: rule.status,
|
||
required: rule.required,
|
||
tags: rule.tags || [],
|
||
score: Math.round(score * 1000) / 1000,
|
||
matchReason: generateMatchReason(query, rule, score, "semantic"),
|
||
searchMethod: "semantic" as const,
|
||
}));
|
||
return { results, searchMethod: "semantic", queryEmbeddingGenerated: true, totalCandidates: scored.length };
|
||
}
|
||
}
|
||
|
||
// ── 第二步:降级到全文检索 ────────────────────────────────────
|
||
try {
|
||
const ftFilter = { ...baseFilter, $text: { $search: query } };
|
||
const ftResults = await db.collection(COLLECTIONS.COMPLIANCE_RULES)
|
||
.find(ftFilter, { projection: { score: { $meta: "textScore" } } })
|
||
.sort({ score: { $meta: "textScore" } })
|
||
.limit(limit)
|
||
.toArray() as any[];
|
||
|
||
if (ftResults.length > 0) {
|
||
const maxScore = ftResults[0].score || 1;
|
||
const results = ftResults.map((rule: any) => ({
|
||
_id: rule._id.toString(),
|
||
jurisdiction: rule.jurisdiction,
|
||
assetType: rule.assetType,
|
||
ruleName: rule.ruleNameI18n?.[lang] || rule.ruleName,
|
||
description: rule.descriptionI18n?.[lang] || rule.description,
|
||
ruleNameI18n: rule.ruleNameI18n,
|
||
descriptionI18n: rule.descriptionI18n,
|
||
status: rule.status,
|
||
required: rule.required,
|
||
tags: rule.tags || [],
|
||
score: Math.round((rule.score / maxScore) * 0.80 * 1000) / 1000,
|
||
matchReason: generateMatchReason(query, rule, rule.score / maxScore, "fulltext"),
|
||
searchMethod: "fulltext" as const,
|
||
}));
|
||
return { results, searchMethod: "fulltext", queryEmbeddingGenerated: !!queryEmbedding, totalCandidates: ftResults.length };
|
||
}
|
||
} catch {
|
||
// 全文索引不可用,继续降级
|
||
}
|
||
|
||
// ── 第三步:关键词正则匹配 ────────────────────────────────────
|
||
const keywords = extractKeywords(query);
|
||
if (keywords.length === 0) {
|
||
return { results: [], searchMethod: "keyword", queryEmbeddingGenerated: false, totalCandidates: 0 };
|
||
}
|
||
|
||
const regexFilter = {
|
||
...baseFilter,
|
||
$or: keywords.flatMap(kw => {
|
||
const re = new RegExp(kw.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"), "i");
|
||
return [
|
||
{ ruleName: re },
|
||
{ description: re },
|
||
{ [`ruleNameI18n.${lang}`]: re },
|
||
{ [`descriptionI18n.${lang}`]: re },
|
||
{ tags: re },
|
||
];
|
||
}),
|
||
};
|
||
|
||
const kwResults = await db.collection(COLLECTIONS.COMPLIANCE_RULES)
|
||
.find(regexFilter)
|
||
.limit(limit)
|
||
.toArray() as any[];
|
||
|
||
const results = kwResults.map((rule: any, idx: number) => ({
|
||
_id: rule._id.toString(),
|
||
jurisdiction: rule.jurisdiction,
|
||
assetType: rule.assetType,
|
||
ruleName: rule.ruleNameI18n?.[lang] || rule.ruleName,
|
||
description: rule.descriptionI18n?.[lang] || rule.description,
|
||
ruleNameI18n: rule.ruleNameI18n,
|
||
descriptionI18n: rule.descriptionI18n,
|
||
status: rule.status,
|
||
required: rule.required,
|
||
tags: rule.tags || [],
|
||
score: Math.round((0.65 - idx * 0.03) * 1000) / 1000,
|
||
matchReason: generateMatchReason(query, rule, 0.65, "keyword"),
|
||
searchMethod: "keyword" as const,
|
||
}));
|
||
|
||
return { results, searchMethod: "keyword", queryEmbeddingGenerated: false, totalCandidates: kwResults.length };
|
||
}
|
||
|
||
// ─── 预计算并缓存所有规则的向量(后台任务) ──────────────────────
|
||
export async function precomputeEmbeddings(lang = "zh"): Promise<{ processed: number; failed: number }> {
|
||
const db = await getMongoDb();
|
||
if (!db) return { processed: 0, failed: 0 };
|
||
|
||
const rules = await db.collection(COLLECTIONS.COMPLIANCE_RULES)
|
||
.find({ _embedding: { $exists: false } })
|
||
.toArray() as any[];
|
||
|
||
let processed = 0;
|
||
let failed = 0;
|
||
|
||
for (const rule of rules) {
|
||
const text = [
|
||
rule.ruleNameI18n?.[lang] || rule.ruleName,
|
||
rule.descriptionI18n?.[lang] || rule.description,
|
||
rule.jurisdiction,
|
||
rule.assetType,
|
||
].join(" ");
|
||
|
||
const embedding = await getEmbedding(text);
|
||
if (embedding) {
|
||
await db.collection(COLLECTIONS.COMPLIANCE_RULES).updateOne(
|
||
{ _id: rule._id },
|
||
{ $set: { _embedding: embedding, _embeddingUpdatedAt: new Date() } }
|
||
);
|
||
processed++;
|
||
} else {
|
||
failed++;
|
||
}
|
||
|
||
// 避免API限速
|
||
await new Promise(r => setTimeout(r, 200));
|
||
}
|
||
|
||
return { processed, failed };
|
||
}
|