1- #!/usr/bin/env python3
2- # -*- coding: utf-8 -*-
3- """
4- Created on Wed Jan 27 11:06:53 2021
1+ import torch
2+ from torch import nn
3+ train_labels_ = None
4+ def get_train_labels (trainloader , device = 'cuda' ):
5+ global train_labels_
6+ if train_labels_ is None :
7+ print ("=> loading all train labels" )
8+ train_labels = - 1 * torch .ones ([len (trainloader .dataset )], dtype = torch .long )
9+ for i , (_ , label , index ) in enumerate (trainloader ):
10+ train_labels [index ] = label
11+ if i % 10000 == 0 :
12+ print ("{}/{}" .format (i , len (trainloader )))
13+ assert all (train_labels != - 1 )
14+ train_labels_ = train_labels .to (device )
15+ return train_labels_
16+ class Supervised_Pretraining (object ):
17+ def __init__ (self ,trainloader ,n , t = 0.07 ):
18+ """
19+ Parameters
20+ ----------
21+ trainloader :
22+ DataLoader containing training data.
23+ n : int
24+ Number of labels.
25+ t : float
26+ Temperature parameter as described in https://arxiv.org/pdf/1805.01978.pdf.
527
6- @author: nuvilabs
7- """
28+ """
29+ super (Supervised_Pretraining ,self ).__init__ ()
30+ # get train labels
31+ self .labels = get_train_labels (trainloader )
32+ # Softmax loss
33+ self .loss_fn = nn .CrossEntropyLoss ()
34+ #init labels
35+ self .n_labels = n
36+ self .t = t
37+ def to (self ,device ):
38+ #send to a device
39+ self .loss_fn .to (device )
40+ def __call__ (self ,out ,y ):
41+ return self .forward (out ,y )
42+ def forward (self ,out ,y ):
43+ """
44+ Parameters
45+ ----------
46+ out :
47+ Output from LinearAverage.py as described in https://arxiv.org/pdf/1812.08781.pdf.
48+ y : tensor
49+ Target Labels.
850
9- class NuviS3KeySensor (S3KeySensor ):
51+
52+ Returns
53+ -------
54+ Softmax Loss.
1055
11-
12- def poke (self , context ):
13- to_return = super ().poke (context )
56+ """
57+ #making it more sensitive by dividing by temperature value as in https://arxiv.org/pdf/1805.01978.pdf
58+ out .div_ (self .t )
59+ #eq (4) in https://arxiv.org/pdf/1812.08781.pdf
60+ scores = torch .zeros (out .shape [0 ],self .n_labels ).cuda ()
61+ for i in range (self .n_labels ):
62+ yi = self .labels == i
1463
15- if to_return :
16- hook = S3Hook ( aws_conn_id = self . aws_conn_id )
17- keys = hook . list_keys ( self . bucket_name )
64+ candidates = yi . view ( 1 , - 1 ). expand ( out . shape [ 0 ], - 1 )
65+ retrieval = out [ candidates ]
66+ retrieval = retrieval . reshape ( out . shape [ 0 ], - 1 )
1867
19- check_key = lambda key : 'ID' in key
20- check_key_v = np .vectorize (check_key )
21- keys = np .asarray (keys )
22- idx = check_key_v (keys )
23- ID_keys = keys [idx ].tolist ()
24- to_return = len (ID_keys ) != 0
25- if to_return :self .xcom_push (context , key = "IDS" , value = ID_keys )
68+ scores [:,i ] = retrieval .sum (1 ,keepdim = True ).view (1 ,- 1 )
2669
27- return to_return
70+ return self . loss_fn ( scores , y )
0 commit comments