import gym
import torch

size = 64

env = [gym.make("CartPole-v1") for index in range(size)]

model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))

optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
epochs = 1000
epsilon = 1.
epoch = 0
state0 = [torch.tensor(env[index].reset(), dtype = torch.float) for index in range(size)]
step1 = [0] * size
gain1 = [0.] * size
done1 = [False] * size
while epoch < epochs:
    loss = torch.tensor(0.)
    for index in range(size):
        quality0 = model(state0[index])
        action0 = torch.randint(2, ()) if torch.rand(()) < epsilon else quality0.argmax()
        state1, reward1, done1[index], info = env[index].step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1[index] += 1
        gain1[index] += reward1
        terminal1 = done1[index] and step1[index] < 500
        quality1 = model(state1)
        target0 = reward1 + 0.99 * (not terminal1) * quality1.max()
        loss += torch.square(quality0[action0] - target0.detach())
        if done1[index]:
            print("%4d %4d %12.3f" % (epoch, step1[index], gain1[index]), flush = True)
            state0[index] = torch.tensor(env[index].reset(), dtype = torch.float)
            step1[index] = 0
            gain1[index] = 0.
            done1[index] = False
            epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon
            epoch += 1
        else:
            state0[index] = state1
    loss /= size
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

for index in range(size):
    env[index].close()

torch.save(model.state_dict(), "multi10.pt")

env = gym.make("CartPole-v1")

epochs = 1000
stats = torch.empty(epochs)
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    done1 = False
    while not done1:
        quality0 = model(state0)
        action0 = quality0.argmax()
        state1, reward1, done1, info = env.step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        state0 = state1
    stats[epoch] = step1
print(stats.mean().item(), stats.std().item())

env.close()

#a: 500+-00
#b: 067+-06
