import torch
import torchvision as tv

samples0, samples1 = 60000, 10000

source0 = tv.datasets.MNIST("../MNIST", train = True, download = True)
source1 = tv.datasets.MNIST("../MNIST", train = False, download = True)
DATA0 = source0.data.unsqueeze(1).float()
DATA1 = source1.data.unsqueeze(1).float()
TARGET0 = source0.targets
TARGET1 = source1.targets

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv0 = torch.nn.Conv2d(1, 8, 5) #24
        self.relu0 = torch.nn.ReLU()
        self.pool0 = torch.nn.MaxPool2d(2) #12
        self.conv1 = torch.nn.Conv2d(8, 16, 5) #8
        self.relu1 = torch.nn.ReLU()
        self.pool1 = torch.nn.MaxPool2d(2) #4
        self.flat1 = torch.nn.Flatten()
        self.line2 = torch.nn.Linear(16 * 4 * 4, 128)
        self.relu2 = torch.nn.ReLU()
        self.line3 = torch.nn.Linear(128, 64)
        self.relu3 = torch.nn.ReLU()
        self.line4 = torch.nn.Linear(64, 10)
    def forward(self, SIGNAL):
        SIGNAL = self.conv0(SIGNAL)
        SIGNAL = self.relu0(SIGNAL)
        SIGNAL = self.pool0(SIGNAL)
        SIGNAL = self.conv1(SIGNAL)
        SIGNAL = self.relu1(SIGNAL)
        SIGNAL = self.pool1(SIGNAL)
        SIGNAL = self.flat1(SIGNAL)
        SIGNAL = self.line2(SIGNAL)
        SIGNAL = self.relu2(SIGNAL)
        SIGNAL = self.line3(SIGNAL)
        SIGNAL = self.relu3(SIGNAL)
        SIGNAL = self.line4(SIGNAL)
        return SIGNAL

model = Model()
model.load_state_dict(torch.load("style.pt"))

CONTENTS = DATA0[:1]
SIGNALC = model(CONTENTS)
print(SIGNALC)
