Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,7 @@ src/outputs/

# scrapped data
scrapped_docs/*
llms/*
*_env/

# vscode
.vscode*
80 changes: 68 additions & 12 deletions client_test/call_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,80 @@
import pathlib
import requests

USER_NAME = 'erfan_miahi'
PASS = 'temp'

def login(port: int = 8000):
# data = {'username': USER_NAME, 'email': 'mhi.erfan1@gmail.com', 'password': PASS}
# response = requests.post(f'http://localhost:{port}/register', json=data)
data = {'username': USER_NAME, 'password': PASS}
form_data = {
'username': USER_NAME,
'password': PASS
}

# Headers to be sent in the POST request
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
}

response = requests.post(f'http://localhost:{port}/token', data=form_data, headers=headers)
if response.status_code == 200:
print('Request successful!')
# Accessing the response content
print('Response content:', response.json())
return response.json()['access_token']
else:
print('Request failed with status code:', response.status_code)
print('Response content:', response.text)
return None


def send_request(input_text, proceeding_text, port=8000):
response = requests.put(f"http://localhost:{port}/generate", json={
"file_path": str(pathlib.Path(__file__).parent.absolute()), # This file's path
token = login(port)
if not token:
print("Failed to retrieve token. Exiting.")
return

# Headers to be sent in the POST request
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {token}'
}
print('Headers: ', headers)

# Ensure the endpoint and method are correct (change PUT to POST if necessary)
response = requests.post(f"http://localhost:{port}/generate", json={
"file_path": str(pathlib.Path(__file__).parent.absolute()), # This file's path
"prior_context": input_text,
"proceeding_context": proceeding_text,
"max_decode_length": 128,
}, timeout=180)
output_data = response.json()
if "error" in output_data:
print(f"Error: {output_data['error']}")
return
output_text = output_data["generated_text"]
score = output_data["score"]
}, timeout=180, headers=headers)

# Check for successful request
if response.status_code == 200:
try:
# Check if the response is JSON
if 'application/json' in response.headers.get('Content-Type', ''):
output_data = response.json()
if "error" in output_data:
print(f"Error: {output_data['error']}")
return
output_text = output_data["generated_text"]
score = output_data["score"]

print("Input text: " + input_text)
print(f"Generated text ({score:.3f}):")
print(output_text)
print("Input text: " + input_text)
print(f"Generated text ({score:.3f}):")
print(output_text)
else:
print("Response is not JSON:")
print(response.text)
except ValueError as e:
print('Error decoding JSON:', e)
print('Response content:', response.text)
else:
print(f"Request failed with status code: {response.status_code}")
print(f"Response content: {response.text}")


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion eval/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

sys.path.append('../')
from src import config_handler
from src.modeling import ModelProvider, FIM_HOLE_TOKEN
from src.modeling.model_provider import ModelProvider
from src.modeling.tokens import FIM_HOLE_TOKEN
from src.routers.fine_tuner import finetune_model, ProjectFinetuneData
from src.training import finetune
from benchmarks import run_human_eval_benchmark
Expand Down
3 changes: 2 additions & 1 deletion eval/run_rag_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from src import config_handler
# from src import finetune, modeling
# from src.data_formatting import IGNORE_INDEX, FIM_HOLE_TOKEN
from src.modeling import ModelProvider, FIM_HOLE_TOKEN
from src.modeling.model_provider import ModelProvider
from src.modeling.tokens import FIM_HOLE_TOKEN
from src.rag import retrieve_context, VectorStoreProvider
from src.routers.fine_tuner import collect_item_data, finetune_model, ProjectFinetuneData
from src.training import finetune
Expand Down
2 changes: 1 addition & 1 deletion eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

sys.path.append('../')
from src.modeling import ModelProvider
from src.modeling.model_provider import ModelProvider


def create_new_model_tuple(model_provider: ModelProvider):
Expand Down
26 changes: 13 additions & 13 deletions requirements.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove version numbers?

Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
transformers==4.38.2
vllm==0.5.1
transformers
fastapi==0.108.0
torch==2.2.1
torch
jax
pydantic[email]==2.5.3
uvicorn[standard]==0.20.0
datasets==2.16.1
accelerate==0.27.2
datasets
accelerate
attrdict
tqdm==4.66.1
bitsandbytes==0.41.2.post2
peft==0.6.2
tqdm
bitsandbytes
peft
pytest==7.2.1
hydra-core==1.3.2
omegaconf==2.3.0
mock==5.1.0
numpy==1.24.4
numpy
aiosqlite==0.20.0
openai==1.13.3
trl==0.7.10
openai
trl
packaging==23.2
ninja==1.11.1.1
scipy==1.12.0
scipy
python-jose==3.3.0
passlib==1.7.4
python-multipart==0.0.5
Expand All @@ -38,5 +39,4 @@ charset-normalizer==3.3.2
pymupdf==1.24.5
pymupdf4llm==0.0.5
nougat-ocr==0.1.17
llama-index==0.10.48.post1
nougat-ocr==0.1.17
llama-index
3 changes: 3 additions & 0 deletions src/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ model:
quant_type: nf4
optim: paged_adamw_32bit
gradient_checkpointing: True

# inference model parameters
inference_model_type: vllm


inference:
Expand Down
4 changes: 3 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

sys.path.append('../')
from src import config_handler, database
from src.modeling import ModelProvider, set_main_thread_id
from src.session import set_main_thread_id
from src.modeling.model_provider import ModelProvider

from src.rag import VectorStoreProvider
from src.routers import auth, fine_tuner, generator
from src.users import SessionTracker
Expand Down
Loading