Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/crawlee/crawlers/_basic/_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast
from urllib.parse import ParseResult, urlparse
from weakref import WeakKeyDictionary

from cachetools import LRUCache
Expand Down Expand Up @@ -204,6 +203,12 @@ class _BasicCrawlerOptions(TypedDict):
"""If set to `True`, the crawler will automatically try to fetch the robots.txt file for each domain,
and skip those that are not allowed. This also prevents disallowed URLs to be added via `EnqueueLinksFunction`."""

send_request_enqueue_strategy: NotRequired[EnqueueStrategy]
"""Strategy applied by `BasicCrawlingContext.send_request` to validate the target URL against the current request's
URL. Defaults to `'all'`, which preserves the historical behaviour of allowing arbitrary URLs. Set to e.g.
`'same-hostname'` to restrict `send_request` to the same host as the current request, for example when handlers
pass URLs extracted from page content to `send_request` directly."""

status_message_logging_interval: NotRequired[timedelta]
"""Interval for logging the crawler status messages."""

Expand Down Expand Up @@ -299,6 +304,7 @@ def __init__(
configure_logging: bool = True,
statistics_log_format: Literal['table', 'inline'] = 'table',
respect_robots_txt_file: bool = False,
send_request_enqueue_strategy: EnqueueStrategy = 'all',
status_message_logging_interval: timedelta = timedelta(seconds=10),
status_message_callback: Callable[[StatisticsState, StatisticsState | None, str], Awaitable[str | None]]
| None = None,
Expand Down Expand Up @@ -352,6 +358,11 @@ def __init__(
respect_robots_txt_file: If set to `True`, the crawler will automatically try to fetch the robots.txt file
for each domain, and skip those that are not allowed. This also prevents disallowed URLs to be added
via `EnqueueLinksFunction`
send_request_enqueue_strategy: Strategy applied by `BasicCrawlingContext.send_request` to validate the
target URL against the current request's URL (`loaded_url` if available, otherwise `url`). Defaults
to `'all'`, preserving the historical behaviour of allowing arbitrary URLs. Set to e.g.
`'same-hostname'` when handlers extract URLs from page content and pass them straight to
`send_request`, to keep cross-host fetches intentional.
status_message_logging_interval: Interval for logging the crawler status messages.
status_message_callback: Allows overriding the default status message. The default status message is
provided in the parameters. Returning `None` suppresses the status message.
Expand Down Expand Up @@ -432,6 +443,7 @@ def __init__(
self._max_session_rotations = max_session_rotations
self._max_crawl_depth = max_crawl_depth
self._respect_robots_txt_file = respect_robots_txt_file
self._send_request_enqueue_strategy = send_request_enqueue_strategy

# Timeouts
self._request_handler_timeout = request_handler_timeout
Expand Down Expand Up @@ -981,8 +993,8 @@ async def _check_url_after_redirects(self, context: TCrawlingContext) -> AsyncGe
"""
if context.request.loaded_url is not None and not self._check_enqueue_strategy(
context.request.enqueue_strategy,
origin_url=urlparse(context.request.url),
target_url=urlparse(context.request.loaded_url),
origin_url=URL(context.request.url),
target_url=URL(context.request.loaded_url),
):
raise ContextPipelineInterruptedError(
f'Skipping URL {context.request.loaded_url} (redirected from {context.request.url})'
Expand Down Expand Up @@ -1053,10 +1065,10 @@ def _enqueue_links_filter_iterator(
) -> Iterator[TRequestIterator]:
"""Filter requests based on the enqueue strategy and URL patterns."""
limit = kwargs.get('limit')
parsed_origin_url = urlparse(origin_url)
parsed_origin_url = URL(origin_url)
strategy = kwargs.get('strategy', 'all')

if strategy == 'all' and not parsed_origin_url.hostname:
if strategy == 'all' and not parsed_origin_url.host:
self.log.warning(f'Skipping enqueue: Missing hostname in origin_url = {origin_url}.')
return

Expand All @@ -1070,9 +1082,9 @@ def _enqueue_links_filter_iterator(
target_url = request.url
else:
target_url = request
parsed_target_url = urlparse(target_url)
parsed_target_url = URL(target_url)

if warning_flag and strategy != 'all' and not parsed_target_url.hostname:
if warning_flag and strategy != 'all' and not parsed_target_url.host:
self.log.warning(f'Skipping enqueue url: Missing hostname in target_url = {target_url}.')
warning_flag = False

Expand All @@ -1090,31 +1102,30 @@ def _check_enqueue_strategy(
self,
strategy: EnqueueStrategy,
*,
target_url: ParseResult,
origin_url: ParseResult,
target_url: URL,
origin_url: URL,
) -> bool:
"""Check if a URL matches the enqueue_strategy."""
if strategy == 'all':
return True

if origin_url.hostname is None or target_url.hostname is None:
if origin_url.host is None or target_url.host is None:
self.log.debug(
f'Skipping enqueue: Missing hostname in origin_url = {origin_url.geturl()} or '
f'target_url = {target_url.geturl()}'
f'Skipping enqueue: Missing hostname in origin_url = {origin_url} or target_url = {target_url}'
)
return False

if strategy == 'same-hostname':
return target_url.hostname == origin_url.hostname
return target_url.host == origin_url.host

if strategy == 'same-domain':
origin_domain = self._tld_extractor.extract_str(origin_url.hostname).top_domain_under_public_suffix
target_domain = self._tld_extractor.extract_str(target_url.hostname).top_domain_under_public_suffix
origin_domain = self._tld_extractor.extract_str(origin_url.host).top_domain_under_public_suffix
target_domain = self._tld_extractor.extract_str(target_url.host).top_domain_under_public_suffix
return origin_domain == target_domain

if strategy == 'same-origin':
return (
target_url.hostname == origin_url.hostname
target_url.host == origin_url.host
and target_url.scheme == origin_url.scheme
and target_url.port == origin_url.port
)
Expand Down Expand Up @@ -1275,14 +1286,28 @@ def _prepare_send_request_function(
self,
session: Session | None,
proxy_info: ProxyInfo | None,
request: Request,
) -> SendRequestFunction:
strategy = self._send_request_enqueue_strategy
origin_url = request.loaded_url or request.url
origin_parsed = URL(origin_url) if strategy != 'all' else None

async def send_request(
url: str,
*,
method: HttpMethod = 'GET',
payload: HttpPayload | None = None,
headers: HttpHeaders | dict[str, str] | None = None,
) -> HttpResponse:
if origin_parsed is not None and not self._check_enqueue_strategy(
strategy,
target_url=URL(url),
origin_url=origin_parsed,
):
raise ValueError(
f'send_request() refusing to fetch {url!r}: does not match enqueue strategy '
f'{strategy!r} relative to {origin_url!r}.'
)
return await self._http_client.send_request(
url=url,
method=method,
Expand Down Expand Up @@ -1428,7 +1453,7 @@ async def __run_task_function(self) -> None:
request=result.request,
session=session,
proxy_info=proxy_info,
send_request=self._prepare_send_request_function(session, proxy_info),
send_request=self._prepare_send_request_function(session, proxy_info, request),
add_requests=result.add_requests,
push_data=result.push_data,
get_key_value_store=result.get_key_value_store,
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/crawlers/_basic/test_basic_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest

from crawlee import ConcurrencySettings, Glob, service_locator
from crawlee import ConcurrencySettings, EnqueueStrategy, Glob, service_locator
from crawlee._request import Request, RequestState
from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, HttpMethod
from crawlee._utils.robots import RobotsTxtFile
Expand Down Expand Up @@ -340,6 +340,43 @@ async def failed_request_handler(context: BasicCrawlingContext, error: Exception
await crawler.run(['https://a.placeholder.com', 'https://b.placeholder.com', 'https://c.placeholder.com'])


@pytest.mark.parametrize(
('strategy', 'target_path', 'should_succeed'),
[
pytest.param('all', 'get', True, id='default-all-allows-same-host'),
pytest.param('same-hostname', 'get', True, id='same-hostname-allows-same-host'),
pytest.param('same-hostname', 'http://other.test/payload', False, id='same-hostname-rejects-cross-host'),
],
)
async def test_send_request_enqueue_strategy(
server_url: URL, strategy: EnqueueStrategy, target_path: str, *, should_succeed: bool
) -> None:
bodies: list[bytes] = []
errors: list[Exception] = []

crawler = BasicCrawler(max_request_retries=1, send_request_enqueue_strategy=strategy)
target_url = target_path if target_path.startswith('http') else str(server_url / target_path)

@crawler.router.default_handler
async def handler(context: BasicCrawlingContext) -> None:
try:
response = await context.send_request(target_url)
except ValueError as exc:
errors.append(exc)
else:
bodies.append(await response.read())

await crawler.run([str(server_url / 'a/page')])

if should_succeed:
assert bodies, 'expected the handler to receive a response'
assert not errors
else:
assert errors, 'expected send_request to refuse the target URL'
assert strategy in str(errors[0])
assert target_url in str(errors[0])


@pytest.mark.parametrize(
('method', 'path', 'payload'),
[
Expand Down
Loading