"""Schemas for the LangSmith API."""
from __future__ import annotations
import json
import logging
import sys
from datetime import datetime, timezone
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast
from uuid import UUID, uuid4
try:
from pydantic.v1 import Field, root_validator # type: ignore[import]
except ImportError:
from pydantic import ( # type: ignore[assignment, no-redef]
Field,
root_validator,
)
import threading
import urllib.parse
from langsmith import schemas as ls_schemas
from langsmith import utils
from langsmith.client import ID_TYPE, RUN_TYPE_T, Client, _dumps_json, _ensure_uuid
logger = logging.getLogger(__name__)
LANGSMITH_PREFIX = "langsmith-"
LANGSMITH_DOTTED_ORDER = sys.intern(f"{LANGSMITH_PREFIX}trace")
LANGSMITH_DOTTED_ORDER_BYTES = LANGSMITH_DOTTED_ORDER.encode("utf-8")
LANGSMITH_METADATA = sys.intern(f"{LANGSMITH_PREFIX}metadata")
LANGSMITH_TAGS = sys.intern(f"{LANGSMITH_PREFIX}tags")
LANGSMITH_PROJECT = sys.intern(f"{LANGSMITH_PREFIX}project")
_CLIENT: Optional[Client] = None
_LOCK = threading.Lock() # Keeping around for a while for backwards compat
# Note, this is called directly by langchain. Do not remove.
[docs]
def get_cached_client(**init_kwargs: Any) -> Client:
global _CLIENT
if _CLIENT is None:
if _CLIENT is None:
_CLIENT = Client(**init_kwargs)
return _CLIENT
[docs]
class RunTree(ls_schemas.RunBase):
"""Run Schema with back-references for posting runs."""
name: str
id: UUID = Field(default_factory=uuid4)
run_type: str = Field(default="chain")
start_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
parent_run: Optional[RunTree] = Field(default=None, exclude=True)
child_runs: List[RunTree] = Field(
default_factory=list,
exclude={"__all__": {"parent_run_id"}},
)
session_name: str = Field(
default_factory=lambda: utils.get_tracer_project() or "default",
alias="project_name",
)
session_id: Optional[UUID] = Field(default=None, alias="project_id")
extra: Dict = Field(default_factory=dict)
tags: Optional[List[str]] = Field(default_factory=list)
events: List[Dict] = Field(default_factory=list)
"""List of events associated with the run, like
start and end events."""
ls_client: Optional[Any] = Field(default=None, exclude=True)
dotted_order: str = Field(
default="", description="The order of the run in the tree."
)
trace_id: UUID = Field(default="", description="The trace id of the run.") # type: ignore
dangerously_allow_filesystem: Optional[bool] = Field(
default=False, description="Whether to allow filesystem access for attachments."
)
class Config:
"""Pydantic model configuration."""
arbitrary_types_allowed = True
allow_population_by_field_name = True
extra = "ignore"
[docs]
@root_validator(pre=True)
def infer_defaults(cls, values: dict) -> dict:
"""Assign name to the run."""
if values.get("name") is None and values.get("serialized") is not None:
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
if values.get("name") is None:
values["name"] = "Unnamed"
if "client" in values: # Handle user-constructed clients
values["ls_client"] = values.pop("client")
elif "_client" in values:
values["ls_client"] = values.pop("_client")
if not values.get("ls_client"):
values["ls_client"] = None
if values.get("parent_run") is not None:
values["parent_run_id"] = values["parent_run"].id
if "id" not in values:
values["id"] = uuid4()
if "trace_id" not in values:
if "parent_run" in values:
values["trace_id"] = values["parent_run"].trace_id
else:
values["trace_id"] = values["id"]
cast(dict, values.setdefault("extra", {}))
if values.get("events") is None:
values["events"] = []
if values.get("tags") is None:
values["tags"] = []
if values.get("outputs") is None:
values["outputs"] = {}
if values.get("attachments") is None:
values["attachments"] = {}
return values
[docs]
@root_validator(pre=False)
def ensure_dotted_order(cls, values: dict) -> dict:
"""Ensure the dotted order of the run."""
current_dotted_order = values.get("dotted_order")
if current_dotted_order and current_dotted_order.strip():
return values
current_dotted_order = _create_current_dotted_order(
values["start_time"], values["id"]
)
if values["parent_run"]:
values["dotted_order"] = (
values["parent_run"].dotted_order + "." + current_dotted_order
)
else:
values["dotted_order"] = current_dotted_order
return values
@property
def client(self) -> Client:
"""Return the client."""
# Lazily load the client
# If you never use this for API calls, it will never be loaded
if self.ls_client is None:
self.ls_client = get_cached_client()
return self.ls_client
@property
def _client(self) -> Optional[Client]:
# For backwards compat
return self.ls_client
def __setattr__(self, name, value):
"""Set the _client specially."""
# For backwards compat
if name == "_client":
self.ls_client = value
else:
return super().__setattr__(name, value)
[docs]
def add_outputs(self, outputs: Dict[str, Any]) -> None:
"""Upsert the given outputs into the run.
Args:
outputs (Dict[str, Any]): A dictionary containing the outputs to be added.
Returns:
None
"""
if self.outputs is None:
self.outputs = {}
self.outputs.update(outputs)
[docs]
def add_event(
self,
events: Union[
ls_schemas.RunEvent,
Sequence[ls_schemas.RunEvent],
Sequence[dict],
dict,
str,
],
) -> None:
"""Add an event to the list of events.
Args:
events (Union[ls_schemas.RunEvent, Sequence[ls_schemas.RunEvent],
Sequence[dict], dict, str]):
The event(s) to be added. It can be a single event, a sequence
of events, a sequence of dictionaries, a dictionary, or a string.
Returns:
None
"""
if self.events is None:
self.events = []
if isinstance(events, dict):
self.events.append(events) # type: ignore[arg-type]
elif isinstance(events, str):
self.events.append(
{
"name": "event",
"time": datetime.now(timezone.utc).isoformat(),
"message": events,
}
)
else:
self.events.extend(events) # type: ignore[arg-type]
[docs]
def end(
self,
*,
outputs: Optional[Dict] = None,
error: Optional[str] = None,
end_time: Optional[datetime] = None,
events: Optional[Sequence[ls_schemas.RunEvent]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Set the end time of the run and all child runs."""
self.end_time = end_time or datetime.now(timezone.utc)
if outputs is not None:
if not self.outputs:
self.outputs = outputs
else:
self.outputs.update(outputs)
if error is not None:
self.error = error
if events is not None:
self.add_event(events)
if metadata is not None:
self.add_metadata(metadata)
[docs]
def create_child(
self,
name: str,
run_type: RUN_TYPE_T = "chain",
*,
run_id: Optional[ID_TYPE] = None,
serialized: Optional[Dict] = None,
inputs: Optional[Dict] = None,
outputs: Optional[Dict] = None,
error: Optional[str] = None,
reference_example_id: Optional[UUID] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
tags: Optional[List[str]] = None,
extra: Optional[Dict] = None,
attachments: Optional[ls_schemas.Attachments] = None,
) -> RunTree:
"""Add a child run to the run tree."""
serialized_ = serialized or {"name": name}
run = RunTree(
name=name,
id=_ensure_uuid(run_id),
serialized=serialized_,
inputs=inputs or {},
outputs=outputs or {},
error=error,
run_type=run_type,
reference_example_id=reference_example_id,
start_time=start_time or datetime.now(timezone.utc),
end_time=end_time,
extra=extra or {},
parent_run=self,
project_name=self.session_name,
ls_client=self.ls_client,
tags=tags,
attachments=attachments or {}, # type: ignore
dangerously_allow_filesystem=self.dangerously_allow_filesystem,
)
self.child_runs.append(run)
return run
def _get_dicts_safe(self):
# Things like generators cannot be copied
self_dict = self.dict(
exclude={"child_runs", "inputs", "outputs"}, exclude_none=True
)
if self.inputs is not None:
# shallow copy. deep copying will occur in the client
self_dict["inputs"] = self.inputs.copy()
if self.outputs is not None:
# shallow copy; deep copying will occur in the client
self_dict["outputs"] = self.outputs.copy()
return self_dict
[docs]
def post(self, exclude_child_runs: bool = True) -> None:
"""Post the run tree to the API asynchronously."""
kwargs = self._get_dicts_safe()
self.client.create_run(**kwargs)
if attachments := kwargs.get("attachments"):
keys = [str(name) for name in attachments]
self.events.append(
{
"name": "uploaded_attachment",
"time": datetime.now(timezone.utc).isoformat(),
"message": set(keys),
}
)
if not exclude_child_runs:
for child_run in self.child_runs:
child_run.post(exclude_child_runs=False)
[docs]
def patch(self) -> None:
"""Patch the run tree to the API in a background thread."""
if not self.end_time:
self.end()
attachments = {
a: v for a, v in self.attachments.items() if isinstance(v, tuple)
}
try:
# Avoid loading the same attachment twice
if attachments:
uploaded = next(
(
ev
for ev in self.events
if ev.get("name") == "uploaded_attachment"
),
None,
)
if uploaded:
attachments = {
a: v
for a, v in attachments.items()
if a not in uploaded["message"]
}
except Exception as e:
logger.warning(f"Error filtering attachments to upload: {e}")
self.client.update_run(
name=self.name,
run_id=self.id,
inputs=self.inputs.copy() if self.inputs else None,
outputs=self.outputs.copy() if self.outputs else None,
error=self.error,
parent_run_id=self.parent_run_id,
reference_example_id=self.reference_example_id,
end_time=self.end_time,
dotted_order=self.dotted_order,
trace_id=self.trace_id,
events=self.events,
tags=self.tags,
extra=self.extra,
attachments=attachments,
)
[docs]
def wait(self) -> None:
"""Wait for all _futures to complete."""
pass
[docs]
def get_url(self) -> str:
"""Return the URL of the run."""
return self.client.get_run_url(run=self)
[docs]
@classmethod
def from_dotted_order(
cls,
dotted_order: str,
**kwargs: Any,
) -> RunTree:
"""Create a new 'child' span from the provided dotted order.
Returns:
RunTree: The new span.
"""
headers = {
LANGSMITH_DOTTED_ORDER: dotted_order,
}
return cast(RunTree, cls.from_headers(headers, **kwargs)) # type: ignore[arg-type]
[docs]
@classmethod
def from_runnable_config(
cls,
config: Optional[dict],
**kwargs: Any,
) -> Optional[RunTree]:
"""Create a new 'child' span from the provided runnable config.
Requires langchain to be installed.
Returns:
Optional[RunTree]: The new span or None if
no parent span information is found.
"""
try:
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
CallbackManager,
)
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.tracers.langchain import LangChainTracer
except ImportError as e:
raise ImportError(
"RunTree.from_runnable_config requires langchain-core to be installed. "
"You can install it with `pip install langchain-core`."
) from e
if config is None:
config_ = ensure_config(
cast(RunnableConfig, config) if isinstance(config, dict) else None
)
else:
config_ = cast(RunnableConfig, config)
if (
(cb := config_.get("callbacks"))
and isinstance(cb, (CallbackManager, AsyncCallbackManager))
and cb.parent_run_id
and (
tracer := next(
(t for t in cb.handlers if isinstance(t, LangChainTracer)),
None,
)
)
):
if (run := tracer.run_map.get(str(cb.parent_run_id))) and run.dotted_order:
dotted_order = run.dotted_order
kwargs["run_type"] = run.run_type
kwargs["inputs"] = run.inputs
kwargs["outputs"] = run.outputs
kwargs["start_time"] = run.start_time
kwargs["end_time"] = run.end_time
kwargs["tags"] = sorted(set(run.tags or [] + kwargs.get("tags", [])))
kwargs["name"] = run.name
extra_ = kwargs.setdefault("extra", {})
metadata_ = extra_.setdefault("metadata", {})
metadata_.update(run.metadata)
elif hasattr(tracer, "order_map") and cb.parent_run_id in tracer.order_map:
dotted_order = tracer.order_map[cb.parent_run_id][1]
else:
return None
kwargs["client"] = tracer.client
kwargs["project_name"] = tracer.project_name
return RunTree.from_dotted_order(dotted_order, **kwargs)
return None
def __repr__(self):
"""Return a string representation of the RunTree object."""
return (
f"RunTree(id={self.id}, name='{self.name}', "
f"run_type='{self.run_type}', dotted_order='{self.dotted_order}')"
)
class _Baggage:
"""Baggage header information."""
def __init__(
self,
metadata: Optional[Dict[str, str]] = None,
tags: Optional[List[str]] = None,
project_name: Optional[str] = None,
):
"""Initialize the Baggage object."""
self.metadata = metadata or {}
self.tags = tags or []
self.project_name = project_name
@classmethod
def from_header(cls, header_value: Optional[str]) -> _Baggage:
"""Create a Baggage object from the given header value."""
if not header_value:
return cls()
metadata = {}
tags = []
project_name = None
try:
for item in header_value.split(","):
key, value = item.split("=", 1)
if key == LANGSMITH_METADATA:
metadata = json.loads(urllib.parse.unquote(value))
elif key == LANGSMITH_TAGS:
tags = urllib.parse.unquote(value).split(",")
elif key == LANGSMITH_PROJECT:
project_name = urllib.parse.unquote(value)
except Exception as e:
logger.warning(f"Error parsing baggage header: {e}")
return cls(metadata=metadata, tags=tags, project_name=project_name)
@classmethod
def from_headers(cls, headers: Mapping[Union[str, bytes], Any]) -> _Baggage:
if "baggage" in headers:
return cls.from_header(headers["baggage"])
elif b"baggage" in headers:
return cls.from_header(cast(bytes, headers[b"baggage"]).decode("utf-8"))
else:
return cls.from_header(None)
def to_header(self) -> str:
"""Return the Baggage object as a header value."""
items = []
if self.metadata:
serialized_metadata = _dumps_json(self.metadata)
items.append(
f"{LANGSMITH_PREFIX}metadata={urllib.parse.quote(serialized_metadata)}"
)
if self.tags:
serialized_tags = ",".join(self.tags)
items.append(
f"{LANGSMITH_PREFIX}tags={urllib.parse.quote(serialized_tags)}"
)
if self.project_name:
items.append(
f"{LANGSMITH_PREFIX}project={urllib.parse.quote(self.project_name)}"
)
return ",".join(items)
def _parse_dotted_order(dotted_order: str) -> List[Tuple[datetime, UUID]]:
"""Parse the dotted order string."""
parts = dotted_order.split(".")
return [
(datetime.strptime(part[:-36], "%Y%m%dT%H%M%S%fZ"), UUID(part[-36:]))
for part in parts
]
def _create_current_dotted_order(
start_time: Optional[datetime], run_id: Optional[UUID]
) -> str:
"""Create the current dotted order."""
st = start_time or datetime.now(timezone.utc)
id_ = run_id or uuid4()
return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_)
__all__ = ["RunTree", "RunTree"]