from matplotlib import pyplot as plt
import torch
import torchvision as tv

samples0, samples1 = 60000, 10000

source0 = tv.datasets.MNIST("../MNIST", train = True, download = True)
source1 = tv.datasets.MNIST("../MNIST", train = False, download = True)
DATA0 = source0.data.flatten(1).float().cuda() / 255.
DATA1 = source1.data.flatten(1).float().cuda() / 255.
TARGET0 = DATA0
TARGET1 = DATA1

model = torch.nn.Sequential(
    torch.nn.Linear(28 * 28, 16 * 16),
    torch.nn.ReLU(),
    torch.nn.Linear(16 * 16, 8 * 8),
    torch.nn.ReLU(),
    torch.nn.Linear(8 * 8, 4 * 4),
    torch.nn.ReLU(),
    torch.nn.Linear(4 * 4, 8 * 8),
    torch.nn.ReLU(),
    torch.nn.Linear(8 * 8, 16 * 16),
    torch.nn.ReLU(),
    torch.nn.Linear(16 * 16, 28 * 28),
    torch.nn.Sigmoid()).cuda()

batch = 1000
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
    LOSS0 = torch.zeros((), device = "cuda")
    count0 = 0
    for index in range(0, samples0, batch):
        optimizer.zero_grad()
        DATA = DATA0[index : index + batch]
        TARGET = TARGET0[index : index + batch]
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.mse_loss(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    with torch.no_grad():
        LOSS1 = torch.zeros((), device = "cuda")
        count1 = 0
        for index in range(0, samples1, batch):
            DATA = DATA1[index : index + batch]
            TARGET = TARGET1[index : index + batch]
            count = TARGET.size(0)
            ACTIVATION = model(DATA)
            LOSS = torch.nn.functional.mse_loss(ACTIVATION, TARGET)
            LOSS1 += LOSS * count
            count1 += count
        LOSS1 /= count1
    print("%4d %12.3f %12.3f" % (epoch, LOSS0, LOSS1), flush = True)

DATA = DATA1[:10]
VALUE = model(DATA).detach()
DATA = DATA.reshape(10, 1, 28, 28).cpu()
VALUE = VALUE.reshape(10, 1, 28, 28).cpu()
GRID = tv.utils.make_grid(torch.cat([DATA, VALUE]), nrow = 10)
plt.imshow(GRID.permute(1, 2, 0))

plt.show()
