from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
import torch
import torchvision as tv

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 4, 7), #2
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #1
    torch.nn.Conv2d(4, 2, 1))
model.load_state_dict(torch.load("detect.pt"))

count = torch.randint(8, ()).item()
POSITIONS = torch.randint(64, (count, 2))
image = Image.new('L', (64, 64))
draw = ImageDraw.Draw(image)
for index in range(count):
    draw.rectangle([tuple(POSITIONS[index] - 3), tuple(POSITIONS[index] + 3)], fill = 255)
DATA = tv.transforms.functional.to_tensor(image).unsqueeze(0)

plt.imshow(DATA[0, 0])

ACTIVATION = model(DATA)
print(ACTIVATION.shape, flush = True)

plt.show()
