Skip to content

Commit 426df10

Browse files
authored
Fix parsing schema for extract with no arguments (#158)
1 parent cd3dc7f commit 426df10

File tree

6 files changed

+49
-9
lines changed

6 files changed

+49
-9
lines changed

.changeset/daffy-rapid-turaco.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Fix parsing schema for extract with no arguments (full page extract)

format

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
3+
# Define source directories (adjust as needed)
4+
SOURCE_DIRS="stagehand"
5+
6+
# Apply Black formatting first
7+
echo "Applying Black formatting..."
8+
black $SOURCE_DIRS
9+
10+
# Apply Ruff with autofix for all issues (including import sorting)
11+
echo "Applying Ruff autofixes (including import sorting)..."
12+
ruff check --fix $SOURCE_DIRS
13+
14+
echo "Checking for remaining issues..."
15+
ruff check $SOURCE_DIRS
16+
17+
echo "Done! Code has been formatted and linted."

stagehand/handlers/extract_handler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from stagehand.a11y.utils import get_accessibility_tree
88
from stagehand.llm.inference import extract as extract_inference
99
from stagehand.metrics import StagehandFunctionName # Changed import location
10-
from stagehand.types import DefaultExtractSchema, ExtractOptions, ExtractResult
10+
from stagehand.types import (
11+
DefaultExtractSchema,
12+
EmptyExtractSchema,
13+
ExtractOptions,
14+
ExtractResult,
15+
)
1116
from stagehand.utils import inject_urls, transform_url_strings_to_ids
1217

1318
T = TypeVar("T", bound=BaseModel)
@@ -166,4 +171,6 @@ async def _extract_page_text(self) -> ExtractResult:
166171

167172
tree = await get_accessibility_tree(self.stagehand_page, self.logger)
168173
output_string = tree["simplified"]
169-
return ExtractResult(data=output_string)
174+
output_dict = {"page_text": output_string}
175+
validated_model = EmptyExtractSchema.model_validate(output_dict)
176+
return ExtractResult(data=validated_model).data

stagehand/page.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ObserveOptions,
1717
ObserveResult,
1818
)
19-
from .types import DefaultExtractSchema
19+
from .types import DefaultExtractSchema, EmptyExtractSchema
2020

2121
_INJECTION_SCRIPT = None
2222

@@ -361,12 +361,17 @@ async def extract(
361361
processed_data_payload = result_dict
362362
if schema_to_validate_with and isinstance(processed_data_payload, dict):
363363
try:
364-
validated_model = schema_to_validate_with.model_validate(
365-
processed_data_payload
366-
)
367-
processed_data_payload = (
368-
validated_model # Payload is now the Pydantic model instance
369-
)
364+
# For extract with no params
365+
if not options_obj:
366+
validated_model = EmptyExtractSchema.model_validate(
367+
processed_data_payload
368+
)
369+
processed_data_payload = validated_model
370+
else:
371+
validated_model = schema_to_validate_with.model_validate(
372+
processed_data_payload
373+
)
374+
processed_data_payload = validated_model
370375
except Exception as e:
371376
self._stagehand.logger.error(
372377
f"Failed to validate extracted data against schema {schema_to_validate_with.__name__}: {e}. Keeping raw data dict in .data field."

stagehand/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ActOptions,
2424
ActResult,
2525
DefaultExtractSchema,
26+
EmptyExtractSchema,
2627
ExtractOptions,
2728
ExtractResult,
2829
MetadataSchema,
@@ -56,4 +57,5 @@
5657
"AgentConfig",
5758
"AgentExecuteOptions",
5859
"AgentResult",
60+
"EmptyExtractSchema",
5961
]

stagehand/types/page.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class DefaultExtractSchema(BaseModel):
99
extraction: str
1010

1111

12+
class EmptyExtractSchema(BaseModel):
13+
page_text: str
14+
15+
1216
class ObserveElementSchema(BaseModel):
1317
element_id: int
1418
description: str = Field(

0 commit comments

Comments
 (0)