From 3f9250ab6e2c1bd7bd49d222be918360c5d13458 Mon Sep 17 00:00:00 2001 From: Brianna Major Date: Tue, 23 Sep 2025 14:49:37 -0400 Subject: [PATCH 1/3] Separate CLI and client concerns - Extract VTKPromptClient class from prompt.py to client.py - Extract CLI interface and main() function to cli.py - Convert prompt.py to backward compatibility module with re-exports Maintains full backward compatibility for existing imports. --- pyproject.toml | 2 +- src/vtk_prompt/cli.py | 148 +++++++++++ src/vtk_prompt/client.py | 294 ++++++++++++++++++++++ src/vtk_prompt/prompt.py | 433 +------------------------------- src/vtk_prompt/vtk_prompt_ui.py | 2 +- 5 files changed, 456 insertions(+), 423 deletions(-) create mode 100644 src/vtk_prompt/cli.py create mode 100644 src/vtk_prompt/client.py diff --git a/pyproject.toml b/pyproject.toml index 2d6c94c..21da840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ Repository = "https://github.com/vicentebolea/vtk-prompt" Issues = "https://github.com/vicentebolea/vtk-prompt/issues" [project.scripts] -vtk-prompt = "vtk_prompt.prompt:main" +vtk-prompt = "vtk_prompt.cli:main" gen-vtk-file = "vtk_prompt.generate_files:main" vtk-build-rag = "vtk_prompt.build_rag_db:main" vtk-test-rag = "vtk_prompt.test_rag:main" diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py new file mode 100644 index 0000000..628c8ae --- /dev/null +++ b/src/vtk_prompt/cli.py @@ -0,0 +1,148 @@ +""" +VTK Prompt Command Line Interface. + +This module provides the CLI interface for VTK code generation using LLMs. +It handles argument parsing, validation, and orchestrates the VTKPromptClient. + +Example: + >>> vtk-prompt "create sphere" --rag --model gpt-4o +""" + +import sys +from typing import Optional + +import click + +from . import get_logger +from .client import VTKPromptClient + +logger = get_logger(__name__) + + +@click.command() +@click.argument("input_string") +@click.option( + "--provider", + type=click.Choice(["openai", "anthropic", "gemini", "nim"]), + default="openai", + help="LLM provider to use", +) +@click.option("-m", "--model", default="gpt-4o", help="Model name to use") +@click.option("-k", "--max-tokens", type=int, default=1000, help="Max # of tokens to generate") +@click.option( + "--temperature", + type=float, + default=0.7, + help="Temperature for generation (0.0-2.0)", +) +@click.option("-t", "--token", required=True, help="API token for the selected provider") +@click.option("--base-url", help="Base URL for API (auto-detected or custom)") +@click.option("-r", "--rag", is_flag=True, help="Use RAG to improve code generation") +@click.option("-v", "--verbose", is_flag=True, help="Show generated source code") +@click.option("--collection", default="vtk-examples", help="Collection name for RAG") +@click.option( + "--database", + default="./db/codesage-codesage-large-v2", + help="Database path for RAG", +) +@click.option("--top-k", type=int, default=5, help="Number of examples to retrieve from RAG") +@click.option( + "--retry-attempts", + type=int, + default=1, + help="Number of times to retry if AST validation fails", +) +@click.option( + "--conversation", + help="Path to conversation file for maintaining chat history", +) +def main( + input_string: str, + provider: str, + model: str, + max_tokens: int, + temperature: float, + token: str, + base_url: Optional[str], + rag: bool, + verbose: bool, + collection: str, + database: str, + top_k: int, + retry_attempts: int, + conversation: Optional[str], +) -> None: + """ + Generate and execute VTK code using LLMs. + + INPUT_STRING: The code description to generate VTK code for + """ + # Set default base URLs + if base_url is None: + base_urls = { + "anthropic": "https://api.anthropic.com/v1", + "gemini": "https://generativelanguage.googleapis.com/v1beta/openai/", + "nim": "https://integrate.api.nvidia.com/v1", + } + base_url = base_urls.get(provider) + + # Set default models based on provider + if model == "gpt-4o": + default_models = { + "anthropic": "claude-3-5-sonnet-20241022", + "gemini": "gemini-1.5-pro", + "nim": "meta/llama3-70b-instruct", + } + model = default_models.get(provider, model) + + try: + client = VTKPromptClient( + collection_name=collection, + database_path=database, + verbose=verbose, + conversation_file=conversation, + ) + result = client.query( + input_string, + api_key=token, + model=model, + base_url=base_url, + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + rag=rag, + retry_attempts=retry_attempts, + ) + + if isinstance(result, tuple) and len(result) == 3: + _explanation, generated_code, usage = result + if verbose and usage: + logger.info( + "Used tokens: input=%d output=%d", + usage.prompt_tokens, + usage.completion_tokens, + ) + client.run_code(generated_code) + else: + # Handle string result + logger.info("Result: %s", result) + + except ValueError as e: + if "RAG components" in str(e): + logger.error("RAG components not found") + sys.exit(1) + elif "Failed to load RAG snippets" in str(e): + logger.error("Failed to load RAG snippets") + sys.exit(2) + elif "max_tokens" in str(e): + logger.error("Error: %s", e) + logger.error("Current max_tokens: %d", max_tokens) + logger.error("Try increasing with: --max-tokens ") + sys.exit(3) + else: + logger.error("Error: %s", e) + sys.exit(4) + + +if __name__ == "__main__": + main() diff --git a/src/vtk_prompt/client.py b/src/vtk_prompt/client.py new file mode 100644 index 0000000..686d21a --- /dev/null +++ b/src/vtk_prompt/client.py @@ -0,0 +1,294 @@ +""" +VTK Code Generation Client. + +This module provides the core VTKPromptClient class which handles conversation management, +code generation, execution, and error handling with retry logic. + +Features: +- Singleton pattern for conversation persistence +- RAG (Retrieval-Augmented Generation) integration for context-aware code generation +- Automatic code execution and error handling +- Conversation history management and file persistence +- Multiple model provider support (OpenAI, Anthropic, Gemini, NIM) +- Template-based prompt construction with VTK-specific context +""" + +import ast +import json +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Union + +import openai + +from . import get_logger +from .prompts import ( + get_no_rag_context, + get_python_role, + get_rag_context, +) + +logger = get_logger(__name__) + + +@dataclass +class VTKPromptClient: + """OpenAI client for VTK code generation.""" + + _instance: Optional["VTKPromptClient"] = None + _initialized: bool = False + collection_name: str = "vtk-examples" + database_path: str = "./db/codesage-codesage-large-v2" + verbose: bool = False + conversation_file: Optional[str] = None + conversation: Optional[list[dict[str, str]]] = None + + def __new__(cls, **kwargs: Any) -> "VTKPromptClient": + """Create singleton instance of VTKPromptClient.""" + # Make sure that this is a singleton + if cls._instance is None: + cls._instance = super(VTKPromptClient, cls).__new__(cls) + cls._instance._initialized = False + cls._instance.conversation = [] + return cls._instance + + def __post_init__(self) -> None: + """Post-init hook to prevent double initialization in singleton.""" + if hasattr(self, "_initialized") and self._initialized: + return + self._initialized = True + + def load_conversation(self) -> list[dict[str, str]]: + """Load conversation history from file.""" + if not self.conversation_file or not Path(self.conversation_file).exists(): + return [] + + try: + with open(self.conversation_file, "r") as f: + data = json.load(f) + if isinstance(data, list): + return data + else: + logger.warning("Invalid conversation file format, no history loaded.") + return [] + except Exception as e: + logger.error("Could not load conversation file: %s", e) + return [] + + def save_conversation(self) -> None: + """Save conversation history to file.""" + if not self.conversation_file or not self.conversation: + return + + try: + # Ensure directory exists + Path(self.conversation_file).parent.mkdir(parents=True, exist_ok=True) + + with open(self.conversation_file, "w") as f: + json.dump(self.conversation, f, indent=2) + except Exception as e: + logger.error("Could not save conversation file: %s", e) + + def update_conversation( + self, new_convo: list[dict[str, str]], new_convo_file: Optional[str] = None + ) -> None: + """Update conversation history with new conversation.""" + if not self.conversation: + self.conversation = [] + self.conversation.extend(new_convo) + + if new_convo_file: + self.conversation_file = new_convo_file + + def validate_code_syntax(self, code_string: str) -> tuple[bool, Optional[str]]: + """Validate Python code syntax using AST.""" + try: + ast.parse(code_string) + return True, None + except SyntaxError as e: + return False, f"Syntax error: {e.msg} at line {e.lineno}" + except Exception as e: + return False, f"AST parsing error: {str(e)}" + + def run_code(self, code_string: str) -> None: + """Execute VTK code using exec() after AST validation.""" + is_valid, error_msg = self.validate_code_syntax(code_string) + if not is_valid: + logger.error("Code validation failed: %s", error_msg) + if self.verbose: + logger.debug("Generated code:\n%s", code_string) + return + + if self.verbose: + logger.debug("Executing code:\n%s", code_string) + + try: + exec(code_string, globals(), {}) + except Exception as e: + logger.error("Error executing code: %s", e) + if not self.verbose: + logger.debug("Failed code:\n%s", code_string) + return + + def query( + self, + message: str = "", + api_key: Optional[str] = None, + model: str = "gpt-4o", + base_url: Optional[str] = None, + max_tokens: int = 1000, + temperature: float = 0.1, + top_k: int = 5, + rag: bool = False, + retry_attempts: int = 1, + ) -> Union[tuple[str, str, Any], str]: + """Generate VTK code with optional RAG enhancement and retry logic. + + Args: + message: The user query + api_key: API key for the service + model: Model name to use + base_url: API base URL + max_tokens: Maximum tokens to generate + temperature: Temperature for generation + top_k: Number of RAG examples to retrieve + rag: Whether to use RAG enhancement + retry_attempts: Number of times to retry if AST validation fails + """ + if not api_key: + api_key = os.environ.get("OPENAI_API_KEY") + + if not api_key: + raise ValueError("No API key provided. Set OPENAI_API_KEY or pass api_key parameter.") + + # Create client with current parameters + client = openai.OpenAI(api_key=api_key, base_url=base_url) + + # Load existing conversation if present + if self.conversation_file and not self.conversation: + self.conversation = self.load_conversation() + + if not message and not self.conversation: + raise ValueError("No prompt or conversation file provided") + + if rag: + from .rag_chat_wrapper import ( + check_rag_components_available, + get_rag_snippets, + ) + + if not check_rag_components_available(): + raise ValueError("RAG components not available") + + rag_snippets = get_rag_snippets( + message, + collection_name=self.collection_name, + database_path=self.database_path, + top_k=top_k, + ) + + if not rag_snippets: + raise ValueError("Failed to load RAG snippets") + + context_snippets = "\n\n".join(rag_snippets["code_snippets"]) + context = get_rag_context(message, context_snippets) + + if self.verbose: + logger.debug("RAG context: %s", context) + references = rag_snippets.get("references") + if references: + logger.info("Using examples from: %s", ", ".join(references)) + else: + context = get_no_rag_context(message) + if self.verbose: + logger.debug("No-RAG context: %s", context) + + # Initialize conversation with system message if empty + if not self.conversation: + self.conversation = [] + self.conversation.append({"role": "system", "content": get_python_role()}) + + # Add current user message + if message: + self.conversation.append({"role": "user", "content": context}) + + # Retry loop for AST validation + for attempt in range(retry_attempts): + if self.verbose: + if attempt > 0: + logger.debug("Retry attempt %d/%d", attempt + 1, retry_attempts) + logger.debug("Making request with model: %s, temperature: %s", model, temperature) + for i, msg in enumerate(self.conversation): + logger.debug("Message %d (%s): %s...", i, msg["role"], msg["content"][:100]) + + response = client.chat.completions.create( + model=model, + messages=self.conversation, # type: ignore[arg-type] + max_tokens=max_tokens, + temperature=temperature, + ) + + if hasattr(response, "choices") and len(response.choices) > 0: + content = response.choices[0].message.content or "No content in response" + finish_reason = response.choices[0].finish_reason + + if finish_reason == "length": + raise ValueError( + f"Output was truncated due to max_tokens limit ({max_tokens}).\n" + "Please increase max_tokens." + ) + + generated_explanation = re.findall( + "(.*?)", content, re.DOTALL + )[0] + generated_code = re.findall("(.*?)", content, re.DOTALL)[0] + if "import vtk" not in generated_code: + generated_code = "import vtk\n" + generated_code + else: + pos = generated_code.find("import vtk") + if pos != -1: + generated_code = generated_code[pos:] + else: + generated_code = generated_code + + is_valid, error_msg = self.validate_code_syntax(generated_code) + if is_valid: + if message: + self.conversation.append({"role": "assistant", "content": content}) + self.save_conversation() + return generated_explanation, generated_code, response.usage + + elif attempt < retry_attempts - 1: # Don't log on last attempt + if self.verbose: + logger.warning("AST validation failed: %s. Retrying...", error_msg) + # Add error feedback to context for retry + self.conversation.append({"role": "assistant", "content": content}) + self.conversation.append( + { + "role": "user", + "content": ( + f"The generated code has a syntax error: {error_msg}. " + "Please fix the syntax and generate valid Python code." + ), + } + ) + else: + # Last attempt failed + if self.verbose: + logger.error("Final attempt failed AST validation: %s", error_msg) + + if message: + self.conversation.append({"role": "assistant", "content": content}) + self.save_conversation() + return ( + generated_explanation, + generated_code, + response.usage or {}, + ) # Return anyway, let caller handle + else: + if attempt == retry_attempts - 1: + return ("No response generated", "", response.usage or {}) + + return "No response generated" diff --git a/src/vtk_prompt/prompt.py b/src/vtk_prompt/prompt.py index 4281c9d..d75c331 100755 --- a/src/vtk_prompt/prompt.py +++ b/src/vtk_prompt/prompt.py @@ -1,429 +1,20 @@ """ -VTK Code Generation with OpenAI Integration. +VTK Code Generation with LLM Integration (Backward Compatibility Module). -This module provides the core functionality for VTK code generation using OpenAI's language models. -It includes the main VTKPromptClient class which handles conversation management, code generation, -execution, and error handling with retry logic. +This module maintains backward compatibility for the VTKPromptClient and main CLI function. +The actual implementations have been moved to: +- client.py: VTKPromptClient class +- cli.py: CLI interface -Features: -- Singleton pattern for conversation persistence -- RAG (Retrieval-Augmented Generation) integration for context-aware code generation -- Automatic code execution and error handling -- Conversation history management and file persistence -- Multiple model provider support (OpenAI, Anthropic, Gemini, NIM) -- Template-based prompt construction with VTK-specific context +This module re-exports them to maintain existing import patterns. -Example: - >>> vtk-prompt "create sphere" --rag --model gpt-4o +Deprecated: Direct imports from this module are deprecated. +Please import from .client or .cli instead for new code. """ -import ast -import json -import os -import re -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Optional, Union +from .cli import main -import click -import openai +# Re-export for backward compatibility +from .client import VTKPromptClient -from . import get_logger -from .prompts import ( - get_no_rag_context, - get_python_role, - get_rag_context, -) - -logger = get_logger(__name__) - - -@dataclass -class VTKPromptClient: - """OpenAI client for VTK code generation.""" - - _instance: Optional["VTKPromptClient"] = None - _initialized: bool = False - collection_name: str = "vtk-examples" - database_path: str = "./db/codesage-codesage-large-v2" - verbose: bool = False - conversation_file: Optional[str] = None - conversation: Optional[list[dict[str, str]]] = None - - def __new__(cls, **kwargs: Any) -> "VTKPromptClient": - """Create singleton instance of VTKPromptClient.""" - # Make sure that this is a singleton - if cls._instance is None: - cls._instance = super(VTKPromptClient, cls).__new__(cls) - cls._instance._initialized = False - cls._instance.conversation = [] - return cls._instance - - def __post_init__(self) -> None: - """Post-init hook to prevent double initialization in singleton.""" - if hasattr(self, "_initialized") and self._initialized: - return - self._initialized = True - - def load_conversation(self) -> list[dict[str, str]]: - """Load conversation history from file.""" - if not self.conversation_file or not Path(self.conversation_file).exists(): - return [] - - try: - with open(self.conversation_file, "r") as f: - data = json.load(f) - if isinstance(data, list): - return data - else: - logger.warning("Invalid conversation file format, no history loaded.") - return [] - except Exception as e: - logger.error("Could not load conversation file: %s", e) - return [] - - def save_conversation(self) -> None: - """Save conversation history to file.""" - if not self.conversation_file or not self.conversation: - return - - try: - # Ensure directory exists - Path(self.conversation_file).parent.mkdir(parents=True, exist_ok=True) - - with open(self.conversation_file, "w") as f: - json.dump(self.conversation, f, indent=2) - except Exception as e: - logger.error("Could not save conversation file: %s", e) - - def update_conversation( - self, new_convo: list[dict[str, str]], new_convo_file: Optional[str] = None - ) -> None: - """Update conversation history with new conversation.""" - if not self.conversation: - self.conversation = [] - self.conversation.extend(new_convo) - - if new_convo_file: - self.conversation_file = new_convo_file - - def validate_code_syntax(self, code_string: str) -> tuple[bool, Optional[str]]: - """Validate Python code syntax using AST.""" - try: - ast.parse(code_string) - return True, None - except SyntaxError as e: - return False, f"Syntax error: {e.msg} at line {e.lineno}" - except Exception as e: - return False, f"AST parsing error: {str(e)}" - - def run_code(self, code_string: str) -> None: - """Execute VTK code using exec() after AST validation.""" - is_valid, error_msg = self.validate_code_syntax(code_string) - if not is_valid: - logger.error("Code validation failed: %s", error_msg) - if self.verbose: - logger.debug("Generated code:\n%s", code_string) - return - - if self.verbose: - logger.debug("Executing code:\n%s", code_string) - - try: - exec(code_string, globals(), {}) - except Exception as e: - logger.error("Error executing code: %s", e) - if not self.verbose: - logger.debug("Failed code:\n%s", code_string) - return - - def query( - self, - message: str = "", - api_key: Optional[str] = None, - model: str = "gpt-4o", - base_url: Optional[str] = None, - max_tokens: int = 1000, - temperature: float = 0.1, - top_k: int = 5, - rag: bool = False, - retry_attempts: int = 1, - ) -> Union[tuple[str, str, Any], str]: - """Generate VTK code with optional RAG enhancement and retry logic. - - Args: - message: The user query - api_key: API key for the service - model: Model name to use - base_url: API base URL - max_tokens: Maximum tokens to generate - temperature: Temperature for generation - top_k: Number of RAG examples to retrieve - rag: Whether to use RAG enhancement - retry_attempts: Number of times to retry if AST validation fails - """ - if not api_key: - api_key = os.environ.get("OPENAI_API_KEY") - - if not api_key: - raise ValueError("No API key provided. Set OPENAI_API_KEY or pass api_key parameter.") - - # Create client with current parameters - client = openai.OpenAI(api_key=api_key, base_url=base_url) - - # Load existing conversation if present - if self.conversation_file and not self.conversation: - self.conversation = self.load_conversation() - - if not message and not self.conversation: - raise ValueError("No prompt or conversation file provided") - - if rag: - from .rag_chat_wrapper import ( - check_rag_components_available, - get_rag_snippets, - ) - - if not check_rag_components_available(): - raise ValueError("RAG components not available") - - rag_snippets = get_rag_snippets( - message, - collection_name=self.collection_name, - database_path=self.database_path, - top_k=top_k, - ) - - if not rag_snippets: - raise ValueError("Failed to load RAG snippets") - - context_snippets = "\n\n".join(rag_snippets["code_snippets"]) - context = get_rag_context(message, context_snippets) - - if self.verbose: - logger.debug("RAG context: %s", context) - references = rag_snippets.get("references") - if references: - logger.info("Using examples from: %s", ", ".join(references)) - else: - context = get_no_rag_context(message) - if self.verbose: - logger.debug("No-RAG context: %s", context) - - # Initialize conversation with system message if empty - if not self.conversation: - self.conversation = [] - self.conversation.append({"role": "system", "content": get_python_role()}) - - # Add current user message - if message: - self.conversation.append({"role": "user", "content": context}) - - # Retry loop for AST validation - for attempt in range(retry_attempts): - if self.verbose: - if attempt > 0: - logger.debug("Retry attempt %d/%d", attempt + 1, retry_attempts) - logger.debug("Making request with model: %s, temperature: %s", model, temperature) - for i, msg in enumerate(self.conversation): - logger.debug("Message %d (%s): %s...", i, msg["role"], msg["content"][:100]) - - response = client.chat.completions.create( - model=model, - messages=self.conversation, # type: ignore[arg-type] - max_tokens=max_tokens, - temperature=temperature, - ) - - if hasattr(response, "choices") and len(response.choices) > 0: - content = response.choices[0].message.content or "No content in response" - finish_reason = response.choices[0].finish_reason - - if finish_reason == "length": - raise ValueError( - f"Output was truncated due to max_tokens limit ({max_tokens}).\n" - "Please increase max_tokens." - ) - - generated_explanation = re.findall( - "(.*?)", content, re.DOTALL - )[0] - generated_code = re.findall("(.*?)", content, re.DOTALL)[0] - if "import vtk" not in generated_code: - generated_code = "import vtk\n" + generated_code - else: - pos = generated_code.find("import vtk") - if pos != -1: - generated_code = generated_code[pos:] - else: - generated_code = generated_code - - is_valid, error_msg = self.validate_code_syntax(generated_code) - if is_valid: - if message: - self.conversation.append({"role": "assistant", "content": content}) - self.save_conversation() - return generated_explanation, generated_code, response.usage - - elif attempt < retry_attempts - 1: # Don't log on last attempt - if self.verbose: - logger.warning("AST validation failed: %s. Retrying...", error_msg) - # Add error feedback to context for retry - self.conversation.append({"role": "assistant", "content": content}) - self.conversation.append( - { - "role": "user", - "content": ( - f"The generated code has a syntax error: {error_msg}. " - "Please fix the syntax and generate valid Python code." - ), - } - ) - else: - # Last attempt failed - if self.verbose: - logger.error("Final attempt failed AST validation: %s", error_msg) - - if message: - self.conversation.append({"role": "assistant", "content": content}) - self.save_conversation() - return ( - generated_explanation, - generated_code, - response.usage or {}, - ) # Return anyway, let caller handle - else: - if attempt == retry_attempts - 1: - return ("No response generated", "", response.usage or {}) - - return "No response generated" - - -@click.command() -@click.argument("input_string") -@click.option( - "--provider", - type=click.Choice(["openai", "anthropic", "gemini", "nim"]), - default="openai", - help="LLM provider to use", -) -@click.option("-m", "--model", default="gpt-4o", help="Model name to use") -@click.option("-k", "--max-tokens", type=int, default=1000, help="Max # of tokens to generate") -@click.option( - "--temperature", - type=float, - default=0.7, - help="Temperature for generation (0.0-2.0)", -) -@click.option("-t", "--token", required=True, help="API token for the selected provider") -@click.option("--base-url", help="Base URL for API (auto-detected or custom)") -@click.option("-r", "--rag", is_flag=True, help="Use RAG to improve code generation") -@click.option("-v", "--verbose", is_flag=True, help="Show generated source code") -@click.option("--collection", default="vtk-examples", help="Collection name for RAG") -@click.option( - "--database", - default="./db/codesage-codesage-large-v2", - help="Database path for RAG", -) -@click.option("--top-k", type=int, default=5, help="Number of examples to retrieve from RAG") -@click.option( - "--retry-attempts", - type=int, - default=1, - help="Number of times to retry if AST validation fails", -) -@click.option( - "--conversation", - help="Path to conversation file for maintaining chat history", -) -def main( - input_string: str, - provider: str, - model: str, - max_tokens: int, - temperature: float, - token: str, - base_url: Optional[str], - rag: bool, - verbose: bool, - collection: str, - database: str, - top_k: int, - retry_attempts: int, - conversation: Optional[str], -) -> None: - """ - Generate and execute VTK code using LLMs. - - INPUT_STRING: The code description to generate VTK code for - """ - # Set default base URLs - if base_url is None: - base_urls = { - "anthropic": "https://api.anthropic.com/v1", - "gemini": "https://generativelanguage.googleapis.com/v1beta/openai/", - "nim": "https://integrate.api.nvidia.com/v1", - } - base_url = base_urls.get(provider) - - # Set default models based on provider - if model == "gpt-4o": - default_models = { - "anthropic": "claude-3-5-sonnet-20241022", - "gemini": "gemini-1.5-pro", - "nim": "meta/llama3-70b-instruct", - } - model = default_models.get(provider, model) - - try: - client = VTKPromptClient( - collection_name=collection, - database_path=database, - verbose=verbose, - conversation_file=conversation, - ) - result = client.query( - input_string, - api_key=token, - model=model, - base_url=base_url, - max_tokens=max_tokens, - temperature=temperature, - top_k=top_k, - rag=rag, - retry_attempts=retry_attempts, - ) - - if isinstance(result, tuple) and len(result) == 3: - _explanation, generated_code, usage = result - if verbose and usage: - logger.info( - "Used tokens: input=%d output=%d", - usage.prompt_tokens, - usage.completion_tokens, - ) - client.run_code(generated_code) - else: - # Handle string result - logger.info("Result: %s", result) - - except ValueError as e: - if "RAG components" in str(e): - logger.error("RAG components not found") - sys.exit(1) - elif "Failed to load RAG snippets" in str(e): - logger.error("Failed to load RAG snippets") - sys.exit(2) - elif "max_tokens" in str(e): - logger.error("Error: %s", e) - logger.error("Current max_tokens: %d", max_tokens) - logger.error("Try increasing with: --max-tokens ") - sys.exit(3) - else: - logger.error("Error: %s", e) - sys.exit(4) - - -if __name__ == "__main__": - main() +__all__ = ["VTKPromptClient", "main"] diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py index b335bfd..7b55e3e 100644 --- a/src/vtk_prompt/vtk_prompt_ui.py +++ b/src/vtk_prompt/vtk_prompt_ui.py @@ -36,7 +36,7 @@ from . import get_logger # Import our prompt functionality -from .prompt import VTKPromptClient +from .client import VTKPromptClient # Import our template system from .prompts import get_ui_post_prompt From e373a392fbbc55a359cde04729a9cbeee7bc1e4d Mon Sep 17 00:00:00 2001 From: Brianna Major Date: Wed, 24 Sep 2025 09:56:23 -0400 Subject: [PATCH 2/3] Update model defaults - Update defaults: gpt-4o -> gpt-5, claude-haiku -> claude-opus-4 - Add provider_utils.py for centralized model management - Support temperature control detection per model --- README.md | 16 +++--- pyproject.toml | 1 - src/vtk_prompt/cli.py | 22 ++++++--- src/vtk_prompt/client.py | 5 +- src/vtk_prompt/generate_files.py | 11 +++-- src/vtk_prompt/provider_utils.py | 78 ++++++++++++++++++++++++++++++ src/vtk_prompt/rag_chat_wrapper.py | 6 +-- src/vtk_prompt/vtk_prompt_ui.py | 63 ++++++++++++------------ 8 files changed, 146 insertions(+), 56 deletions(-) create mode 100644 src/vtk_prompt/provider_utils.py diff --git a/README.md b/README.md index 2c50547..9efeff0 100644 --- a/README.md +++ b/README.md @@ -98,13 +98,13 @@ vtk-prompt "Create a red sphere" # Advanced options vtk-prompt "Create a textured cone with 32 resolution" \ --provider anthropic \ - --model claude-3-5-sonnet-20241022 \ + --model claude-opus-4-1-20250805 \ --max-tokens 4000 \ --rag \ --verbose # Using different providers -vtk-prompt "Create a blue cube" --provider openai --model gpt-4o +vtk-prompt "Create a blue cube" --provider openai --model gpt-5 vtk-prompt "Create a cylinder" --provider nim --model meta/llama3-70b-instruct ``` @@ -149,12 +149,12 @@ print(code) ### Supported Providers & Models -| Provider | Default Model | Base URL | -| ------------- | ------------------------- | ----------------------------------- | -| **anthropic** | claude-3-5-haiku-20241022 | https://api.anthropic.com/v1 | -| **openai** | gpt-4o | https://api.openai.com/v1 | -| **nim** | meta/llama3-70b-instruct | https://integrate.api.nvidia.com/v1 | -| **custom** | User-defined | User-defined (for local models) | +| Provider | Default Model | Base URL | +| ------------- | ------------------------ | ----------------------------------- | +| **anthropic** | claude-opus-4-1-20250805 | https://api.anthropic.com/v1 | +| **openai** | gpt-5 | https://api.openai.com/v1 | +| **nim** | meta/llama3-70b-instruct | https://integrate.api.nvidia.com/v1 | +| **custom** | User-defined | User-defined (for local models) | ### Custom/Local Models diff --git a/pyproject.toml b/pyproject.toml index 21da840..e8f5620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ authors = [ {name = "Vicente Adolfo Bolea Sanchez", email = "vicente.bolea@kitware.com"}, ] dependencies = [ - "anthropic>=0.22.0", "chromadb>=0.6.3", "click>=8.0.0", "importlib_resources>=5.0.0", diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py index 628c8ae..a85cc2c 100644 --- a/src/vtk_prompt/cli.py +++ b/src/vtk_prompt/cli.py @@ -5,7 +5,7 @@ It handles argument parsing, validation, and orchestrates the VTKPromptClient. Example: - >>> vtk-prompt "create sphere" --rag --model gpt-4o + >>> vtk-prompt "create sphere" --rag --model gpt-5 """ import sys @@ -15,6 +15,7 @@ from . import get_logger from .client import VTKPromptClient +from .provider_utils import supports_temperature logger = get_logger(__name__) @@ -27,7 +28,7 @@ default="openai", help="LLM provider to use", ) -@click.option("-m", "--model", default="gpt-4o", help="Model name to use") +@click.option("-m", "--model", default="gpt-5", help="Model name to use") @click.option("-k", "--max-tokens", type=int, default=1000, help="Max # of tokens to generate") @click.option( "--temperature", @@ -87,14 +88,23 @@ def main( base_url = base_urls.get(provider) # Set default models based on provider - if model == "gpt-4o": + if model == "gpt-5": default_models = { - "anthropic": "claude-3-5-sonnet-20241022", - "gemini": "gemini-1.5-pro", + "anthropic": "claude-opus-4-1-20250805", + "gemini": "gemini-2.5-pro", "nim": "meta/llama3-70b-instruct", } model = default_models.get(provider, model) + # Handle temperature override for unsupported models + if not supports_temperature(model) and temperature != 0.7: + logger.warning( + "Model %s does not support temperature control. " + "Temperature parameter will be ignored (using 1.0).", + model, + ) + temperature = 1.0 + try: client = VTKPromptClient( collection_name=collection, @@ -115,7 +125,7 @@ def main( ) if isinstance(result, tuple) and len(result) == 3: - _explanation, generated_code, usage = result + explanation, generated_code, usage = result if verbose and usage: logger.info( "Used tokens: input=%d output=%d", diff --git a/src/vtk_prompt/client.py b/src/vtk_prompt/client.py index 686d21a..65d3267 100644 --- a/src/vtk_prompt/client.py +++ b/src/vtk_prompt/client.py @@ -136,7 +136,7 @@ def query( self, message: str = "", api_key: Optional[str] = None, - model: str = "gpt-4o", + model: str = "gpt-5", base_url: Optional[str] = None, max_tokens: int = 1000, temperature: float = 0.1, @@ -226,7 +226,8 @@ def query( response = client.chat.completions.create( model=model, messages=self.conversation, # type: ignore[arg-type] - max_tokens=max_tokens, + max_completion_tokens=max_tokens, + # max_tokens=max_tokens, temperature=temperature, ) diff --git a/src/vtk_prompt/generate_files.py b/src/vtk_prompt/generate_files.py index e616d75..769101c 100644 --- a/src/vtk_prompt/generate_files.py +++ b/src/vtk_prompt/generate_files.py @@ -71,7 +71,8 @@ def generate_xml( {"role": "system", "content": get_xml_role()}, {"role": "user", "content": context}, ], - max_tokens=max_tokens, + max_completion_tokens=max_tokens, + # max_tokens=max_tokens, temperature=temperature, ) @@ -112,7 +113,7 @@ def openai_query( default="openai", help="LLM provider to use", ) -@click.option("-m", "--model", default="gpt-4o", help="Model to use for generation") +@click.option("-m", "--model", default="gpt-5", help="Model to use for generation") @click.option("-t", "--token", required=True, help="API token for the selected provider") @click.option("--base-url", help="Base URL for API (auto-detected or custom)") @click.option( @@ -154,10 +155,10 @@ def main( base_url = base_urls.get(provider) # Set default models based on provider - if model == "gpt-4o": + if model == "gpt-5": default_models = { - "anthropic": "claude-3-5-sonnet-20241022", - "gemini": "gemini-1.5-pro", + "anthropic": "claude-opus-4-1-20250805", + "gemini": "gemini-2.5-pro", "nim": "meta/llama3-70b-instruct", } model = default_models.get(provider, model) diff --git a/src/vtk_prompt/provider_utils.py b/src/vtk_prompt/provider_utils.py new file mode 100644 index 0000000..063de83 --- /dev/null +++ b/src/vtk_prompt/provider_utils.py @@ -0,0 +1,78 @@ +""" +Provider utilities for managing curated model lists and provider configurations. + +This module provides curated lists of models that work well for VTK code generation, +rather than dynamically fetching all available models from providers. +""" + +import logging +from typing import Dict, List + +logger = logging.getLogger(__name__) + +# Curated models for each provider - selected for VTK code generation quality +OPENAI_MODELS = ["gpt-5", "gpt-4.1", "o4-mini", "o3"] + +ANTHROPIC_MODELS = [ + "claude-opus-4-1-20250805", + "claude-sonnet-4-20250514", + "claude-3-7-sonnet-20250219", +] + +GEMINI_MODELS = ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"] + +NIM_MODELS = [ + "meta/llama3-70b-instruct", + "meta/llama3-8b-instruct", + "microsoft/phi-3-medium-4k-instruct", + "nvidia/llama-3.1-nemotron-70b-instruct", +] + + +# Models that don't support temperature control (must use temperature=1.0) +TEMPERATURE_UNSUPPORTED_MODELS = ["gpt-5", "o4-mini", "o3"] + + +def supports_temperature(model: str) -> bool: + """Check if a model supports temperature control.""" + return model not in TEMPERATURE_UNSUPPORTED_MODELS + + +def get_model_temperature(model: str, requested_temperature: float = 0.7) -> float: + """Get the appropriate temperature for a model.""" + if supports_temperature(model): + return requested_temperature + else: + return 1.0 + + +def get_available_models() -> Dict[str, List[str]]: + """Get curated models for all providers.""" + return { + "openai": OPENAI_MODELS, + "anthropic": ANTHROPIC_MODELS, + "gemini": GEMINI_MODELS, + "nim": NIM_MODELS, + } + + +def get_provider_models(provider: str) -> List[str]: + """Get curated models for a specific provider.""" + models = get_available_models() + return models.get(provider, []) + + +def get_supported_providers() -> List[str]: + """Get list of supported providers.""" + return ["openai", "anthropic", "gemini", "nim"] + + +def get_default_model(provider: str) -> str: + """Get the default/recommended model for a provider.""" + defaults = { + "openai": "gpt-5", + "anthropic": "claude-opus-4-1-20250805", + "gemini": "gemini-2.5-pro", + "nim": "meta/llama3-70b-instruct", + } + return defaults.get(provider, "gpt-5") diff --git a/src/vtk_prompt/rag_chat_wrapper.py b/src/vtk_prompt/rag_chat_wrapper.py index e0eab4e..fda5bca 100644 --- a/src/vtk_prompt/rag_chat_wrapper.py +++ b/src/vtk_prompt/rag_chat_wrapper.py @@ -14,7 +14,7 @@ - CLI interface for standalone RAG chat testing Example: - >>> vtk-rag-chat --query "sphere creation" --model gpt-4o + >>> vtk-rag-chat --query "sphere creation" --model gpt-5 """ import importlib.util @@ -93,7 +93,7 @@ class OpenAIRAGChat: """OpenAI-compatible wrapper for RAG chat functionality.""" def __init__( - self, model: str = "gpt-4o", database: str = "./db/codesage-codesage-large-v2" + self, model: str = "gpt-5", database: str = "./db/codesage-codesage-large-v2" ) -> None: """Initialize the OpenAI RAG chat system. @@ -207,7 +207,7 @@ def generate_urls_from_references(self, references: list[str]) -> list[str]: default=15, help="Retrieve the top k examples from the database", ) -@click.option("--model", default="gpt-4o", help="OpenAI model to use") +@click.option("--model", default="gpt-5", help="OpenAI model to use") def main(database: str, collection_name: str, top_k: int, model: str) -> None: """Query database for code snippets using OpenAI API only.""" # Initialize the chat system diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py index 7b55e3e..4df1964 100644 --- a/src/vtk_prompt/vtk_prompt_ui.py +++ b/src/vtk_prompt/vtk_prompt_ui.py @@ -28,18 +28,17 @@ from trame.widgets import html from trame.widgets import vuetify3 as vuetify from trame_vtk.widgets import vtk as vtk_widgets - -# Add VTK and Trame imports from vtkmodules.vtkInteractionStyle import vtkInteractorStyleSwitch # noqa -# Import logging from . import get_logger - -# Import our prompt functionality from .client import VTKPromptClient - -# Import our template system from .prompts import get_ui_post_prompt +from .provider_utils import ( + get_available_models, + get_default_model, + get_supported_providers, + supports_temperature, +) logger = get_logger(__name__) @@ -142,28 +141,12 @@ def _add_default_scene(self) -> None: # Cloud model configuration self.state.provider = "openai" - self.state.model = "gpt-4o" - self.state.available_providers = [ - "openai", - "anthropic", - "gemini", - "nim", - ] - self.state.available_models = { - "openai": ["gpt-4o", "gpt-4o-mini", "o1-preview", "o1-mini"], - "anthropic": [ - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", - ], - "gemini": ["gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro"], - "nim": [ - "meta/llama3-70b-instruct", - "meta/llama3-8b-instruct", - "microsoft/phi-3-medium-4k-instruct", - "nvidia/llama3-chatqa-1.5-70b", - ], - } + self.state.model = "gpt-5" + self.state.temperature_supported = True + + # Initialize with supported providers and fallback models + self.state.available_providers = get_supported_providers() + self.state.available_models = get_available_models() self.state.api_token = "" @@ -214,7 +197,7 @@ def _get_base_url(self) -> Optional[str]: def _get_model(self) -> str: """Get model name based on configuration mode.""" if self.state.use_cloud_models: - return getattr(self.state, "model", "gpt-4o") + return getattr(self.state, "model", "gpt-5") else: local_model = getattr(self.state, "local_model", "") return local_model.strip() if local_model and local_model.strip() else "llama3.2:latest" @@ -264,6 +247,14 @@ def on_tab_change(self, tab_index: int, **_: Any) -> None: """Handle tab change to sync use_cloud_models state.""" self.state.use_cloud_models = tab_index == 0 + @change("model", "local_model") + def _on_model_change(self, **_: Any) -> None: + """Handle model change to update temperature support.""" + current_model = self._get_model() + self.state.temperature_supported = supports_temperature(current_model) + if not self.state.temperature_supported: + self.state.temperature = 1 + @controller.set("generate_code") def generate_code(self) -> None: """Generate VTK code from user query.""" @@ -572,6 +563,15 @@ def save_conversation(self) -> str: return json.dumps(self.prompt_client.conversation, indent=2) return "" + @change("provider") + def _on_provider_change(self, provider, **kwargs) -> None: + """Handle provider selection change.""" + # Set default model for the provider if current model not available + if provider in self.state.available_models: + models = self.state.available_models[provider] + if models and self.state.model not in models: + self.state.model = get_default_model(provider) + def _build_ui(self) -> None: """Build a simplified Vuetify UI.""" # Initialize drawer state as collapsed @@ -653,7 +653,7 @@ def _build_ui(self) -> None: # Model selection vuetify.VSelect( label="Model", - v_model=("model", "gpt-4o"), + v_model=("model", "gpt-5"), items=("available_models[provider] || []",), density="compact", variant="outlined", @@ -750,6 +750,7 @@ def _build_ui(self) -> None: color="orange", prepend_icon="mdi-thermometer", classes="mt-2", + disabled=("!temperature_supported",), ) vuetify.VTextField( label="Max Tokens", From 40c4d1c1d8f17544507a69d95556f241782f3393 Mon Sep 17 00:00:00 2001 From: Brianna Major Date: Fri, 17 Oct 2025 14:08:09 -0400 Subject: [PATCH 3/3] Fix conditional statement --- src/vtk_prompt/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vtk_prompt/cli.py b/src/vtk_prompt/cli.py index a85cc2c..6005296 100644 --- a/src/vtk_prompt/cli.py +++ b/src/vtk_prompt/cli.py @@ -97,7 +97,7 @@ def main( model = default_models.get(provider, model) # Handle temperature override for unsupported models - if not supports_temperature(model) and temperature != 0.7: + if not supports_temperature(model): logger.warning( "Model %s does not support temperature control. " "Temperature parameter will be ignored (using 1.0).",