You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey guys! I might be mistaken but I think the way the samplers are implemented, if using a distributed backend (such as ddp, ddp-sharded), samples the same examples for all the accelerators (gpus).
Instead of inheriting from a torch.utils.data.Sampler I suggest to inherit from a torch.utils.data.distributed.DistributedSampler and partition the data across accelerators doing something like this:
classRandomSampler(DistributedSampler):
r""" Implementation of a Random Sampler for sampling the dataset. Args: data_source (torch.utils.data.Dataset): dataset to sample from batch_size (int): size of batch drop_last (bool): flat indication whether to drop last batch or not """def__init__(self, data_source, batch_size: int=32, drop_last: bool=True) ->None:
super(RandomSampler, self).__init__(data_source, drop_last=drop_last)
self.data_source=data_sourceself.batch_size=batch_sizeids=list(range(0, len(data_source)))
start=int(len(data_source)*self.rank/self.num_replicas)
end=int(len(data_source)*(self.rank+1)/self.num_replicas)
self.bins= [ids[i:i+batch_size] foriinrange(start, end, batch_size)]
self.drop_last=drop_lastdef__iter__(self):
foridsinself.bins:
yieldidsdef__len__(self):
returnlen(self.bins)
The text was updated successfully, but these errors were encountered:
Hey guys! I might be mistaken but I think the way the samplers are implemented, if using a distributed backend (such as ddp, ddp-sharded), samples the same examples for all the accelerators (gpus).
Instead of inheriting from a
torch.utils.data.Sampler
I suggest to inherit from atorch.utils.data.distributed.DistributedSampler
and partition the data across accelerators doing something like this:The text was updated successfully, but these errors were encountered: