diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 59ab27a9c39989e347678a4d156c6b5be24feaec..ef1e78fb0eb6fae303450ce97cb921ae5b9becba 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -17,16 +17,11 @@ the class for Worker import os import socket -import warnings from dataclasses import dataclass import ray -from verl.utils.device import ( - get_torch_device, - get_visible_devices_keyword, - is_npu_available, -) +from verl.utils.device import get_torch_device, get_visible_devices_keyword from .decorator import Dispatch, Execute, register @@ -62,13 +57,6 @@ class WorkerHelper: return sock.getsockname()[1] def get_availale_master_addr_port(self): - warnings.warn( - "This function is deprecated due to typo in name; Please use `get_available_master_addr_port` instead", - stacklevel=2, - ) - return self.get_available_master_addr_port() - - def get_available_master_addr_port(self): return self._get_node_ip().strip("[]"), str(self._get_free_port()) @@ -83,52 +71,56 @@ class Worker(WorkerHelper): fused_worker_attr_name = "fused_worker_dict" - def _register_dispatch_collect_info(self, mesh_name: str, dp_rank: int, is_collect: bool): - """Register the dp_rank for a given mesh name. This function is meant to be called by the worker + def __new__(cls, *args, **kwargs): + """Create a new Worker instance with proper initialization based on environment settings.""" + instance = super().__new__(cls) - Args: - mesh_name (str): - Name of the mesh to register dp_rank for. - dp_rank (int): - dp_rank to register for the given mesh name. - is_collect (bool): - Whether the dp_rank is used for collect. - """ - if mesh_name in self.__dispatch_dp_rank or mesh_name in self.__collect_dp_rank: - raise ValueError(f"mesh_name {mesh_name} has been registered") - self.__dispatch_dp_rank[mesh_name] = dp_rank - self.__collect_dp_rank[mesh_name] = is_collect + # note that here we use int to distinguish + disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) + if disable_worker_init: + return instance - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def _query_dispatch_info(self, mesh_name: str): - """Query the dispatch info for a given mesh name. + rank = os.environ.get("RANK", None) + worker_group_prefix = os.environ.get("WG_PREFIX", None) - Args: - mesh_name (str): - Name of the mesh to query dispatch info for. + # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init + if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: + instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) - Returns: - int: - The dp_rank for the given mesh name. - """ - assert mesh_name in self.__dispatch_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}" - # note that each rank store its own dp_rank - return self.__dispatch_dp_rank[mesh_name] + return instance - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def _query_collect_info(self, mesh_name: str): - """Query the collect info for a given mesh name. + def _configure_before_init(self, register_center_name: str, rank: int): + """Configure worker settings before initialization. Args: - mesh_name (str): - Name of the mesh to query collect info for. - - Returns: - bool: - Whether the dp_rank is used for collect. + register_center_name (str): + Name of the register center Ray actor for worker coordination + rank (int): + Rank of the worker in the distributed setup """ - assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}" - return self.__collect_dp_rank[mesh_name] + assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" + + if rank == 0: + master_addr, master_port = self.get_availale_master_addr_port() + rank_zero_info = { + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + } + + if os.getenv("WG_BACKEND", None) == "ray": + from verl.single_controller.base.register_center.ray import \ + create_worker_group_register_center + + self.register_center = create_worker_group_register_center( + name=register_center_name, info=rank_zero_info + ) + + os.environ.update(rank_zero_info) + else: + self.register_center = ray.get_actor(register_center_name) + + # set worker info for node affinity scheduling + ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) @classmethod def env_keys(cls): @@ -154,6 +146,8 @@ class Worker(WorkerHelper): # it is executed remotely import os + import torch + self._setup_env_cuda_visible_devices() world_size = int(os.environ["WORLD_SIZE"]) @@ -165,7 +159,9 @@ class Worker(WorkerHelper): master_port = os.environ["MASTER_PORT"] local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) + local_rank = int(os.getenv("RAY_LOCAL_RANK", "0")) + + print(f"world_size: {world_size}, rank: {rank}, local_world_size: {local_world_size}, local_rank: {local_rank}") store = { "_world_size": world_size, @@ -180,9 +176,9 @@ class Worker(WorkerHelper): self._configure_with_store(store=store) + torch.cuda.set_device(local_rank) + self.fused_worker_dict = {} - self.__dispatch_dp_rank = {} - self.__collect_dp_rank = {} def get_fused_worker_by_name(self, worker_name: str): """Get a fused worker by its name. @@ -199,9 +195,11 @@ class Worker(WorkerHelper): is_ray_noset_visible_devices = ray_noset_visible_devices() # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES`` - rocr_val = os.environ.get("ROCR_VISIBLE_DEVICES", None) + rocr_val = None hip_val = os.environ.get("HIP_VISIBLE_DEVICES", None) - cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES", None) + cuda_val = None + + print(is_ray_noset_visible_devices, rocr_val, hip_val, cuda_val) if hip_val: # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency. # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES @@ -215,8 +213,8 @@ class Worker(WorkerHelper): ) else: cuda_val = val - os.environ["CUDA_VISIBLE_DEVICES"] = val - # os.environ["HIP_VISIBLE_DEVICES"] = val + # os.environ["CUDA_VISIBLE_DEVICES"] = val + os.environ["HIP_VISIBLE_DEVICES"] = val if rocr_val: # You must take care if both HIP/CUDA and ROCR env vars are set as they have @@ -240,8 +238,7 @@ class Worker(WorkerHelper): # environment variable for each actor, unless # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, # so we need to set local rank when the flag is set. - device_name = "NPU" if is_npu_available else "GPU" - local_rank = ray.get_runtime_context().get_accelerator_ids()[device_name][0] + local_rank = os.environ.get("RAY_LOCAL_RANK") os.environ["LOCAL_RANK"] = local_rank get_torch_device().set_device(int(local_rank))