jaxchem.loss

binary_cross_entropy_with_logits(inputs, targets, average=True)[source]

Binary cross entropy loss.

This function is based on the PyTorch implemantation.

See : https://discuss.pytorch.org/t/numerical-stability-of-bcewithlogitsloss/8246

Parameters
  • inputs (jnp.ndarray) – This is a model output. This is a value before passing a sigmoid function.

  • targets (jnp.ndarray) – This is a label and the same shape as inputs.

  • average (bool) – Whether to mean loss values or sum, default to be True.

Returns

loss – This is a binary cross entropy loss.

Return type

jnp.ndarray