diff --git a/pred.py b/pred.py index 1c20cba9..5f0f5d48 100644 --- a/pred.py +++ b/pred.py @@ -53,7 +53,7 @@ def query_llm(prompt, model, tokenizer, client=None, temperature=0.5, max_new_to time.sleep(1) else: print("Max tries. Failed.") - return '' + return None def extract_answer(response): response = response.replace('*', '') @@ -95,14 +95,14 @@ def get_pred(data, args, fout): output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=1024) else: output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=128) - if output == '': + if output == None: continue if args.cot: # extract answer response = output.strip() item['response_cot'] = response prompt = template_0shot_cot_ans.replace('$DOC$', context.strip()).replace('$Q$', item['question'].strip()).replace('$C_A$', item['choice_A'].strip()).replace('$C_B$', item['choice_B'].strip()).replace('$C_C$', item['choice_C'].strip()).replace('$C_D$', item['choice_D'].strip()).replace('$COT$', response) output = query_llm(prompt, model, tokenizer, client, temperature=0.1, max_new_tokens=128) - if output == '': + if output == None: continue response = output.strip() item['response'] = response