Skip to content

Commit f621b31

Browse files
author
Andrei Bratu
committed
Parse '|' annotations in tool decorator
1 parent 03261c4 commit f621b31

File tree

3 files changed

+107
-3
lines changed

3 files changed

+107
-3
lines changed

.github/workflows/ci.yml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
- name: Set up python
1111
uses: actions/setup-python@v4
1212
with:
13-
python-version: 3.9
13+
python-version: 3.12
1414
- name: Bootstrap poetry
1515
run: |
1616
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
@@ -41,6 +41,32 @@ jobs:
4141
REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }}
4242
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
4343
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
44+
test_3_12:
45+
# Test with Python 3.12
46+
# Some tool decorator tests assert the ability to parse the signature
47+
# if functions that use typing features introduced in Python 3.10 e.g. '|'
48+
runs-on: ubuntu-20.04
49+
steps:
50+
- name: Checkout repo
51+
uses: actions/checkout@v3
52+
- name: Set up python
53+
uses: actions/setup-python@v4
54+
with:
55+
python-version: 3.12
56+
- name: Bootstrap poetry
57+
run: |
58+
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
59+
- name: Install dependencies
60+
run: poetry install
61+
62+
- name: Test
63+
run: poetry run pytest -rP .
64+
env:
65+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
66+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
67+
REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }}
68+
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
69+
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
4470

4571
publish:
4672
needs: [compile, test]

src/humanloop/decorators/tool.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import builtins
22
import inspect
33
import logging
4+
import sys
45
import textwrap
56
import typing
67
from dataclasses import dataclass
@@ -24,6 +25,9 @@
2425
from humanloop.requests.tool_function import ToolFunctionParams
2526
from humanloop.requests.tool_kernel_request import ToolKernelRequestParams
2627

28+
if sys.version_info >= (3, 10):
29+
import types
30+
2731
logger = logging.getLogger("humanloop.sdk")
2832

2933

@@ -335,7 +339,7 @@ def _parse_annotation(annotation: typing.Type) -> _ParsedAnnotation:
335339
annotation=[_parse_annotation(arg) for arg in typing.get_args(annotation)],
336340
)
337341

338-
if origin is typing.Union:
342+
if origin is typing.Union or (sys.version_info >= (3, 10) and origin is types.UnionType):
339343
sub_types = typing.get_args(annotation)
340344
if sub_types[-1] is type(None):
341345
# type(None) in sub_types indicates Optional type
@@ -495,4 +499,8 @@ def _parameter_is_optional(
495499
origin = typing.get_origin(parameter.annotation)
496500
# sub_types refers to T inside the annotation
497501
sub_types = typing.get_args(parameter.annotation)
498-
return origin is typing.Union and len(sub_types) > 0 and sub_types[-1] is type(None)
502+
return (
503+
(origin is typing.Union or (sys.version_info >= (3, 10) and origin is types.UnionType))
504+
and len(sub_types) > 0
505+
and sub_types[-1] is type(None)
506+
)

tests/decorators/test_tool_decorator.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from typing import Any, Optional, TypedDict, Union
23

34
import pytest
@@ -460,3 +461,72 @@ def calculator(operation: str, num1: float, num2: float) -> float:
460461
key=HUMANLOOP_FILE_KEY,
461462
)
462463
assert hl_file_higher_order_fn["tool"]["source_code"] == hl_file_decorated_fn["tool"]["source_code"] # type: ignore
464+
465+
466+
def test_python310_syntax(
467+
opentelemetry_test_configuration: tuple[Tracer, InMemorySpanExporter],
468+
):
469+
if sys.version_info < (3, 10):
470+
pytest.skip("Requires Python 3.10")
471+
# GIVEN an OTel configuration
472+
tracer, _ = opentelemetry_test_configuration
473+
474+
# GIVEN a function annotated with @tool where a parameter uses `|` for Optional
475+
@tool(opentelemetry_tracer=tracer)
476+
def calculator(a: float, b: float | None = None) -> float:
477+
if a is None:
478+
a = 0
479+
return a + b
480+
481+
# WHEN building the Tool kernel
482+
# THEN the JSON schema is correct
483+
assert calculator.json_schema == {
484+
"description": "",
485+
"name": "calculator",
486+
"parameters": {
487+
"properties": {
488+
"a": {"type": "number"},
489+
"b": {"type": ["number", "null"]},
490+
},
491+
"required": ("a",),
492+
"type": "object",
493+
"additionalProperties": False,
494+
},
495+
"strict": True,
496+
}
497+
498+
Validator.check_schema(calculator.json_schema)
499+
500+
501+
def test_python310_union_syntax(
502+
opentelemetry_test_configuration: tuple[Tracer, InMemorySpanExporter],
503+
):
504+
if sys.version_info < (3, 10):
505+
pytest.skip("Requires Python 3.10")
506+
507+
# GIVEN an OTel configuration
508+
tracer, _ = opentelemetry_test_configuration
509+
510+
# GIVEN a function annotated with @tool where a parameter uses `|` for Union
511+
@tool(opentelemetry_tracer=tracer)
512+
def calculator(a: float, b: float | int | str) -> float:
513+
return a + b
514+
515+
# WHEN building the Tool kernel
516+
# THEN the JSON schema is correct
517+
assert calculator.json_schema == {
518+
"description": "",
519+
"name": "calculator",
520+
"parameters": {
521+
"properties": {
522+
"a": {"type": "number"},
523+
"b": {"anyOf": [{"type": "number"}, {"type": "integer"}, {"type": "string"}]},
524+
},
525+
"required": ("a", "b"),
526+
"type": "object",
527+
"additionalProperties": False,
528+
},
529+
"strict": True,
530+
}
531+
532+
Validator.check_schema(calculator.json_schema)

0 commit comments

Comments
 (0)