diff --git a/sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py b/sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py index 43a5ca502d..3ca9b82fa6 100644 --- a/sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py +++ b/sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py @@ -17,7 +17,6 @@ import os import re import shutil -import stat import subprocess from tempfile import TemporaryDirectory from typing import Any, Dict, List, Optional @@ -58,15 +57,34 @@ SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE" -def _rmtree(path): +def _rmtree(path, image=None, is_studio=False): """Remove a directory tree, handling root-owned files from Docker containers.""" - def _onerror(func, path, exc_info): - if isinstance(exc_info[1], PermissionError): - os.chmod(path, stat.S_IRWXU) - func(path) - else: - raise exc_info[1] - shutil.rmtree(path, onerror=_onerror) + try: + shutil.rmtree(path) + except PermissionError: + # Files created by Docker containers are owned by root. + # Use docker to chmod as root, then retry shutil.rmtree. + if image is None: + logger.warning( + "Failed to clean up root-owned files in %s. " + "You may need to remove them manually with: sudo rm -rf %s", + path, path, + ) + raise + try: + cmd = ["docker", "run", "--rm"] + if is_studio: + cmd += ["--network", "sagemaker"] + cmd += ["-v", f"{path}:/delete", image, "chmod", "-R", "777", "/delete"] + subprocess.run(cmd, check=True, capture_output=True) + shutil.rmtree(path) + except Exception: + logger.warning( + "Failed to clean up root-owned files in %s. " + "You may need to remove them manually with: sudo rm -rf %s", + path, path, + ) + raise class _LocalContainer(BaseModel): @@ -221,12 +239,12 @@ def train( # Print our Job Complete line logger.info("Local training job completed, output artifacts saved to %s", artifacts) - _rmtree(os.path.join(self.container_root, "input")) - _rmtree(os.path.join(self.container_root, "shared")) + _rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio) + _rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio) for host in self.hosts: - _rmtree(os.path.join(self.container_root, host)) + _rmtree(os.path.join(self.container_root, host), self.image, self.is_studio) for folder in self._temporary_folders: - _rmtree(os.path.join(self.container_root, folder)) + _rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio) return artifacts def retrieve_artifacts( diff --git a/sagemaker-train/src/sagemaker/train/local/local_container.py b/sagemaker-train/src/sagemaker/train/local/local_container.py index bcf84395ce..329e21237b 100644 --- a/sagemaker-train/src/sagemaker/train/local/local_container.py +++ b/sagemaker-train/src/sagemaker/train/local/local_container.py @@ -18,7 +18,6 @@ import os import re import shutil -import stat import subprocess from tempfile import TemporaryDirectory from typing import Any, Dict, List, Optional @@ -66,15 +65,34 @@ SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE" -def _rmtree(path): +def _rmtree(path, image=None, is_studio=False): """Remove a directory tree, handling root-owned files from Docker containers.""" - def _onerror(func, path, exc_info): - if isinstance(exc_info[1], PermissionError): - os.chmod(path, stat.S_IRWXU) - func(path) - else: - raise exc_info[1] - shutil.rmtree(path, onerror=_onerror) + try: + shutil.rmtree(path) + except PermissionError: + # Files created by Docker containers are owned by root. + # Use docker to chmod as root, then retry shutil.rmtree. + if image is None: + logger.warning( + "Failed to clean up root-owned files in %s. " + "You may need to remove them manually with: sudo rm -rf %s", + path, path, + ) + raise + try: + cmd = ["docker", "run", "--rm"] + if is_studio: + cmd += ["--network", "sagemaker"] + cmd += ["-v", f"{path}:/delete", image, "chmod", "-R", "777", "/delete"] + subprocess.run(cmd, check=True, capture_output=True) + shutil.rmtree(path) + except Exception: + logger.warning( + "Failed to clean up root-owned files in %s. " + "You may need to remove them manually with: sudo rm -rf %s", + path, path, + ) + raise class _LocalContainer(BaseModel): @@ -229,12 +247,12 @@ def train( # Print our Job Complete line logger.info("Local training job completed, output artifacts saved to %s", artifacts) - _rmtree(os.path.join(self.container_root, "input")) - _rmtree(os.path.join(self.container_root, "shared")) + _rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio) + _rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio) for host in self.hosts: - _rmtree(os.path.join(self.container_root, host)) + _rmtree(os.path.join(self.container_root, host), self.image, self.is_studio) for folder in self._temporary_folders: - _rmtree(os.path.join(self.container_root, folder)) + _rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio) return artifacts def retrieve_artifacts( diff --git a/sagemaker-train/tests/unit/train/local/test_local_container.py b/sagemaker-train/tests/unit/train/local/test_local_container.py new file mode 100644 index 0000000000..ff3abd53eb --- /dev/null +++ b/sagemaker-train/tests/unit/train/local/test_local_container.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from unittest.mock import patch, call +import pytest + +from sagemaker.train.local.local_container import _rmtree + +IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1-cpu-py310" + + +class TestRmtree: + """Test cases for _rmtree function.""" + + @patch("sagemaker.train.local.local_container.shutil.rmtree") + def test_rmtree_success(self, mock_rmtree): + """Normal case — shutil.rmtree succeeds.""" + _rmtree("/tmp/test", IMAGE) + mock_rmtree.assert_called_once_with("/tmp/test") + + @patch("sagemaker.train.local.local_container.shutil.rmtree") + @patch("sagemaker.train.local.local_container.subprocess.run") + def test_rmtree_permission_error_docker_chmod_fallback(self, mock_run, mock_rmtree): + """PermissionError triggers docker chmod then retry.""" + mock_rmtree.side_effect = [PermissionError("Permission denied"), None] + + _rmtree("/tmp/test", IMAGE) + + mock_run.assert_called_once_with( + ["docker", "run", "--rm", "-v", "/tmp/test:/delete", IMAGE, "chmod", "-R", "777", "/delete"], + check=True, + capture_output=True, + ) + assert mock_rmtree.call_count == 2 + + @patch("sagemaker.train.local.local_container.shutil.rmtree") + @patch("sagemaker.train.local.local_container.subprocess.run") + def test_rmtree_studio_adds_network(self, mock_run, mock_rmtree): + """In Studio, docker run includes --network sagemaker.""" + mock_rmtree.side_effect = [PermissionError("Permission denied"), None] + + _rmtree("/tmp/test", IMAGE, is_studio=True) + + mock_run.assert_called_once_with( + [ + "docker", "run", "--rm", + "--network", "sagemaker", + "-v", "/tmp/test:/delete", IMAGE, + "chmod", "-R", "777", "/delete", + ], + check=True, + capture_output=True, + ) + + @patch("sagemaker.train.local.local_container.shutil.rmtree") + @patch("sagemaker.train.local.local_container.subprocess.run") + def test_rmtree_docker_fallback_fails_raises(self, mock_run, mock_rmtree): + """If docker fallback also fails, the exception propagates.""" + mock_rmtree.side_effect = PermissionError("Permission denied") + mock_run.side_effect = Exception("docker failed") + + with pytest.raises(Exception, match="docker failed"): + _rmtree("/tmp/test", IMAGE) + + @patch("sagemaker.train.local.local_container.shutil.rmtree") + def test_rmtree_no_image_raises(self, mock_rmtree): + """PermissionError without image raises immediately.""" + mock_rmtree.side_effect = PermissionError("Permission denied") + + with pytest.raises(PermissionError): + _rmtree("/tmp/test")