import torch

torch.set_printoptions(precision = 2, linewidth = 120, sci_mode = False)

states = 10
actions = 5
rewards = 2

system = torch.zeros(states, rewards, states, actions)
for action in range(actions):
    for state in range(states):
        system[(action + state) % states, 1, state, action] = 1.
system[:, :, 0, :] = 0.
system[0, 0, 0, :] = 1.

model = torch.nn.Sequential(
    torch.nn.Linear(states, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, actions))

optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
for epoch in range(10000):
    state0 = torch.randint(states, ())
    data0 = torch.eq(state0, torch.arange(states)).float()
    quality0 = model(data0)
    action0 = torch.randint(actions, ())
    index1 = torch.multinomial(system[:, :, state0, action0].flatten(), 1).squeeze()
    state1, reward1 = index1 // rewards, -(index1 % rewards)
    if state1 == 0:
        quality1 = torch.zeros(actions)
    else:
        data1 = torch.eq(state1, torch.arange(states)).float()
        quality1 = model(data1)
    target0 = reward1 + quality1.max()
    loss = torch.nn.functional.mse_loss(quality0[action0], target0.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("%4d %12.3f" % (epoch, loss), flush = True)

data = torch.eye(states)
quality = model(data)
action = quality.argmax(1)
print(quality)
print(action)
