Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Tool.from_component #159

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions haystack_experimental/components/tools/openai/component_caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import MISSING, fields, is_dataclass
from inspect import getdoc
from typing import Any, Callable, Dict, Union, get_args, get_origin

from docstring_parser import parse
from haystack import logging
from haystack.core.component import Component
from pydantic.fields import FieldInfo

from haystack_experimental.util.utils import is_pydantic_v2_model

logger = logging.getLogger(__name__)


def extract_component_parameters(component: Component) -> Dict[str, Any]:
"""
Extracts parameters from a Haystack component and converts them to OpenAI tools JSON format.

:param component: The component to extract parameters from.
:returns: A dictionary representing the component's input parameters schema.
"""
properties = {}
required = []

param_descriptions = get_param_descriptions(component.run)

for input_name, socket in component.__haystack_input__._sockets_dict.items():
input_type = socket.type
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")

try:
property_schema = create_property_schema(input_type, description)
except ValueError as e:
raise ValueError(f"Error processing input '{input_name}': {e}")

properties[input_name] = property_schema

# Use socket.is_mandatory() to check if the input is required
if socket.is_mandatory:
required.append(input_name)

parameters_schema = {"type": "object", "properties": properties}

if required:
parameters_schema["required"] = required

return parameters_schema


def get_param_descriptions(method: Callable) -> Dict[str, str]:
"""
Extracts parameter descriptions from the method's docstring using docstring_parser.

:param method: The method to extract parameter descriptions from.
:returns: A dictionary mapping parameter names to their descriptions.
"""
docstring = getdoc(method)
if not docstring:
return {}

parsed_doc = parse(docstring)
param_descriptions = {}
for param in parsed_doc.params:
if not param.description:
logger.warning(
"Missing description for parameter '%s'. Please add a description in the component's "
"run() method docstring using the format ':param %s: <description>'. "
"This description is used to generate the Tool and helps the LLM understand how to use this parameter.",
param.arg_name,
param.arg_name,
)
param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
return param_descriptions


# ruff: noqa: PLR0912
def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
"""
Creates a property schema for a given Python type, recursively if necessary.

:param python_type: The Python type to create a property schema for.
:param description: The description of the property.
:param default: The default value of the property.
:returns: A dictionary representing the property schema.
"""
nullable = is_nullable_type(python_type)
if nullable:
non_none_types = [t for t in get_args(python_type) if t is not type(None)]
python_type = non_none_types[0] if non_none_types else str

origin = get_origin(python_type)
if origin is list:
item_type = get_args(python_type)[0] if get_args(python_type) else Any
# recursively call create_property_schema for the item type
items_schema = create_property_schema(item_type, "")
items_schema.pop("description", None)
schema = {"type": "array", "description": description, "items": items_schema}
elif is_dataclass(python_type) or is_pydantic_v2_model(python_type):
schema = {"type": "object", "description": description, "properties": {}}
required_fields = []

if is_dataclass(python_type):
# Get the actual class if python_type is an instance otherwise use the type
cls = python_type if isinstance(python_type, type) else python_type.__class__
for field in fields(cls):
field_description = f"Field '{field.name}' of '{cls.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][field.name] = create_property_schema(field.type, field_description)

else: # Pydantic model
for m_name, m_field in python_type.model_fields.items():
field_description = f"Field '{m_name}' of '{python_type.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description)
if m_field.is_required():
required_fields.append(m_name)

if required_fields:
schema["required"] = required_fields
else:
# Basic types
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
schema = {"type": type_mapping.get(python_type, "string"), "description": description}

if default is not None:
schema["default"] = default

return schema


def is_nullable_type(python_type: Any) -> bool:
"""
Checks if the type is a Union with NoneType (i.e., Optional).

:param python_type: The Python type to check.
:returns: True if the type is a Union with NoneType, False otherwise.
"""
origin = get_origin(python_type)
if origin is Union:
return type(None) in get_args(python_type)
return False
69 changes: 67 additions & 2 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, get_args, get_origin

from haystack import logging
from haystack.core.component import Component
from haystack.lazy_imports import LazyImport
from haystack.utils import deserialize_callable, serialize_callable
from pydantic import create_model
from pydantic import TypeAdapter, create_model

with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
from jsonschema import Draft202012Validator
from jsonschema.exceptions import SchemaError


logger = logging.getLogger(__name__)


class ToolInvocationError(Exception):
"""
Exception raised when a Tool invocation fails.
Expand Down Expand Up @@ -198,6 +203,66 @@ def get_weather(

return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)

@classmethod
def from_component(cls, component: Component, name: str, description: str) -> "Tool":
"""
Create a Tool instance from a Haystack component.

:param component: The Haystack component to be converted into a Tool.
:param name: Name for the tool.
:param description: Description of the tool.
:returns: The Tool created from the Component.
:raises ValueError: If the component is invalid or schema generation fails.
"""

if not isinstance(component, Component):
raise ValueError(
f"{component} is not a Haystack component!" "Can only create a Tool from a Haystack component instance."
)

if getattr(component, "__haystack_added_to_pipeline__", None):
msg = (
"Component has been added in a Pipeline and can't be used to create a Tool. "
"Create Tool from a non-pipeline component instead."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain this?

If I remember correctly, one of the requirements was about deserializing Tools from YAML (which should be feasible if Tools are components). I'm not totally sure...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought we can have a component declared but not be part of the pipeline. Maybe not, depending on that we can remove this check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand if this is a self-imposed limitation (I don't think so) or there are strong reasons to avoid that. Could you please explain this point further?

raise ValueError(msg)

from haystack_experimental.components.tools.openai.component_caller import extract_component_parameters
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

# Extract the parameters schema from the component
parameters = extract_component_parameters(component)

def component_invoker(**kwargs):
"""
Invokes the component using keyword arguments provided by the LLM function calling/tool generated response.

:param kwargs: The keyword arguments to invoke the component with.
:returns: The result of the component invocation.
"""
converted_kwargs = {}
input_sockets = component.__haystack_input__._sockets_dict
for param_name, param_value in kwargs.items():
param_type = input_sockets[param_name].type

# Check if the type (or list element type) has from_dict
target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
if hasattr(target_type, "from_dict"):
if isinstance(param_value, list):
param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
elif isinstance(param_value, dict):
param_value = target_type.from_dict(param_value)
else:
# Let TypeAdapter handle both single values and lists
type_adapter = TypeAdapter(param_type)
param_value = type_adapter.validate_python(param_value)

converted_kwargs[param_name] = param_value
logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
return component.run(**converted_kwargs)

# Return a new Tool instance with the component invoker as the function to be called
return Tool(name=name, description=description, parameters=parameters, function=component_invoker)


def _remove_title_from_schema(schema: Dict[str, Any]):
"""
Expand Down
12 changes: 11 additions & 1 deletion haystack_experimental/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union
from typing import Any, List, Union


def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
Expand Down Expand Up @@ -41,3 +41,13 @@ def expand_page_range(page_range: List[Union[str, int]]) -> List[int]:
raise ValueError("No valid page numbers or ranges found in the input list")

return expanded_page_range


def is_pydantic_v2_model(instance: Any) -> bool:
"""
Checks if the instance is a Pydantic v2 model.

:param instance: The instance to check.
:returns: True if the instance is a Pydantic v2 model, False otherwise.
"""
return hasattr(instance, "model_validate")
Loading
Loading