Changes to filename and path validation (#45529)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Joakim Sørensen 2021-01-26 15:53:21 +01:00 committed by GitHub
parent 4739e8a207
commit b1c2cde40b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 19 deletions

View file

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.const import HTTP_OK from homeassistant.const import HTTP_OK
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util import sanitize_filename from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -70,8 +70,8 @@ def setup(hass, config):
overwrite = service.data.get(ATTR_OVERWRITE) overwrite = service.data.get(ATTR_OVERWRITE)
if subdir: # Check the path
subdir = sanitize_filename(subdir) raise_if_invalid_path(subdir)
final_path = None final_path = None
@ -101,8 +101,8 @@ def setup(hass, config):
if not filename: if not filename:
filename = "ha_download" filename = "ha_download"
# Remove stuff to ruin paths # Check the filename
filename = sanitize_filename(filename) raise_if_invalid_filename(filename)
# Do we want to download to subdir, create if needed # Do we want to download to subdir, create if needed
if subdir: if subdir:
@ -148,6 +148,16 @@ def setup(hass, config):
{"url": url, "filename": filename}, {"url": url, "filename": filename},
) )
# Remove file if we started downloading but failed
if final_path and os.path.isfile(final_path):
os.remove(final_path)
except ValueError:
_LOGGER.exception("Invalid value")
hass.bus.fire(
f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}",
{"url": url, "filename": filename},
)
# Remove file if we started downloading but failed # Remove file if we started downloading but failed
if final_path and os.path.isfile(final_path): if final_path and os.path.isfile(final_path):
os.remove(final_path) os.remove(final_path)

View file

@ -12,7 +12,6 @@ from homeassistant.helpers import collection, config_validation as cv
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType
from homeassistant.loader import async_get_integration from homeassistant.loader import async_get_integration
from homeassistant.util import sanitize_path
from . import dashboard, resources, websocket from . import dashboard, resources, websocket
from .const import ( from .const import (
@ -47,7 +46,7 @@ YAML_DASHBOARD_SCHEMA = vol.Schema(
{ {
**DASHBOARD_BASE_CREATE_FIELDS, **DASHBOARD_BASE_CREATE_FIELDS,
vol.Required(CONF_MODE): MODE_YAML, vol.Required(CONF_MODE): MODE_YAML,
vol.Required(CONF_FILENAME): vol.All(cv.string, sanitize_path), vol.Required(CONF_FILENAME): cv.path,
} }
) )

View file

@ -10,7 +10,7 @@ from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
from homeassistant.components.media_player.errors import BrowseError from homeassistant.components.media_player.errors import BrowseError
from homeassistant.components.media_source.error import Unresolvable from homeassistant.components.media_source.error import Unresolvable
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import sanitize_path from homeassistant.util import raise_if_invalid_filename
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
@ -50,8 +50,10 @@ class LocalSource(MediaSource):
if source_dir_id not in self.hass.config.media_dirs: if source_dir_id not in self.hass.config.media_dirs:
raise Unresolvable("Unknown source directory.") raise Unresolvable("Unknown source directory.")
if location != sanitize_path(location): try:
raise Unresolvable("Invalid path.") raise_if_invalid_filename(location)
except ValueError as err:
raise Unresolvable("Invalid path.") from err
return source_dir_id, location return source_dir_id, location
@ -189,8 +191,10 @@ class LocalMediaView(HomeAssistantView):
self, request: web.Request, source_dir_id: str, location: str self, request: web.Request, source_dir_id: str, location: str
) -> web.FileResponse: ) -> web.FileResponse:
"""Start a GET request.""" """Start a GET request."""
if location != sanitize_path(location): try:
raise web.HTTPNotFound() raise_if_invalid_filename(location)
except ValueError as err:
raise web.HTTPBadRequest() from err
if source_dir_id not in self.hass.config.media_dirs: if source_dir_id not in self.hass.config.media_dirs:
raise web.HTTPNotFound() raise web.HTTPNotFound()

View file

@ -23,7 +23,7 @@ from homeassistant.const import SERVICE_RELOAD
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import sanitize_filename from homeassistant.util import raise_if_invalid_filename
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.yaml.loader import load_yaml from homeassistant.util.yaml.loader import load_yaml
@ -135,7 +135,8 @@ def discover_scripts(hass):
def execute_script(hass, name, data=None): def execute_script(hass, name, data=None):
"""Execute a script.""" """Execute a script."""
filename = f"{name}.py" filename = f"{name}.py"
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil: raise_if_invalid_filename(filename)
with open(hass.config.path(FOLDER, filename)) as fil:
source = fil.read() source = fil.read()
execute(hass, filename, source, data) execute(hass, filename, source, data)

View file

@ -87,7 +87,7 @@ from homeassistant.helpers import (
template as template_helper, template as template_helper,
) )
from homeassistant.helpers.logging import KeywordStyleAdapter from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import sanitize_path, slugify as util_slugify from homeassistant.util import raise_if_invalid_path, slugify as util_slugify
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -118,8 +118,10 @@ def path(value: Any) -> str:
if not isinstance(value, str): if not isinstance(value, str):
raise vol.Invalid("Expected a string") raise vol.Invalid("Expected a string")
if sanitize_path(value) != value: try:
raise vol.Invalid("Invalid path") raise_if_invalid_path(value)
except ValueError as err:
raise vol.Invalid("Invalid path") from err
return value return value

View file

@ -1,4 +1,5 @@
"""Deprecation helpers for Home Assistant.""" """Deprecation helpers for Home Assistant."""
import functools
import inspect import inspect
import logging import logging
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
@ -73,3 +74,25 @@ def get_deprecated(
) )
return config.get(old_name) return config.get(old_name)
return config.get(new_name, default) return config.get(new_name, default)
def deprecated_function(replacement: str) -> Callable[..., Callable]:
"""Mark function as deprecated and provide a replacement function to be used instead."""
def deprecated_decorator(func: Callable) -> Callable:
"""Decorate function as deprecated."""
@functools.wraps(func)
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any:
"""Wrap for the original function."""
logger = logging.getLogger(func.__module__)
logger.warning(
"%s is a deprecated function. Use %s instead",
func.__name__,
replacement,
)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator

View file

@ -22,6 +22,7 @@ from typing import (
import slugify as unicode_slug import slugify as unicode_slug
from ..helpers.deprecation import deprecated_function
from .dt import as_local, utcnow from .dt import as_local, utcnow
T = TypeVar("T") T = TypeVar("T")
@ -32,6 +33,27 @@ RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)") RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
def raise_if_invalid_filename(filename: str) -> None:
"""
Check if a filename is valid.
Raises a ValueError if the filename is invalid.
"""
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
raise ValueError(f"{filename} is not a safe filename")
def raise_if_invalid_path(path: str) -> None:
"""
Check if a path is valid.
Raises a ValueError if the path is invalid.
"""
if RE_SANITIZE_PATH.sub("", path) != path:
raise ValueError(f"{path} is not a safe path")
@deprecated_function(replacement="raise_if_invalid_filename")
def sanitize_filename(filename: str) -> str: def sanitize_filename(filename: str) -> str:
"""Check if a filename is safe. """Check if a filename is safe.
@ -47,6 +69,7 @@ def sanitize_filename(filename: str) -> str:
return filename return filename
@deprecated_function(replacement="raise_if_invalid_path")
def sanitize_path(path: str) -> str: def sanitize_path(path: str) -> str:
"""Check if a path is safe. """Check if a path is safe.

View file

@ -23,7 +23,7 @@ async def test_async_browse_media(hass):
await media_source.async_browse_media( await media_source.async_browse_media(
hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist" hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist"
) )
assert str(excinfo.value) == "Path does not exist." assert str(excinfo.value) == "Invalid path."
# Test browse file # Test browse file
with pytest.raises(media_source.BrowseError) as excinfo: with pytest.raises(media_source.BrowseError) as excinfo:

View file

@ -1,7 +1,11 @@
"""Test deprecation helpers.""" """Test deprecation helpers."""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from homeassistant.helpers.deprecation import deprecated_substitute, get_deprecated from homeassistant.helpers.deprecation import (
deprecated_function,
deprecated_substitute,
get_deprecated,
)
class MockBaseClass: class MockBaseClass:
@ -78,3 +82,17 @@ def test_config_get_deprecated_new(mock_get_logger):
config = {"new_name": True} config = {"new_name": True}
assert get_deprecated(config, "new_name", "old_name") is True assert get_deprecated(config, "new_name", "old_name") is True
assert not mock_logger.warning.called assert not mock_logger.warning.called
def test_deprecated_function(caplog):
"""Test deprecated_function decorator."""
@deprecated_function("new_function")
def mock_deprecated_function():
pass
mock_deprecated_function()
assert (
"mock_deprecated_function is a deprecated function. Use new_function instead"
in caplog.text
)

View file

@ -24,6 +24,34 @@ def test_sanitize_path():
assert util.sanitize_path("~/../test/path") == "" assert util.sanitize_path("~/../test/path") == ""
def test_raise_if_invalid_filename():
"""Test raise_if_invalid_filename."""
assert util.raise_if_invalid_filename("test") is None
with pytest.raises(ValueError):
util.raise_if_invalid_filename("/test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("..test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("\\test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("\\../test")
def test_raise_if_invalid_path():
"""Test raise_if_invalid_path."""
assert util.raise_if_invalid_path("test/path") is None
with pytest.raises(ValueError):
assert util.raise_if_invalid_path("~test/path")
with pytest.raises(ValueError):
assert util.raise_if_invalid_path("~/../test/path")
def test_slugify(): def test_slugify():
"""Test slugify.""" """Test slugify."""
assert util.slugify("T-!@#$!#@$!$est") == "t_est" assert util.slugify("T-!@#$!#@$!$est") == "t_est"