diff --git a/codes/dataloader.py b/codes/dataloader.py index 70d43a25..ed3f3492 100644 --- a/codes/dataloader.py +++ b/codes/dataloader.py @@ -59,8 +59,8 @@ def __getitem__(self, idx): negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] - negative_sample = torch.from_numpy(negative_sample) - + negative_sample = torch.LongTensor(negative_sample) + positive_sample = torch.LongTensor(positive_sample) return positive_sample, negative_sample, subsampling_weight, self.mode @@ -181,4 +181,4 @@ def one_shot_iterator(dataloader): ''' while True: for data in dataloader: - yield data \ No newline at end of file + yield data