diff --git a/script/download_dataset.py b/script/download_dataset.py index b6f290dc..5e12d59a 100644 --- a/script/download_dataset.py +++ b/script/download_dataset.py @@ -21,7 +21,7 @@ os.makedirs(destination_dir) os.chdir(destination_dir) - for language in ('python', 'javascript', 'java', 'ruby', 'php', 'go'): + for language in ['java']: call(['wget', 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/{}.zip'.format(language), '-P', destination_dir, '-O', '{}.zip'.format(language)]) call(['unzip', '{}.zip'.format(language)]) call(['rm', '{}.zip'.format(language)]) diff --git a/src/data_dirs_test.txt b/src/data_dirs_test.txt index d5a4f530..71a2246f 100644 --- a/src/data_dirs_test.txt +++ b/src/data_dirs_test.txt @@ -1,6 +1 @@ -../resources/data/python/final/jsonl/test -../resources/data/javascript/final/jsonl/test ../resources/data/java/final/jsonl/test -../resources/data/php/final/jsonl/test -../resources/data/ruby/final/jsonl/test -../resources/data/go/final/jsonl/test \ No newline at end of file diff --git a/src/data_dirs_train.txt b/src/data_dirs_train.txt index 28827543..2709b415 100644 --- a/src/data_dirs_train.txt +++ b/src/data_dirs_train.txt @@ -1,6 +1 @@ -../resources/data/python/final/jsonl/train -../resources/data/javascript/final/jsonl/train ../resources/data/java/final/jsonl/train -../resources/data/php/final/jsonl/train -../resources/data/ruby/final/jsonl/train -../resources/data/go/final/jsonl/train \ No newline at end of file diff --git a/src/data_dirs_valid.txt b/src/data_dirs_valid.txt index 949e70a3..721ed021 100644 --- a/src/data_dirs_valid.txt +++ b/src/data_dirs_valid.txt @@ -1,6 +1 @@ -../resources/data/python/final/jsonl/valid -../resources/data/javascript/final/jsonl/valid ../resources/data/java/final/jsonl/valid -../resources/data/php/final/jsonl/valid -../resources/data/ruby/final/jsonl/valid -../resources/data/go/final/jsonl/valid \ No newline at end of file diff --git a/src/predict.py b/src/predict.py index 08690f46..60e7cfbe 100755 --- a/src/predict.py +++ b/src/predict.py @@ -113,7 +113,7 @@ def query_model(query, model, indices, language, topk=100): hyper_overrides={}) predictions = [] - for language in ('python', 'go', 'javascript', 'java', 'php', 'ruby'): + for language in ['java']: print("Evaluating language: %s" % language) definitions = pickle.load(open('../resources/data/{}_dedupe_definitions_v2.pkl'.format(language), 'rb')) indexes = [{'code_tokens': d['function_tokens'], 'language': d['language']} for d in tqdm(definitions)]