반응형
에러가 일어난 곳: https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/train
Openflamingo model을 gloo로 설정해서 돌리는데 해당 부분에서 에러가 일어났다.
검색을 해보니, _allgather_base함수가 nccl에서 작동되지 않는다고 한다...
어떻게 해결해야할까!?
Method: [torch_list] 형태로 all_gather 이용하기
해당 문제를 풀기 위해서는 일단 _exec_order_utils.py라는 system file로 들어와야 한다.
그리고 해당 파일에서 밑의 코드와 같은 부분을 찾을 수 있다!
world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)
local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload]
dist._allgather_base(
world_num_valid_indices,
local_num_valid_indices,
group=self.process_group,
)
해당 코드는 nccl에서는 작동하지만, gloo에서는 작동하지 않는다.
따라서 밑의 코드를 아래와 같이 변경해줘야 한다.
world_num_valid_indices = [torch.zeros(self.world_size, **tensor_kwargs) for _ in range(1)]
local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload]
dist.all_gather(
world_num_valid_indices,
local_num_valid_indices,
group=self.process_group,
)
일단 all_gather 함수를 보면 input type이 tensor_list가 기댓값이다.
따라서 world_num_valid_indices를 list()를 이용해서 tensor를 저장해야한다.
*추가로, range(1)에서 숫자는 할당되는 rank 숫자이다. 저는 단일 GPU라서 1로 설정했습니다.
또한, gloo에서 적용가능하도록 dist.all_gather()로 변경해주면 된다.
2023.07.25 Kyujinpy 작성
반응형