结合 Spark 的 Aura 图分析

Open In Colab

此 Jupyter 笔记本托管在 Neo4j 图数据科学客户端 Github 仓库的此处

本笔记本展示了如何使用 graphdatascience Python 库在 Apache Spark 集群中创建、管理和使用 GDS 会话。

我们以自行车租赁图为例,展示如何将数据从 Spark 投影到 GDS 会话、运行算法,并最终将结果返回给 Spark。本笔记本侧重于与 Apache Spark 的交互,不会涵盖使用 GDS 会话的所有可能操作。有关更多详细信息,请参阅其他教程。

1. 前置条件

我们还需要安装 graphdatascience Python 库(版本 1.18 或更高版本)以及 pyspark

%pip install "graphdatascience>=1.18" python-dotenv "pyspark[sql]"
from dotenv import load_dotenv

# This allows to load required secrets from `.env` file in local directory
# This can include Aura API Credentials and Database Credentials.
# If file does not exist this is a noop.
load_dotenv("sessions.env")

1.1. 连接到 Spark 会话

要与 Spark 集群交互,我们需要首先实例化一个 Spark 会话。在此示例中,我们将使用本地 Spark 会话,它将在同一台机器上运行 Spark。使用远程 Spark 集群的方法类似。有关设置 pyspark 的更多信息,请访问 https://spark.apache.ac.cn/docs/latest/api/python/getting_started/

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[4]").appName("GraphAnalytics").getOrCreate()

# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

2. Aura API 凭据

管理 GDS 会话的入口点是 GdsSessions 对象,该对象需要创建 Aura API 凭据

import os

from graphdatascience.session import AuraAPICredentials, GdsSessions

# you can also use AuraAPICredentials.from_env() to load credentials from environment variables
api_credentials = AuraAPICredentials(
    client_id=os.environ["CLIENT_ID"],
    client_secret=os.environ["CLIENT_SECRET"],
    # If your account is a member of several projects, you must also specify the project ID to use
    project_id=os.environ.get("PROJECT_ID", None),
)

sessions = GdsSessions(api_credentials=api_credentials)

3. 创建新会话

通过调用 sessions.get_or_create() 并传入以下参数来创建新会话:

  • 会话名称,允许您通过再次调用 get_or_create 重新连接到现有会话。

  • 会话内存大小。

  • 云区域位置。

  • 生存时间 (TTL),确保会话在设定的时间内未使用后自动删除,以避免产生额外费用。

有关参数的更多详细信息,请参阅 API 参考文档或手册。

from datetime import timedelta

from graphdatascience.session import CloudLocation, SessionMemory

# Create a GDS session!
gds = sessions.get_or_create(
    # we give it a representative name
    session_name="bike_trips",
    memory=SessionMemory.m_2GB,
    ttl=timedelta(minutes=30),
    cloud_location=CloudLocation("gcp", "europe-west1"),
)
# Verify the connectivity. Hints towards TLS or firewall issues if this fails directly after get_or_create
gds.verify_connectivity()

4. 添加数据集

下一步,我们将在 Spark 中设置数据集。在此示例中,我们将使用纽约自行车出行数据集 (https://www.kaggle.com/datasets/gabrielramos87/bike-trips)。自行车出行数据构成了一个图,其中节点代表自行车租赁站,关系代表自行车租赁的起点和终点。

import io
import os
import zipfile

import requests

download_path = "bike_trips_data"
if not os.path.exists(download_path):
    url = "https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips"

    response = requests.get(url)
    response.raise_for_status()

    # Unzip the content
    with zipfile.ZipFile(io.BytesIO(response.content)) as z:
        z.extractall(download_path)

df = spark.read.csv(download_path, header=True, inferSchema=True)
df.createOrReplaceTempView("bike_trips")
df.limit(10).show()

5. 投影图

现在我们的数据集已在 Spark 会话中就绪,是时候将其投影到 GDS 会话中了。

我们首先需要访问 GDSArrowClient。该客户端允许我们直接与会话提供的 Arrow Flight 服务器进行通信。

我们的输入数据类似于三元组,其中每一行代表从源站点到目标站点的一条边。这使我们可以使用 Arrow 服务器的“从三元组导入图”功能,该功能需要遵循以下协议:

  1. 发送操作 v2/graph.project.fromTriplets。这将初始化导入过程,并允许我们指定图名称以及 undirected_relationship_types 等设置。它返回一个作业 ID,我们需要在后续步骤中引用该作业 ID。

  2. 将数据分批发送到 Arrow 服务器。

  3. 发送另一个名为 v2/graph.project.fromTriplets.done 的操作,以告知导入过程不会再有数据发送。这将触发 GDS 会话内的最终图创建。

  4. 等待导入过程达到 DONE 状态。

这里最复杂的一步是在每个 Spark 工作节点 (worker) 上运行实际的数据上传。我们将使用 mapInArrow 函数在每个 Spark 工作节点上运行自定义代码。每个工作节点将接收一定数量的 Arrow 记录批次,我们可以直接将其发送到 GDS 会话的 Arrow 服务器。

用户希望在等待导入作业完成的循环中添加 1 秒的延迟(sleep)。这需要导入 time 模块,并在单元格末尾的 while 循环中添加 time.sleep(1)

graph-analytics-serverless-spark.ipynb

import time

import pandas as pd
import pyarrow
from pyspark.sql import functions

graph_name = "bike_trips"

arrow_client = gds.arrow_client()

# 1. Start the import process
job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)


# Define a function that receives an arrow batch and uploads it to the GDS session
def upload_batch(iterator):
    for batch in iterator:
        arrow_client.upload_triplets(job_id, [batch])
        yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({"batch_rows_imported": [len(batch)]}))


# Select the source target pairs from our source data
source_target_pairs = spark.sql("""
                                SELECT start_station_id AS sourceNode, end_station_id AS targetNode
                                FROM bike_trips
                                """)

# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.
uploaded_batches = source_target_pairs.mapInArrow(upload_batch, "batch_rows_imported long")

# Aggregate the batch sizes to receive the row count.
aggregated_batch_sizes = uploaded_batches.agg(functions.sum("batch_rows_imported").alias("rows_imported"))

# Show the result. This will trigger the computation and thus run the data upload.
aggregated_batch_sizes.show()

# 3. Finish the import process
arrow_client.triplet_load_done(job_id)

# 4. Wait for the import to finish
while not arrow_client.job_status(job_id).succeeded():
    time.sleep(1)

G = gds.v2.graph.get(graph_name)
G

6. 运行算法

我们可以使用标准的 GDS Python 客户端 API 在构建的图上运行算法。有关更多示例,请参阅其他教程。

print("Running PageRank ...")
pr_result = gds.v2.page_rank.mutate(G, mutate_property="pagerank")

7. 将计算结果返回给 Spark

计算完成后,我们可能希望在 Spark 中进一步使用结果。我们可以通过将数据批次流式传输到每个 Spark 工作节点,以类似于投影的方式执行此操作。由于我们需要输入 DataFrame 来触发 Spark 工作节点上的计算,因此检索数据稍微复杂一些。我们使用与集群中工作节点数量相等的数据范围作为驱动表。在工作节点上,我们将忽略输入,而是从 GDS 会话中流式传输计算数据。

# 1. Start the node property export on the GDS session
job_id = arrow_client.get_node_properties(G.name(), ["pagerank"])


# Define a function that receives data from the GDS Session and turns it into data batches
def retrieve_data(ignored):
    stream_data = arrow_client.stream_job(G.name(), job_id)
    batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)
    for b in batches:
        yield b


# Create DataFrame with a single column and one row per worker
input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF("batch_id")
# 2. Stream the data from the GDS Session into the Spark workers
received_batches = input_partitions.mapInArrow(retrieve_data, "nodeId long, pagerank double")
# Optional: Repartition the data to make sure it is distributed equally
result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)

result.toPandas()

8. 清理

现在我们已经完成了分析,可以删除 GDS 会话并停止 Spark 会话。

删除 GDS 会话将释放与其关联的所有资源,并停止产生费用。

gds.delete()
spark.stop()