[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
2021/8/18 9:06:07
本文主要是介绍[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
目录- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
- 0x00 摘要
- 0x01 前情回顾
- 0x02 DataLoader
- 2.1 初始化
- 2.2 关键函数
- 2.3 单进程加载
- 2.3.1 区分生成
- 2.3.2 迭代器基类
- 2.3.3 单进程迭代器
- 2.3.4 获取样本
- 2.4 多进程加载
- 2.4.1 总体逻辑
- 2.4.2 初始化
- 2.4.3 业务重置
- 2.4.4 获取 index
- 2.4.5 worker主函数
- 2.4.6 Pin memory thread
- 2.4.7 用户获取data
- 2.4.8 小结
- 2.5 Pipleline
- 0xFF 参考
0x00 摘要
为了更好的介绍参数服务器Paracel的数据加载,我们临时插入两篇PyTorch的数据加载,主要是从分布式的角度进行切入。本文只算是开胃甜点,后续会有专门系列分析PyTorch分布式。
参数服务器系列其他文章如下:
[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice
[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van
[源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer
[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现
[源码解析] 机器学习参数服务器 Paracel (1)-----总体架构
[源码解析] 机器学习参数服务器 Paracel (2)--------SSP控制协议实现
[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
0x01 前情回顾
关于数据加载,上回书我们说到了 DistributedSampler,本文接下来就进行 DataLoader的分析。
为了更好说明,我们首先给出上文的流水线图,本文会对这个图进行细化。
+------------+ +--------+ | | | | | Process 1 | | Data 1 +--------> | +------+ | | | Load Data | | +--------+ | | | +------------+ | | | | +------------+ | +-----------------------------------+ +--------+ | | | | | | | | Process 2 | +------> | Pin-memory process | | Data 2 +--------> | | | | | | | Load Data +-------------> | | +--------+ | | | Transfer to Pinned Memory | +------------+ +-----> | | | | | | +-----------------------------------+ | +--------+ +------------+ | | | | | | | Data 3 +--------> | Process 3 +-------+ | | | | +--------+ | Load Data | | | +------------+
其次,我们再看看数据加载总体逻辑,具体如下图,简要说就是:
- DataSet 把数据集数目发给DistributedSampler。
- Sampler 按照某种规则生成数据indices并发送给DataLoader。
- DataLoader 依据indices来从DataSet之中加载数据(其内部的DataLoaderIter对象负责协调单进程/多进程加载Dataset)。
- DataLoader 把数据发给模型,进行训练。
+------------------------+ +-----------+ |DistributedSampler | |DataLoader | | | 2 indices | | | Some strategy +-------------------> | | | | | | |-------------+----------| | | ^ | | 4 data +-------+ | | -------------->+ train | 1 | length | | +-------+ | | | +-------------+----------+ | | |DataSet | | | | +---------+ | 3 Load | | | | Data +-------------------------> | | | +---------+ | | | | | | | +------------------------+ +-----------+
接下来,我们就正式进入 DataLoader。
0x02 DataLoader
DataLoader的作用是:结合Dataset和Sampler之后,在数据集上提供了一个迭代器。
可以这么理解:
DataSet 是原始数据,Sampler 提供了如何切分数据的策略(或者说是提供了切分数据的维度),DataLoader就是依据策略来具体打工干活的,其中单进程加载就是一个人干活,多进程加载就是多拉几个人一起干活。
2.1 初始化
初始化的主要参数如下:
- dataset (Dataset) :所加载的数据集。
- batch_size (int, optional) :每个批次加载多少个样本。
- shuffle (bool, optional) :如果为 True,则每个epoch 都会再打乱数据。
- sampler (Sampler or Iterable, optional) :定义了如何从样本采样的策略。可以是任何实现了
__len__
的迭代器。 - batch_sampler (Sampler or Iterable, optional) :与
sampler
类似,但是每次返回一个批次的数据索引。 - num_workers (int, optional) :数据加载的子进程数目。如果是 0,表示从主进程加载数据。
- collate_fn (callable, optional):从一个小批次( mini-batch)张量中合并出一个样本列表。当从 map-style 数据集做批量加载时候使用。
- pin_memory (bool, optional) : 如果为true,则在返回张量之前把张量拷贝到CUDA固定内存之中。
- drop_last (bool, optional) :当数据集不能被均匀分割时,如果为true,丢掉最后一个不完整的批次。如果为False,那么最后一个批次的数据较小。
- timeout (numeric, optional): 如果是整数,则是worker收集批次数据的超时值。
- worker_init_fn (callable, optional):如果非空,则会在seeding和数据加载之前被每个子进程调用,以Iworker id (
[0, num_workers - 1]
)作为输入参数。 - generator (torch.Generator, optional):如果非空,则被RandomSampler 用来产生随机索引,也被多进程用来产生
base_seed
。 - prefetch_factor (int, optional, keyword-only arg):每个 worker 提前加载 的 sample 数量。
- persistent_workers (bool, optional):如果为
True
, 则在消费一次之后,data loader也 不会关掉worker进程。这允许workerDataset
实例维持活动状态。
具体初始化代码如下,主要就是各种设置,为了更好的说明,去除了异常处理代码:
class DataLoader(Generic[T_co]): dataset: Dataset[T_co] batch_size: Optional[int] num_workers: int pin_memory: bool drop_last: bool timeout: float sampler: Sampler prefetch_factor: int _iterator : Optional['_BaseDataLoaderIter'] __initialized = False def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): torch._C._log_api_usage_once("python.data_loader") self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory self.timeout = timeout self.worker_init_fn = worker_init_fn self.multiprocessing_context = multiprocessing_context if isinstance(dataset, IterableDataset): self._dataset_kind = _DatasetKind.Iterable # 省略异常处理 else: self._dataset_kind = _DatasetKind.Map if batch_sampler is not None: # auto_collation with custom batch_sampler # 省略异常处理 batch_size = None drop_last = False elif batch_size is None: # no auto_collation if drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with drop_last') if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset, generator=generator) else: sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler self.generator = generator if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn self.persistent_workers = persistent_workers self.__initialized = True self._IterableDataset_len_called = None self._iterator = None self.check_worker_number_rationality()
2.2 关键函数
这里关键函数之一就是_index_sampler,用来让迭代器调用sampler,我们接下来就会讲到
@property def _index_sampler(self): # The actual sampler used for generating indices for `_DatasetFetcher` # (see _utils/fetch.py) to read data at each time. This would be # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. # We can't change `.sampler` and `.batch_sampler` attributes for BC # reasons. if self._auto_collation: return self.batch_sampler else: return self.sampler
2.3 单进程加载
单进程模式下,Data Loader会在计算进程内加载数据,所以加载过程中可能会阻塞计算。
for 语句会调用enumerate 会返回一个迭代器,以此来遍历数据集。在eumerate之中,dataloader 的 __next__(self)
方法会被调用,逐一获取下一个对象,从而遍历数据集。
cuda0 = torch.device('cuda:0') # CUDA GPU 0 for i, x in enumerate(train_loader): x = x.to(cuda0)
2.3.1 区分生成
当多进程加载时候,在DataLoader声明周期之中,迭代器只被建立一次,这样worker可以重用迭代器。
在单进程加载时候,应该每次生成,以避免重置状态。
def __iter__(self) -> '_BaseDataLoaderIter': if self.persistent_workers and self.num_workers > 0: # 如果是多进程或者设置了持久化 if self._iterator is None: # 如果没有,才会新生成 self._iterator = self._get_iterator() else: self._iterator._reset(self) return self._iterator else: # 单进程 return self._get_iterator() # 每次都直接生成新的
具体会依据是否是多进程来区别生成。
def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self)
2.3.2 迭代器基类
_BaseDataLoaderIter 是迭代器基类,我们挑选关键函数看看。
这里关键成员变量就是:
- _index_sampler:这里设置了loader 的 sampler,所以迭代器可以据此获取采样策略。
- _sampler_iter:得到 sampler 的迭代器。
class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: # 初始化参数 self._dataset = loader.dataset self._dataset_kind = loader._dataset_kind self._IterableDataset_len_called = loader._IterableDataset_len_called self._auto_collation = loader._auto_collation self._drop_last = loader.drop_last self._index_sampler = loader._index_sampler # 得到采样策略 self._num_workers = loader.num_workers self._prefetch_factor = loader.prefetch_factor self._pin_memory = loader.pin_memory and torch.cuda.is_available() self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器 self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) def __next__(self) -> Any: with torch.autograd.profiler.record_function(self._profile_name): if self._sampler_iter is None: self._reset() data = self._next_data() # 获取数据 self._num_yielded += 1 if self._dataset_kind == _DatasetKind.Iterable and \ self._IterableDataset_len_called is not None and \ self._num_yielded > self._IterableDataset_len_called: # 忽略错误提示处理 warnings.warn(warn_msg) return data
2.3.3 单进程迭代器
_SingleProcessDataLoaderIter
继承了 _BaseDataLoaderIter
,可以看到,其增加了 _dataset_fetcher
,在构造时候传入了 _collate_fn
等各种参数。
回忆下,__next__
会调用 self._next_data()
获取数据,而在这里,_next_data
就会:
- 使用
self._next_index()
,其又会使用_sampler_iter
(采样器的迭代器)来获取indices 。 - 使用
self._dataset_fetcher.fetch(index)
来依据indices获取数据。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 # 获取样本方法 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def _next_data(self): index = self._next_index() # may raise StopIteration # 获取样本 data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data def _next_index(self): # 得到indices return next(self._sampler_iter) # may raise StopIteration
2.3.4 获取样本
我们接下来看看如何获取样本。就是通过索引传入 fetcher,从而获取想要的样本。
fetcher生成如下,这是在_SingleProcessDataLoaderIter初始化时候生成的:
class _DatasetKind(object): Map = 0 Iterable = 1 @staticmethod def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): if kind == _DatasetKind.Map: return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) else: return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
对于Map-style,就使用 _MapDatasetFetcher 处理,就是使用 possibly_batched_index 从数据集之中提取数据,possibly_batched_index 是key。
如果有batch sampler,就使用 batch sampler。
如果需要从一个小批次( mini-batch)张量中合并出一个样本列表。就使用 collate_fn后处理。
class _MapDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) def fetch(self, possibly_batched_index): if self.auto_collation: # 如果配置了batch_sampler,_auto_collation就为True, # 那么就优先使用batch_sampler,此时fetcher中传入的就是一个batch的索引 data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
对于 Iterable-style,因为 __init__
方法内设置了 dataset 初始的迭代器,所以在fetch 方法内获取元素的时候,如果是常规 sampler,index 其实已经不起作用,直接从dataset迭代器获取。如果是batch sampler,则index有效果。
class _IterableDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, possibly_batched_index): if self.auto_collation: # 即auto_collation为True,表示使用batch_sampler。 # 则使用possibly_batched_index,获取1个batch大小的样本 data = [] for _ in possibly_batched_index: try: data.append(next(self.dataset_iter)) except StopIteration: break if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)): raise StopIteration else: # sampler则直接往后遍历,提取1个样本 data = next(self.dataset_iter) return self.collate_fn(data)
此时总逻辑如下:
+--------------------------+ +-------------------------------+ | DataLoader | | _SingleProcessDataLoaderIter | | | | | | | | __next__ | +---------------+ Sampler | | | | | | | _next_data +-----------+ | | Dataset | | | | | | | | _next_index | | | | __iter__ | | | | | | | | _index_sampler | | | | _get_iterator +--------------> | + | | | | | | | | | | +--------------------------+ +-------------------------------+ | | | | | | | | | | | | | | | | | +----------------------------+ | | | |Sampler | | | +------------------------> | | <------+ | | | | | | | | | | +----------------------------+ | | | +----------------------------+ | |_BaseDatasetFetcher | | | | | | | | | dataset | | | | <----------------------+ | collate_fn | | | +----------------------------+
动态流程如下:
User DataLoader _SingleProcessDataLoaderIter _DatasetKind Sampler + + + + + | | | | | | 1 | | | | enumerate--------> __iter__ | | | | + | | | | | | | | | | | | | | | 2 v 3 v | | _get_iterator--------> __init__ +----------> create_fetcher | | 4 | + + | | <-----------------+ | | | | iterator | | | | | | 5 | | | for loop +------------------------------> __next__ | | | | | | | | | | | | | | | | | | | _next_data | | | | | | | | | | | | | | | 6 next | | | | _next_index +-------------------------> | | | | | | | | | <---------------------------------+ | | | 7 index | | | | | | | | | | | | | | | 8 fetch(index) | | | | | +--------------------> | | | | | | | | | | <---------------------+ | | | | 9 data | | | <-------------------------------------+ | | | 10 data | | | | | | | | | v v v v v
2.4 多进程加载
为了加速,PyTorch提供了多进程下载,只要把将参数 num_workers
设置为正整数,系统就会相应生成多进程处理,在这种模式下,每个worker都是一个独立进程。
由上节我们可以知道,_SingleProcessDataLoaderIter 是单进程加载数据的核心,loader通过它来与sampler,dataset交互。在多进程中,这个核心对应的就是 _MultiProcessingDataLoaderIter。
def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self)
我们接下来就从 _MultiProcessingDataLoaderIter 开始分析。
2.4.1 总体逻辑
_MultiProcessingDataLoaderIter 中的注释十分详尽,值得大家深读,而且给出了逻辑流程图如下,其基本流程是围绕着三个queue进行的:
- 主进程把需要获取的数据 index 放入index_queue,这是指定子进程需要获取哪些数据的队列。同时也给子进程传入结果队列,关于结果队列,有两个分支:
- 如果设置了pin memory,则传入的是 worker_result_queue。
- 否则传入 data_queue。
- 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue,这是向主进程返回结果的队列。
- 主进程进行处理,这里有两个分支:
- 如果设置了pin memory,则主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue,这是处理结果的队列。
- 如果不需要pin memory,则结果已经存在 data_queue 之中,不做新操作。
可以看到,每个进程的输入是一个队列index_queue ,输出也是一个队列worker_result_queue。主进程和子进程通过这2~3个 queue 联系了起来,从而达到解耦合和加速的作用。
# NOTE [ Data Loader Multiprocessing Shutdown Logic ] # # Preliminary: # # Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/ # # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if # `pin_memory=False`.
具体如下图所示,如果不需要 pin memory,则为:
+-----------+ indices -------------+ indices | Worker | Data +--------->+index queue +-------->+ Process +------+ | | | | | | | -------------+ +-----------+ | | | +------------+ | | | | +---------+ | +---> | | Main | | indices -------------+ indices +-----------+ | | | Process +------------>+index queue +-------->+ Worker | Data | Data Queue | | | | | | | Process +----------> | +---------+ | -------------+ | | | | | +-----------+ +---> | | | +------------+ | | | indices -------------+ indices +-----------+ | +--------->+index queue +-------->+ Worker | Data | | | | Process +------+ -------------+ | | +-----------+
当有pin memory时候,则是先进入 result queue,然后 pin_memory_thread 处理之后会转入到 data queue:
+-----------+ indices -------------+ indices | Worker | Data +--------->+index queue +-------->+ Process +------+ | | | | | | | -------------+ +-----------+ | | | --------------+ | | | | +---------+ | +---> | | Main | | indices -------------+ indices +-----------+ | | | Process +------------>+index queue +-------->+ Worker | Data | result_queue| | | | | | | Process +----------> | +---------+ | -------------+ | | | | | +-----------+ +---> | | | ---------+----+ | | | | indices -------------+ indices +-----------+ | | +--------->+index queue +-------->+ Worker | Data | +---------+--------+ | | | Process +------+ | pin_memory_thread| -------------+ | | | | | +-----------+ | | | | | | +------------------+ | | | v +-----+------+ | Data Queue | | | +------------+
2.4.2 初始化
初始化函数如下,主要是:
- 配置,生成各种成员变量,配置各种queue。
- 启动各个子进程。
- 启动主进程中的pin_memory的线程。
主要成员变量为:
_index_queues
: 这是一个queue 列表,列表的每一个元素是一个 queue,就是每个子进程的队列需要处理的数据index,每个子进程对应一个 queue。_worker_result_queue
: 子进程处理完的 (idx, data)。data_queue
: 经过主进程 pin_memory 线程处理之后的数据队列,如果不需要pin,则直接会使用_worker_result_queue
。_worker_queue_idx_cycle
用以找出下一个工作的worker。
具体代码如下:
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader): super(_MultiProcessingDataLoaderIter, self).__init__(loader) assert self._num_workers > 0 assert self._prefetch_factor > 0 if loader.multiprocessing_context is None: multiprocessing_context = multiprocessing else: multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) # No certainty which module multiprocessing_context is self._worker_result_queue = multiprocessing_context.Queue() # 子进程输出,读取完数据的index self._worker_pids_set = False self._shutdown = False self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] # 子进程输入,需读取数据的index self._workers = [] for i in range(self._num_workers): # No certainty which module multiprocessing_context is index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] # Need to `cancel_join_thread` here! # See sections (2) and (3b) above. index_queue.cancel_join_thread() w = multiprocessing_context.Process( target=_utils.worker._worker_loop, # worker进程主函数,把各种queue和函数传进去 args=(self._dataset_kind, self._dataset, index_queue, self._worker_result_queue, self._workers_done_event, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed, self._worker_init_fn, i, self._num_workers, self._persistent_workers)) w.daemon = True w.start() self._index_queues.append(index_queue) # 把这个worker对应的index_queue放到主进程这里存起来,以后就可以交互了 self._workers.append(w) if self._pin_memory: self._pin_memory_thread_done_event = threading.Event() # Queue is not type-annotated self._data_queue = queue.Queue() # pin 处理之后的数据结果 pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=(self._worker_result_queue, self._data_queue, torch.cuda.current_device(), self._pin_memory_thread_done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: self._data_queue = self._worker_result_queue # 如果不需要pin,则直接使用_worker_result_queue # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) # 继续完善业务
2.4.3 业务重置
__init__
函数最后会调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。
上小节函数中,已经启动了worker子进程,但是没有分配任务,所以_reset函数会进行任务分配,预取。
_MultiProcessingDataLoaderIter有如下 flag 参数来协调各个 worker (包括各种queue)之间的工作:
-
_send_idx
: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx -
_rcvd_idx
: 接受索引,记录要从 data_queue 中取出的 batch 的 idx -
_task_info
: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整型索引),value 为(worker_id,)
或(worker_id, data)
,分别对应数据 未取 和 已取 的情况 -
_tasks_outstanding
: 整型,代表已经准备好的 task/batch 的数量(可能有些正在准备中) -
_send_idx
: 发送索引,记录下一次要放 index_queue 中 task batch 的 idx。 -
_rcvd_idx
: 接受索引,记录下一次要从 data_queue 中取出的 task batch 的 idx。_send_idx
和_rcvd_idx
主要用来进行流量控制和确保接受索引有意义。 -
_task_info
: 存储将要产生的 data 信息的 dict,key为 task batch idx(由 0 开始的整型索引),value 为(worker_id,)
或(worker_id, data)
,分别对应数据 未取 和 已取 的情况。_task_info
的作用是依据 task batch idx 获取对应的 worker id 和暂存乱序数据。 -
_tasks_outstanding
: 整型,正在准备的 task/batch 的数量,实际上就是进行一些确认工作,没有太实际的意义。
对于加载数据,每个 worker 一次产生一个 batch 的数据,返回 batch 数据前,会放入下一个批次要处理的数据下标,所以 reset 函数会把 _send_idx
,_rcvd_idx
都恢复成0,这样下次迭代就可以重新处理。
在 reset 方法最后,有一个预取数据操作。我们会在后面结合乱序处理进行讲解。
def _reset(self, loader, first_iter=False): super()._reset(loader, first_iter) self._send_idx = 0 # idx of the next task to be sent to workers self._rcvd_idx = 0 # idx of the next task to be returned in __next__ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. # 每个worker的状态 self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled if not first_iter: for idx in range(self._num_workers): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, return_data = self._get_data() if isinstance(return_idx, _utils.worker._ResumeIteration): assert return_data is None resume_iteration_cnt -= 1 # prime the prefetch loop # 预取若干index,目的是为了配合后续的乱序处理。 for _ in range(self._prefetch_factor * self._num_workers): self._try_put_index()
2.4.4 获取 index
_try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。
- 从sampler获取下一批次的index。
- 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
- 并且调整主进程的信息。
def _next_index(self): # 定义在基类 _BaseDataLoaderIter 之中,就是获取下一批index return next(self._sampler_iter) # may raise StopIteration def _try_put_index(self): assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try: index = self._next_index() # 获取下一批index except StopIteration: return for _ in range(self._num_workers): # find the next active worker, if any worker_queue_idx = next(self._worker_queue_idx_cycle) if self._workers_status[worker_queue_idx]: # 如果已经工作,就继续找 break else: # not found (i.e., didn't break) return # 以下是主进程进行相关记录 # 给下一个工作worker放入 (任务index, 数据index), 就是给queue放入数据,所以worker loop之中就立刻会从queue中得到index,从而开始获取数据。 self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 记录 将要产生的 data 信息 self._task_info[self._send_idx] = (worker_queue_idx,) # 正在处理的batch个数+1 self._tasks_outstanding += 1 # send_idx 记录从sample_iter中发送索引到index_queue的次数 self._send_idx += 1 # 递增下一批发送的task index
2.4.5 worker主函数
_worker_loop 是 worker进程的主函数,主要逻辑如其注释所示:
# [ worker processes ] # While loader process is alive: # Get from `index_queue`. # If get anything else, # Check `workers_done_event`. # If set, continue to next iteration # i.e., keep getting until see the `None`, then exit. # Otherwise, process data: # If is fetching from an `IterableDataset` and the iterator # is exhausted, send an `_IterableDatasetStopIteration` # object to signal iteration end. The main process, upon # receiving such an object, will send `None` to this # worker and not use the corresponding `index_queue` # anymore. # If timed out, # No matter `workers_done_event` is set (still need to see `None`) # or not, must continue to next iteration. # (outside loop) # If `workers_done_event` is set, (this can be False with `IterableDataset`) # `data_queue.cancel_join_thread()`. (Everything is ending here: # main process won't read from it; # other workers will also call # `cancel_join_thread`.)
就是通过index_queue, data_queue与主进程交互。
- 从 index_queue 获取新的数据index;
- 如果没有设置本worker结束,就使用 fetcher获取数据。
- 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue, 否则传入 data_queue。
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event, auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id, num_workers, persistent_workers): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. try: # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal had already happened # again. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) seed = base_seed + worker_id random.seed(seed) torch.manual_seed(seed) if HAS_NUMPY: np_seed = _generate_state(base_seed, worker_id) import numpy as np np.random.seed(np_seed) global _worker_info _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) from torch.utils.data import _DatasetKind init_exception = None try: if init_fn is not None: init_fn(worker_id) fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) except Exception: init_exception = ExceptionWrapper( where="in DataLoader worker process {}".format(worker_id)) iteration_end = False watchdog = ManagerWatchdog() while watchdog.is_alive(): # 等待在这里 try: # _try_put_index 如果放入了数据index,这里就被激活,开始工作 r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if isinstance(r, _ResumeIteration): # Acknowledge the main process data_queue.put((r, None)) iteration_end = False # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher( dataset_kind, dataset, auto_collation, collate_fn, drop_last) continue elif r is None: # Received the final signal assert done_event.is_set() or iteration_end break elif done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. continue idx, index = r data: Union[_IterableDatasetStopIteration, ExceptionWrapper] if init_exception is not None: data = init_exception init_exception = None else: try: data = fetcher.fetch(index) except Exception as e: # 省略处理代码 data_queue.put((idx, data)) # 放入数据,通知主进程 del data, idx, index, r # save memory except KeyboardInterrupt: # Main process will raise KeyboardInterrupt anyways. pass if done_event.is_set(): data_queue.cancel_join_thread() data_queue.close()
2.4.6 Pin memory thread
在主进程之中,如果设置了需要pin memory,主进程的 pin_memory_thread 会从 worker_result_queue 读取数据,进行处理(加速CPU和GPU的数据拷贝),把结果放入 data_queue。
# [ pin_memory_thread ] # # No need to check main thread. If this thread is alive, the main loader # # thread must be alive, because this thread is set as daemonic. # While `pin_memory_thread_done_event` is not set: # Get from `index_queue`. # If timed out, continue to get in the next iteration. # Otherwise, process data. # While `pin_memory_thread_done_event` is not set: # Put processed data to `data_queue` (a `queue.Queue` with blocking put) # If timed out, continue to put in the next iteration. # Otherwise, break, i.e., continuing to the out loop. # # NOTE: we don't check the status of the main thread because # 1. if the process is killed by fatal signal, `pin_memory_thread` # ends. # 2. in other cases, either the cleaning-up in __del__ or the # automatic exit of daemonic thread will take care of it. # This won't busy-wait either because `.get(timeout)` does not # busy-wait.
具体代码如下:
def _pin_memory_loop(in_queue, out_queue, device_id, done_event): # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. torch.set_num_threads(1) torch.cuda.set_device(device_id) # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. while not done_event.is_set(): try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue idx, data = r if not done_event.is_set() and not isinstance(data, ExceptionWrapper): data = pin_memory(data) # 省略异常处理代码 r = (idx, data) while not done_event.is_set(): try: out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) break except queue.Full: continue del r # save memory def pin_memory(data): if isinstance(data, torch.Tensor): return data.pin_memory() elif isinstance(data, string_classes): return data elif isinstance(data, collections.abc.Mapping): return {k: pin_memory(sample) for k, sample in data.items()} elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return type(data)(*(pin_memory(sample) for sample in data)) elif isinstance(data, collections.abc.Sequence): return [pin_memory(sample) for sample in data] elif hasattr(data, "pin_memory"): return data.pin_memory() else: return data
2.4.7 用户获取data
现在数据已经加载完毕,我们接下来看用户如何从DataLoader之中获取数据。
这里有一个很关键的地方:如何保持在不同实验之中数据读取顺序的一致性。为了让多次实验之间可以比对,就需要尽量保证在这些实验中,每次读取数据的顺序都是一致的,这样才不会因为数据原因造成结果的误差。
打破顺序一致性的最大可能就是乱序数据。而造成乱序问题的原因就是:多进程读取,可能某个进程快,某个进程慢。比如,用户这次需要读取6-19,16-26,37-46。但是某一个worker慢,6-19不能即时返回,另一个worker 的 16-26 先返回了,于是就会造成乱序。
如何处理乱序数据?PyTorch的具体做法就是:DataLoader严格按照Sampler的顺序返回数据。如果某一个数据是乱序的,则会把它暂存起来,转而去获取下一个数据,见下面代码中 "store out-of-order samples" 注释处。等到应该返回时候(这个数据顺序到了)才返回。
但是其风险就是数据返回会比当前请求慢,比如应该获取 6,但是Data queue里面没有这个数据,只有 16,27,于是用户只能等待 6 加载完成。
解决慢的方法是:预取(prefetch)。就是在reset方法最后,提前提取若干index,让DataLoader提前去取,这虽然不能保证任意两次训练的数据返回顺序完全一致,但是可以最大限度保证。
具体代码如下,首先,回忆基类的 __next__
函数 ,可以看到其调用了 _next_data 获取数据。
class _BaseDataLoaderIter(object): def __next__(self) -> Any: with torch.autograd.profiler.record_function(self._profile_name): if self._sampler_iter is None: self._reset() data = self._next_data() # 获取数据 self._num_yielded += 1 if self._dataset_kind == _DatasetKind.Iterable and \ self._IterableDataset_len_called is not None and \ self._num_yielded > self._IterableDataset_len_called: # 忽略错误提示处理 warnings.warn(warn_msg) return data
所以,我们要看 _MultiProcessingDataLoaderIter
的_next_data
。
- 因为之前有预取了index,worker进程已经开始获取数据,所以主进程此时可以得到数据,如果没有数据,就继续while True等待。
- 如果获取成功,则使用
_process_data
设定下一次的indx,准备下一次迭代。 - 通过
_task_info
来记录乱序数据,如果暂时无法处理,就在这里保存。
def _next_data(self): while True: # If the worker responsible for `self._rcvd_idx` has already ended # and was unable to fulfill this task (due to exhausting an `IterableDataset`), # we try to advance `self._rcvd_idx` to find the next valid index. # # This part needs to run in the loop because both the `self._get_data()` # call and `_IterableDatasetStopIteration` check below can mark # extra worker(s) as dead. # 找到待取idx while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx info = self._task_info[self._rcvd_idx] worker_id = info[0] if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active break # 有数据或者正在工作,就跳出内部这个while del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) if not self._persistent_workers: self._shutdown_workers() raise StopIteration # Now `self._rcvd_idx` is the batch index we want to fetch # Check if the next sample has already been generated if len(self._task_info[self._rcvd_idx]) == 2: data = self._task_info.pop(self._rcvd_idx)[1] return self._process_data(data) # 设定下一次的indx,进行下一次迭代 assert not self._shutdown and self._tasks_outstanding > 0 idx, data = self._get_data() # 从 self._data_queue 中取数据 self._tasks_outstanding -= 1 # 正在准备的batch个数需要减1 if self._dataset_kind == _DatasetKind.Iterable: # Check for _IterableDatasetStopIteration if isinstance(data, _utils.worker._IterableDatasetStopIteration): if self._persistent_workers: self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) self._try_put_index() continue if idx != self._rcvd_idx: # 乱序数据 # store out-of-order samples self._task_info[idx] += (data,) else: del self._task_info[idx] # 正常数据 return self._process_data(data) # 设定下一次的indx,进行下一次迭代
其次,我们看看 _get_data
如何从 self._data_queue
中取数据。具体是使用 _try_get_data 来提取。
- 如果有超时配置,就按照超时读取。
- 如果设置了pin memory,则从pin 线程处理之后的数据读取。
- 否则循环读取worker处理的数据,直至获取到数据为止。
def _get_data(self): # Fetches data from `self._data_queue`. # # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` # in a loop. This is the only mechanism to detect worker failures for # Windows. For other platforms, a SIGCHLD handler is also used for # worker failure detection. # # If `pin_memory=True`, we also need check if `pin_memory_thread` had # died at timeouts. if self._timeout > 0: # 如果有超时配置,就按照超时读取 success, data = self._try_get_data(self._timeout) if success: return data else: raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout)) elif self._pin_memory: # 从pin 线程处理之后的数据读取 while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() if success: return data else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError('Pin memory thread exited unexpectedly') # In this case, `self._data_queue` is a `queue.Queue`,. But we don't # need to call `.task_done()` because we don't use `.join()`. else: while True: success, data = self._try_get_data() # 读取worker处理的数据 if success: return data
_try_get_data
就是从 _data_queue
读取。主进程和worker进程通过queue上的put, get进行通讯交互。
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # Tries to fetch data from `self._data_queue` once for a given timeout. # This can also be used as inner loop of fetching without timeout, with # the sender status as the loop condition. # # This raises a `RuntimeError` if any worker died expectedly. This error # can come from either the SIGCHLD handler in `_utils/signal_handling.py` # (only for non-Windows platforms), or the manual check below on errors # and timeouts. # # Returns a 2-tuple: # (bool: whether successfully get data, any: data if successful else None) try: data = self._data_queue.get(timeout=timeout) return (True, data) except Exception as e: # At timeout and error, we manually check whether any worker has # failed. Note that this is the only mechanism for Windows to detect # worker failures. failed_workers = [] for worker_id, w in enumerate(self._workers): if self._workers_status[worker_id] and not w.is_alive(): failed_workers.append(w) self._mark_worker_as_unavailable(worker_id) # 省略异常处理代码 import tempfile import errno try: # Raise an exception if we are this close to the FDs limit. # Apparently, trying to open only one file is not a sufficient # test. # See NOTE [ DataLoader on Linux and open files limit ] fds_limit_margin = 10 fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] except OSError as e: # 省略异常处理代码 raise
设置下一次迭代是使用_process_data
。
def _process_data(self, data): self._rcvd_idx += 1 self._try_put_index() # 设定下一次的indx,进行下一次迭代 if isinstance(data, ExceptionWrapper): data.reraise() return data # 返回数据
2.4.8 小结
我们小结一下多进程逻辑。
总体逻辑如下:
- 主进程把需要获取的数据 index 放入index_queue。
- 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue。
- 主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue。
具体流程如下图:
- 在 _MultiProcessingDataLoaderIter 的初始化函数
__init__
之中会进行初始化:- 配置,生成各种成员变量,配置各种queue。
- 启动各个子进程。
- 启动主进程中的pin_memory的线程。
- 调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。上面已经启动了worker子进程,但是没有分配任务,所以reset函数会进行任务分配,预取。
- 接下来是一个预取操作(在看下图中一定要留意)。
- _try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。
- 使用 _next_index 从sampler获取下一批次的index。
- 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
- 并且调整主进程的信息。
- 拿到index之后,回到主线程。这里会进行数据提取。就是通过index_queue, data_queue与主进程交互。
- 从 index_queue 获取新的数据index;
- 如果没有设置本worker结束,就使用 fetcher获取数据。
- 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue,否则传入 data_queue。
- _try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。
- 当用户迭代时,调用了Loader基类的
__next__
函数 ,其调用 _next_data 从 DataLoader 之中获取数据。- 使用
_get_data
如何从self._data_queue
中取数据。 - 使用
_process_data
设置下一次迭代的 index,即使用_try_put_index
,_next_index
来进行下一轮设置。
- 使用
具体如下图:
user _MultiProcessingDataLoaderIter Sampler Queue(index_queue) Queue(data_queue) _worker_loop Fetcher + + + + + + + | | | | | | | | | | | | | | | v | | | | | | __init__ | | | | | | 1 _reset | | | | | | + | | | | | | | | | | | | | | | | | | | | v | | | | | | 2 _try_put_index next | | | | | | _next_index +------------> | | | | | | + | | | | | | | <-----------------+ | | | | | | | index | | | | | | | | | | | | | | +------------------------------------> | | | | | | put | | | get | | | | | +--------------------------------------> | | | | | | | | index | | | | | | +------------> | | next | | | | | <----------+ | +---------------------> | | | | <----------------+ data | | | | | | data | | | + | | | | | | _next_data | | | | | | 3 _get_data get | | | | | | _try_get_data +--------------------------------------------------> | | | | + | | | | | | | <----------------------------------------------------------+ | | | | | data | | | | | | + | | | | | | _process_data | | | | | | _try_put_index next | | | | | | _next_index +-------------> | | | | | | + <--------------------+ | | | | | | index | | | | | | +---------------------------------------> | | get | | | <-------------------+ | put | +-------------------------------------> | index | | data | | | | | +----------> | | | | | | +<-----------+ | v v v v v v data v
手机上如下:
2.5 Pipleline
至此,我们把之前的pipeline图进一步细化,具体如下:
+------------+ +--------+ | | | | | Process 1 | +-----> | Data 1 +--------> | +------+ | | | | Load Data | | | +--------+ | | | | +------------+ | | | | | | | +----------------+ | +------------+ | +-------------------------+ |Main process | | +--------+ | | | | pin_memory_thread | | | | | | | Process 2 | +------> +------------------------+ | | +------------+ | index_queue +----------> | Data 2 +--------> | | | | | | | | | | | | | | Load Data +-------------> | _worker_result_queue +-----> | Write to pinned memory +--------> | data_queue | | | | +--------+ | | | | | | | | +----------------+ | +------------+ +-----> | | | | +------------+ | | +------------------------+ | | | | +-------------------------+ | | | +--------+ +------------+ | | | | | | | +-----> | Data 3 +--------> | Process 3 +-------+ | | | | +--------+ | Load Data | | | +------------+
手机如下:
至此,PyTorch 分布式的数据加载部分分析完毕,下一篇我们回归到 Paracel 如何处理数据加载。
0xFF 参考
卷积神经网络的并行化模型--One weird trick for parallelizing convolutional neural networks
AI框架中数据处理的挑战与解决思路
PyTorch 源码解读之 torch.utils.data:解析数据处理全流程
谈谈你对大规模机器学习这个领域的理解和认识?
Nvidia-DALI 从放弃到入门
pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel
Pytorch数据Pipeline设计总结
深度学习框架数据Pipeline设计
这篇关于[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-29uni-app 中使用 Vant Weapp,怎么安装和配置npm ?-icode9专业技术文章分享
- 2024-12-27Nacos多环境配置学习入门
- 2024-12-27Nacos快速入门学习入门
- 2024-12-27Nacos快速入门学习入门
- 2024-12-27Nacos配置中心学习入门指南
- 2024-12-27Nacos配置中心学习入门
- 2024-12-27Nacos做项目隔离学习入门
- 2024-12-27Nacos做项目隔离学习入门
- 2024-12-27Nacos初识学习入门:轻松掌握服务发现与配置管理
- 2024-12-27Nacos初识学习入门:轻松掌握Nacos基础操作