数据过大时dataloader怎么设计?

数据太大无法加载?

前言:

最近笔者在跑项目的时候遇到一个场景:训练数据过大比如100G,那么是不可能全部加载到内存后训练的,那怎么办呢?

其实具体来说当数据过大时,其会导致两个问题

(a)加载数据时间过长

(b)时间长也就忍了,关键是内存会爆掉

怎么解决呢?可能想到的是边加载数据边训练,那如果我们的代码恰好又是多机多卡的呢?那我们的dataloader该怎么实现呢?

下面会涉及到一系列小碎点,我们逐步深入的看看怎么实现。

本篇涉及的知识较多,比较绕,建议收藏反复琢磨。一个基本能跑通的dataloader在文末已给出~

加载数据快起来

本节需要的是python的file.seek()和file.tell(),不熟悉的小伙伴,建议先查查相关资料,很简单~

一般来说,我们在写dataloader的时候,是会把数据全部加载进来一遍的进而导致崩掉,这里通常采用的一个trick手段是:file.seek()和file.tell()

file.seek()是会将文件指针移动到指定的地方,那么假设我们的训练数据train.txt有10000行,那么我们在写dataloader的时候就不需要将数据全部加载进来,而是通过如下:

f = open("train.txt")
with open(dataset_idx_path, 'rb') as fp:
    offsets = pickle.load(fp)
    print('loaded dataset idx', offsets[:10])

#后续通过seek不断的取行idx的数据
f.seek(offsets[idx], 0)
cur_data = f.readline()

注意看open(“train.txt”)一行基本上很快,而且不用占用很大内存,是通过seek不断移动到对应行取对应的数据,细心的读者可能已经发现这里有一个offsets列表,其实是因为f.seek是不能直接按行移动的,其只能按照字符移动,但是每一行的字符数又不一样,所以offsets列表就是记录每行之前字符数的,比如[45, 67, 89],那么以后就直接可以根据idx取出对应字符数,进而seek根据其可以直接定位到行的启始位置。

那么offsets(dataset_idx_path)怎么得到呢?

很简单

offsets = [0]
with open(data_file, "r", encoding="utf-8") as fp:
    while fp.readline() != "":
        offsets.append(fp.tell())

offsets.pop()
print(len(offsets))
with open("dataset_tmp.id", "wb") as f:
    pickle.dump(offsets, f)

这里的data_file就是你的100G的大文件,”dataset_tmp.id”就是产生的dataset_idx_path,供后续加载。

说到这里,需要注意两点

(1)数据能离线预处理的就离线预处理,比如如果使用的是bert等预训练模型,那么最好离线处理成tokenizer id即data_file中的数据是模型直接可以使用的,而且上述”dataset_tmp.id”其实也是预先离线得到好的文件。

(2)线上dataloader加载数据训练的时候其实只需要花一点时间加载”dataset_tmp.id”和open(data_file),后者基本不花时间。而且全程没有加载很大的数据到内存,因为其是不断通过文件指针找对应数据的。

通过以上方式,我们就可以基本避免因为全部加载数据而奔溃掉的问题。是不是有点时间(文件指针不断找)换空间(内存不爆掉)的味道~~

所以总的来说这块解决了内存爆掉已经加载过长的问题。

GPU利用率低

本节需要的是python的multiprocessing包,不熟悉的小伙伴,建议先看看相关的教程,也不是很难,熟悉最基本的API即可,比如熟悉Process,Manager,Lock,Pool概念。

如果我们的数据很大的时候,使用上述方式,虽然可以基本跑起来了,但是会发现GPU的利用率还是不够高,或者说波动很大,那是为什么呢?

其实是GPU在等待数据从CPU传输过来,GPU处理速度很多,马上结束后,会等待下一个batch的传入。之所以要等是因为文件指针要花时间不断的找对应数据,这是需要换时间的。

所以我们为了空间,牺牲掉了时间,但是时间又影响了GPU利用率。

那怎么解决呢?多进程!!!

既然找的慢,那我们多开几个进程来找,这样不就快起来了吗?注意不是使用的多线程哦,多线程本质上还是时间切片,不能真真意义上的加速。

另外一个问题是怎么实现一边装数据一边训练呢?那就是创建一个缓冲buffer。一个进程是不断往buffer装数据,另外一个是不断取数据去训练,两则互不影响。也就是说你在训练的时候,我还在不停的往buffer里面装数据以至于你想用的时候可以直接取就行。

下面我们再具体来理一下这里的实现逻辑:

因为要用多线程,所以我们这里可以申请两个全局buff列表即multiprocessing.manager.list(),这里简单起名为

buff_single和buff_batch,下面我们来定义一些变量,比较复杂,大家可以花点时间理一下,

装数据的进程不断的seek到一个个样本数据后装到buff_single列表。因为训练的数据进程取数据的时候希望是一个batch一个batch的取,所以我们还需要将buffsingle列表的数据组装成一个一个batch形式放到buffer_batch列表中,这个函数就暂时叫做fill_batch吧。

假设我们buff_single的大小设置为25600即最多装25600个样本。同时我们在装数据的时候(本质上是上述第一节seek的逻辑)希望又是一个多进程,这里假设用5个进程吧即num_process=5。

这里简单提示一下:看到这里很多人可能有个疑问,装数据和取数据本身就是两个进程?那装数据又是多个进程?一边装一边取的关键trick是怎么实现的呢?下面我们来看:

(1)我们首先来看取数据的逻辑

class DataLoader(object):
    def __iter__(self):
         if self.buffer_thread is None:
             self._fill_buf()
         while True:
             if len(self.buff_batch):
                 yield self.buffer_batch.pop(0)

通常来说我们声明了一个dataloader后,采用迭代器来不断的取即

train_dataloader = DataLoader()
train_dataloader = iter(train_dataloader)
while True:
    if step == total_step:
        break
    data = next(train_dataloader)

这里构造出来的train_dataloader迭代器,其实就是_iter_函数。

所以我们重点来看DataLoader这个类,可以看到_iter_函数是一个while True的逻辑,也就是说不断循环,而且不会停,即就是不断从self.buffer_batch中取batch,即使self.buffer_batch没有数据了,其还是会不断在这里循环等,直到有了数据后yield返回。这就是取数据的逻辑。很简单吧

那装数据的逻辑在那里呢?那就是_fill_buf这个函数,可以看到其实在第一次从train_dataloader这个迭代器中取数据的时候,也即第一次调用这个_iter_函数的时候,其实是先会进行_fill_buf这个函数的,它就是会开启多个进程不断的向self.buffer_batch装数据,特别需要注意,multiprocessing开启进程后,就是不断的自己去运行了,程序会接着往下走,也就是说_fill_buf相当于开启了后台进程,主程序继续往下走即while True这里。

上面这段话主要就是想解释怎么实现的一边装数据、一边取数据。大家可以多理解理解,说简单其实也很简单,就是_fill_buf在后台开启几个进程,往self.buffer_batch里面不断装数据。外部取数据就是不断从self.buffer_batch取,没有的话就是不停的while True进行循环直到self.buffer_batch里面装了数据就行,不需要多,只有有一个数据装进来就可以啦。

(2)接着我们来看怎么实现多进程装数据

class DataLoader(object):
    def __init__(self, num_process=1):
        self.buffer_thread = None
        self.num_process = num_process
    def _fill_buf(self):
         if self.buffer_thread is None:
             self.buffer_thread = []
             for process in range(self.num_process):
                 buffer_thread = Process(target=self.buf_thread, args=(process, self.num_process))
                 buffer_thread.start()
                 self.buffer_thread.append(buffer_thread)
    def __iter__(self):
         if self.buffer_thread is None:
             self._fill_buf()
         while True:
             if len(self.buff_batch):
                 yield self.buffer_batch.pop(0)

可以看到我们这里就是很简单的开启了多个进程,即开启了self.num_process个进程运行self.buf_thread程序,然后就是start,其就在后台不断的运行啦,为什么说不断?是的就是不断!self.buf_thread函数是不会停的,其就是不断的装数据,一遍又一遍的循环数据,即达到一个epoch后,再重新进行取新一轮的epoch。

这里再理一下这里的逻辑:

大家可以看到不论是取数据还是装数据,其实内部都是一个while True的逻辑即不断循环,进而实现不断的取不断的装的逻辑,这里弱化了epoch的概念,是不断的取数据流,重复一遍一遍的取,那程序整个结束的标志在哪里呢?其实是在外面的,即在训练流程(看上面):

if step == total_step:
    break

所以看的是全程是step来控制,没有epoch,当然了有了step自己推一下epoch是很轻松的事。

好啦,言归正传,现在回到正题,目前最重要的应该是self.buf_thread这个函数了,其的作用就是不断的往self.buffer_batch里面装batch数据,当然了self.buffer_batch是一个全局变量,因为我们是开启了多个进程同时都往self.buffer_batch里面装数据的。下面我们看看buf_thread的实现

from multiprocessing import Process, Manager, Lock

class DataLoader(object):
    def __init__(self, args, num_process=1):
        self.buffer_thread = None
        self.dataset_idx_path = args.dataset_idx_path
        self.dataset_path = args.dataset_idx_path
        self.instances_buffer_size = args.instances_buffer_size
        self.batch_size = args.batch_size
        self.num_process = num_process
        self.buffer_single = manager.list()
        self.buffer_batch = manager.list()

    def _fill_buf(self):
        if self.buffer_thread is None:
            self.buffer_thread = []
            for process in range(self.num_process):
                buffer_thread = Process(target=self.buf_thread, args=(process, self.num_process))
                buffer_thread.start()
                self.buffer_thread.append(buffer_thread)

    def buf_thread(self, process, num_process):
        print('=========start buf thread')
        read_count = 0
        while True:
            with open(self.dataset_idx_path, 'rb') as f:
                dataset_idx_list = pickle.load(f)
            f_read = open(self.dataset_path, "rb")
            num_data = len(dataset_idx_list)
            while True:
                #self.buffer_single装满啦,等等训练取走一些数据再装
                if len(self.buffer_single) >= self.instances_buffer_size:
                    max_batch_buffer = max(self.instances_buffer_size / self.batch_size, 256)
                    if len(self.buffer_batch) < max_batch_buffer:
                        self._fill_batch()
                    else:
                        time.sleep(0.1)
                        continue
                #装到当前epcoh最后一个数据了
                if read_count >= num_data:
                    break
                read_count += 1
                #多个进程互不影响,确保不重复装数据
                if read_count % num_process != process:
                    continue
                start_idx = dataset_idx_list[read_count]
                #一个个装对应位置的数据
                f_read.seek(start_idx, 0)
                self.buffer_single.append(f_read.readline())
            f_read.close()
            read_count = 0

    def __iter__(self):
        if self.buffer_thread is None:
            self._fill_buf()
        while True:
            if len(self.buffer_batch):
                yield self.buffer_batch.pop(0)

首先可以看到有两个while True,最外面那个确保的是不断的一遍一遍的取整个数据集。最里面那个就是一个个取装到buffer,首先我们取一个个的是先装到self.buffer_single,当装慢时(>=self.instances_buffer_size),就开始组batch,即调用的是_fill_batch()这个函数(上面说过),说白了起就是将self.buffer_single里面的单个样本组合成batch的形式放到另一个buffer中即self.buffer_batch供程序取数据进行训练。

需要关注的是:

if read_count % num_process != process:
    continue

这里本质上是为了确保不同进程装不同的数据,进而避免装重复了。

if len(self.buffer_batch) < max_batch_buffer:
    self._fill_batch()
else:
    time.sleep(0.1)
    continue

同时需要注意一点的是文件的读取等都要放到buf_thread子函数中,比如不能把 f_read

 f_read = open(self.dataset_path, "rb")

作为一个类变量,因为多进程是可以抢的,假设当成一个类全局变量,那么可能遇到这种情况:进程1刚将 f_read seek到自己想要的位置,下一步马上就要f_read.readline()了,结果另外一个进程抢先又seek了一下,所以最后结果就是乱的,所以每个进程都自己open一个数据流f_read避免乱。

下面我们再看组bacth的函数_fill_batch

import numpy as np
from multiprocessing import Process, Manager, Lock

class DataLoader(object):
    def __init__(self, args, num_process=1):
        self.buffer_thread = None
        self.dataset_idx_path = args.dataset_idx_path
        self.dataset_path = args.dataset_idx_path
        self.instances_buffer_size = args.instances_buffer_size
        self.batch_size = args.batch_size
        self.num_process = num_process
        self.buffer_single = manager.list()
        self.buffer_batch = manager.list()
        self.buffer_lock = Lock()

    def _fill_buf(self):
        if self.buffer_thread is None:
            self.buffer_thread = []
            for process in range(self.num_process):
                buffer_thread = Process(target=self.buf_thread, args=(process, self.num_process))
                buffer_thread.start()
                self.buffer_thread.append(buffer_thread)

    def _fill_batch(self):
        while len(self.buffer_single) > self.instances_buffer_size * 0.75:
            self.buffer_lock.acquire()
            num_data = len(self.buffer_single)
            batch_idx = np.random.choice(num_data, self.batch_size, replace=num_data < self.batch_size)
            batch_idx.sort()
            instances = [self.buffer_single.pop(i) for i in batch_idx[::-1]]
            self.buffer_lock.release()
            for ins in instances:
                self.buffer_batch.append(ins)

    def buf_thread(self, process, num_process):
        print('=========start buf thread')
        read_count = 0
        while True:
            with open(self.dataset_idx_path, 'rb') as f:
                dataset_idx_list = pickle.load(f)
            f_read = open(self.dataset_path, "rb")
            num_data = len(dataset_idx_list)
            while True:
                #self.buffer_single装满啦,等等训练取走一些数据再装
                if len(self.buffer_single) >= self.instances_buffer_size:
                    max_batch_buffer = max(self.instances_buffer_size / self.batch_size, 256)
                    if len(self.buffer_batch) < max_batch_buffer:
                        self._fill_batch()
                    else:
                        time.sleep(0.1)
                        continue
                #装到当前epcoh最后一个数据了
                if read_count >= num_data:
                    break
                read_count += 1
                #多个进程互不影响,确保不重复装数据
                if read_count % num_process != process:  # skip
                    continue
                start_idx = dataset_idx_list[read_count]
                #一个个装对应位置的数据
                f_read.seek(start_idx, 0)
                self.buffer_single.append(f_read.readline())
            f_read.close()
            read_count = 0

    def __iter__(self):
        if self.buffer_thread is None:
            self._fill_buf()
        while True:
            if len(self.buffer_batch):
                yield self.buffer_batch.pop(0)

可以看到_fill_batch其实就是随机取self.buffer_single这个buffer中的一批数据来做batch放到self.buffer_batch中,并且把对应的用掉的数据从self.buffer_single中pop出来,小插曲:为啥这里要用sort呢?其实是由于python的pop函数导致的,我们是按index从大到小pop,这样就不会乱,不然pop的过程中,self.buffer_single的大小是变化的,不从大到小确保不了pop出对应位置的数据。

注意这里使用了multiprocessing的Lock,因为所有进程都是可能进到_fill_batch这个函数进行操作的,那大家都可能会同时pop self.buffer_single这个数据流,那 self.buffer_single的长度都是变化的,进而导致都是乱的,你pop一下,另外一个进程冷不丁的pop一下,所以这里用了Lock即同一时间只能有一个进程进行这里的操作。

当然了对于往self.buffer_single和 self.buffer_batch装数据这一过程(append)都没使用Lock,那是因为不需要,大家同时合力装就可以啦,你一下我一下,顺序无所谓。

多机多卡

本节需要的知识是多机多卡的基本概念,比如:local_rank、world_size、rank,不需要的小伙伴可以先查阅一下相关知识。

我们同时还想在单机多卡上面,甚至多机多卡上面跑,那这里又有什么变化呢?

试想使用了多卡后,对于dataloder我们本质要解决的问题是什么呢?

那就是不同卡应该加载不同的数据,之间不能有重复,说白了,如果我们可以把数据分段,假设现在一共8张卡,那我们把数据分成8段,只让对应卡加载对应的数据段上面的数据不就可以啦。

那怎么实现呢?非常非常简单

这里我们来说最复杂的场景:多机多卡

那在多机多卡的场景中每一张卡的全局ID可以使用rank,那每一张卡对应的自己数据片段的全部数据就应该是:

to_read_idx = dataset_idx_list[rank:: world_size]

即在上面dataloder中补上这个逻辑就可以啦。

当然了pytorch也提供了torch.utils.data.distributed.DistributedSampler这个API,其会自动取对应自己数据片段的index,而不会去取别人的数据,本质上和上面的实现逻辑一样。如果懒得写,可以直接使用torch.utils.data.distributed.DistributedSampler。

那如果是单机多卡呢?没有rank和world_size,其只有一个local_rank参数。其实更好办了,我们只需要理解了其实我们需要两个参数,一个是当前卡的ID,另外一个是总卡数,那么就可以通过如下代码实现唯一切片数据

to_read_idx = dataset_idx_list[当前卡的ID:: 总卡数]

所以对应到单机多卡便可以这样写:

to_read_idx = dataset_idx_list[duo]

dev dataloader怎么设计?

最后再说一个大家可能遇到的问题,就是在训练的时候,有可能需要在固定步数看一下验证集上的效果,那必然需要一个dev dataloader,可能有人会说,这有啥,就用上面的就行,可是这里会有一些问题,大家可以停几分钟想一下直接用上面的会有什么问题?

好,下面我们来说一下,不知道大家注意到没有,之前我们也说了,上面的dataloader是一直处于while True的逻辑的,不会停,循环装数据,而我们的dev dataloader是希望过一遍dev数据即可,所以适当的stop就尤为重要。

这时候可能有的小伙伴会说,计算出对应的最后一个数据stop就行啦,但是这里涉及到多进程,要等所有进程结束,所以这里又是一个比较复杂的逻辑实现,注意train dataloader其实通过弱化epoch,使用step不断取的过程简化了该步骤。

当然了这不是什么难事,大家看buf_thread函数最里面while True这个函数的break,其实就是一个epoch的结束。

这里将train dataloder和dev dataloader统一写出一个dataloder,为了简单,我们dev dataloader可以写成之前的最基本的dataloader形式,主要基于以下两个原因:

(a)通常来说dev数据不大,不太需要多进程以及边加载边取数据

(b)在多卡训练场景,dev数据可以只在一张固定的卡上进行(比如卡0),所以也不会涉及到多机多卡一节中说到的数据切片,直接在一张卡上过所有数据即可。

总结

看完后最好是将其逻辑应用到自己的当前dataloder,代码不一定要按上面写,可以结合自己的场景随意改动,甚至可以用到tensorflow中,所以理解其逻辑最重要,如果能跑通,恭喜你,说明理解了所有逻辑。

最后给一份笔者目前整理的一个能跑通的dataloder,欢迎star~,后续有时间会持续优化一下相关代码支持更多的逻辑比如train dataloder不想用边装边取,只想普通的等等。

GitHub – Mryangkaitong/big_dataloaderContribute to Mryangkaitong/big_dataloader development by creating an account on GitHub.数据过大时dataloader怎么设计?https://github.com/Mryangkaitong/big_dataloader

彩蛋

文章开头我们介绍了要尽可能的将我们的数据预处理好,那当我们的数据非常大的时候,势必要用到一些大数据工具比如spark,haddop等等,下面一篇我们会简单介绍一下pyspark基本dmeo,供大家快速上手~

关注

欢迎关注,下期再见啦~

欢迎关注笔者微信公众号:

数据过大时dataloader怎么设计?

github:

https://github.com/Mryangkaitong​github.com

知乎:

小小梦想 – 知乎ML/NLP研究员,欢迎关注微信公众号“算法让生活更美好” 回答数 111,获得 85 次赞同数据过大时dataloader怎么设计?https://www.zhihu.com/people/sa-tuo-de-yisheng/posts

版权声明:本文为博主weixin_42001089原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/weixin_42001089/article/details/122641298

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年1月22日 下午6:21
下一篇 2022年1月22日 下午6:53

相关推荐