diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index ffbd1b3..1fedcf2 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -19,6 +19,7 @@ warnings.simplefilter('default') OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' +OPENAI_API_BASE = os.getenv('OPENAI_API_BASE', OPENAI_API_BASE) class GPTAPI(BaseAPILLM): @@ -903,6 +904,12 @@ def generate_request_data(self, **gen_params } } + elif 'llama' in model_type.lower(): + data = { + 'model': model_type, + 'messages': messages, + **gen_params + } else: raise NotImplementedError( f'Model type {model_type} is not supported')