import collections
import gym
import random
import torch

env = gym.make("CartPole-v1")

model = torch.nn.Sequential(
    torch.nn.Linear(4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 2))

epochs = 1000
epsilon = 0.3
replay = collections.deque(maxlen = 1000)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    gain1 = 0.
    done1 = False
    while not done1:
        quality0 = model(state0)
        action0 = torch.randint(2, ()) if torch.rand(()) < epsilon else quality0.argmax()
        state1, reward1, done1, info = env.step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        gain1 += reward1
        terminal1 = done1 and step1 < 500
        replay.append((state0, action0, reward1, state1, terminal1))
        if len(replay) > 100:
            batch = random.sample(replay, 100)
            STATE0 = torch.stack([state0 for (state0, action0, reward1, state1, terminal1) in batch])
            ACTION0 = torch.tensor([action0 for (state0, action0, reward1, state1, terminal1) in batch])
            REWARD1 = torch.tensor([reward1 for (state0, sction0, reward1, state1, terminal1) in batch])
            STATE1 = torch.stack([state1 for (state0, action0, reward1, state1, terminal1) in batch])
            TERMINAL1 = torch.tensor([terminal1 for (state0, action0, reward1, state1, terminal1) in batch])
            QUALITY0 = model(STATE0)
            QUALITY1 = model(STATE1)
            TARGET0 = REWARD1 + 0.99 * TERMINAL1.logical_not() * QUALITY1.max(1)[0]
            loss = torch.nn.functional.mse_loss(QUALITY0[torch.arange(100), ACTION0], TARGET0.detach())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        state0 = state1
    print("%4d %4d %12.3f" % (epoch, step1, gain1), flush = True)
    epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon

torch.save(model.state_dict(), "replay10.pt")

epochs = 1000
stats = torch.empty(epochs)
for epoch in range(epochs):
    state0 = torch.tensor(env.reset(), dtype = torch.float)
    step1 = 0
    done1 = False
    while not done1:
        quality0 = model(state0)
        action0 = quality0.argmax()
        state1, reward1, done1, info = env.step(action0.item())
        state1 = torch.tensor(state1, dtype = torch.float)
        step1 += 1
        state0 = state1
    stats[epoch] = step1
print(stats.mean().item(), stats.std().item())

env.close()

#a: 426+-19
