diff --git a/pred.py b/pred.py index 1c20cba9..efbf9bcb 100644 --- a/pred.py +++ b/pred.py @@ -95,14 +95,14 @@ def get_pred(data, args, fout): output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=1024) else: output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=128) - if output == '': + if output == '' or output is None: continue if args.cot: # extract answer response = output.strip() item['response_cot'] = response prompt = template_0shot_cot_ans.replace('$DOC$', context.strip()).replace('$Q$', item['question'].strip()).replace('$C_A$', item['choice_A'].strip()).replace('$C_B$', item['choice_B'].strip()).replace('$C_C$', item['choice_C'].strip()).replace('$C_D$', item['choice_D'].strip()).replace('$COT$', response) output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=128) - if output == '': + if output == '' or output is None: continue response = output.strip() item['response'] = response @@ -156,4 +156,4 @@ def main(): parser.add_argument("--rag", "-rag", type=int, default=0) # set to 0 if RAG is not used, otherwise set to N when using top-N retrieved context parser.add_argument("--n_proc", "-n", type=int, default=16) args = parser.parse_args() - main() \ No newline at end of file + main()