Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The definition of bins in Predicted Aligned Error Head(PAE) may be wrong #929

Open
xiergo opened this issue Apr 17, 2024 · 0 comments
Open

Comments

@xiergo
Copy link

xiergo commented Apr 17, 2024

I am confused about the definition of bins in Predicted Aligned Error Head(PAE). The breaks is defined as [0, 0.5, 1, ..., 31]

# self.config.max_error_bin=31,  self.config.num_bins=64
  breaks = jnp.linspace(
        0., self.config.max_error_bin, self.config.num_bins - 1)

and the centers are [0.25, 0.75, ..., 30.5, 31.5, 32.5], according to:

def _calculate_bin_centers(breaks: np.ndarray):
  """Gets the bin centers from the bin edges.

  Args:
    breaks: [num_bins - 1] the error bin edges.

  Returns:
    bin_centers: [num_bins] the error bin centers.
  """
  step = (breaks[1] - breaks[0])

  # Add half-step to get the center
  bin_centers = breaks + step / 2
  # Add a catch-all bin at the end.
  bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
                               axis=0)
  return bin_centers

Then the 64 bins are [0, 0.5], [0.5, 1] ..., [31, 31.5], [31.5, +inf].

But the bins defined in the PAE-loss are [-inf, 0], [0, 0.5], ...[31, +inf], which are left shifted for one bin, based on the definition in alphafold/alphafold/model/modules.py line 1200:

sq_breaks = jnp.square(breaks) #[0, 0.5, ..., 31]
    true_bins = jnp.sum((
        error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)

    errors = softmax_cross_entropy(
        labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)

For example, for error_dist=0.75, which should fall into the second bin [0.5, 1], but (0.75>breaks).sum() is 2, the one_hot values are [0, 0, 1, 0, ..., 0] with the third entry being 1, which is incorrect.

@xiergo xiergo changed the title The bin edges defined in The definition of bins in Predicted Aligned Error Head(PAE) may be wrong Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant