Changes to filename and path validation (#45529)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
4739e8a207
commit
b1c2cde40b
10 changed files with 127 additions and 19 deletions
|
@ -9,7 +9,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import HTTP_OK
|
from homeassistant.const import HTTP_OK
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util import sanitize_filename
|
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -70,8 +70,8 @@ def setup(hass, config):
|
||||||
|
|
||||||
overwrite = service.data.get(ATTR_OVERWRITE)
|
overwrite = service.data.get(ATTR_OVERWRITE)
|
||||||
|
|
||||||
if subdir:
|
# Check the path
|
||||||
subdir = sanitize_filename(subdir)
|
raise_if_invalid_path(subdir)
|
||||||
|
|
||||||
final_path = None
|
final_path = None
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ def setup(hass, config):
|
||||||
if not filename:
|
if not filename:
|
||||||
filename = "ha_download"
|
filename = "ha_download"
|
||||||
|
|
||||||
# Remove stuff to ruin paths
|
# Check the filename
|
||||||
filename = sanitize_filename(filename)
|
raise_if_invalid_filename(filename)
|
||||||
|
|
||||||
# Do we want to download to subdir, create if needed
|
# Do we want to download to subdir, create if needed
|
||||||
if subdir:
|
if subdir:
|
||||||
|
@ -148,6 +148,16 @@ def setup(hass, config):
|
||||||
{"url": url, "filename": filename},
|
{"url": url, "filename": filename},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove file if we started downloading but failed
|
||||||
|
if final_path and os.path.isfile(final_path):
|
||||||
|
os.remove(final_path)
|
||||||
|
except ValueError:
|
||||||
|
_LOGGER.exception("Invalid value")
|
||||||
|
hass.bus.fire(
|
||||||
|
f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}",
|
||||||
|
{"url": url, "filename": filename},
|
||||||
|
)
|
||||||
|
|
||||||
# Remove file if we started downloading but failed
|
# Remove file if we started downloading but failed
|
||||||
if final_path and os.path.isfile(final_path):
|
if final_path and os.path.isfile(final_path):
|
||||||
os.remove(final_path)
|
os.remove(final_path)
|
||||||
|
|
|
@ -12,7 +12,6 @@ from homeassistant.helpers import collection, config_validation as cv
|
||||||
from homeassistant.helpers.service import async_register_admin_service
|
from homeassistant.helpers.service import async_register_admin_service
|
||||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType
|
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType
|
||||||
from homeassistant.loader import async_get_integration
|
from homeassistant.loader import async_get_integration
|
||||||
from homeassistant.util import sanitize_path
|
|
||||||
|
|
||||||
from . import dashboard, resources, websocket
|
from . import dashboard, resources, websocket
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -47,7 +46,7 @@ YAML_DASHBOARD_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
**DASHBOARD_BASE_CREATE_FIELDS,
|
**DASHBOARD_BASE_CREATE_FIELDS,
|
||||||
vol.Required(CONF_MODE): MODE_YAML,
|
vol.Required(CONF_MODE): MODE_YAML,
|
||||||
vol.Required(CONF_FILENAME): vol.All(cv.string, sanitize_path),
|
vol.Required(CONF_FILENAME): cv.path,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
|
||||||
from homeassistant.components.media_player.errors import BrowseError
|
from homeassistant.components.media_player.errors import BrowseError
|
||||||
from homeassistant.components.media_source.error import Unresolvable
|
from homeassistant.components.media_source.error import Unresolvable
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import sanitize_path
|
from homeassistant.util import raise_if_invalid_filename
|
||||||
|
|
||||||
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
|
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
|
||||||
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
|
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
|
||||||
|
@ -50,8 +50,10 @@ class LocalSource(MediaSource):
|
||||||
if source_dir_id not in self.hass.config.media_dirs:
|
if source_dir_id not in self.hass.config.media_dirs:
|
||||||
raise Unresolvable("Unknown source directory.")
|
raise Unresolvable("Unknown source directory.")
|
||||||
|
|
||||||
if location != sanitize_path(location):
|
try:
|
||||||
raise Unresolvable("Invalid path.")
|
raise_if_invalid_filename(location)
|
||||||
|
except ValueError as err:
|
||||||
|
raise Unresolvable("Invalid path.") from err
|
||||||
|
|
||||||
return source_dir_id, location
|
return source_dir_id, location
|
||||||
|
|
||||||
|
@ -189,8 +191,10 @@ class LocalMediaView(HomeAssistantView):
|
||||||
self, request: web.Request, source_dir_id: str, location: str
|
self, request: web.Request, source_dir_id: str, location: str
|
||||||
) -> web.FileResponse:
|
) -> web.FileResponse:
|
||||||
"""Start a GET request."""
|
"""Start a GET request."""
|
||||||
if location != sanitize_path(location):
|
try:
|
||||||
raise web.HTTPNotFound()
|
raise_if_invalid_filename(location)
|
||||||
|
except ValueError as err:
|
||||||
|
raise web.HTTPBadRequest() from err
|
||||||
|
|
||||||
if source_dir_id not in self.hass.config.media_dirs:
|
if source_dir_id not in self.hass.config.media_dirs:
|
||||||
raise web.HTTPNotFound()
|
raise web.HTTPNotFound()
|
||||||
|
|
|
@ -23,7 +23,7 @@ from homeassistant.const import SERVICE_RELOAD
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.service import async_set_service_schema
|
from homeassistant.helpers.service import async_set_service_schema
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util import sanitize_filename
|
from homeassistant.util import raise_if_invalid_filename
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
from homeassistant.util.yaml.loader import load_yaml
|
from homeassistant.util.yaml.loader import load_yaml
|
||||||
|
|
||||||
|
@ -135,7 +135,8 @@ def discover_scripts(hass):
|
||||||
def execute_script(hass, name, data=None):
|
def execute_script(hass, name, data=None):
|
||||||
"""Execute a script."""
|
"""Execute a script."""
|
||||||
filename = f"{name}.py"
|
filename = f"{name}.py"
|
||||||
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil:
|
raise_if_invalid_filename(filename)
|
||||||
|
with open(hass.config.path(FOLDER, filename)) as fil:
|
||||||
source = fil.read()
|
source = fil.read()
|
||||||
execute(hass, filename, source, data)
|
execute(hass, filename, source, data)
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ from homeassistant.helpers import (
|
||||||
template as template_helper,
|
template as template_helper,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.logging import KeywordStyleAdapter
|
from homeassistant.helpers.logging import KeywordStyleAdapter
|
||||||
from homeassistant.util import sanitize_path, slugify as util_slugify
|
from homeassistant.util import raise_if_invalid_path, slugify as util_slugify
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -118,8 +118,10 @@ def path(value: Any) -> str:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise vol.Invalid("Expected a string")
|
raise vol.Invalid("Expected a string")
|
||||||
|
|
||||||
if sanitize_path(value) != value:
|
try:
|
||||||
raise vol.Invalid("Invalid path")
|
raise_if_invalid_path(value)
|
||||||
|
except ValueError as err:
|
||||||
|
raise vol.Invalid("Invalid path") from err
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Deprecation helpers for Home Assistant."""
|
"""Deprecation helpers for Home Assistant."""
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
@ -73,3 +74,25 @@ def get_deprecated(
|
||||||
)
|
)
|
||||||
return config.get(old_name)
|
return config.get(old_name)
|
||||||
return config.get(new_name, default)
|
return config.get(new_name, default)
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated_function(replacement: str) -> Callable[..., Callable]:
|
||||||
|
"""Mark function as deprecated and provide a replacement function to be used instead."""
|
||||||
|
|
||||||
|
def deprecated_decorator(func: Callable) -> Callable:
|
||||||
|
"""Decorate function as deprecated."""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""Wrap for the original function."""
|
||||||
|
logger = logging.getLogger(func.__module__)
|
||||||
|
logger.warning(
|
||||||
|
"%s is a deprecated function. Use %s instead",
|
||||||
|
func.__name__,
|
||||||
|
replacement,
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return deprecated_func
|
||||||
|
|
||||||
|
return deprecated_decorator
|
||||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
||||||
|
|
||||||
import slugify as unicode_slug
|
import slugify as unicode_slug
|
||||||
|
|
||||||
|
from ..helpers.deprecation import deprecated_function
|
||||||
from .dt import as_local, utcnow
|
from .dt import as_local, utcnow
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -32,6 +33,27 @@ RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
|
||||||
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
|
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
|
||||||
|
|
||||||
|
|
||||||
|
def raise_if_invalid_filename(filename: str) -> None:
|
||||||
|
"""
|
||||||
|
Check if a filename is valid.
|
||||||
|
|
||||||
|
Raises a ValueError if the filename is invalid.
|
||||||
|
"""
|
||||||
|
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
|
||||||
|
raise ValueError(f"{filename} is not a safe filename")
|
||||||
|
|
||||||
|
|
||||||
|
def raise_if_invalid_path(path: str) -> None:
|
||||||
|
"""
|
||||||
|
Check if a path is valid.
|
||||||
|
|
||||||
|
Raises a ValueError if the path is invalid.
|
||||||
|
"""
|
||||||
|
if RE_SANITIZE_PATH.sub("", path) != path:
|
||||||
|
raise ValueError(f"{path} is not a safe path")
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated_function(replacement="raise_if_invalid_filename")
|
||||||
def sanitize_filename(filename: str) -> str:
|
def sanitize_filename(filename: str) -> str:
|
||||||
"""Check if a filename is safe.
|
"""Check if a filename is safe.
|
||||||
|
|
||||||
|
@ -47,6 +69,7 @@ def sanitize_filename(filename: str) -> str:
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated_function(replacement="raise_if_invalid_path")
|
||||||
def sanitize_path(path: str) -> str:
|
def sanitize_path(path: str) -> str:
|
||||||
"""Check if a path is safe.
|
"""Check if a path is safe.
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ async def test_async_browse_media(hass):
|
||||||
await media_source.async_browse_media(
|
await media_source.async_browse_media(
|
||||||
hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist"
|
hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist"
|
||||||
)
|
)
|
||||||
assert str(excinfo.value) == "Path does not exist."
|
assert str(excinfo.value) == "Invalid path."
|
||||||
|
|
||||||
# Test browse file
|
# Test browse file
|
||||||
with pytest.raises(media_source.BrowseError) as excinfo:
|
with pytest.raises(media_source.BrowseError) as excinfo:
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
"""Test deprecation helpers."""
|
"""Test deprecation helpers."""
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from homeassistant.helpers.deprecation import deprecated_substitute, get_deprecated
|
from homeassistant.helpers.deprecation import (
|
||||||
|
deprecated_function,
|
||||||
|
deprecated_substitute,
|
||||||
|
get_deprecated,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockBaseClass:
|
class MockBaseClass:
|
||||||
|
@ -78,3 +82,17 @@ def test_config_get_deprecated_new(mock_get_logger):
|
||||||
config = {"new_name": True}
|
config = {"new_name": True}
|
||||||
assert get_deprecated(config, "new_name", "old_name") is True
|
assert get_deprecated(config, "new_name", "old_name") is True
|
||||||
assert not mock_logger.warning.called
|
assert not mock_logger.warning.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecated_function(caplog):
|
||||||
|
"""Test deprecated_function decorator."""
|
||||||
|
|
||||||
|
@deprecated_function("new_function")
|
||||||
|
def mock_deprecated_function():
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_deprecated_function()
|
||||||
|
assert (
|
||||||
|
"mock_deprecated_function is a deprecated function. Use new_function instead"
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
|
@ -24,6 +24,34 @@ def test_sanitize_path():
|
||||||
assert util.sanitize_path("~/../test/path") == ""
|
assert util.sanitize_path("~/../test/path") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_if_invalid_filename():
|
||||||
|
"""Test raise_if_invalid_filename."""
|
||||||
|
assert util.raise_if_invalid_filename("test") is None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
util.raise_if_invalid_filename("/test")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
util.raise_if_invalid_filename("..test")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
util.raise_if_invalid_filename("\\test")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
util.raise_if_invalid_filename("\\../test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_if_invalid_path():
|
||||||
|
"""Test raise_if_invalid_path."""
|
||||||
|
assert util.raise_if_invalid_path("test/path") is None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert util.raise_if_invalid_path("~test/path")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert util.raise_if_invalid_path("~/../test/path")
|
||||||
|
|
||||||
|
|
||||||
def test_slugify():
|
def test_slugify():
|
||||||
"""Test slugify."""
|
"""Test slugify."""
|
||||||
assert util.slugify("T-!@#$!#@$!$est") == "t_est"
|
assert util.slugify("T-!@#$!#@$!$est") == "t_est"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue