Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TdLambdaReturns for alpha_zero_torch #940

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

mattrek
Copy link
Contributor

@mattrek mattrek commented Oct 6, 2022

I've been using this locally while hacking with backgammon. Lemme know if it's worth merging into the repo. Logs showing the state values of a trajectory and the training values returned from TdLambdaReturns for different settings are pasted below.

Sorry this is such a large change - I can break it up if needed for review:

  • Changed the value stored in a Trajectory::State to match the MCTS value of the chosen action. (Note TD(0) would train to this value) Far as I could tell, Trajectory::State::value was only being used for outcome prediction/accuracy, and this usage remains valid after the change.
  • Set missing config settings by merging with defaults, so that saved config files continue to work when new settings are added.
  • Added a min_simulations setting to allow the MCTS value of a forced action more sims. This required updates to all 3 implementations of alpha_zero, as well as pybind + julia.
  • Added config settings td_lambda and td_n_steps
  • Added a (default false) verbose setting to learner() for logging how TrainInputs are populated from Trajectory.
  • Considered adding TdLambdaReturns in a new file under algorithms, but it would have required some refactoring to make Trajectory visible outside of the AlphaZero namespace, so... left for possibly later.

Examples from tic_tac_toe for various settings:

td_lambda: 0.0
td_n_steps: 0
  StateIdx: 0  Value: 0.408  TrainTo: 0.408
  StateIdx: 1  Value: 0.312  TrainTo: 0.312
  StateIdx: 2  Value: 0.152  TrainTo: 0.152
  StateIdx: 3  Value: -0.430  TrainTo: -0.430
  StateIdx: 4  Value: -0.682  TrainTo: -0.682
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 1.0
td_n_steps: 0
  StateIdx: 0  Value: 0.408  TrainTo: -1.000
  StateIdx: 1  Value: 0.312  TrainTo: -1.000
  StateIdx: 2  Value: 0.152  TrainTo: -1.000
  StateIdx: 3  Value: -0.430  TrainTo: -1.000
  StateIdx: 4  Value: -0.682  TrainTo: -1.000
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 1.0
td_n_steps: 1
  StateIdx: 0  Value: 0.408  TrainTo: 0.312
  StateIdx: 1  Value: 0.312  TrainTo: 0.152
  StateIdx: 2  Value: 0.152  TrainTo: -0.430
  StateIdx: 3  Value: -0.430  TrainTo: -0.682
  StateIdx: 4  Value: -0.682  TrainTo: -1.000
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 1.0
td_n_steps: 3
  StateIdx: 0  Value: 0.408  TrainTo: -0.430
  StateIdx: 1  Value: 0.312  TrainTo: -0.682
  StateIdx: 2  Value: 0.152  TrainTo: -1.000
  StateIdx: 3  Value: -0.430  TrainTo: -1.000
  StateIdx: 4  Value: -0.682  TrainTo: -1.000
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.1
td_n_steps: 0
  StateIdx: 0  Value: 0.408  TrainTo: 0.396
  StateIdx: 1  Value: 0.312  TrainTo: 0.290
  StateIdx: 2  Value: 0.152  TrainTo: 0.091
  StateIdx: 3  Value: -0.430  TrainTo: -0.459
  StateIdx: 4  Value: -0.682  TrainTo: -0.714
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.1
td_n_steps: 1
  StateIdx: 0  Value: 0.408  TrainTo: 0.398
  StateIdx: 1  Value: 0.312  TrainTo: 0.296
  StateIdx: 2  Value: 0.152  TrainTo: 0.094
  StateIdx: 3  Value: -0.430  TrainTo: -0.456
  StateIdx: 4  Value: -0.682  TrainTo: -0.714
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.1
td_n_steps: 3
  StateIdx: 0  Value: 0.408  TrainTo: 0.396
  StateIdx: 1  Value: 0.312  TrainTo: 0.290
  StateIdx: 2  Value: 0.152  TrainTo: 0.091
  StateIdx: 3  Value: -0.430  TrainTo: -0.459
  StateIdx: 4  Value: -0.682  TrainTo: -0.714
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.5
td_n_steps: 0
  StateIdx: 0  Value: 0.408  TrainTo: 0.222
  StateIdx: 1  Value: 0.312  TrainTo: 0.035
  StateIdx: 2  Value: 0.152  TrainTo: -0.242
  StateIdx: 3  Value: -0.430  TrainTo: -0.636
  StateIdx: 4  Value: -0.682  TrainTo: -0.841
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.5
td_n_steps: 1
  StateIdx: 0  Value: 0.408  TrainTo: 0.360
  StateIdx: 1  Value: 0.312  TrainTo: 0.232
  StateIdx: 2  Value: 0.152  TrainTo: -0.139
  StateIdx: 3  Value: -0.430  TrainTo: -0.556
  StateIdx: 4  Value: -0.682  TrainTo: -0.841
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.5
td_n_steps: 3
  StateIdx: 0  Value: 0.408  TrainTo: 0.247
  StateIdx: 1  Value: 0.312  TrainTo: 0.055
  StateIdx: 2  Value: 0.152  TrainTo: -0.242
  StateIdx: 3  Value: -0.430  TrainTo: -0.636
  StateIdx: 4  Value: -0.682  TrainTo: -0.841
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.9
td_n_steps: 0
  StateIdx: 0  Value: 0.408  TrainTo: -0.585
  StateIdx: 1  Value: 0.312  TrainTo: -0.696
  StateIdx: 2  Value: 0.152  TrainTo: -0.808
  StateIdx: 3  Value: -0.430  TrainTo: -0.914
  StateIdx: 4  Value: -0.682  TrainTo: -0.968
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.9
td_n_steps: 1
  StateIdx: 0  Value: 0.408  TrainTo: 0.321
  StateIdx: 1  Value: 0.312  TrainTo: 0.168
  StateIdx: 2  Value: 0.152  TrainTo: -0.372
  StateIdx: 3  Value: -0.430  TrainTo: -0.657
  StateIdx: 4  Value: -0.682  TrainTo: -0.968
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
td_lambda: 0.9
td_n_steps: 3
  StateIdx: 0  Value: 0.408  TrainTo: -0.233
  StateIdx: 1  Value: 0.312  TrainTo: -0.487
  StateIdx: 2  Value: 0.152  TrainTo: -0.808
  StateIdx: 3  Value: -0.430  TrainTo: -0.914
  StateIdx: 4  Value: -0.682  TrainTo: -0.968
  StateIdx: 5  Value: -1.000  TrainTo: -1.000
'''

@lanctot
Copy link
Collaborator

lanctot commented Oct 6, 2022

Very cool!!! @christianjans are you able to take a look? Appreciate the continued help, this code is seeing some good use 👍

@christianjans
Copy link
Contributor

Thanks for another addition, @mattrek!

And yes, it's my pleasure. I will likely have time this weekend or early next week to take a look at it. Happy to see the code is being used!

@mattrek
Copy link
Contributor Author

mattrek commented Oct 6, 2022

Ok, sounds good, and thanks for the kind words. @lanctot: I emailed you some questions about OpenSpiel back in May, and you mentioned in reply something like this might help with backgammon... I finally got some time to work on it, and am happy to share.

@lanctot
Copy link
Collaborator

lanctot commented Oct 7, 2022

Yeah very cool! This is a great addition. The implementation is non-trivial :) So might take a bit of time to review, but it's great to support this case... and indeed I hope it helps for Backgammon!

@christianjans
Copy link
Contributor

Just an update on my end: Took a quick look-through but will look at it more thoroughly on the weekend and publish the review.

Copy link
Contributor

@christianjans christianjans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks it looks really great. I just have a question to help my understanding.

I unfortunately do not have time right now to help verify that the implementation is correct, but I was wondering if you had any unit tests for this feature that you have already developed or can think of that can also be added to this PR?

As a related side note, when I get more time, I would like to add more unit tests for AlphaZero Torch in general.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants