diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 7106c13a56..a1d60df1a2 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -17,7 +17,6 @@ from io import StringIO from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, cast -from urllib.parse import ParseResult, urlparse from weakref import WeakKeyDictionary from cachetools import LRUCache @@ -204,6 +203,12 @@ class _BasicCrawlerOptions(TypedDict): """If set to `True`, the crawler will automatically try to fetch the robots.txt file for each domain, and skip those that are not allowed. This also prevents disallowed URLs to be added via `EnqueueLinksFunction`.""" + send_request_enqueue_strategy: NotRequired[EnqueueStrategy] + """Strategy applied by `BasicCrawlingContext.send_request` to validate the target URL against the current request's + URL. Defaults to `'all'`, which preserves the historical behaviour of allowing arbitrary URLs. Set to e.g. + `'same-hostname'` to restrict `send_request` to the same host as the current request, for example when handlers + pass URLs extracted from page content to `send_request` directly.""" + status_message_logging_interval: NotRequired[timedelta] """Interval for logging the crawler status messages.""" @@ -299,6 +304,7 @@ def __init__( configure_logging: bool = True, statistics_log_format: Literal['table', 'inline'] = 'table', respect_robots_txt_file: bool = False, + send_request_enqueue_strategy: EnqueueStrategy = 'all', status_message_logging_interval: timedelta = timedelta(seconds=10), status_message_callback: Callable[[StatisticsState, StatisticsState | None, str], Awaitable[str | None]] | None = None, @@ -352,6 +358,11 @@ def __init__( respect_robots_txt_file: If set to `True`, the crawler will automatically try to fetch the robots.txt file for each domain, and skip those that are not allowed. This also prevents disallowed URLs to be added via `EnqueueLinksFunction` + send_request_enqueue_strategy: Strategy applied by `BasicCrawlingContext.send_request` to validate the + target URL against the current request's URL (`loaded_url` if available, otherwise `url`). Defaults + to `'all'`, preserving the historical behaviour of allowing arbitrary URLs. Set to e.g. + `'same-hostname'` when handlers extract URLs from page content and pass them straight to + `send_request`, to keep cross-host fetches intentional. status_message_logging_interval: Interval for logging the crawler status messages. status_message_callback: Allows overriding the default status message. The default status message is provided in the parameters. Returning `None` suppresses the status message. @@ -432,6 +443,7 @@ def __init__( self._max_session_rotations = max_session_rotations self._max_crawl_depth = max_crawl_depth self._respect_robots_txt_file = respect_robots_txt_file + self._send_request_enqueue_strategy = send_request_enqueue_strategy # Timeouts self._request_handler_timeout = request_handler_timeout @@ -981,8 +993,8 @@ async def _check_url_after_redirects(self, context: TCrawlingContext) -> AsyncGe """ if context.request.loaded_url is not None and not self._check_enqueue_strategy( context.request.enqueue_strategy, - origin_url=urlparse(context.request.url), - target_url=urlparse(context.request.loaded_url), + origin_url=URL(context.request.url), + target_url=URL(context.request.loaded_url), ): raise ContextPipelineInterruptedError( f'Skipping URL {context.request.loaded_url} (redirected from {context.request.url})' @@ -1053,10 +1065,10 @@ def _enqueue_links_filter_iterator( ) -> Iterator[TRequestIterator]: """Filter requests based on the enqueue strategy and URL patterns.""" limit = kwargs.get('limit') - parsed_origin_url = urlparse(origin_url) + parsed_origin_url = URL(origin_url) strategy = kwargs.get('strategy', 'all') - if strategy == 'all' and not parsed_origin_url.hostname: + if strategy == 'all' and not parsed_origin_url.host: self.log.warning(f'Skipping enqueue: Missing hostname in origin_url = {origin_url}.') return @@ -1070,9 +1082,9 @@ def _enqueue_links_filter_iterator( target_url = request.url else: target_url = request - parsed_target_url = urlparse(target_url) + parsed_target_url = URL(target_url) - if warning_flag and strategy != 'all' and not parsed_target_url.hostname: + if warning_flag and strategy != 'all' and not parsed_target_url.host: self.log.warning(f'Skipping enqueue url: Missing hostname in target_url = {target_url}.') warning_flag = False @@ -1090,31 +1102,30 @@ def _check_enqueue_strategy( self, strategy: EnqueueStrategy, *, - target_url: ParseResult, - origin_url: ParseResult, + target_url: URL, + origin_url: URL, ) -> bool: """Check if a URL matches the enqueue_strategy.""" if strategy == 'all': return True - if origin_url.hostname is None or target_url.hostname is None: + if origin_url.host is None or target_url.host is None: self.log.debug( - f'Skipping enqueue: Missing hostname in origin_url = {origin_url.geturl()} or ' - f'target_url = {target_url.geturl()}' + f'Skipping enqueue: Missing hostname in origin_url = {origin_url} or target_url = {target_url}' ) return False if strategy == 'same-hostname': - return target_url.hostname == origin_url.hostname + return target_url.host == origin_url.host if strategy == 'same-domain': - origin_domain = self._tld_extractor.extract_str(origin_url.hostname).top_domain_under_public_suffix - target_domain = self._tld_extractor.extract_str(target_url.hostname).top_domain_under_public_suffix + origin_domain = self._tld_extractor.extract_str(origin_url.host).top_domain_under_public_suffix + target_domain = self._tld_extractor.extract_str(target_url.host).top_domain_under_public_suffix return origin_domain == target_domain if strategy == 'same-origin': return ( - target_url.hostname == origin_url.hostname + target_url.host == origin_url.host and target_url.scheme == origin_url.scheme and target_url.port == origin_url.port ) @@ -1275,7 +1286,12 @@ def _prepare_send_request_function( self, session: Session | None, proxy_info: ProxyInfo | None, + request: Request, ) -> SendRequestFunction: + strategy = self._send_request_enqueue_strategy + origin_url = request.loaded_url or request.url + origin_parsed = URL(origin_url) if strategy != 'all' else None + async def send_request( url: str, *, @@ -1283,6 +1299,15 @@ async def send_request( payload: HttpPayload | None = None, headers: HttpHeaders | dict[str, str] | None = None, ) -> HttpResponse: + if origin_parsed is not None and not self._check_enqueue_strategy( + strategy, + target_url=URL(url), + origin_url=origin_parsed, + ): + raise ValueError( + f'send_request() refusing to fetch {url!r}: does not match enqueue strategy ' + f'{strategy!r} relative to {origin_url!r}.' + ) return await self._http_client.send_request( url=url, method=method, @@ -1428,7 +1453,7 @@ async def __run_task_function(self) -> None: request=result.request, session=session, proxy_info=proxy_info, - send_request=self._prepare_send_request_function(session, proxy_info), + send_request=self._prepare_send_request_function(session, proxy_info, request), add_requests=result.add_requests, push_data=result.push_data, get_key_value_store=result.get_key_value_store, diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index f358ac4a9d..557712484d 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -19,7 +19,7 @@ import pytest -from crawlee import ConcurrencySettings, Glob, service_locator +from crawlee import ConcurrencySettings, EnqueueStrategy, Glob, service_locator from crawlee._request import Request, RequestState from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, HttpMethod from crawlee._utils.robots import RobotsTxtFile @@ -340,6 +340,43 @@ async def failed_request_handler(context: BasicCrawlingContext, error: Exception await crawler.run(['https://a.placeholder.com', 'https://b.placeholder.com', 'https://c.placeholder.com']) +@pytest.mark.parametrize( + ('strategy', 'target_path', 'should_succeed'), + [ + pytest.param('all', 'get', True, id='default-all-allows-same-host'), + pytest.param('same-hostname', 'get', True, id='same-hostname-allows-same-host'), + pytest.param('same-hostname', 'http://other.test/payload', False, id='same-hostname-rejects-cross-host'), + ], +) +async def test_send_request_enqueue_strategy( + server_url: URL, strategy: EnqueueStrategy, target_path: str, *, should_succeed: bool +) -> None: + bodies: list[bytes] = [] + errors: list[Exception] = [] + + crawler = BasicCrawler(max_request_retries=1, send_request_enqueue_strategy=strategy) + target_url = target_path if target_path.startswith('http') else str(server_url / target_path) + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + try: + response = await context.send_request(target_url) + except ValueError as exc: + errors.append(exc) + else: + bodies.append(await response.read()) + + await crawler.run([str(server_url / 'a/page')]) + + if should_succeed: + assert bodies, 'expected the handler to receive a response' + assert not errors + else: + assert errors, 'expected send_request to refuse the target URL' + assert strategy in str(errors[0]) + assert target_url in str(errors[0]) + + @pytest.mark.parametrize( ('method', 'path', 'payload'), [