You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
================================================================= FAILURES =================================================================
_________________________________________ test_backward_compatibility[aloha-act-extra_overrides2] __________________________________________
env_name = 'aloha', policy_name = 'act', extra_overrides = ['policy.n_action_steps=10']
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides",
[
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
def test_backward_compatibility(env_name, policy_name, extra_overrides):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `test/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `python test/scripts/save_policy_to_safetensors.py`. The test artifact should be updated.
4. Check that this test now passes.
5. Remember to restore `test/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
for key in saved_output_dict:
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
for key in saved_grad_stats:
> assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
E assert tensor(False)
E + where tensor(False) = <built-in method all of Tensor object at 0x7d36a6cda660>()
E + where <built-in method all of Tensor object at 0x7d36a6cda660> = tensor(False).all
E + where tensor(False) = <built-in method isclose of type object at 0x7d3855386760>(tensor(0.0026), tensor(0.0005), rtol=0.1, atol=1e-07)
E + where <built-in method isclose of type object at 0x7d3855386760> = torch.isclose
tests/test_policies.py:274: AssertionError
========================================================= short test summary info ==========================================================
FAILED tests/test_policies.py::test_backward_compatibility[aloha-act-extra_overrides2] - assert tensor(False)
================================================ 1 failed, 38 passed, 26 skipped in 14.17s =================================================
Expected behavior
All tests to pass
The text was updated successfully, but these errors were encountered:
Yes, unfortunately this test is very sensitive to the platform you're running it with.
It's passing on the CI right now (here and here) so it shouldn't be cause for concern but definitely not the best.
We're thinking of how to improve it (ideally it should be able to run & pass on any platform), if you have any ideas on how to do that please don't hesitate to share your thoughts, here or in a PR ;)
System Info
`main` / 89c6be8
Information
Reproduction
Run
Output:
Expected behavior
All tests to pass
The text was updated successfully, but these errors were encountered: