Skip to content

Utilities - Prompt Embedding Function #33

@umar-anzar

Description

@umar-anzar

Use in Prompt Embedding for json response

import inspect

def class_to_string(cls):
    """
    Convert a Python class and its base classes to a string representation.

    This function takes a Python class as input and returns its source code 
    as a string. It gathers the source code of the specified class and 
    recursively includes any user-defined parent classes that are part of the 
    same module. It does not include external library classes (like Pydantic's 
    BaseModel) or classes from other modules.

    Returns:
        str: A string containing the class definition and its base classes.

    Example:
        >>> class Parent:
        ...     x: int
        ...
        >>> class Child(Parent):
        ...     y: str
        ...
        >>> print(class_to_string(Child))
        '''
        class Parent:
            x: int
        
        class Child(Parent):
            y: str
        '''
    
    Limitations:
        - Only captures classes defined in the same script or module.
        - Does not include parent classes from external libraries.
        - If the class has a complex inheritance hierarchy, only user-defined
          parent classes are included.
    """
    class_definitions = []

    # Gather base class definitions recursively
    for base in cls.__bases__:
        if base.__module__ == "__main__":  # Only get user-defined classes
            class_definitions.append(inspect.getsource(base))

    # Add the main class definition
    class_definitions.append(inspect.getsource(cls))

    # Join all class definitions and wrap in triple quotes
    return f'"""\n{"".join(class_definitions)}"""'
import inspect

from pydantic import BaseModel


def get_related_classes(cls, seen: set[type] = None, ordered_classes: list[type] = None):
    """
    Recursively collects all user-defined classes related to the given class in the correct order.

    Args:
        cls (type): The class whose dependencies need to be collected.
        seen (set, optional): A set to track visited classes.
        ordered_classes (list, optional): A list to store classes in the correct order.

    Returns:
        list: A list of user-defined classes in the correct order.
    """
    if seen is None:
        seen = set()
    if ordered_classes is None:
        ordered_classes = []

    if cls in seen or not inspect.isclass(cls):
        return ordered_classes
    seen.add(cls)

    # Check all annotations to find referenced classes
    for _, annotation in getattr(cls, '__annotations__', {}).items():
        if hasattr(annotation, '__origin__'):  # Handle Annotated and Optional
            for arg in annotation.__args__:
                if inspect.isclass(arg) and issubclass(arg, BaseModel):
                    get_related_classes(arg, seen, ordered_classes)
        elif inspect.isclass(annotation) and issubclass(annotation, BaseModel):
            get_related_classes(annotation, seen, ordered_classes)

    # Add the class at the end to maintain the correct order
    if cls not in ordered_classes:
        ordered_classes.append(cls)

    return ordered_classes

def class_to_string(cls):
    """
    Convert a class and all its referenced user-defined classes to a string in the correct order.

    Args:
        cls (type): The main class to convert.

    Returns:
        str: A string containing all relevant class definitions in the correct order.
    """
    ordered_classes = get_related_classes(cls)
    class_definitions = [inspect.getsource(c) for c in ordered_classes]

    return f'"""\n{"".join(class_definitions)}"""'

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions