Skip to content

Commit ffd2f60

Browse files
committed
Add TransE
1 parent 1fad8be commit ffd2f60

File tree

2 files changed

+736
-0
lines changed

2 files changed

+736
-0
lines changed

transE_Bernoulli_pytorch.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Date : 2018-01-07 01:05:44
4+
# @Author : jimmy (jimmywangheng@qq.com)
5+
# @Link : http://sdcs.sysu.edu.cn
6+
# @Version : $Id$
7+
8+
import os
9+
10+
import torch
11+
import torch.autograd as autograd
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
import torch.optim as optim
15+
16+
import numpy as np
17+
import time
18+
import datetime
19+
import random
20+
21+
from utils import *
22+
from data import *
23+
from evaluation import *
24+
import loss
25+
import model
26+
27+
from hyperboard import Agent
28+
29+
USE_CUDA = torch.cuda.is_available()
30+
31+
if USE_CUDA:
32+
longTensor = torch.cuda.LongTensor
33+
floatTensor = torch.cuda.FloatTensor
34+
35+
else:
36+
longTensor = torch.LongTensor
37+
floatTensor = torch.FloatTensor
38+
39+
"""
40+
The meaning of parameters:
41+
self.dataset: Which dataset is used to train the model? Such as 'FB15k', 'WN18', etc.
42+
self.learning_rate: Initial learning rate (lr) of the model.
43+
self.early_stopping_round: How many times will lr decrease? If set to 0, it remains constant.
44+
self.L1_flag: If set to True, use L1 distance as dissimilarity; else, use L2.
45+
self.embedding_size: The embedding size of entities and relations.
46+
self.num_batches: How many batches to train in one epoch?
47+
self.train_times: The maximum number of epochs for training.
48+
self.margin: The margin set for MarginLoss.
49+
self.filter: Whether to check a generated negative sample is false negative.
50+
self.momentum: The momentum of the optimizer.
51+
self.optimizer: Which optimizer to use? Such as SGD, Adam, etc.
52+
self.loss_function: Which loss function to use? Typically, we use margin loss.
53+
self.entity_total: The number of different entities.
54+
self.relation_total: The number of different relations.
55+
self.batch_size: How many instances is contained in one batch?
56+
"""
57+
58+
class Config(object):
59+
def __init__(self):
60+
self.dataset = None
61+
self.learning_rate = 0.001
62+
self.early_stopping_round = 0
63+
self.L1_flag = True
64+
self.embedding_size = 100
65+
self.num_batches = 100
66+
self.train_times = 1000
67+
self.margin = 1.0
68+
self.filter = True
69+
self.momentum = 0.9
70+
self.optimizer = optim.Adam
71+
self.loss_function = loss.marginLoss
72+
self.entity_total = 0
73+
self.relation_total = 0
74+
self.batch_size = 0
75+
76+
if __name__ == "__main__":
77+
78+
import argparse
79+
argparser = argparse.ArgumentParser()
80+
81+
"""
82+
The meaning of some parameters:
83+
seed: Fix the random seed. Except for 0, which means no setting of random seed.
84+
port: The port number used by hyperboard,
85+
which is a demo showing training curves in real time.
86+
You can refer to https://github.com/WarBean/hyperboard to know more.
87+
num_processes: Number of processes used to evaluate the result.
88+
"""
89+
90+
argparser.add_argument('-d', '--dataset', type=str)
91+
argparser.add_argument('-l', '--learning_rate', type=float, default=0.001)
92+
argparser.add_argument('-es', '--early_stopping_round', type=int, default=0)
93+
argparser.add_argument('-L', '--L1_flag', type=int, default=1)
94+
argparser.add_argument('-em', '--embedding_size', type=int, default=100)
95+
argparser.add_argument('-nb', '--num_batches', type=int, default=100)
96+
argparser.add_argument('-n', '--train_times', type=int, default=1000)
97+
argparser.add_argument('-m', '--margin', type=float, default=1.0)
98+
argparser.add_argument('-f', '--filter', type=int, default=1)
99+
argparser.add_argument('-mo', '--momentum', type=float, default=0.9)
100+
argparser.add_argument('-s', '--seed', type=int, default=0)
101+
argparser.add_argument('-op', '--optimizer', type=int, default=1)
102+
argparser.add_argument('-lo', '--loss_type', type=int, default=0)
103+
argparser.add_argument('-p', '--port', type=int, default=5000)
104+
argparser.add_argument('-np', '--num_processes', type=int, default=4)
105+
106+
args = argparser.parse_args()
107+
108+
# Start the hyperboard agent
109+
agent = Agent(address='127.0.0.1', port=args.port)
110+
111+
if args.seed != 0:
112+
torch.manual_seed(args.seed)
113+
114+
trainTotal, trainList, trainDict = loadTriple('./data/' + args.dataset, 'train2id.txt')
115+
validTotal, validList, validDict = loadTriple('./data/' + args.dataset, 'valid2id.txt')
116+
tripleTotal, tripleList, tripleDict = loadTriple('./data/' + args.dataset, 'triple2id.txt')
117+
with open(os.path.join('./data/', args.dataset, 'head_tail_proportion.pkl'), 'rb') as fr:
118+
tail_per_head = pickle.load(fr)
119+
head_per_tail = pickle.load(fr)
120+
121+
config = Config()
122+
config.dataset = args.dataset
123+
config.learning_rate = args.learning_rate
124+
125+
config.early_stopping_round = args.early_stopping_round
126+
127+
if args.L1_flag == 1:
128+
config.L1_flag = True
129+
else:
130+
config.L1_flag = False
131+
132+
config.embedding_size = args.embedding_size
133+
config.num_batches = args.num_batches
134+
config.train_times = args.train_times
135+
config.margin = args.margin
136+
137+
if args.filter == 1:
138+
config.filter = True
139+
else:
140+
config.filter = False
141+
142+
config.momentum = args.momentum
143+
144+
if args.optimizer == 0:
145+
config.optimizer = optim.SGD
146+
elif args.optimizer == 1:
147+
config.optimizer = optim.Adam
148+
elif args.optimizer == 2:
149+
config.optimizer = optim.RMSprop
150+
151+
if args.loss_type == 0:
152+
config.loss_function = loss.marginLoss
153+
154+
config.entity_total = getAnythingTotal('./data/' + config.dataset, 'entity2id.txt')
155+
config.relation_total = getAnythingTotal('./data/' + config.dataset, 'relation2id.txt')
156+
config.batch_size = trainTotal // config.num_batches
157+
158+
shareHyperparameters = {'dataset': args.dataset,
159+
'learning_rate': args.learning_rate,
160+
'early_stopping_round': args.early_stopping_round,
161+
'L1_flag': args.L1_flag,
162+
'embedding_size': args.embedding_size,
163+
'margin': args.margin,
164+
'filter': args.filter,
165+
'momentum': args.momentum,
166+
'seed': args.seed,
167+
'optimizer': args.optimizer,
168+
'loss_type': args.loss_type,
169+
}
170+
171+
trainHyperparameters = shareHyperparameters.copy()
172+
trainHyperparameters.update({'type': 'train_loss'})
173+
174+
validHyperparameters = shareHyperparameters.copy()
175+
validHyperparameters.update({'type': 'valid_loss'})
176+
177+
hit10Hyperparameters = shareHyperparameters.copy()
178+
hit10Hyperparameters.update({'type': 'hit10'})
179+
180+
meanrankHyperparameters = shareHyperparameters.copy()
181+
meanrankHyperparameters.update({'type': 'mean_rank'})
182+
183+
trainCurve = agent.register(trainHyperparameters, 'train loss', overwrite=True)
184+
validCurve = agent.register(validHyperparameters, 'valid loss', overwrite=True)
185+
hit10Curve = agent.register(hit10Hyperparameters, 'hit@10', overwrite=True)
186+
meanrankCurve = agent.register(meanrankHyperparameters, 'mean rank', overwrite=True)
187+
188+
loss_function = config.loss_function()
189+
model = model.TransEModel(config)
190+
191+
if USE_CUDA:
192+
model.cuda()
193+
loss_function.cuda()
194+
195+
optimizer = config.optimizer(model.parameters(), lr=config.learning_rate)
196+
margin = autograd.Variable(floatTensor([config.margin]))
197+
198+
start_time = time.time()
199+
200+
filename = '_'.join(
201+
['l', str(args.learning_rate),
202+
'es', str(args.early_stopping_round),
203+
'L', str(args.L1_flag),
204+
'em', str(args.embedding_size),
205+
'nb', str(args.num_batches),
206+
'n', str(args.train_times),
207+
'm', str(args.margin),
208+
'f', str(args.filter),
209+
'mo', str(args.momentum),
210+
's', str(args.seed),
211+
'op', str(args.optimizer),
212+
'lo', str(args.loss_type),]) + '_TransE_Bernoulli.ckpt'
213+
214+
trainBatchList = getBatchList(trainList, config.num_batches)
215+
216+
for epoch in range(config.train_times):
217+
total_loss = floatTensor([0.0])
218+
random.shuffle(trainBatchList)
219+
for batchList in trainBatchList:
220+
if config.filter == True:
221+
pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch = getBatch_filter_all_v2(batchList,
222+
config.entity_total, tripleDict, tail_per_head, head_per_tail)
223+
else:
224+
pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch = getBatch_raw_all_v2(batchList,
225+
config.entity_total, tail_per_head, head_per_tail)
226+
227+
batch_entity_set = set(pos_h_batch + pos_t_batch + neg_h_batch + neg_t_batch)
228+
batch_relation_set = set(pos_r_batch + neg_r_batch)
229+
batch_entity_list = list(batch_entity_set)
230+
batch_relation_list = list(batch_relation_set)
231+
232+
pos_h_batch = autograd.Variable(longTensor(pos_h_batch))
233+
pos_t_batch = autograd.Variable(longTensor(pos_t_batch))
234+
pos_r_batch = autograd.Variable(longTensor(pos_r_batch))
235+
neg_h_batch = autograd.Variable(longTensor(neg_h_batch))
236+
neg_t_batch = autograd.Variable(longTensor(neg_t_batch))
237+
neg_r_batch = autograd.Variable(longTensor(neg_r_batch))
238+
239+
model.zero_grad()
240+
pos, neg = model(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch)
241+
242+
if args.loss_type == 0:
243+
losses = loss_function(pos, neg, margin)
244+
else:
245+
losses = loss_function(pos, neg)
246+
ent_embeddings = model.ent_embeddings(torch.cat([pos_h_batch, pos_t_batch, neg_h_batch, neg_t_batch]))
247+
rel_embeddings = model.rel_embeddings(torch.cat([pos_r_batch, neg_r_batch]))
248+
losses = losses + loss.normLoss(ent_embeddings) + loss.normLoss(rel_embeddings)
249+
250+
losses.backward()
251+
optimizer.step()
252+
total_loss += losses.data
253+
254+
agent.append(trainCurve, epoch, total_loss[0])
255+
256+
if epoch % 10 == 0:
257+
now_time = time.time()
258+
print(now_time - start_time)
259+
print("Train total loss: %d %f" % (epoch, total_loss[0]))
260+
261+
if epoch % 10 == 0:
262+
if config.filter == True:
263+
pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch = getBatch_filter_random_v2(validList,
264+
config.batch_size, config.entity_total, tripleDict, tail_per_head, head_per_tail)
265+
else:
266+
pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch = getBatch_raw_random_v2(validList,
267+
config.batch_size, config.entity_total, tail_per_head, head_per_tail)
268+
pos_h_batch = autograd.Variable(longTensor(pos_h_batch))
269+
pos_t_batch = autograd.Variable(longTensor(pos_t_batch))
270+
pos_r_batch = autograd.Variable(longTensor(pos_r_batch))
271+
neg_h_batch = autograd.Variable(longTensor(neg_h_batch))
272+
neg_t_batch = autograd.Variable(longTensor(neg_t_batch))
273+
neg_r_batch = autograd.Variable(longTensor(neg_r_batch))
274+
275+
pos, neg = model(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch)
276+
277+
if args.loss_type == 0:
278+
losses = loss_function(pos, neg, margin)
279+
else:
280+
losses = loss_function(pos, neg)
281+
ent_embeddings = model.ent_embeddings(torch.cat([pos_h_batch, pos_t_batch, neg_h_batch, neg_t_batch]))
282+
rel_embeddings = model.rel_embeddings(torch.cat([pos_r_batch, neg_r_batch]))
283+
losses = losses + loss.normLoss(ent_embeddings) + loss.normLoss(rel_embeddings)
284+
print("Valid batch loss: %d %f" % (epoch, losses.data[0]))
285+
agent.append(validCurve, epoch, losses.data[0])
286+
287+
if config.early_stopping_round > 0:
288+
if epoch == 0:
289+
ent_embeddings = model.ent_embeddings.weight.data.cpu().numpy()
290+
rel_embeddings = model.rel_embeddings.weight.data.cpu().numpy()
291+
L1_flag = model.L1_flag
292+
filter = model.filter
293+
hit10, best_meanrank = evaluation_transE(validList, tripleDict, ent_embeddings, rel_embeddings,
294+
L1_flag, filter, config.batch_size, num_processes=args.num_processes)
295+
agent.append(hit10Curve, epoch, hit10)
296+
agent.append(meanrankCurve, epoch, best_meanrank)
297+
torch.save(model, os.path.join('./model/' + args.dataset, filename))
298+
best_epoch = 0
299+
meanrank_not_decrease_time = 0
300+
lr_decrease_time = 0
301+
#if USE_CUDA:
302+
#model.cuda()
303+
304+
# Check the result on validation set for every 5 epochs
305+
elif epoch % 5 == 0:
306+
ent_embeddings = model.ent_embeddings.weight.data.cpu().numpy()
307+
rel_embeddings = model.rel_embeddings.weight.data.cpu().numpy()
308+
L1_flag = model.L1_flag
309+
filter = model.filter
310+
hit10, now_meanrank = evaluation_transE(validList, tripleDict, ent_embeddings, rel_embeddings,
311+
L1_flag, filter, config.batch_size, num_processes=args.num_processes)
312+
agent.append(hit10Curve, epoch, hit10)
313+
agent.append(meanrankCurve, epoch, now_meanrank)
314+
if now_meanrank < best_meanrank:
315+
meanrank_not_decrease_time = 0
316+
best_meanrank = now_meanrank
317+
torch.save(model, os.path.join('./model/' + args.dataset, filename))
318+
else:
319+
meanrank_not_decrease_time += 1
320+
# If the result hasn't improved for consecutive 5 epochs, decrease learning rate
321+
if meanrank_not_decrease_time == 5:
322+
lr_decrease_time += 1
323+
if lr_decrease_time == config.early_stopping_round:
324+
break
325+
else:
326+
optimizer.param_groups[0]['lr'] *= 0.5
327+
meanrank_not_decrease_time = 0
328+
#if USE_CUDA:
329+
#model.cuda()
330+
331+
elif (epoch + 1) % 10 == 0 or epoch == 0:
332+
torch.save(model, os.path.join('./model/' + args.dataset, filename))
333+
334+
testTotal, testList, testDict = loadTriple('./data/' + args.dataset, 'test2id.txt')
335+
oneToOneTotal, oneToOneList, oneToOneDict = loadTriple('./data/' + args.dataset, 'one_to_one.txt')
336+
oneToManyTotal, oneToManyList, oneToManyDict = loadTriple('./data/' + args.dataset, 'one_to_many.txt')
337+
manyToOneTotal, manyToOneList, manyToOneDict = loadTriple('./data/' + args.dataset, 'many_to_one.txt')
338+
manyToManyTotal, manyToManyList, manyToManyDict = loadTriple('./data/' + args.dataset, 'many_to_many.txt')
339+
340+
ent_embeddings = model.ent_embeddings.weight.data.cpu().numpy()
341+
rel_embeddings = model.rel_embeddings.weight.data.cpu().numpy()
342+
L1_flag = model.L1_flag
343+
filter = model.filter
344+
345+
hit10Test, meanrankTest = evaluation_transE(testList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=0)
346+
347+
hit10OneToOneHead, meanrankOneToOneHead = evaluation_transE(oneToOneList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=1)
348+
hit10OneToManyHead, meanrankOneToManyHead = evaluation_transE(oneToManyList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=1)
349+
hit10ManyToOneHead, meanrankManyToOneHead = evaluation_transE(manyToOneList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=1)
350+
hit10ManyToManyHead, meanrankManyToManyHead = evaluation_transE(manyToManyList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=1)
351+
352+
hit10OneToOneTail, meanrankOneToOneTail = evaluation_transE(oneToOneList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=2)
353+
hit10OneToManyTail, meanrankOneToManyTail = evaluation_transE(oneToManyList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=2)
354+
hit10ManyToOneTail, meanrankManyToOneTail = evaluation_transE(manyToOneList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=2)
355+
hit10ManyToManyTail, meanrankManyToManyTail = evaluation_transE(manyToManyList, tripleDict, ent_embeddings, rel_embeddings, L1_flag, filter, head=2)
356+
357+
writeList = [filename,
358+
'testSet', '%.6f' % hit10Test, '%.6f' % meanrankTest,
359+
'one_to_one_head', '%.6f' % hit10OneToOneHead, '%.6f' % meanrankOneToOneHead,
360+
'one_to_many_head', '%.6f' % hit10OneToManyHead, '%.6f' % meanrankOneToManyHead,
361+
'many_to_one_head', '%.6f' % hit10ManyToOneHead, '%.6f' % meanrankManyToOneHead,
362+
'many_to_many_head', '%.6f' % hit10ManyToManyHead, '%.6f' % meanrankManyToManyHead,
363+
'one_to_one_tail', '%.6f' % hit10OneToOneTail, '%.6f' % meanrankOneToOneTail,
364+
'one_to_many_tail', '%.6f' % hit10OneToManyTail, '%.6f' % meanrankOneToManyTail,
365+
'many_to_one_tail', '%.6f' % hit10ManyToOneTail, '%.6f' % meanrankManyToOneTail,
366+
'many_to_many_tail', '%.6f' % hit10ManyToManyTail, '%.6f' % meanrankManyToManyTail,]
367+
368+
# Write the result into file
369+
with open(os.path.join('./result/', args.dataset + '.txt'), 'a') as fw:
370+
fw.write('\t'.join(writeList) + '\n')

0 commit comments

Comments
 (0)