Skip to content

Commit d0bfa83

Browse files
committed
Further refine types in overload.py
1 parent f664214 commit d0bfa83

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/humanloop/overload.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import logging
33
import types
4-
from typing import Any, Callable, Dict, Optional, Union, TypeVar, Protocol
4+
from typing import Any, Callable, Dict, Optional, TypeVar, Union
55

66
from humanloop.agents.client import AgentsClient
77
from humanloop.context import (
@@ -59,13 +59,13 @@ def _get_file_type_from_client(
5959
return "dataset"
6060
elif isinstance(client, EvaluatorsClient):
6161
return "evaluator"
62+
else:
63+
raise ValueError(f"Unsupported client type: {type(client)}")
6264

63-
raise ValueError(f"Unsupported client type: {type(client)}")
6465

65-
66-
def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, Any]:
66+
def _handle_tracing_context(kwargs: Dict[str, Any], client: T) -> Dict[str, Any]:
6767
"""Handle tracing context for both log and call methods."""
68-
trace_id = get_trace_id()
68+
trace_id = get_trace_id()
6969
if trace_id is not None:
7070
if "flow" in str(type(client).__name__).lower():
7171
context = get_decorator_context()
@@ -90,7 +90,7 @@ def _handle_tracing_context(kwargs: Dict[str, Any], client: Any) -> Dict[str, An
9090

9191
def _handle_local_files(
9292
kwargs: Dict[str, Any],
93-
client: Any,
93+
client: T,
9494
sync_client: Optional[SyncClient],
9595
use_local_files: bool,
9696
) -> Dict[str, Any]:
@@ -140,7 +140,7 @@ def _handle_evaluation_context(kwargs: Dict[str, Any]) -> tuple[Dict[str, Any],
140140
return kwargs, None
141141

142142

143-
def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
143+
def _overload_log(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> LogResponseType:
144144
try:
145145
# Special handling for flows - prevent direct log usage
146146
if type(self) is FlowsClient and get_trace_id() is not None:
@@ -162,7 +162,7 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
162162
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
163163

164164
kwargs, eval_callback = _handle_evaluation_context(kwargs)
165-
response = self._log(**kwargs) # Use stored original method
165+
response = self._log(**kwargs) # type: ignore[union-attr] # Use stored original method
166166
if eval_callback is not None:
167167
eval_callback(response.id)
168168
return response
@@ -174,11 +174,11 @@ def _overload_log(self: Any, sync_client: Optional[SyncClient], use_local_files:
174174
raise HumanloopRuntimeError from e
175175

176176

177-
def _overload_call(self: Any, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
177+
def _overload_call(self: T, sync_client: Optional[SyncClient], use_local_files: bool, **kwargs) -> CallResponseType:
178178
try:
179179
kwargs = _handle_tracing_context(kwargs, self)
180180
kwargs = _handle_local_files(kwargs, self, sync_client, use_local_files)
181-
return self._call(**kwargs) # Use stored original method
181+
return self._call(**kwargs) # type: ignore[union-attr] # Use stored original method
182182
except HumanloopRuntimeError:
183183
# Re-raise HumanloopRuntimeError without wrapping to preserve the message
184184
raise
@@ -200,7 +200,7 @@ def overload_client(
200200
setattr(client, "_log", original_log)
201201

202202
# Create a closure to capture sync_client and use_local_files
203-
def log_wrapper(self: Any, **kwargs) -> LogResponseType:
203+
def log_wrapper(self: T, **kwargs) -> LogResponseType:
204204
return _overload_log(self, sync_client, use_local_files, **kwargs)
205205

206206
# Replace the log method
@@ -217,7 +217,7 @@ def log_wrapper(self: Any, **kwargs) -> LogResponseType:
217217
setattr(client, "_call", original_call)
218218

219219
# Create a closure to capture sync_client and use_local_files
220-
def call_wrapper(self: Any, **kwargs) -> CallResponseType:
220+
def call_wrapper(self: T, **kwargs) -> CallResponseType:
221221
return _overload_call(self, sync_client, use_local_files, **kwargs)
222222

223223
# Replace the call method

0 commit comments

Comments
 (0)