"""Implements triplet loss."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchflare.criterion.utils import cosine_dist, euclidean_dist
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
difference = dist - max_v
z = (
torch.sum(torch.exp(difference) * mask, dim=1, keepdim=True) + 1e-6
) # avoid division by zero
weights = torch.exp(difference) * mask / z
return weights
# skipcq: PYL-W0107
"""Source :
https://github.com/earhian/Humpback-Whale-Identification-1st-/"""
def hard_example_mining(distance_matrix, pos_idxs, neg_idxs):
"""For each anchor, find the hardest positive and negative sample.
Args:
distance_matrix: pair wise distance between samples, shape [N, M]
pos_idxs: positive index with shape [N, M]
neg_idxs: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
Note:
Only consider the case in which all targets have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(distance_matrix.size()) == 2 # noqa: S101
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
dist_ap, _ = torch.max(distance_matrix * pos_idxs, dim=1)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
dist_an, _ = torch.min(distance_matrix * neg_idxs + pos_idxs * 99999999.0, dim=1)
return dist_ap, dist_an
def weighted_example_mining(distance_matrix, pos_idxs, neg_idxs):
"""For each anchor, find the weighted positive and negative sample.
Args:
distance_matrix: pytorch Variable, pair wise distance between samples, shape [N, N]
pos_idxs:positive index with shape [N, M]
neg_idxs: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(distance_matrix.size()) == 2 # noqa: S101
dist_ap = distance_matrix * pos_idxs
dist_an = distance_matrix * neg_idxs
weights_ap = softmax_weights(dist_ap, pos_idxs)
weights_an = softmax_weights(-dist_an, neg_idxs)
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
dist_an = torch.sum(dist_an * weights_an, dim=1)
return dist_ap, dist_an
[docs]class TripletLoss(nn.Module):
"""Computes Triplet loss.
Args:
normalize_features: Whether to normalize the features. Default = True
margin: The value for margin. Default = None.
hard_mining: Whether to use hard sample mining. Default = True.
"""
def __init__(
self,
normalize_features: bool = True,
margin: float = None,
hard_mining: bool = True,
):
"""Constructor method for TripletLoss."""
super(TripletLoss, self).__init__()
self.normalize_features = normalize_features
self.margin = margin
self.hard_mining = hard_mining
[docs] def forward(self, embedding: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Forward Method.
Args:
embedding: The output of the network.
targets: The targets.
Returns:
The computed Triplet Loss.
"""
distance_matrix = (
cosine_dist(embedding, embedding)
if self.normalize_features
else euclidean_dist(embedding, embedding)
)
n = distance_matrix.size(0)
pos_idxs = targets.view(n, 1).expand(n, n).eq(targets.view(n, 1).expand(n, n).t()).float()
neg_idxs = targets.view(n, 1).expand(n, n).ne(targets.view(n, 1).expand(n, n).t()).float()
if self.hard_mining:
dist_ap, dist_an = hard_example_mining(
distance_matrix=distance_matrix, pos_idxs=pos_idxs, neg_idxs=neg_idxs
)
else:
dist_ap, dist_an = weighted_example_mining(
distance_matrix=distance_matrix, pos_idxs=pos_idxs, neg_idxs=neg_idxs
)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self.margin is not None and self.margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self.margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float("Inf"):
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
return loss
__all__ = ["TripletLoss"]