from matplotlib import pyplot as plt
import torch
import torchvision as tv

set0 = tv.datasets.STL10("../STL10", split = "train", download = True,
                         transform = tv.transforms.ToTensor())
set1 = tv.datasets.STL10("../STL10", split = "test", download = True,
                         transform = tv.transforms.ToTensor())

print(len(set0), len(set1), flush = True)

DATA, target = set0[0]
print(DATA.dtype, DATA.shape, flush = True)

loader0 = torch.utils.data.DataLoader(set0, batch_size = 25)
for DATA, TARGET in loader0:
    GRID = tv.utils.make_grid(DATA, nrow = 5)
    plt.imshow(GRID.permute(1, 2, 0))
    print(TARGET, flush = True)
    plt.show()
    input()
