Skip to content

Commit 67fe4e7

Browse files
committed
Added a supervised metric pretraining option as described in the paper
1 parent 764e950 commit 67fe4e7

File tree

1 file changed

+64
-21
lines changed

1 file changed

+64
-21
lines changed

lib/SupervisedMetricPretraining.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,70 @@
1-
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
3-
"""
4-
Created on Wed Jan 27 11:06:53 2021
1+
import torch
2+
from torch import nn
3+
train_labels_= None
4+
def get_train_labels(trainloader, device='cuda'):
5+
global train_labels_
6+
if train_labels_ is None:
7+
print("=> loading all train labels")
8+
train_labels = -1 * torch.ones([len(trainloader.dataset)], dtype=torch.long)
9+
for i, (_, label, index) in enumerate(trainloader):
10+
train_labels[index] = label
11+
if i % 10000 == 0:
12+
print("{}/{}".format(i, len(trainloader)))
13+
assert all(train_labels != -1)
14+
train_labels_ = train_labels.to(device)
15+
return train_labels_
16+
class Supervised_Pretraining(object):
17+
def __init__(self,trainloader,n, t=0.07):
18+
"""
19+
Parameters
20+
----------
21+
trainloader :
22+
DataLoader containing training data.
23+
n : int
24+
Number of labels.
25+
t : float
26+
Temperature parameter as described in https://arxiv.org/pdf/1805.01978.pdf.
527
6-
@author: nuvilabs
7-
"""
28+
"""
29+
super(Supervised_Pretraining,self).__init__()
30+
# get train labels
31+
self.labels = get_train_labels(trainloader)
32+
# Softmax loss
33+
self.loss_fn = nn.CrossEntropyLoss()
34+
#init labels
35+
self.n_labels = n
36+
self.t = t
37+
def to(self,device):
38+
#send to a device
39+
self.loss_fn.to(device)
40+
def __call__(self,out,y):
41+
return self.forward(out,y)
42+
def forward(self,out,y):
43+
"""
44+
Parameters
45+
----------
46+
out :
47+
Output from LinearAverage.py as described in https://arxiv.org/pdf/1812.08781.pdf.
48+
y : tensor
49+
Target Labels.
850
9-
class NuviS3KeySensor(S3KeySensor):
51+
52+
Returns
53+
-------
54+
Softmax Loss.
1055
11-
12-
def poke(self, context):
13-
to_return = super().poke(context)
56+
"""
57+
#making it more sensitive by dividing by temperature value as in https://arxiv.org/pdf/1805.01978.pdf
58+
out.div_(self.t)
59+
#eq (4) in https://arxiv.org/pdf/1812.08781.pdf
60+
scores = torch.zeros(out.shape[0],self.n_labels).cuda()
61+
for i in range(self.n_labels):
62+
yi = self.labels == i
1463

15-
if to_return:
16-
hook = S3Hook(aws_conn_id=self.aws_conn_id)
17-
keys = hook.list_keys(self.bucket_name)
64+
candidates = yi.view(1,-1).expand(out.shape[0], -1)
65+
retrieval = out[candidates]
66+
retrieval = retrieval.reshape(out.shape[0], -1)
1867

19-
check_key = lambda key: 'ID' in key
20-
check_key_v = np.vectorize(check_key)
21-
keys = np.asarray(keys)
22-
idx = check_key_v(keys)
23-
ID_keys = keys[idx].tolist()
24-
to_return = len(ID_keys) != 0
25-
if to_return:self.xcom_push(context, key="IDS", value = ID_keys)
68+
scores[:,i] = retrieval.sum(1,keepdim=True).view(1,-1)
2669

27-
return to_return
70+
return self.loss_fn(scores, y)

0 commit comments

Comments
 (0)