想要的效果是:意义相近的句子的embedding夹角小,更近似。反之亦然。
给定句子,先用encoder模型(bert)做每一个token的embedding,然后取所有embedding的平均(average pooling)。也可以用max pooling或者用bert自带的CLS token,但工业界普遍用average pooling,效果最好。

训练:传统的bert并不能确保意义相近的句子的embedding夹角小,于是要单独训练。使用同一个transformer结构(embedding → encoder → average pooling → (optional) linear),处理两个训练样本句子,计算它们的余弦相似度,和Label的余弦相似度比较(这里余弦相似度是人标的一个-1到1的值),计算MSE。也可以用三个句子,两个相似,另一个不相似,计算triplet loss,使得相似句子的距离更近,不相似句子的距离更远。具体来说,给定anchor, positive和negative三个句子,计算$max(||anchor - positive|| - ||anchor - negative|| + \epsilon, 0)$。

MSE

triplet loss
传统方法:使用传统的knn,也就是暴力搜索,找top k余弦相似度最小的句子向量返回。缺点是时间O(N*D),N是可能上亿的向量数,D是可能上百的向量维度。

解决办法1:使用速度更快的近似查找。一个最开始的想法是navigable small worlds。受到“人与人之间的距离只有3.5个人”的思想启发,把所有数据库中的向量组织成一张图,每个向量是一个节点,链接最与之最相似的n个向量。在进行查找的时候,随机找m次初始节点,每次比较节点和邻居节点哪个和query向量最近似。跳转到最近似的节点。一直到无法跳转,就找到了局部最近似。重复m次后,返回最相似的k个节点。

解决方法2(grok使用的算法):受到skip linked list结构的启发,把图变成有层级的skip graph,做hierarchial navigable small worlds。每次从最上层随机找m次初始节点,先找本层最近似,然后顺着往下找,一直找到最底层的最近似。重复m次后,返回最相似的k个节点。

skip linked list

skip graph
图数据库的结构:有节点和边。节点有type(eg. Person)和键值对组成的properties(age: …, name: ….)。边有type(eg. Teach),键值对组成的properties,和从一个节点到另一个节点的方向。可以通过查找语句找到所有符合要求的节点、边、和属性。给定任何文件,可以提取自动/手动提取文件里面的信息做成一个图。

