from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
import torchvision as tv
import torch

batch = 100
DATA = torch.empty(batch, 1, 8, 8)
TARGET = torch.empty(batch, dtype = torch.int64)
for sample in range(batch):
    count = torch.randint(2, ()).item()
    image = Image.new('L', (8, 8))
    draw = ImageDraw.Draw(image)
    for index in range(count):
        draw.rectangle([(1, 1), (7, 7)], fill = 255)
    DATA[sample] = tv.transforms.functional.to_tensor(image)
    TARGET[sample] = count

GRID = tv.utils.make_grid(DATA, nrow = 10)
plt.imshow(GRID.permute(1, 2, 0))

print(TARGET, flush = True)

plt.show()
