47 lines
2.1 KiB
Python
47 lines
2.1 KiB
Python
from typing import Optional, Dict, Any
|
|
from google.adk.sessions import InMemorySessionService
|
|
from google.adk.runners import Runner
|
|
from google.adk.agents import Agent
|
|
from google.genai import types
|
|
|
|
class AgentCaller:
|
|
"""Wrapper for interacting with an ADK agent."""
|
|
def __init__(self, agent: Agent, runner: Runner, user_id: str, session_id: str):
|
|
self.agent = agent
|
|
self.runner = runner
|
|
self.user_id = user_id
|
|
self.session_id = session_id
|
|
|
|
async def get_session(self):
|
|
return await self.runner.session_service.get_session(
|
|
app_name=self.runner.app_name, user_id=self.user_id, session_id=self.session_id
|
|
)
|
|
|
|
async def call(self, user_message: str, verbose: bool = False):
|
|
content = types.Content(role='user', parts=[types.Part(text=user_message)])
|
|
final_response_text = "Agent did not produce a final response."
|
|
async for event in self.runner.run_async(
|
|
user_id=self.user_id, session_id=self.session_id, new_message=content
|
|
):
|
|
if verbose:
|
|
print(f"[Event] Author: {event.author}, Final: {event.is_final_response()}")
|
|
if event.is_final_response():
|
|
if event.content and event.content.parts:
|
|
final_response_text = event.content.parts[0].text
|
|
elif getattr(event, "actions", None) and getattr(event.actions, "escalate", False):
|
|
final_response_text = f"Agent escalated: {getattr(event, 'error_message', 'No specific message.')}"
|
|
break
|
|
return final_response_text
|
|
|
|
async def make_agent_caller(agent: Agent, initial_state: Optional[Dict[str, Any]] = None) -> AgentCaller:
|
|
app_name = agent.name + "_app"
|
|
user_id = agent.name + "_user"
|
|
session_id = agent.name + "_session_01"
|
|
|
|
session_service = InMemorySessionService()
|
|
await session_service.create_session(
|
|
app_name=app_name, user_id=user_id, session_id=session_id, state=initial_state or {}
|
|
)
|
|
runner = Runner(agent=agent, app_name=app_name, session_service=session_service)
|
|
return AgentCaller(agent, runner, user_id, session_id)
|