-
Notifications
You must be signed in to change notification settings - Fork 652
/
st_utils.py
191 lines (159 loc) · 6.18 KB
/
st_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Streamlit utils."""
from core.agent_builder.loader import (
load_meta_agent_and_tools,
AgentCacheRegistry,
)
from core.agent_builder.base import BaseRAGAgentBuilder
from core.param_cache import ParamCache
from core.constants import (
AGENT_CACHE_DIR,
)
from typing import Optional, cast
from pydantic import BaseModel
from llama_index.agent.types import BaseAgent
import streamlit as st
def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:
"""Update selected agent with id."""
# set session state
st.session_state.selected_id = (
selected_id if selected_id != "Create a new agent" else None
)
# clear agent builder and builder agent
st.session_state.builder_agent = None
st.session_state.agent_builder = None
# clear selected cache
st.session_state.selected_cache = None
## handler for sidebar specifically
def update_selected_agent() -> None:
"""Update selected agent."""
selected_id = st.session_state.agent_selector
update_selected_agent_with_id(selected_id)
def get_cached_is_multimodal() -> bool:
"""Get default multimodal st."""
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
default_val = False
else:
selected_cache = cast(ParamCache, st.session_state.selected_cache)
default_val = True if selected_cache.builder_type == "multimodal" else False
return default_val
def get_is_multimodal() -> bool:
"""Get is multimodal."""
if "is_multimodal_st" not in st.session_state.keys():
st.session_state.is_multimodal_st = False
return st.session_state.is_multimodal_st
def add_builder_config() -> None:
"""Add builder config."""
with st.expander("Builder Config (Advanced)"):
# add a few options - openai api key, and
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
is_locked = False
else:
is_locked = True
st.checkbox(
"Enable multimodal search (beta)",
key="is_multimodal_st",
on_change=update_selected_agent,
value=get_cached_is_multimodal(),
disabled=is_locked,
)
def add_sidebar() -> None:
"""Add sidebar."""
with st.sidebar:
agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
choices = ["Create a new agent"] + st.session_state.cur_agent_ids
# by default, set index to 0. if value is in selected_id, set index to that
index = 0
if "selected_id" in st.session_state.keys():
if st.session_state.selected_id is not None:
index = choices.index(st.session_state.selected_id)
# display buttons
st.radio(
"Agents",
choices,
index=index,
on_change=update_selected_agent,
key="agent_selector",
)
class CurrentSessionState(BaseModel):
"""Current session state."""
# arbitrary types
class Config:
arbitrary_types_allowed = True
agent_registry: AgentCacheRegistry
selected_id: Optional[str]
selected_cache: Optional[ParamCache]
agent_builder: BaseRAGAgentBuilder
cache: ParamCache
builder_agent: BaseAgent
def get_current_state() -> CurrentSessionState:
"""Get current state.
This includes current state stored in session state and derived from it, e.g.
- agent registry
- selected agent
- selected cache
- agent builder
- builder agent
"""
# get agent registry
agent_registry = AgentCacheRegistry(str(AGENT_CACHE_DIR))
if "agent_registry" not in st.session_state.keys():
st.session_state.agent_registry = agent_registry
if "cur_agent_ids" not in st.session_state.keys():
st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
if "selected_id" not in st.session_state.keys():
st.session_state.selected_id = None
# set selected cache if doesn't exist
if (
"selected_cache" not in st.session_state.keys()
or st.session_state.selected_cache is None
):
# update selected cache
if st.session_state.selected_id is None:
st.session_state.selected_cache = None
else:
# load agent from directory
agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
agent_cache = agent_registry.get_agent_cache(st.session_state.selected_id)
st.session_state.selected_cache = agent_cache
# set builder agent / agent builder
if (
"builder_agent" not in st.session_state.keys()
or st.session_state.builder_agent is None
or "agent_builder" not in st.session_state.keys()
or st.session_state.agent_builder is None
):
if (
"selected_cache" in st.session_state.keys()
and st.session_state.selected_cache is not None
):
# create builder agent / tools from selected cache
builder_agent, agent_builder = load_meta_agent_and_tools(
cache=st.session_state.selected_cache,
agent_registry=st.session_state.agent_registry,
# NOTE: we will probably generalize this later into different
# builder configs
is_multimodal=get_cached_is_multimodal(),
)
else:
# create builder agent / tools from new cache
builder_agent, agent_builder = load_meta_agent_and_tools(
agent_registry=st.session_state.agent_registry,
is_multimodal=get_is_multimodal(),
)
st.session_state.builder_agent = builder_agent
st.session_state.agent_builder = agent_builder
return CurrentSessionState(
agent_registry=st.session_state.agent_registry,
selected_id=st.session_state.selected_id,
selected_cache=st.session_state.selected_cache,
agent_builder=st.session_state.agent_builder,
cache=st.session_state.agent_builder.cache,
builder_agent=st.session_state.builder_agent,
)