Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(platform): Support multiple credentials inputs on blocks #8932

Open
wants to merge 17 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions autogpt_platform/backend/backend/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from backend.util.settings import Config

from .model import (
CREDENTIALS_FIELD_NAME,
ContributorDetails,
Credentials,
CredentialsMetaInput,
is_credentials_field_name,
)

app_config = Config()
Expand Down Expand Up @@ -140,17 +140,38 @@ def get_required_fields(cls) -> set[str]:
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""Validates the schema definition. Rules:
- Only one `CredentialsMetaInput` field may be present.
- This field MUST be called `credentials`.
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
- Fields with annotation `CredentialsMetaInput` MUST be
named `credentials` or `*_credentials`
- Fields named `credentials` or `*_credentials` MUST be
of type `CredentialsMetaInput`
"""
super().__pydantic_init_subclass__(**kwargs)

# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}

credentials_fields = [
field_name
credentials_fields = cls.get_credentials_fields()

for field_name in cls.get_fields():
if is_credentials_field_name(field_name):
if field_name not in credentials_fields:
raise TypeError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
f"is not of type {CredentialsMetaInput.__name__}"
)

credentials_fields[field_name].validate_credentials_field_schema(cls)

elif field_name in credentials_fields:
raise KeyError(
f"Credentials field '{field_name}' on {cls.__qualname__} "
"has invalid name: must be 'credentials' or *_credentials"
)

@classmethod
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
return {
field_name: info.annotation
for field_name, info in cls.model_fields.items()
if (
inspect.isclass(info.annotation)
Expand All @@ -159,32 +180,7 @@ def __pydantic_init_subclass__(cls, **kwargs):
CredentialsMetaInput,
)
)
]
if len(credentials_fields) > 1:
raise ValueError(
f"{cls.__qualname__} can only have one CredentialsMetaInput field"
)
elif (
len(credentials_fields) == 1
and credentials_fields[0] != CREDENTIALS_FIELD_NAME
):
raise ValueError(
f"CredentialsMetaInput field on {cls.__qualname__} "
"must be named 'credentials'"
)
elif (
len(credentials_fields) == 0
and CREDENTIALS_FIELD_NAME in cls.model_fields.keys()
):
raise TypeError(
f"Field 'credentials' on {cls.__qualname__} "
f"must be of type {CredentialsMetaInput.__name__}"
)
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
credentials_input_type = cast(
CredentialsMetaInput, credentials_field.annotation
)
credentials_input_type.validate_credentials_field_schema(cls)
}


BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
Expand Down Expand Up @@ -242,7 +238,7 @@ def __init__(
test_input: BlockInput | list[BlockInput] | None = None,
test_output: BlockData | list[BlockData] | None = None,
test_mock: dict[str, Any] | None = None,
test_credentials: Optional[Credentials] = None,
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
disabled: bool = False,
static_output: bool = False,
block_type: BlockType = BlockType.STANDARD,
Expand Down Expand Up @@ -299,6 +295,12 @@ def __init__(
"field must be a BaseModel and all its fields must be boolean"
)

# Disallow multiple credentials inputs on webhook blocks
if len(self.input_schema.get_credentials_fields()) > 1:
raise ValueError(
"Multiple credentials input fields not supported on webhook blocks"
)

# Enforce presence of 'payload' input
if "payload" not in self.input_schema.model_fields:
raise TypeError(
Expand Down
44 changes: 27 additions & 17 deletions autogpt_platform/backend/backend/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ class UserIntegrations(BaseModel):
CT = TypeVar("CT", bound=CredentialsType)


CREDENTIALS_FIELD_NAME = "credentials"
def is_credentials_field_name(field_name: str) -> bool:
return field_name == "credentials" or field_name.endswith("_credentials")


class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
Expand All @@ -247,21 +248,21 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT

@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = get_args(
cls.model_fields["provider"].annotation
)
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
@classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...]:
return get_args(cls.model_fields["provider"].annotation)

model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
@classmethod
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
return get_args(cls.model_fields["type"].annotation)

@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a `credentials` field"""
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
"""Validates the schema of a credentials input field"""
field_name = next(
name for name, type in model.get_credentials_fields().items() if type is cls
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
Expand All @@ -275,11 +276,20 @@ def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
f"{field_schema}"
) from e

if (
len(schema_extra.credentials_provider) > 1
and not schema_extra.discriminator
):
raise TypeError("Multi-provider CredentialsField requires discriminator!")
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
raise TypeError(
f"Multi-provider CredentialsField '{field_name}' "
"requires discriminator!"
)

@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = cls.allowed_providers()
schema["credentials_types"] = cls.allowed_cred_types()

model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)


class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
Expand Down
78 changes: 41 additions & 37 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast

from pydantic import BaseModel
from redis.lock import Lock as RedisLock

if TYPE_CHECKING:
Expand All @@ -20,7 +19,14 @@

from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.execution import (
ExecutionQueue,
ExecutionResult,
Expand All @@ -31,7 +37,6 @@
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
Expand Down Expand Up @@ -170,10 +175,11 @@ def update_execution(status: ExecutionStatus) -> ExecutionResult:
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
creds_lock = None
if CREDENTIALS_FIELD_NAME in input_data:
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
input_model = cast(type[BlockSchema], node_block.input_schema)
for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name])
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
extra_exec_kwargs["credentials"] = credentials
extra_exec_kwargs[field_name] = credentials

output_size = 0
end_status = ExecutionStatus.COMPLETED
Expand Down Expand Up @@ -890,41 +896,39 @@ def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")

# Find any fields of type CredentialsMetaInput
model_fields = cast(type[BaseModel], block.input_schema).model_fields
if CREDENTIALS_FIELD_NAME not in model_fields:
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue

field = model_fields[CREDENTIALS_FIELD_NAME]

# The BlockSchema class enforces that a `credentials` field is always a
# `CredentialsMetaInput`, so we can safely assume this here.
credentials_meta_type = cast(CredentialsMetaInput, field.annotation)
credentials_meta = credentials_meta_type.model_validate(
node.input_default[CREDENTIALS_FIELD_NAME]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id}"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)


# ------- UTILITIES ------- #
Expand Down
Loading
Loading