在数据量为1mb时,ray.get()速度不如ZeroMQ的问题,以及ray.get()在1.13.0版本下优于2.3.0

【Ray版本和类库】ray-1.13.0, zmq-4.3.4, python-3.8.10,
【操作系统和内核版本】Ubuntu 20.04.3 LTS, Linux version 5.10.0-20-amd64
【问题复现】ray在不同节点之间传输不同大小的数据的速度对比

ray.get()的测试代码:

import time
from collections import defaultdict
import ray 
from multiprocessing import Event, Process
import numpy as np

class EventProcess(Process):
    def __init__(self, begin_event: Event, stop_event: Event):
        super().__init__()
        self._begin_event = begin_event
        self._stop_event = stop_event
 
generate_large_uint_object = lambda size: np.random.randint(low=255, size=size * 1024 * 1024, dtype=np.uint8)


class TestActor(object):
    _name = None
 
    def __init__(self, id):
        self.id = id
 
    def name(self):
        return "{}_{}".format(self._name, self.id)

if __name__ == "__main__":
    test_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
    test_round = 100
 
    ray.init(address="auto")
 
 
    @ray.remote(num_cpus=1, resources={"RemoteWorker": 1})
    class RemoteWorker(TestActor):
        def __init__(self, id=0):
            super().__init__(id)
            self._name = "Worker"
            self._weight = None
            self._cur_round = 0
            self._size_map_time = defaultdict(lambda: list())
 
        def ready(self):
            pass
 
        def get_weight(self, weight, send_time, size):
            self._weight = weight
            self._cur_round += 1
            recv_time = time.time()
            time_consume = recv_time - send_time
            self._size_map_time[size].append(time_consume)
 
        def get_final_stat(self):
            for size, time_consume_list in self._size_map_time.items():
                if len(time_consume_list) > 0:
                    avg_time_consume = sum(time_consume_list) / len(time_consume_list)
                    print("{}Mb broatcast to remote worker cost:{} seconds".format(
                        size, avg_time_consume))
 
 
    @ray.remote(num_cpus=1, resources={"Router": 1})
    class Router(TestActor):
        def __init__(self, id=0):
            super().__init__(id)
            self._name = "Router"
            self.remote_workers = [RemoteWorker.remote(i) for i in range(1)]
            # wait for worker ready
            ray.get([worker.ready.remote() for worker in self.remote_workers])
 
        def ready(self):
            pass
 
        def send_to_remote_worker(self, size=1):
            for i in range(test_round):
                weight = generate_large_uint_object(size)
                send_time = time.time()
                ray.get([worker.get_weight.remote(weight, send_time, size) for worker in self.remote_workers])
 
        def get_final_stat(self):
            ray.get([worker.get_final_stat.remote() for worker in self.remote_workers])
 
 
    router = Router.remote()
    # wait for router ready
    ray.get(router.ready.remote())
 
    for size in test_sizes:
        ray.get(router.send_to_remote_worker.remote(size))
        time.sleep(1)
 
    ray.get(router.get_final_stat.remote())
 
    time.sleep(5)
    ray.shutdown()

启动脚本

// Head Node Start command
ray start --head --port 6379 --num-cpus=3 --resources '{"Router":1}' --object-store-memory 3221225472
// Worker Node Start command
ray start --address 'ray-master:6379' --num-cpus=2 --resources '{"RemoteWorker":2}' --object-store-memory 2147483648

这个是zmq和ray在不同条件下对比结果。在数据量为1mb时ray.get()在remote模式下的耗时是zmq的remote模式下的两倍。

补充zmq的remote代码
recv:

import pickle
import time
from collections import defaultdict
 
import blosc
import zmq
 
if __name__ == "__main__":
    test_size = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
    test_round = 100
 
    host = "0.0.0.0"
    port = 6005
    addr = "tcp://{}:{}".format(host, port)
    use_blosc = False
 
    size_map_time = defaultdict(lambda: list())
    ctx = zmq.Context()
    socket = ctx.socket(zmq.SUB)
    socket.bind(addr)
    socket.setsockopt(zmq.SUBSCRIBE, b'test')
 
    for size in test_size:
        for _ in range(test_round):
            _, data = socket.recv_multipart()
            if use_blosc:
                data = blosc.decompress(data)
            weight, send_time, size = pickle.loads(data)
            recv_time = time.time()
            time_consume = recv_time - send_time
            size_map_time[size].append(time_consume)
 
    for size, total_time_consume in size_map_time.items():
        avg_time_consume = sum(total_time_consume) / len(total_time_consume)
        print("{}Mb Test Round:{} Avg Time Cousume:{}".format(size, test_round, avg_time_consume))
 
    socket.close()
    ctx.term()

send:

import pickle
import time
 
import blosc
import zmq
import numpy as np
generate_large_uint_object = lambda size: np.random.randint(low=255, size=size * 1024 * 1024, dtype=np.uint8)
if __name__ == "__main__":
    test_size = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
    test_round = 100
 
    host = "ybq-rlease3test"
    port = 6005
    addr = "tcp://{}:{}".format(host, port)
    use_blosc = False
 
    ctx = zmq.Context()
    socket = ctx.socket(zmq.PUB)
    socket.bind(addr)
 
    # wait for receiver connect
    input("wait for recv connect")
    for size in test_size:
        for _ in range(test_round):
            weight = generate_large_uint_object(size)
            send_time = time.time()
            data = pickle.dumps((weight, send_time, size), pickle.HIGHEST_PROTOCOL)
            if use_blosc:
                data = blosc.compress(data)
            socket.send_multipart([b'test', data])
        time.sleep(5)
 
    time.sleep(10)
    socket.close()
    ctx.term()

之前怀疑是ray的版本太老导致的,又补充了一个ray2.3.0版本的实验,发现除了大于等于64mb的文件传输,其他比ray1.13.0还要慢

很详细的测试:+1:

ray vs zmq的场景,看起来只有1M是慢的?
ray 1.13.0 vs 2.3.0看测试数据差距也不太大,这种不确定好不好分析原因。

@Peter 首先补充一下,ray 的数据传输和正常的 mq 是有不同的,ray 的 worker 间通讯主要是两部分:

  1. 小数据(默认是100kb以内,可以通过环境变量 RAY_max_direct_call_object_size 来调整,单位 bytest):这个是 rpc 直传
  2. 大数据:这个是在发送端把数据 put 到 object store 中,然后把数据凭证(object ref)通过 rpc 发送到对端,对端再 get 出来。

所以在你的测试里面,1、2、4、8 这类比较小的数据不如 mq 是可以理解的,我们的数据传输过程中的控制流会比 mq 的要长很多,但是数据大了之后,raylet 进程里面的并行传输之类的优化就有用了。

据我所知,2.3.0 在 object store 上有一些新的优化和 bug fix,在大一点的数据上更加优秀倒是 make sense 的,我建议你要是使用的话应该吧 RAY_max_direct_call_object_size 调整到一个合理的值。

另外,看你的测试表我有一点不是很懂,就是这个 local 下面的 diff node 是啥意思?

这个逻辑是讲不通的。
理论上越小的数据,差距应该越小。 因为大家都是通信开销,不应该出现有复杂的控制逻辑。

Local下面的Same Node指发送方和接收方在同一个“节点"内,共享同一个对象池,Diff Node指发送方和接收方不在同一个"节点”内,对象池不能共享

这里应该是数据越小(在大于 inline 阈值的前提下)差距越大,ray 这里需要进 object store,控制流就会变长,因为要和 raylet 通信,但是 mq 没有这段 “和 raylet 进程通信” 的额外开销,数据越小,通信开销占的比例就越大

1B 2B为什么会进plamsa?

我在使用了 环境变量 RAY_max_direct_call_object_size 来调整直传大小之后发现并没有特别大的改善,在1Mb数据场景下基本还是mq的两倍耗时。
这是我的启动命令

RAY_max_direct_call_object_size=1048576 python /root/ray/remote/remote_ray.py

以下是我分别尝试了不同大小的RAY_max_direct_call_object_size条件下,在1Mb,2Mb的数据传输速度比较

这个object在序列化之后我们会给他一个meta data,而且np.ndarray也是有metadata的,你给个2MB吧

@jovany-wang 论坛设置了不能连续恢复两次。。。,在这里说一下,他这里最小的都是1MB,他测试的message全都大于 inline 阈值

2Mb 4Mb 都给过了 表中ray下的四列对应了RAY_max_direct_call_object_size分别为100b、1Mb、2Mb 4Mb。测的时候发现ray的波动比较大,会有上下20%的浮动,但是都是zmq耗时的2倍左右。

我看了一下代码,得在启动 node 节点的时候也要加上环境变量 RAY_max_direct_call_object_size,worker 也是通过 rpc 来获取配置的

1 Like

感谢回复,在修改了启动命令后,1、2Mb的传输速度和mq差不多了。

1 Like

所以理论上现在是可以去除zeromq的依赖了对吗?

理论上是的