import torch
import torchvision as tv
import torchsummary

samples0, samples1 = 60000, 10000

source0 = tv.datasets.MNIST("../MNIST", train = True, download = True)
source1 = tv.datasets.MNIST("../MNIST", train = False, download = True)
DATA0 = source0.data.unsqueeze(1).float().cuda()
DATA1 = source1.data.unsqueeze(1).float().cuda()
TARGET0 = source0.targets.cuda()
TARGET1 = source1.targets.cuda()

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, 5), #24
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2), #12
    torch.nn.Flatten(),
    torch.nn.Linear(8 * 12 * 12, 10)).cuda()
variables = model.parameters()

torchsummary.summary(model, input_size = DATA0.shape[1:])

batch = 1000
optimizer = torch.optim.Adam(variables, lr = 0.0001)
for epoch in range(100):
    LOSS0 = torch.zeros((), device = "cuda")
    ACCURACY0 = torch.zeros((), device = "cuda")
    count0 = 0
    for index in range(0, samples0, batch):
        optimizer.zero_grad()
        DATA = DATA0[index : index + batch]
        TARGET = TARGET0[index : index + batch]
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        VALUE = torch.argmax(ACTIVATION, 1)
        ACCURACY0 += torch.sum(VALUE == TARGET)
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    ACCURACY0 /= count0
    with torch.no_grad():
        LOSS1 = torch.zeros((), device = "cuda")
        ACCURACY1 = torch.zeros((), device = "cuda")
        count1 = 0
        for index in range(0, samples1, batch):
            DATA = DATA1[index : index + batch]
            TARGET = TARGET1[index : index + batch]
            ACTIVATION = model(DATA)
            LOSS1 += torch.nn.functional.cross_entropy(ACTIVATION, TARGET, reduction = "sum")
            VALUE = torch.argmax(ACTIVATION, 1)
            ACCURACY1 += torch.sum(VALUE == TARGET)
            count1 += TARGET.size(0)
        LOSS1 /= count1
        ACCURACY1 /= count1
    print("%5d %12.3f %4.3f %12.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, LOSS1, ACCURACY1), flush = True)

#parameters: 11 738
#slow convergence
