-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Tool.from_component #159
Open
vblagoje
wants to merge
16
commits into
main
Choose a base branch
from
from_component
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
784a19c
Initial Tool.from_component
vblagoje 6d6f307
Simplify types conversion with TypeAdapter
vblagoje eb1496c
Pylint, small fixes
vblagoje 519d53f
Improve warning when component run pydocs are missing
vblagoje 7178e1d
Add Anthropic integration tests
vblagoje d83ca13
Minor test fix
vblagoje 6341268
Merge branch 'main' into from_component
vblagoje 1c259f9
Handle our own dataclasses (e.g. Document)
vblagoje 34a0861
For dataclasses don't check required fields, add more itegration tests
vblagoje 551a528
Small fix for better test
vblagoje ce864dd
Make sure we are only using non-pipeline components for Tools
vblagoje 931df70
Move modules around
vblagoje 3b29458
Refactor and simplify tools schema creation
vblagoje d8f722c
Better naming
vblagoje 2dbb8d2
Rename module
vblagoje b35c098
PR feedback
vblagoje File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .component_schema import create_tool_parameters_schema | ||
|
||
__all__ = ["create_tool_parameters_schema"] | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from dataclasses import fields, is_dataclass | ||
from inspect import getdoc | ||
from typing import Any, Callable, Dict, Union, get_args, get_origin | ||
|
||
from docstring_parser import parse | ||
from haystack import logging | ||
from haystack.core.component import Component | ||
|
||
from haystack_experimental.util.utils import is_pydantic_v2_model | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def create_tool_parameters_schema(component: Component) -> Dict[str, Any]: | ||
""" | ||
Creates an OpenAI tools schema from a component's run method parameters. | ||
|
||
:param component: The component to create the schema from. | ||
:returns: OpenAI tools schema for the component's run method parameters. | ||
""" | ||
properties = {} | ||
required = [] | ||
|
||
param_descriptions = get_param_descriptions(component.run) | ||
|
||
for input_name, socket in component.__haystack_input__._sockets_dict.items(): | ||
input_type = socket.type | ||
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") | ||
|
||
try: | ||
property_schema = create_property_schema(input_type, description) | ||
except ValueError as e: | ||
raise ValueError(f"Error processing input '{input_name}': {e}") | ||
|
||
properties[input_name] = property_schema | ||
|
||
# Use socket.is_mandatory to check if the input is required | ||
if socket.is_mandatory: | ||
required.append(input_name) | ||
|
||
parameters_schema = {"type": "object", "properties": properties} | ||
|
||
if required: | ||
parameters_schema["required"] = required | ||
|
||
return parameters_schema | ||
|
||
|
||
def get_param_descriptions(method: Callable) -> Dict[str, str]: | ||
""" | ||
Extracts parameter descriptions from the method's docstring using docstring_parser. | ||
|
||
:param method: The method to extract parameter descriptions from. | ||
:returns: A dictionary mapping parameter names to their descriptions. | ||
""" | ||
docstring = getdoc(method) | ||
if not docstring: | ||
return {} | ||
|
||
parsed_doc = parse(docstring) | ||
param_descriptions = {} | ||
for param in parsed_doc.params: | ||
if not param.description: | ||
logger.warning( | ||
"Missing description for parameter '%s'. Please add a description in the component's " | ||
"run() method docstring using the format ':param %s: <description>'. " | ||
"This description is used to generate the Tool and helps the LLM understand how to use this parameter.", | ||
param.arg_name, | ||
param.arg_name, | ||
) | ||
param_descriptions[param.arg_name] = param.description.strip() if param.description else "" | ||
return param_descriptions | ||
|
||
|
||
def is_nullable_type(python_type: Any) -> bool: | ||
""" | ||
Checks if the type is a Union with NoneType (i.e., Optional). | ||
|
||
:param python_type: The Python type to check. | ||
:returns: True if the type is a Union with NoneType, False otherwise. | ||
""" | ||
origin = get_origin(python_type) | ||
if origin is Union: | ||
return type(None) in get_args(python_type) | ||
return False | ||
|
||
|
||
def _create_list_schema(item_type: Any, description: str) -> Dict[str, Any]: | ||
""" | ||
Creates a schema for a list type. | ||
|
||
:param item_type: The type of items in the list. | ||
:param description: The description of the list. | ||
:returns: A dictionary representing the list schema. | ||
""" | ||
items_schema = create_property_schema(item_type, "") | ||
items_schema.pop("description", None) | ||
return {"type": "array", "description": description, "items": items_schema} | ||
|
||
|
||
def _create_dataclass_schema(python_type: Any, description: str) -> Dict[str, Any]: | ||
""" | ||
Creates a schema for a dataclass. | ||
|
||
:param python_type: The dataclass type. | ||
:param description: The description of the dataclass. | ||
:returns: A dictionary representing the dataclass schema. | ||
""" | ||
schema = {"type": "object", "description": description, "properties": {}} | ||
cls = python_type if isinstance(python_type, type) else python_type.__class__ | ||
for field in fields(cls): | ||
field_description = f"Field '{field.name}' of '{cls.__name__}'." | ||
if isinstance(schema["properties"], dict): | ||
schema["properties"][field.name] = create_property_schema(field.type, field_description) | ||
return schema | ||
|
||
|
||
def _create_pydantic_schema(python_type: Any, description: str) -> Dict[str, Any]: | ||
""" | ||
Creates a schema for a Pydantic model. | ||
|
||
:param python_type: The Pydantic model type. | ||
:param description: The description of the model. | ||
:returns: A dictionary representing the Pydantic model schema. | ||
""" | ||
schema = {"type": "object", "description": description, "properties": {}} | ||
required_fields = [] | ||
|
||
for m_name, m_field in python_type.model_fields.items(): | ||
field_description = f"Field '{m_name}' of '{python_type.__name__}'." | ||
if isinstance(schema["properties"], dict): | ||
schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) | ||
if m_field.is_required(): | ||
required_fields.append(m_name) | ||
|
||
if required_fields: | ||
schema["required"] = required_fields | ||
return schema | ||
|
||
|
||
def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]: | ||
""" | ||
Creates a schema for a basic Python type. | ||
|
||
:param python_type: The Python type. | ||
:param description: The description of the type. | ||
:returns: A dictionary representing the basic type schema. | ||
""" | ||
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} | ||
return {"type": type_mapping.get(python_type, "string"), "description": description} | ||
|
||
|
||
def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: | ||
""" | ||
Creates a property schema for a given Python type, recursively if necessary. | ||
|
||
:param python_type: The Python type to create a property schema for. | ||
:param description: The description of the property. | ||
:param default: The default value of the property. | ||
:returns: A dictionary representing the property schema. | ||
""" | ||
nullable = is_nullable_type(python_type) | ||
if nullable: | ||
non_none_types = [t for t in get_args(python_type) if t is not type(None)] | ||
python_type = non_none_types[0] if non_none_types else str | ||
|
||
origin = get_origin(python_type) | ||
if origin is list: | ||
schema = _create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) | ||
elif is_dataclass(python_type): | ||
schema = _create_dataclass_schema(python_type, description) | ||
elif is_pydantic_v2_model(python_type): | ||
schema = _create_pydantic_schema(python_type, description) | ||
else: | ||
schema = _create_basic_type_schema(python_type, description) | ||
|
||
if default is not None: | ||
schema["default"] = default | ||
|
||
return schema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok makes sense, will do 🙏