pytorch_lightning_spells.samplers module

Classes:

SortSampler(data_source, key)

Go through the text data by order of length (longest to shortest).

SortishSampler(data_source, key, bs[, ...])

Go through the text data by order of length with a bit of randomness.

class pytorch_lightning_spells.samplers.SortSampler(data_source, key)[source]

Bases: Sampler

Go through the text data by order of length (longest to shortest).

Taken from Fast.ai library.

Parameters:
  • data_source (Iterable) – The data you want to sample from.

  • key (Callable) – A function to get keys to sort. Input: the index number of the entry in data_source.

Example

>>> data_source = [[0], [0, 1], [0, 1, 2, 3]]
>>> sampler = SortSampler(data_source, key=lambda idx: len(data_source[idx]))
>>> len(list(sampler))
3
>>> next(iter(sampler)) # the longest entry is the third one.
2
class pytorch_lightning_spells.samplers.SortishSampler(data_source, key, bs, chunk_size=100)[source]

Bases: Sampler

Go through the text data by order of length with a bit of randomness.

Returns an iterator that traverses the the data in randomly ordered batches that are approximately the same size.

The data is first randomly shuffled and then put into a number of chunks. The data in each chunk is then sorted and sliced to get batches that are approximately the same size.

The max key size batch is always returned in the first call because of pytorch cuda memory allocation sequencing.

Without that max key returned first multiple buffers may be allocated when the first created isn’t large enough to hold the next in the sequence.

Taken from Fast.ai.

Parameters:
  • data_source (Iterable) – The data you want to sample from.

  • key (Callable) – A function to get keys to sort. Input: the index number of the entry in data_source.

  • bs (int) – the batch size for the data loader

  • chunk_size (int, optional) – the number of batches one chunk contains. Defaults to 100.

Example

>>> data_source = [[0], [0, 1]] * 100
>>> sampler = SortishSampler(data_source, key=lambda idx: len(data_source[idx]), bs=2, chunk_size=2)
>>> len(list(sampler))
200
>>> len(data_source[next(iter(sampler))]) # the largest/longest batch always goes first
2