import torch

def reset():
    state0 = torch.zeros(4, 4, 4)
    state0[3, 2, 0] = 1. #pit
    state0[2, 3, 0] = 1. #goal
    state0[1, 1, 1] = 1. #wall
    row0, column0 = 0, 0
    state0[0, row0, column0] = 1. #player
    return row0, column0, state0

def step(row0, column0, state0, action0):
    row1, column1, state1 = row0, column0, state0.clone()
    state1[0, row1, column1] = 0.
    if action0 == 0 and row1 + 1 < 4 and state1[1, row1 + 1, column1] == 0.:
        row1 += 1
    if action0 == 1 and 0 < row1 and state1[1, row1 - 1, column1] == 0.:
        row1 -= 1
    if action0 == 2 and column1 + 1 < 4 and state1[1, row1, column1 + 1] == 0.:
        column1 += 1
    if action0 == 3 and 0 < column1 and state1[1, row1, column1 - 1] == 0.:
        column1 -= 1
    state1[0, row1, column1] = 1.
    if state1[3, row1, column1] == 1.:
        reward1 = -10.
        done1 = True
    elif state1[2, row1, column1] == 1.:
        reward1 = +10.
        done1 = True
    else:
        reward1 = -1.
        done1 = False
    return row1, column1, state1, reward1, done1

model = torch.nn.Sequential(
    torch.nn.Flatten(-3, -1),
    torch.nn.Linear(64, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 4))

epochs = 1000
epsilon = 1.
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
    row0, column0, state0 = reset()
    gain1 = 0.
    done1 = False
    while not done1:
        quality0 = model(state0)
        action0 = torch.randint(4, ()) if torch.rand(()) < epsilon else quality0.argmax()
        row1, column1, state1, reward1, done1 = step(row0, column0, state0, action0)
        gain1 += reward1
        quality1 = model(state1)
        target0 = reward1 + (not done1) * quality1.max()
        loss = torch.nn.functional.mse_loss(quality0[action0], target0.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        row0, column0, state0 = row1, column1, state1
    print("%4d %12.3f" % (epoch, gain1), flush = True)
    epsilon = epsilon - 1. / epochs if 0.1 < epsilon else epsilon
