Update: Here https://colab.research.google.com/drive/1VS6-dYk3YAzoRmALhgTK7bb2_tBPrB4c?usp=sharing是一个小型协作笔记本,用于演示这个答案。
想象一下,您有一个数据集:[1, 2, 3, 4, 5, 6]
, then:
ds.shuffle() 的工作原理
dataset.shuffle(buffer_size=3)
将分配一个大小为 3 的缓冲区来选择随机条目。该缓冲区将连接到源数据集。
我们可以这样想象:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
我们假设该条目2
是从随机缓冲区中取出的。可用空间由源缓冲区中的下一个元素填充,即4
:
2 <= [1,3,4] <= [5,6]
我们继续阅读,直到什么都没有剩下:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
ds.repeat() 的工作原理
一旦从数据集中读取了所有条目并且您尝试读取下一个元素,数据集就会抛出错误。
那就是那里ds.repeat()
发挥作用。它将重新初始化数据集,使其再次如下所示:
[1,2,3] <= [4,5,6]
ds.batch() 会产生什么
The ds.batch()
将采取第一个batch_size
条目并从中制作一批。因此,我们的示例数据集的批量大小为 3 将生成两个批量记录:
[2,1,5]
[3,6,4]
因为我们有一个ds.repeat()
在批量之前,数据的生成将继续。但元素的顺序会有所不同,因为ds.random()
。应该考虑的是6
由于随机缓冲区的大小,永远不会出现在第一批中。