Objective function in Neural Network

2018-07-02

Triplet Loss

Inference

Strategies in online mining

In online mining, we have computed a batch of B embeddings from a batch of B inputs . Whenever we have three indices i, j, k in [1, B], if examples i and j have the same label but are distinct, and example k has a different label, we say that (i, j, k) is a valid triplet.

Suppose that you have a batch of faces as input of size B = PK, composed of P different persons with K images each.

There are two strategies:

  • batch all: select all the valid triplets, and average the loss on the hard and semi-hard triplets.
    • a crucial point here is to not take into account the easy triplets(those with loss 0), as averaging on them would make the overall loss very small.
    • this produces a total of PK(K - 1)(PK - K) triplets(PK anchors, K - 1 possible positives per anchor, PK - K possible negatives.).
  • batch hard: for each anchor, select the hardest positive(biggest distance d(a, p)) and the hardest negative among the batch.
    • this produces PK triplets.
    • the selected triplets are the hardest among the batch.

According to this paper, the batch hard strategy yields the best performance:

Additionally, the selected triplets can be considered moderate triplet, since they are the hardest within a small subset of the data, which is exactly what is best for learning with the triplet loss.

An tensorflow implementation with online triplet mining

Compute the distance matrix

As the final triplet loss depends on the distance d(a, p) and d(a, n), we first need to efficiently compute the pairwise distance matrix. We implement this for euclidean norm and the squared euclidean norm, in the _pairwise_distancefunction:

def _pairwise_distances(embeddings, squared=False):
    """Compute the 2D matrix of distances between all the embeddings.

    Args:
        embeddings: tensor of shape (batch_size, embed_dim)
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        pairwise_distances: tensor of shape (batch_size, batch_size)
    """
    # Get the dot product between all embeddings
    # shape (batch_size, batch_size)
    dot_product = tf.matmul(embeddings, tf.transpose(embeddings))

    # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
    # This also provides more numerical stability (the diagonal of the result will be exactly 0).
    # shape (batch_size,)
    square_norm = tf.diag_part(dot_product)

    # Compute the pairwise distance matrix as we have:
    # ||a - b||^2 = ||a||^2  - 2 <a, b> + ||b||^2
    # shape (batch_size, batch_size)
    distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)

    # Because of computation errors, some distances might be negative so we put everything >= 0.0
    distances = tf.maximum(distances, 0.0)

    if not squared:
        # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
        # we need to add a small epsilon where distances == 0.0
        mask = tf.to_float(tf.equal(distances, 0.0))
        distances = distances + mask * 1e-16

        distances = tf.sqrt(distances)

        # Correct the epsilon added: set the distances on the mask to be exactly 0.0
        distances = distances * (1.0 - mask)

    return distances


An numpy example:

embeddings = np.arange(8)
embeddings = np.expand_dims(embeddings, -1)
embeddings = np.concatenate([embeddings, embeddings], -1)
print(embeddings)

dot_product = np.matmul(embeddings, np.transpose(embeddings))
print(dot_product)

square_norm = np.diag(dot_product)
print(square_norm)

# here use broadcasting.
distances = np.expand_dims(square_norm, 1) - 2 * dot_product + np.expand_dims(square_norm, 0)
print(distances)

Batch hard strategy

In this strategy, we want to find the hardest positive and negative for each anchor.

  • Hardest positive

    To compute the hardest positive, we begin with the pairwise distance matrix. We then get a 2D mask of the valid (a, p) (i.e. a != p and p have same labels) and put to 0 any element outside of the mask.

    The last step is just to take the maximum distance over each row of this modified distance matrix. The result should be a valid pair (a, p) since invalid elements are set to 0.

  • Hardest negative

    The hardest negative is similar but a bit trickier to compute. Here we need to get the minimum distance for each row, so we cannot set to 0 the invalid pairs (a, n)(invalid if a and n have the same label).

    Our trick here is for each row to add the maximum value to the invalid pairs (a, n). We then take the minimum over each row. The result should be a valid pair (a, n) since invalid elements are set to the maximum value.

The final step is to combine these into the triplet loss:

triplet_loss = tf.maximum(hardest_positive_dist - hardest_negetive_dist + margin, 0.0)

Every thing is implemented in function ‘batch_hard_triplet_loss’:

def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
    """Build the triplet loss over a batch of embeddings.

    For each anchor, we get the hardest positive and hardest negative to form a triplet.

    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    """
    # Get the pairwise distance matrix
    pairwise_dist = _pairwise_distances(embeddings, squared=squared)

    # For each anchor, get the hardest positive
    # First, we need to get a mask for every valid positive (they should have same label)
    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels)
    mask_anchor_positive = tf.to_float(mask_anchor_positive)

    # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
    anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)

    # shape (batch_size, 1)
    hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)

    # For each anchor, get the hardest negative
    # First, we need to get a mask for every valid negative (they should have different labels)
    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = tf.to_float(mask_anchor_negative)

    # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
    max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)

    # shape (batch_size,)
    hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)

    # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
    triplet_loss = tf.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)

    # Get final mean triplet loss
    triplet_loss = tf.reduce_mean(triplet_loss)

    return triplet_loss