from __future__ import annotations
import functools
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from pydantic import TypeAdapter
from typing_extensions import Self, TypedDict
from langsmith import client as ls_client
from langsmith import run_helpers
from langsmith.schemas import InputTokenDetails, UsageMetadata
if TYPE_CHECKING:
import httpx
from anthropic import Anthropic, AsyncAnthropic
from anthropic.types import Completion, Message, MessageStreamEvent
C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
logger = logging.getLogger(__name__)
@functools.lru_cache
def _get_not_given() -> Optional[Type]:
try:
from anthropic._types import NotGiven
return NotGiven
except ImportError:
return None
def _strip_not_given(d: dict) -> dict:
try:
if (not_given := _get_not_given()) is not None:
d = {k: v for k, v in d.items() if not isinstance(v, not_given)}
except Exception as e:
logger.error(f"Error stripping NotGiven: {e}")
if "system" in d:
d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
"messages", []
)
d.pop("system")
return {k: v for k, v in d.items() if v is not None}
def _infer_ls_params(kwargs: dict):
stripped = _strip_not_given(kwargs)
stop = stripped.get("stop")
if stop and isinstance(stop, str):
stop = [stop]
return {
"ls_provider": "anthropic",
"ls_model_type": "chat",
"ls_model_name": stripped.get("model", None),
"ls_temperature": stripped.get("temperature", None),
"ls_max_tokens": stripped.get("max_tokens", None),
"ls_stop": stop,
}
def _accumulate_event(
*, event: MessageStreamEvent, current_snapshot: Message | None
) -> Message | None:
try:
from anthropic.types import ContentBlock
except ImportError:
logger.debug("Error importing ContentBlock")
return current_snapshot
if current_snapshot is None:
if event.type == "message_start":
return event.message
raise RuntimeError(
f'Unexpected event order, got {event.type} before "message_start"'
)
if event.type == "content_block_start":
# TODO: check index <-- from anthropic SDK :)
adapter: TypeAdapter = TypeAdapter(ContentBlock)
content_block_instance = adapter.validate_python(
event.content_block.model_dump()
)
current_snapshot.content.append(
content_block_instance, # type: ignore[attr-defined]
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
current_snapshot.usage.output_tokens = event.usage.output_tokens
return current_snapshot
def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
full_message = None
for chunk in all_chunks:
try:
full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
except RuntimeError as e:
logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
return {"output": all_chunks}
if full_message is None:
return {"output": all_chunks}
d = full_message.model_dump()
d["usage_metadata"] = _create_usage_metadata(d.pop("usage", {}))
d.pop("type", None)
return {"message": d}
def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
input_tokens = anthropic_token_usage.get("input_tokens") or 0
output_tokens = anthropic_token_usage.get("output_tokens") or 0
total_tokens = input_tokens + output_tokens
input_token_details: dict = {
"cache_read": anthropic_token_usage.get("cache_creation_input_tokens", 0)
+ anthropic_token_usage.get("cache_read_input_tokens", 0)
}
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
input_token_details=InputTokenDetails(
**{k: v for k, v in input_token_details.items() if v is not None}
),
)
def _reduce_completions(all_chunks: List[Completion]) -> dict:
all_content = []
for chunk in all_chunks:
content = chunk.completion
if content is not None:
all_content.append(content)
content = "".join(all_content)
if all_chunks:
d = all_chunks[-1].model_dump()
d["choices"] = [{"text": content}]
else:
d = {"choices": [{"text": content}]}
return d
def _process_chat_completion(outputs: Any):
try:
rdict = outputs.model_dump()
anthropic_token_usage = rdict.pop("usage", None)
rdict["usage_metadata"] = (
_create_usage_metadata(anthropic_token_usage)
if anthropic_token_usage
else None
)
rdict.pop("type", None)
return {"message": rdict}
except BaseException as e:
logger.debug(f"Error processing chat completion: {e}")
return {"output": outputs}
def _get_wrapper(
original_create: Callable,
name: str,
reduce_fn: Callable,
tracing_extra: TracingExtra,
) -> Callable:
@functools.wraps(original_create)
def create(*args, **kwargs):
stream = kwargs.get("stream")
decorator = run_helpers.traceable(
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
process_outputs=_process_chat_completion,
_invocation_params_fn=_infer_ls_params,
**tracing_extra,
)
result = decorator(original_create)(*args, **kwargs)
return result
@functools.wraps(original_create)
async def acreate(*args, **kwargs):
stream = kwargs.get("stream")
decorator = run_helpers.traceable(
name=name,
run_type="llm",
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
process_outputs=_process_chat_completion,
_invocation_params_fn=_infer_ls_params,
**tracing_extra,
)
result = await decorator(original_create)(*args, **kwargs)
return result
return acreate if run_helpers.is_async(original_create) else create
def _get_stream_wrapper(
original_stream: Callable,
name: str,
tracing_extra: TracingExtra,
) -> Callable:
"""Create a wrapper for Anthropic's streaming context manager."""
import anthropic
is_async = "async" in str(original_stream).lower()
configured_traceable = run_helpers.traceable(
name=name,
reduce_fn=_reduce_chat_chunks,
run_type="llm",
process_inputs=_strip_not_given,
_invocation_params_fn=_infer_ls_params,
**tracing_extra,
)
configured_traceable_text = run_helpers.traceable(
name=name,
run_type="llm",
process_inputs=_strip_not_given,
process_outputs=_process_chat_completion,
_invocation_params_fn=_infer_ls_params,
**tracing_extra,
)
if is_async:
class AsyncMessageStreamWrapper:
def __init__(
self,
wrapped: anthropic.lib.streaming._messages.AsyncMessageStream,
**kwargs,
) -> None:
self._wrapped = wrapped
self._kwargs = kwargs
@property
def text_stream(self):
@configured_traceable_text
async def _text_stream(**_):
async for chunk in self._wrapped.text_stream:
yield chunk
run_tree = run_helpers.get_current_run_tree()
final_message = await self._wrapped.get_final_message()
run_tree.outputs = _process_chat_completion(final_message)
return _text_stream(**self._kwargs)
@property
def response(self) -> httpx.Response:
return self._wrapped.response
@property
def request_id(self) -> str | None:
return self._wrapped.request_id
async def __anext__(self) -> MessageStreamEvent:
aiter = self.__aiter__()
return await aiter.__anext__()
async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
@configured_traceable
def traced_iter(**_):
return self._wrapped.__aiter__()
async for chunk in traced_iter(**self._kwargs):
yield chunk
async def __aenter__(self) -> Self:
await self._wrapped.__aenter__()
return self
async def __aexit__(self, *exc) -> None:
await self._wrapped.__aexit__(*exc)
async def close(self) -> None:
await self._wrapped.close()
async def get_final_message(self) -> Message:
return await self._wrapped.get_final_message()
async def get_final_text(self) -> str:
return await self._wrapped.get_final_text()
async def until_done(self) -> None:
await self._wrapped.until_done()
@property
def current_message_snapshot(self) -> Message:
return self._wrapped.current_message_snapshot
class AsyncMessagesStreamManagerWrapper:
def __init__(self, **kwargs):
self._kwargs = kwargs
async def __aenter__(self):
self._manager = original_stream(**self._kwargs)
stream = await self._manager.__aenter__()
return AsyncMessageStreamWrapper(stream, **self._kwargs)
async def __aexit__(self, *exc):
await self._manager.__aexit__(*exc)
return AsyncMessagesStreamManagerWrapper
else:
class MessageStreamWrapper:
def __init__(
self,
wrapped: anthropic.lib.streaming._messages.MessageStream,
**kwargs,
) -> None:
self._wrapped = wrapped
self._kwargs = kwargs
@property
def response(self) -> Any:
return self._wrapped.response
@property
def request_id(self) -> str | None:
return self._wrapped.request_id # type: ignore[no-any-return]
@property
def text_stream(self):
@configured_traceable_text
def _text_stream(**_):
yield from self._wrapped.text_stream
run_tree = run_helpers.get_current_run_tree()
final_message = self._wrapped.get_final_message()
run_tree.outputs = _process_chat_completion(final_message)
return _text_stream(**self._kwargs)
def __next__(self) -> MessageStreamEvent:
return self.__iter__().__next__()
def __iter__(self):
@configured_traceable
def traced_iter(**_):
return self._wrapped.__iter__()
return traced_iter(**self._kwargs)
def __enter__(self) -> Self:
self._wrapped.__enter__()
return self
def __exit__(self, *exc) -> None:
self._wrapped.__exit__(*exc)
def close(self) -> None:
self._wrapped.close()
def get_final_message(self) -> Message:
return self._wrapped.get_final_message()
def get_final_text(self) -> str:
return self._wrapped.get_final_text()
def until_done(self) -> None:
return self._wrapped.until_done()
@property
def current_message_snapshot(self) -> Message:
return self._wrapped.current_message_snapshot
class MessagesStreamManagerWrapper:
def __init__(self, **kwargs):
self._kwargs = kwargs
def __enter__(self):
self._manager = original_stream(**self._kwargs)
return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
def __exit__(self, *exc):
self._manager.__exit__(*exc)
return MessagesStreamManagerWrapper
[docs]
def wrap_anthropic(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C:
"""Patch the Anthropic client to make it traceable.
Args:
client (Union[Anthropic, AsyncAnthropic]): The client to patch.
tracing_extra (Optional[TracingExtra], optional): Extra tracing information.
Defaults to None.
Returns:
Union[Anthropic, AsyncAnthropic]: The patched client.
Example:
.. code-block:: python
import anthropic
from langsmith import wrappers
client = wrappers.wrap_anthropic(anthropic.Anthropic())
# Use Anthropic client same as you normally would:
system = "You are a helpful assistant."
messages = [
{
"role": "user",
"content": "What physics breakthroughs do you predict will happen by 2300?",
}
]
completion = client.messages.create(
model="claude-3-5-sonnet-latest",
messages=messages,
max_tokens=1000,
system=system,
)
print(completion.content)
# You can also use the streaming context manager:
with client.messages.stream(
model="claude-3-5-sonnet-latest",
messages=messages,
max_tokens=1000,
system=system,
) as stream:
for text in stream.text_stream:
print(text, end="", flush=True)
message = stream.get_final_message()
""" # noqa: E501
tracing_extra = tracing_extra or {}
client.messages.create = _get_wrapper( # type: ignore[method-assign]
client.messages.create,
"ChatAnthropic",
_reduce_chat_chunks,
tracing_extra,
)
client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
client.messages.stream,
"ChatAnthropic",
tracing_extra,
)
client.completions.create = _get_wrapper( # type: ignore[method-assign]
client.completions.create,
"Anthropic",
_reduce_completions,
tracing_extra,
)
if (
hasattr(client, "beta")
and hasattr(client.beta, "messages")
and hasattr(client.beta.messages, "create")
):
client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
client.beta.messages.create, # type: ignore
"ChatAnthropic",
_reduce_chat_chunks,
tracing_extra,
)
return client