Skip to content

Allow to pass both session and input list #1298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/agents/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .session import Session, SQLiteSession
from .util import SessionInputHandler, SessionMixerCallable

__all__ = ["Session", "SQLiteSession"]
__all__ = ["Session", "SessionInputHandler", "SessionMixerCallable", "SQLiteSession"]
29 changes: 29 additions & 0 deletions src/agents/memory/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from typing import Callable, Union

from ..items import TResponseInputItem
from ..util._types import MaybeAwaitable

SessionMixerCallable = Callable[
[list[TResponseInputItem], list[TResponseInputItem]],
MaybeAwaitable[list[TResponseInputItem]],
]
"""A function that combines session history with new input items.

Args:
history_items: The list of items from the session history.
new_items: The list of new input items for the current turn.

Returns:
A list of combined items to be used as input for the agent. Can be sync or async.
"""


SessionInputHandler = Union[SessionMixerCallable, None]
"""Defines how to handle session history when new input is provided.

- `None` (default): The new input is appended to the session history.
- `SessionMixerCallable`: A custom function that receives the history and new input, and
returns the desired combined list of items.
"""
52 changes: 38 additions & 14 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .lifecycle import RunHooks
from .logger import logger
from .memory import Session
from .memory import Session, SessionInputHandler
from .model_settings import ModelSettings
from .models.interface import Model, ModelProvider
from .models.multi_provider import MultiProvider
Expand Down Expand Up @@ -139,6 +139,14 @@ class RunConfig:
An optional dictionary of additional metadata to include with the trace.
"""

session_input_callback: SessionInputHandler = None
"""Defines how to handle session history when new input is provided.

- `None` (default): The new input is appended to the session history.
- `SessionMixerCallable`: A custom function that receives the history and new input, and
returns the desired combined list of items.
"""


class RunOptions(TypedDict, Generic[TContext]):
"""Arguments for ``AgentRunner`` methods."""
Expand Down Expand Up @@ -343,7 +351,9 @@ async def run(
run_config = RunConfig()

# Prepare input with session if enabled
prepared_input = await self._prepare_input_with_session(input, session)
prepared_input = await self._prepare_input_with_session(
input, session, run_config.session_input_callback
)

tool_use_tracker = AgentToolUseTracker()

Expand Down Expand Up @@ -662,7 +672,9 @@ async def _start_streaming(

try:
# Prepare input with session if enabled
prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session)
prepared_input = await AgentRunner._prepare_input_with_session(
starting_input, session, run_config.session_input_callback
)

# Update the streamed result with the prepared input
streamed_result.input = prepared_input
Expand Down Expand Up @@ -1191,18 +1203,18 @@ async def _prepare_input_with_session(
cls,
input: str | list[TResponseInputItem],
session: Session | None,
session_input_callback: SessionInputHandler,
) -> str | list[TResponseInputItem]:
"""Prepare input by combining it with session history if enabled."""
if session is None:
return input

# Validate that we don't have both a session and a list input, as this creates
# ambiguity about whether the list should append to or replace existing session history
if isinstance(input, list):
# If the user doesn't explicitly specify a mode, raise an error
if isinstance(input, list) and not session_input_callback:
raise UserError(
"Cannot provide both a session and a list of input items. "
"When using session memory, provide only a string input to append to the "
"conversation, or use session=None and provide a list to manually manage "
"You must specify the `session_input_callback` in the `RunConfig`. "
"Otherwise, when using session memory, provide only a string input to append to "
"the conversation, or use session=None and provide a list to manually manage "
"conversation history."
)

Expand All @@ -1212,10 +1224,18 @@ async def _prepare_input_with_session(
# Convert input to list format
new_input_list = ItemHelpers.input_to_new_input_list(input)

# Combine history with new input
combined_input = history + new_input_list

return combined_input
if session_input_callback is None:
return history + new_input_list
elif callable(session_input_callback):
res = session_input_callback(history, new_input_list)
if inspect.isawaitable(res):
return await res
return res
else:
raise UserError(
f"Invalid `session_input_callback` value: {session_input_callback}. "
"Choose between `None` or a custom callable function."
)

@classmethod
async def _save_result_to_session(
Expand All @@ -1224,7 +1244,11 @@ async def _save_result_to_session(
original_input: str | list[TResponseInputItem],
result: RunResult,
) -> None:
"""Save the conversation turn to session."""
"""
Save the conversation turn to session.
It does not account for any filtering or modification performed by
`RunConfig.session_input_handling`.
"""
if session is None:
return

Expand Down
52 changes: 50 additions & 2 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from agents import Agent, Runner, SQLiteSession, TResponseInputItem
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
from agents.exceptions import UserError

from .fake_model import FakeModel
Expand Down Expand Up @@ -394,7 +394,55 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
await run_agent_async(runner_method, agent, list_input, session=session)

# Verify the error message explains the issue
assert "Cannot provide both a session and a list of input items" in str(exc_info.value)
assert "You must specify the `session_input_callback` in" in str(exc_info.value)
assert "manually manage conversation history" in str(exc_info.value)

session.close()


@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
@pytest.mark.asyncio
async def test_session_callback_prepared_input(runner_method):
"""Test if the user passes a list of items and want to append them."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "test_memory.db"

model = FakeModel()
agent = Agent(name="test", model=model)

# Session
session_id = "session_1"
session = SQLiteSession(session_id, db_path)

# Add first messages manually
initial_history: list[TResponseInputItem] = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm here to assist you."},
]
await session.add_items(initial_history)

def filter_assistant_messages(history, new_input):
# Only include user messages from history
return [item for item in history if item["role"] == "user"] + new_input

new_turn_input = [{"role": "user", "content": "What your name?"}]
model.set_next_output([get_text_message("I'm gpt-4o")])

# Run the agent with the callable
await run_agent_async(
runner_method,
agent,
new_turn_input,
session=session,
run_config=RunConfig(session_input_callback=filter_assistant_messages),
)

expected_model_input = [
initial_history[0], # From history
new_turn_input[0], # New input
]

assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"] == expected_model_input

session.close()