使用开源库创建嵌入

Python 库 SentenceTransformers 提供预训练模型,用于为文本和图像生成嵌入,并且让您无需 OpenAI 或其他专有服务的账户即可玩转嵌入。

本页假设您已经 导入了推荐数据集 并且 设置了环境,并展示如何基于标题和情节为 Movie 节点生成并存储嵌入。

嵌入始终在 Neo4j 外部 生成,但 存储 在 Neo4j 数据库中。

设置环境

作为最后一步设置,安装 sentence-transformers 包。

pip install sentence-transformers

为电影创建嵌入

下面的示例从数据库中获取所有 Movie 节点,为标题和情节生成嵌入,并将其作为额外的 embedding 属性添加到每个节点。

from sentence_transformers import SentenceTransformer
import neo4j


URI = '<database-uri>'
AUTH = ('<username>', '<password>')
DB_NAME = '<database-name>'  # examples: 'recommendations-5.26', 'neo4j'


def main():
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:  (1)
        driver.verify_connectivity()

        model = SentenceTransformer('all-MiniLM-L6-v2')  # vector size 384  (2)

        batch_size = 100
        batch_n = 1
        movies_with_embeddings = []
        with driver.session(database=DB_NAME) as session:
            # Fetch `Movie` nodes
            result = session.run('MATCH (m:Movie) RETURN m.plot AS plot, m.title AS title')
            for record in result:
                title = record.get('title')
                plot = record.get('plot')

                # Create embedding for title and plot
                if title is not None and plot is not None:
                    movies_with_embeddings.append({
                        'title': title,
                        'plot': plot,
                        'embedding': model.encode(f'''  (3)
                            Title: {title}\n
                            Plot: {plot}
                        '''),
                    })

                # Import when a batch of movies has embeddings ready; flush buffer
                if len(movies_with_embeddings) == batch_size:  (4)
                    import_batch(driver, movies_with_embeddings, batch_n)
                    movies_with_embeddings = []
                    batch_n += 1

            # Flush last batch
            import_batch(driver, movies_with_embeddings, batch_n)

        # Import complete, show counters
        records, _, _ = driver.execute_query('''
        MATCH (m:Movie WHERE m.embedding IS NOT NULL)
        RETURN count(*) AS countMoviesWithEmbeddings, size(m.embedding) AS embeddingSize
        ''', database_=DB_NAME)
        print(f"""
    Embeddings generated and attached to nodes.
    Movie nodes with embeddings: {records[0].get('countMoviesWithEmbeddings')}.
    Embedding size: {records[0].get('embeddingSize')}.
        """)


def import_batch(driver, nodes_with_embeddings, batch_n):
    # Add embeddings to Movie nodes
    driver.execute_query('''  (5)
    UNWIND $movies as movie
    MATCH (m:Movie {title: movie.title, plot: movie.plot})
    CALL db.create.setNodeVectorProperty(m, 'embedding', movie.embedding)
    ''', movies=nodes_with_embeddings, database_=DB_NAME)
    print(f'Processed batch {batch_n}.')


if __name__ == '__main__':
    main()

'''
Movie nodes with embeddings: 9083.
Embedding size: 384.
'''
1 driver 对象是与 Neo4j 实例交互的接口。欲了解更多信息,请参见 使用 Neo4j 与 Python 构建应用程序
2 模型 all-MiniLM-L6-V2 将文本映射为大小为 384 的向量(即 384 个数字的列表)。 您应始终使用相同的模型为数据集生成嵌入:选择一个并在整个项目中坚持使用它。
3 .encode() 方法为给定的字符串生成嵌入(在本例中为标题和情节一起)。
4 在将整批数据提交到数据库之前,会先收集一定数量的嵌入。这样可以避免将整个数据集一次性加载到内存中,防止超时(尤其对较大数据集更为重要)。
5 导入查询在每个节点 m 上设置一个新的 embedding 属性,其值为嵌入向量 movie.embedding。Cypher 过程 db.create.setNodeVectorProperty 能更高效地存储向量属性,而非将其存为列表。要在关系上设置向量属性,请使用 db.create.setRelationshipVectorProperty

在 Enterprise Edition 中,您可以不必调用 db.create.setNodeVectorProperty,而是将嵌入作为驱动的 Vector 类型传递,并通过 Cypher 子句 SET 将其设为属性。

from sentence_transformers import SentenceTransformer
import neo4j
from neo4j.vector import Vector


URI = '<database-uri>'
AUTH = ('<username>', '<password>')
DB_NAME = '<database-name>'  # examples: 'recommendations-5.26', 'neo4j'


def main():
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
        driver.verify_connectivity()

        model = SentenceTransformer('all-MiniLM-L6-v2')  # vector size 384

        batch_size = 100
        batch_n = 1
        movies_with_embeddings = []
        with driver.session(database=DB_NAME) as session:
            # Fetch `Movie` nodes
            result = session.run('MATCH (m:Movie) RETURN m.plot AS plot, m.title AS title')
            for record in result:
                title = record.get('title')
                plot = record.get('plot')

                # Create embedding for title and plot
                if title is not None and plot is not None:
                    movies_with_embeddings.append({
                        'title': title,
                        'plot': plot,
                        'embedding': Vector(model.encode(f'''
                            Title: {title}\n
                            Plot: {plot}
                        ''')),
                    })

                # Import when a batch of movies has embeddings ready; flush buffer
                if len(movies_with_embeddings) == batch_size:
                    import_batch(driver, movies_with_embeddings, batch_n)
                    movies_with_embeddings = []
                    batch_n += 1

            # Flush last batch
            import_batch(driver, movies_with_embeddings, batch_n)

        # Import complete, show counters
        records, _, _ = driver.execute_query('''
        MATCH (m:Movie WHERE m.embedding IS NOT NULL)
        RETURN count(*) AS countMoviesWithEmbeddings, size(m.embedding) AS embeddingSize
        ''', database_=DB_NAME)
        print(f"""
    Embeddings generated and attached to nodes.
    Movie nodes with embeddings: {records[0].get('countMoviesWithEmbeddings')}.
    Embedding size: {records[0].get('embeddingSize')}.
        """)


def import_batch(driver, nodes_with_embeddings, batch_n):
    # Add embeddings to Movie nodes
    driver.execute_query('''
    UNWIND $movies as movie
    MATCH (m:Movie {title: movie.title, plot: movie.plot})
    SET m.embedding = movie.embedding
    ''', movies=nodes_with_embeddings, database_=DB_NAME)
    print(f'Processed batch {batch_n}.')


if __name__ == '__main__':
    main()

'''
Movie nodes with embeddings: 9083.
Embedding size: 384.
'''

嵌入进入数据库后,您可以使用它们来 比较一部电影与另一部电影的相似度