jaxchem.loss¶
-
binary_cross_entropy_with_logits
(inputs, targets, average=True)[source]¶ Binary cross entropy loss.
This function is based on the PyTorch implemantation.
See : https://discuss.pytorch.org/t/numerical-stability-of-bcewithlogitsloss/8246
- Parameters
inputs (jnp.ndarray) – This is a model output. This is a value before passing a sigmoid function.
targets (jnp.ndarray) – This is a label and the same shape as inputs.
average (bool) – Whether to mean loss values or sum, default to be True.
- Returns
loss – This is a binary cross entropy loss.
- Return type
jnp.ndarray