Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TdLambdaReturns for alpha_zero_torch #940

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions open_spiel/algorithms/alpha_zero/alpha_zero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ std::unique_ptr<MCTSBot> InitAZBot(
game,
std::move(evaluator),
config.uct_c,
config.min_simulations,
config.max_simulations,
/*max_memory_mb=*/ 10,
/*solve=*/ false,
Expand Down Expand Up @@ -231,6 +232,7 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
game,
rand_evaluator,
config.uct_c,
/*min_simulations=*/0,
rand_max_simulations,
/*max_memory_mb=*/1000,
/*solve=*/true,
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/algorithms/alpha_zero/alpha_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct AlphaZeroConfig {
int evaluation_window;

double uct_c;
int min_simulations;
int max_simulations;
double policy_alpha;
double policy_epsilon;
Expand Down Expand Up @@ -74,6 +75,7 @@ struct AlphaZeroConfig {
{"checkpoint_freq", checkpoint_freq},
{"evaluation_window", evaluation_window},
{"uct_c", uct_c},
{"min_simulations", min_simulations},
{"max_simulations", max_simulations},
{"policy_alpha", policy_alpha},
{"policy_epsilon", policy_epsilon},
Expand Down
98 changes: 82 additions & 16 deletions open_spiel/algorithms/alpha_zero_torch/alpha_zero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,30 +125,43 @@ Trajectory PlayGame(Logger* logger, int game_num, const open_spiel::Game& game,
std::pow(c.explore_count, 1.0 / temperature));
}
NormalizePolicy(&policy);
open_spiel::Action action;
const SearchNode* action_node;
if (history.size() >= temperature_drop) {
action = root->BestChild().action;
action_node = &root->BestChild();
} else {
open_spiel::Action action;
action = open_spiel::SampleAction(policy, *rng).first;
for (const SearchNode& child : root->children) {
if (child.action == action) {
action_node = &child;
break;
}
}
}

double root_value = root->total_reward / root->explore_count;
double action_value =
action_node->outcome.empty()
? (action_node->total_reward / action_node->explore_count)
* (action_node->player == player ? 1 : -1)
: action_node->outcome[player];
trajectory.states.push_back(Trajectory::State{
state->ObservationTensor(), player, state->LegalActions(), action,
std::move(policy), root_value});
std::string action_str = state->ActionToString(player, action);
state->ObservationTensor(), player, state->LegalActions(),
action_node->action, std::move(policy), action_value});
std::string action_str =
state->ActionToString(player, action_node->action);
history.push_back(action_str);
state->ApplyAction(action);
state->ApplyAction(action_node->action);
if (verbose) {
logger->Print("Player: %d, action: %s", player, action_str);
logger->Print("Player: %d, action: %s, value: %6.3f",
player, action_str, action_value);
}
if (state->IsTerminal()) {
trajectory.returns = state->Returns();
break;
} else if (std::abs(root_value) > cutoff_value) {
} else if (std::abs(action_value) > cutoff_value) {
trajectory.returns.resize(2);
trajectory.returns[player] = root_value;
trajectory.returns[1 - player] = -root_value;
trajectory.returns[player] = action_value;
trajectory.returns[1 - player] = -action_value;
break;
}
}
Expand All @@ -165,7 +178,8 @@ std::unique_ptr<MCTSBot> InitAZBot(const AlphaZeroConfig& config,
std::shared_ptr<Evaluator> evaluator,
bool evaluation) {
return std::make_unique<MCTSBot>(
game, std::move(evaluator), config.uct_c, config.max_simulations,
game, std::move(evaluator), config.uct_c,
config.min_simulations, config.max_simulations,
/*max_memory_mb=*/10,
/*solve=*/false,
/*seed=*/0,
Expand Down Expand Up @@ -269,7 +283,8 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
bots.reserve(2);
bots.push_back(InitAZBot(config, game, vp_eval, true));
bots.push_back(std::make_unique<MCTSBot>(
game, rand_evaluator, config.uct_c, rand_max_simulations,
game, rand_evaluator, config.uct_c,
/*min_simulations=*/0, rand_max_simulations,
/*max_memory_mb=*/1000,
/*solve=*/true,
/*seed=*/num * 1000 + game_num,
Expand All @@ -295,12 +310,54 @@ void evaluator(const open_spiel::Game& game, const AlphaZeroConfig& config,
logger.Print("Got a quit.");
}

// Returns the 'lambda' discounted value of all future values of 'trajectory',
// including its outcome, beginning at 'state_idx'. The calculation is
// truncated after 'td_n_steps' if that parameter is greater than zero.
double TdLambdaReturns(const Trajectory& trajectory, int state_idx,
double td_lambda, int td_n_steps) {
double outcome = trajectory.returns[0];
if (td_lambda >= 1.0 || Near(td_lambda, 1.0)) {
// lambda == 1.0 simplifies to returning the outcome (or value at nth-step)
if (td_n_steps <= 0) {
return outcome;
}
int idx = state_idx + td_n_steps;
if (idx >= trajectory.states.size()) {
return outcome;
}
const Trajectory::State& n_state = trajectory.states[idx];
return n_state.value * (n_state.current_player == 0 ? 1 : -1);
}
const Trajectory::State& s_state = trajectory.states[state_idx];
double retval = s_state.value * (s_state.current_player == 0 ? 1 : -1);
if (td_lambda <= 0.0 || Near(td_lambda, 0.0)) {
// lambda == 0 simplifies to returning the start state's value
return retval;
}
double lambda_inv = (1.0 - td_lambda);
double lambda_pow = td_lambda;
retval *= lambda_inv;
for (int i = state_idx + 1; i < trajectory.states.size(); ++i) {
const Trajectory::State& i_state = trajectory.states[i];
double value = i_state.value * (i_state.current_player == 0 ? 1 : -1);
if (td_n_steps > 0 && i == state_idx + td_n_steps) {
retval += lambda_pow * value;
return retval;
}
retval += lambda_inv * lambda_pow * value;
lambda_pow *= td_lambda;
}
retval += lambda_pow * outcome;
return retval;
}

void learner(const open_spiel::Game& game, const AlphaZeroConfig& config,
DeviceManager* device_manager,
std::shared_ptr<VPNetEvaluator> eval,
ThreadedQueue<Trajectory>* trajectory_queue,
EvalResults* eval_results, StopToken* stop,
const StartInfo& start_info) {
const StartInfo& start_info,
bool verbose = false) {
FileLogger logger(config.path, "learner", "a");
DataLoggerJsonLines data_logger(
config.path, "learner", true, "a", start_info.start_time);
Expand Down Expand Up @@ -357,10 +414,19 @@ void learner(const open_spiel::Game& game, const AlphaZeroConfig& config,
double p1_outcome = trajectory->returns[0];
outcomes.Add(p1_outcome > 0 ? 0 : (p1_outcome < 0 ? 1 : 2));

for (const Trajectory::State& state : trajectory->states) {
for (int i = 0; i < trajectory->states.size(); ++i ) {
const Trajectory::State& state = trajectory->states[i];
double value = TdLambdaReturns(*trajectory, i,
config.td_lambda, config.td_n_steps);
replay_buffer.Add(VPNetModel::TrainInputs{state.legal_actions,
state.observation,
state.policy, p1_outcome});
state.policy,
value});
if (verbose && num_trajectories == 1) {
double v = state.value * (state.current_player == 0 ? 1 : -1);
logger.Print("StateIdx: %d Value: %0.3f TrainTo: %0.3f",
i, v, value);
}
num_states += 1;
}

Expand Down
75 changes: 44 additions & 31 deletions open_spiel/algorithms/alpha_zero_torch/alpha_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ struct AlphaZeroConfig {
int evaluation_window;

double uct_c;
int min_simulations;
int max_simulations;
double policy_alpha;
double policy_epsilon;
double temperature;
double temperature_drop;
double cutoff_probability;
double cutoff_value;
double td_lambda;
int td_n_steps;

int actors;
int evaluators;
Expand Down Expand Up @@ -83,51 +86,61 @@ struct AlphaZeroConfig {
{"checkpoint_freq", checkpoint_freq},
{"evaluation_window", evaluation_window},
{"uct_c", uct_c},
{"min_simulations", min_simulations},
{"max_simulations", max_simulations},
{"policy_alpha", policy_alpha},
{"policy_epsilon", policy_epsilon},
{"temperature", temperature},
{"temperature_drop", temperature_drop},
{"cutoff_probability", cutoff_probability},
{"cutoff_value", cutoff_value},
{"td_lambda", td_lambda},
{"td_n_steps", td_n_steps},
{"actors", actors},
{"evaluators", evaluators},
{"eval_levels", eval_levels},
{"max_steps", max_steps},
});
}

void FromJson(const json::Object& config_json) {
game = config_json.at("game").GetString();
path = config_json.at("path").GetString();
graph_def = config_json.at("graph_def").GetString();
nn_model = config_json.at("nn_model").GetString();
nn_width = config_json.at("nn_width").GetInt();
nn_depth = config_json.at("nn_depth").GetInt();
devices = config_json.at("devices").GetString();
explicit_learning = config_json.at("explicit_learning").GetBool();
learning_rate = config_json.at("learning_rate").GetDouble();
weight_decay = config_json.at("weight_decay").GetDouble();
train_batch_size = config_json.at("train_batch_size").GetInt();
inference_batch_size = config_json.at("inference_batch_size").GetInt();
inference_threads = config_json.at("inference_threads").GetInt();
inference_cache = config_json.at("inference_cache").GetInt();
replay_buffer_size = config_json.at("replay_buffer_size").GetInt();
replay_buffer_reuse = config_json.at("replay_buffer_reuse").GetInt();
checkpoint_freq = config_json.at("checkpoint_freq").GetInt();
evaluation_window = config_json.at("evaluation_window").GetInt();
uct_c = config_json.at("uct_c").GetDouble();
max_simulations = config_json.at("max_simulations").GetInt();
policy_alpha = config_json.at("policy_alpha").GetDouble();
policy_epsilon = config_json.at("policy_epsilon").GetDouble();
temperature = config_json.at("temperature").GetDouble();
temperature_drop = config_json.at("temperature_drop").GetDouble();
cutoff_probability = config_json.at("cutoff_probability").GetDouble();
cutoff_value = config_json.at("cutoff_value").GetDouble();
actors = config_json.at("actors").GetInt();
evaluators = config_json.at("evaluators").GetInt();
eval_levels = config_json.at("eval_levels").GetInt();
max_steps = config_json.at("max_steps").GetInt();
void FromJsonWithDefaults(const json::Object& config_json,
const json::Object& defaults_json) {
json::Object merged;
merged.insert(config_json.begin(), config_json.end());
merged.insert(defaults_json.begin(), defaults_json.end());
game = merged.at("game").GetString();
path = merged.at("path").GetString();
graph_def = merged.at("graph_def").GetString();
nn_model = merged.at("nn_model").GetString();
nn_width = merged.at("nn_width").GetInt();
nn_depth = merged.at("nn_depth").GetInt();
devices = merged.at("devices").GetString();
explicit_learning = merged.at("explicit_learning").GetBool();
learning_rate = merged.at("learning_rate").GetDouble();
weight_decay = merged.at("weight_decay").GetDouble();
train_batch_size = merged.at("train_batch_size").GetInt();
inference_batch_size = merged.at("inference_batch_size").GetInt();
inference_threads = merged.at("inference_threads").GetInt();
inference_cache = merged.at("inference_cache").GetInt();
replay_buffer_size = merged.at("replay_buffer_size").GetInt();
replay_buffer_reuse = merged.at("replay_buffer_reuse").GetInt();
checkpoint_freq = merged.at("checkpoint_freq").GetInt();
evaluation_window = merged.at("evaluation_window").GetInt();
uct_c = merged.at("uct_c").GetDouble();
min_simulations = merged.at("min_simulations").GetInt();
max_simulations = merged.at("max_simulations").GetInt();
policy_alpha = merged.at("policy_alpha").GetDouble();
policy_epsilon = merged.at("policy_epsilon").GetDouble();
temperature = merged.at("temperature").GetDouble();
temperature_drop = merged.at("temperature_drop").GetDouble();
cutoff_probability = merged.at("cutoff_probability").GetDouble();
cutoff_value = merged.at("cutoff_value").GetDouble();
td_lambda = merged.at("td_lambda").GetDouble();
td_n_steps = merged.at("td_n_steps").GetInt();
actors = merged.at("actors").GetInt();
evaluators = merged.at("evaluators").GetInt();
eval_levels = merged.at("eval_levels").GetInt();
max_steps = merged.at("max_steps").GetInt();
}
};

Expand Down
7 changes: 4 additions & 3 deletions open_spiel/algorithms/mcts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,13 @@ std::vector<double> dirichlet_noise(int count, double alpha,
}

MCTSBot::MCTSBot(const Game& game, std::shared_ptr<Evaluator> evaluator,
double uct_c, int max_simulations, int64_t max_memory_mb,
bool solve, int seed, bool verbose,
double uct_c, int min_simulations, int max_simulations,
int64_t max_memory_mb, bool solve, int seed, bool verbose,
ChildSelectionPolicy child_selection_policy,
double dirichlet_alpha, double dirichlet_epsilon,
bool dont_return_chance_node)
: uct_c_{uct_c},
min_simulations_{min_simulations},
max_simulations_{max_simulations},
max_nodes_((max_memory_mb << 20) / sizeof(SearchNode) + 1),
nodes_(0),
Expand Down Expand Up @@ -428,7 +429,7 @@ std::unique_ptr<SearchNode> MCTSBot::MCTSearch(const State& state) {
}

if (!root->outcome.empty() || // Full game tree is solved.
root->children.size() == 1) {
(root->children.size() == 1 && i >= min_simulations_)) {
break;
}
if (max_nodes_ > 1 && nodes_ >= max_nodes_) {
Expand Down
3 changes: 2 additions & 1 deletion open_spiel/algorithms/mcts.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class MCTSBot : public Bot {
// failing. We don't know why right now, but intend to fix this.
MCTSBot(
const Game& game, std::shared_ptr<Evaluator> evaluator, double uct_c,
int max_simulations,
int min_simulations, int max_simulations,
int64_t max_memory_mb, // Max memory use in megabytes.
bool solve, // Whether to back up solved states.
int seed, bool verbose,
Expand Down Expand Up @@ -203,6 +203,7 @@ class MCTSBot : public Bot {
void GarbageCollect(SearchNode* node);

double uct_c_;
int min_simulations_;
int max_simulations_;
int max_nodes_; // Max nodes allowed in the tree
int nodes_; // Nodes used in the tree.
Expand Down
2 changes: 1 addition & 1 deletion open_spiel/algorithms/mcts_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ std::unique_ptr<open_spiel::Bot> InitBot(const open_spiel::Game& game,
int max_simulations,
std::shared_ptr<Evaluator> evaluator) {
return std::make_unique<open_spiel::algorithms::MCTSBot>(
game, std::move(evaluator), UCT_C, max_simulations,
game, std::move(evaluator), UCT_C, /*min_simulations=*/0, max_simulations,
/*max_memory_mb=*/5, /*solve=*/true, /*seed=*/42, /*verbose=*/false);
}

Expand Down
4 changes: 3 additions & 1 deletion open_spiel/examples/alpha_zero_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ ABSL_FLAG(int, replay_buffer_size, 1 << 16,
ABSL_FLAG(double, replay_buffer_reuse, 3,
"How many times to reuse each state in the replay buffer.");
ABSL_FLAG(int, checkpoint_freq, 100, "Save a checkpoint every N steps.");
ABSL_FLAG(int, max_simulations, 300, "How many simulations to run.");
ABSL_FLAG(int, min_simulations, 0, "How many simulations to run (min).");
ABSL_FLAG(int, max_simulations, 300, "How many simulations to run (max).");
ABSL_FLAG(int, train_batch_size, 1 << 10,
"How many states to learn from per batch.");
ABSL_FLAG(int, inference_batch_size, 1,
Expand Down Expand Up @@ -102,6 +103,7 @@ int main(int argc, char** argv) {
config.checkpoint_freq = absl::GetFlag(FLAGS_checkpoint_freq);
config.evaluation_window = 100;
config.uct_c = absl::GetFlag(FLAGS_uct_c);
config.min_simulations = absl::GetFlag(FLAGS_min_simulations);
config.max_simulations = absl::GetFlag(FLAGS_max_simulations);
config.train_batch_size = absl::GetFlag(FLAGS_train_batch_size);
config.inference_batch_size = absl::GetFlag(FLAGS_inference_batch_size);
Expand Down