From 61721478f3715ad66d980c4b2682a1414204dea3 Mon Sep 17 00:00:00 2001 From: Robin Date: Tue, 7 Aug 2018 06:30:36 +0100 Subject: [PATCH] Add facebox auth (#15439) * Adds auth * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Remove TIMEOUT * Update test_facebox.py * fix lint * Update facebox.py * Update test_facebox.py * Update facebox.py * Adds check_box_health * Adds test auth * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Ups coverage * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py --- .../components/image_processing/facebox.py | 137 +++++++++------ .../image_processing/test_facebox.py | 157 ++++++++++++------ 2 files changed, 198 insertions(+), 96 deletions(-) diff --git a/homeassistant/components/image_processing/facebox.py b/homeassistant/components/image_processing/facebox.py index c863f804513..e5ce0b825d0 100644 --- a/homeassistant/components/image_processing/facebox.py +++ b/homeassistant/components/image_processing/facebox.py @@ -17,25 +17,29 @@ import homeassistant.helpers.config_validation as cv from homeassistant.components.image_processing import ( PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE, CONF_ENTITY_ID, CONF_NAME, DOMAIN) -from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT) +from homeassistant.const import ( + CONF_IP_ADDRESS, CONF_PORT, CONF_PASSWORD, CONF_USERNAME, + HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED) _LOGGER = logging.getLogger(__name__) ATTR_BOUNDING_BOX = 'bounding_box' ATTR_CLASSIFIER = 'classifier' ATTR_IMAGE_ID = 'image_id' +ATTR_ID = 'id' ATTR_MATCHED = 'matched' +FACEBOX_NAME = 'name' CLASSIFIER = 'facebox' DATA_FACEBOX = 'facebox_classifiers' -EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier' FILE_PATH = 'file_path' SERVICE_TEACH_FACE = 'facebox_teach_face' -TIMEOUT = 9 PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_IP_ADDRESS): cv.string, vol.Required(CONF_PORT): cv.port, + vol.Optional(CONF_USERNAME): cv.string, + vol.Optional(CONF_PASSWORD): cv.string, }) SERVICE_TEACH_SCHEMA = vol.Schema({ @@ -45,6 +49,26 @@ SERVICE_TEACH_SCHEMA = vol.Schema({ }) +def check_box_health(url, username, password): + """Check the health of the classifier and return its id if healthy.""" + kwargs = {} + if username: + kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) + try: + response = requests.get( + url, + **kwargs + ) + if response.status_code == HTTP_UNAUTHORIZED: + _LOGGER.error("AuthenticationError on %s", CLASSIFIER) + return None + if response.status_code == HTTP_OK: + return response.json()['hostname'] + except requests.exceptions.ConnectionError: + _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) + return None + + def encode_image(image): """base64 encode an image stream.""" base64_img = base64.b64encode(image).decode('ascii') @@ -63,10 +87,10 @@ def parse_faces(api_faces): for entry in api_faces: face = {} if entry['matched']: # This data is only in matched faces. - face[ATTR_NAME] = entry['name'] + face[FACEBOX_NAME] = entry['name'] face[ATTR_IMAGE_ID] = entry['id'] else: # Lets be explicit. - face[ATTR_NAME] = None + face[FACEBOX_NAME] = None face[ATTR_IMAGE_ID] = None face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2) face[ATTR_MATCHED] = entry['matched'] @@ -75,17 +99,46 @@ def parse_faces(api_faces): return known_faces -def post_image(url, image): +def post_image(url, image, username, password): """Post an image to the classifier.""" + kwargs = {} + if username: + kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) try: response = requests.post( url, json={"base64": encode_image(image)}, - timeout=TIMEOUT + **kwargs ) + if response.status_code == HTTP_UNAUTHORIZED: + _LOGGER.error("AuthenticationError on %s", CLASSIFIER) + return None return response except requests.exceptions.ConnectionError: _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) + return None + + +def teach_file(url, name, file_path, username, password): + """Teach the classifier a name associated with a file.""" + kwargs = {} + if username: + kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) + try: + with open(file_path, 'rb') as open_file: + response = requests.post( + url, + data={FACEBOX_NAME: name, ATTR_ID: file_path}, + files={'file': open_file}, + **kwargs + ) + if response.status_code == HTTP_UNAUTHORIZED: + _LOGGER.error("AuthenticationError on %s", CLASSIFIER) + elif response.status_code == HTTP_BAD_REQUEST: + _LOGGER.error("%s teaching of file %s failed with message:%s", + CLASSIFIER, file_path, response.text) + except requests.exceptions.ConnectionError: + _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) def valid_file_path(file_path): @@ -104,13 +157,20 @@ def setup_platform(hass, config, add_devices, discovery_info=None): if DATA_FACEBOX not in hass.data: hass.data[DATA_FACEBOX] = [] + ip_address = config[CONF_IP_ADDRESS] + port = config[CONF_PORT] + username = config.get(CONF_USERNAME) + password = config.get(CONF_PASSWORD) + url_health = "http://{}:{}/healthz".format(ip_address, port) + hostname = check_box_health(url_health, username, password) + if hostname is None: + return + entities = [] for camera in config[CONF_SOURCE]: facebox = FaceClassifyEntity( - config[CONF_IP_ADDRESS], - config[CONF_PORT], - camera[CONF_ENTITY_ID], - camera.get(CONF_NAME)) + ip_address, port, username, password, hostname, + camera[CONF_ENTITY_ID], camera.get(CONF_NAME)) entities.append(facebox) hass.data[DATA_FACEBOX].append(facebox) add_devices(entities) @@ -129,33 +189,37 @@ def setup_platform(hass, config, add_devices, discovery_info=None): classifier.teach(name, file_path) hass.services.register( - DOMAIN, - SERVICE_TEACH_FACE, - service_handle, + DOMAIN, SERVICE_TEACH_FACE, service_handle, schema=SERVICE_TEACH_SCHEMA) class FaceClassifyEntity(ImageProcessingFaceEntity): """Perform a face classification.""" - def __init__(self, ip, port, camera_entity, name=None): + def __init__(self, ip_address, port, username, password, hostname, + camera_entity, name=None): """Init with the API key and model id.""" super().__init__() - self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER) - self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER) + self._url_check = "http://{}:{}/{}/check".format( + ip_address, port, CLASSIFIER) + self._url_teach = "http://{}:{}/{}/teach".format( + ip_address, port, CLASSIFIER) + self._username = username + self._password = password + self._hostname = hostname self._camera = camera_entity if name: self._name = name else: camera_name = split_entity_id(camera_entity)[1] - self._name = "{} {}".format( - CLASSIFIER, camera_name) + self._name = "{} {}".format(CLASSIFIER, camera_name) self._matched = {} def process_image(self, image): """Process an image.""" - response = post_image(self._url_check, image) - if response is not None: + response = post_image( + self._url_check, image, self._username, self._password) + if response: response_json = response.json() if response_json['success']: total_faces = response_json['facesCount'] @@ -173,34 +237,8 @@ class FaceClassifyEntity(ImageProcessingFaceEntity): if (not self.hass.config.is_allowed_path(file_path) or not valid_file_path(file_path)): return - with open(file_path, 'rb') as open_file: - response = requests.post( - self._url_teach, - data={ATTR_NAME: name, 'id': file_path}, - files={'file': open_file}) - - if response.status_code == 200: - self.hass.bus.fire( - EVENT_CLASSIFIER_TEACH, { - ATTR_CLASSIFIER: CLASSIFIER, - ATTR_NAME: name, - FILE_PATH: file_path, - 'success': True, - 'message': None - }) - - elif response.status_code == 400: - _LOGGER.warning( - "%s teaching of file %s failed with message:%s", - CLASSIFIER, file_path, response.text) - self.hass.bus.fire( - EVENT_CLASSIFIER_TEACH, { - ATTR_CLASSIFIER: CLASSIFIER, - ATTR_NAME: name, - FILE_PATH: file_path, - 'success': False, - 'message': response.text - }) + teach_file( + self._url_teach, name, file_path, self._username, self._password) @property def camera_entity(self): @@ -218,4 +256,5 @@ class FaceClassifyEntity(ImageProcessingFaceEntity): return { 'matched_faces': self._matched, 'total_matched_faces': len(self._matched), + 'hostname': self._hostname } diff --git a/tests/components/image_processing/test_facebox.py b/tests/components/image_processing/test_facebox.py index 86811f94db3..b1d9fb8bf79 100644 --- a/tests/components/image_processing/test_facebox.py +++ b/tests/components/image_processing/test_facebox.py @@ -7,19 +7,19 @@ import requests_mock from homeassistant.core import callback from homeassistant.const import ( - ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME, - CONF_IP_ADDRESS, CONF_PORT, STATE_UNKNOWN) + ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME, CONF_PASSWORD, + CONF_USERNAME, CONF_IP_ADDRESS, CONF_PORT, + HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED, STATE_UNKNOWN) from homeassistant.setup import async_setup_component import homeassistant.components.image_processing as ip import homeassistant.components.image_processing.facebox as fb -# pylint: disable=redefined-outer-name - MOCK_IP = '192.168.0.1' MOCK_PORT = '8080' # Mock data returned by the facebox API. -MOCK_ERROR = "No face found" +MOCK_BOX_ID = 'b893cc4f7fd6' +MOCK_ERROR_NO_FACE = "No face found" MOCK_FACE = {'confidence': 0.5812028911604818, 'id': 'john.jpg', 'matched': True, @@ -28,14 +28,21 @@ MOCK_FACE = {'confidence': 0.5812028911604818, MOCK_FILE_PATH = '/images/mock.jpg' +MOCK_HEALTH = {'success': True, + 'hostname': 'b893cc4f7fd6', + 'metadata': {'boxname': 'facebox', 'build': 'development'}, + 'errors': []} + MOCK_JSON = {"facesCount": 1, "success": True, "faces": [MOCK_FACE]} MOCK_NAME = 'mock_name' +MOCK_USERNAME = 'mock_username' +MOCK_PASSWORD = 'mock_password' # Faces data after parsing. -PARSED_FACES = [{ATTR_NAME: 'John Lennon', +PARSED_FACES = [{fb.FACEBOX_NAME: 'John Lennon', fb.ATTR_IMAGE_ID: 'john.jpg', fb.ATTR_CONFIDENCE: 58.12, fb.ATTR_MATCHED: True, @@ -62,6 +69,15 @@ VALID_CONFIG = { } +@pytest.fixture +def mock_healthybox(): + """Mock fb.check_box_health.""" + check_box_health = 'homeassistant.components.image_processing.' \ + 'facebox.check_box_health' + with patch(check_box_health, return_value=MOCK_BOX_ID) as _mock_healthybox: + yield _mock_healthybox + + @pytest.fixture def mock_isfile(): """Mock os.path.isfile.""" @@ -70,6 +86,14 @@ def mock_isfile(): yield _mock_isfile +@pytest.fixture +def mock_image(): + """Return a mock camera image.""" + with patch('homeassistant.components.camera.demo.DemoCamera.camera_image', + return_value=b'Test') as image: + yield image + + @pytest.fixture def mock_open_file(): """Mock open.""" @@ -79,6 +103,22 @@ def mock_open_file(): yield _mock_open +def test_check_box_health(caplog): + """Test check box health.""" + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/healthz".format(MOCK_IP, MOCK_PORT) + mock_req.get(url, status_code=HTTP_OK, json=MOCK_HEALTH) + assert fb.check_box_health(url, 'user', 'pass') == MOCK_BOX_ID + + mock_req.get(url, status_code=HTTP_UNAUTHORIZED) + assert fb.check_box_health(url, None, None) is None + assert "AuthenticationError on facebox" in caplog.text + + mock_req.get(url, exc=requests.exceptions.ConnectTimeout) + fb.check_box_health(url, None, None) + assert "ConnectionError: Is facebox running?" in caplog.text + + def test_encode_image(): """Test that binary data is encoded correctly.""" assert fb.encode_image(b'test') == 'dGVzdA==' @@ -100,22 +140,24 @@ def test_valid_file_path(): assert not fb.valid_file_path('test_path') -@pytest.fixture -def mock_image(): - """Return a mock camera image.""" - with patch('homeassistant.components.camera.demo.DemoCamera.camera_image', - return_value=b'Test') as image: - yield image - - -async def test_setup_platform(hass): +async def test_setup_platform(hass, mock_healthybox): """Setup platform with one entity.""" await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG) assert hass.states.get(VALID_ENTITY_ID) -async def test_process_image(hass, mock_image): - """Test processing of an image.""" +async def test_setup_platform_with_auth(hass, mock_healthybox): + """Setup platform with one entity and auth.""" + valid_config_auth = VALID_CONFIG.copy() + valid_config_auth[ip.DOMAIN][CONF_USERNAME] = MOCK_USERNAME + valid_config_auth[ip.DOMAIN][CONF_PASSWORD] = MOCK_PASSWORD + + await async_setup_component(hass, ip.DOMAIN, valid_config_auth) + assert hass.states.get(VALID_ENTITY_ID) + + +async def test_process_image(hass, mock_healthybox, mock_image): + """Test successful processing of an image.""" await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG) assert hass.states.get(VALID_ENTITY_ID) @@ -157,11 +199,12 @@ async def test_process_image(hass, mock_image): PARSED_FACES[0][fb.ATTR_BOUNDING_BOX]) -async def test_connection_error(hass, mock_image): - """Test connection error.""" +async def test_process_image_errors(hass, mock_healthybox, mock_image, caplog): + """Test process_image errors.""" await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG) assert hass.states.get(VALID_ENTITY_ID) + # Test connection error. with requests_mock.Mocker() as mock_req: url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT) mock_req.register_uri( @@ -171,34 +214,40 @@ async def test_connection_error(hass, mock_image): ip.SERVICE_SCAN, service_data=data) await hass.async_block_till_done() + assert "ConnectionError: Is facebox running?" in caplog.text state = hass.states.get(VALID_ENTITY_ID) assert state.state == STATE_UNKNOWN assert state.attributes.get('faces') == [] assert state.attributes.get('matched_faces') == {} + # Now test with bad auth. + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT) + mock_req.register_uri( + 'POST', url, status_code=HTTP_UNAUTHORIZED) + data = {ATTR_ENTITY_ID: VALID_ENTITY_ID} + await hass.services.async_call(ip.DOMAIN, + ip.SERVICE_SCAN, + service_data=data) + await hass.async_block_till_done() + assert "AuthenticationError on facebox" in caplog.text -async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file): + +async def test_teach_service( + hass, mock_healthybox, mock_image, + mock_isfile, mock_open_file, caplog): """Test teaching of facebox.""" await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG) assert hass.states.get(VALID_ENTITY_ID) - teach_events = [] - - @callback - def mock_teach_event(event): - """Mock event.""" - teach_events.append(event) - - hass.bus.async_listen( - 'image_processing.teach_classifier', mock_teach_event) - # Patch out 'is_allowed_path' as the mock files aren't allowed hass.config.is_allowed_path = Mock(return_value=True) + # Test successful teach. with requests_mock.Mocker() as mock_req: url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) - mock_req.post(url, status_code=200) + mock_req.post(url, status_code=HTTP_OK) data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, ATTR_NAME: MOCK_NAME, fb.FILE_PATH: MOCK_FILE_PATH} @@ -206,17 +255,10 @@ async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file): ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data) await hass.async_block_till_done() - assert len(teach_events) == 1 - assert teach_events[0].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER - assert teach_events[0].data[ATTR_NAME] == MOCK_NAME - assert teach_events[0].data[fb.FILE_PATH] == MOCK_FILE_PATH - assert teach_events[0].data['success'] - assert not teach_events[0].data['message'] - - # Now test the failed teaching. + # Now test with bad auth. with requests_mock.Mocker() as mock_req: url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) - mock_req.post(url, status_code=400, text=MOCK_ERROR) + mock_req.post(url, status_code=HTTP_UNAUTHORIZED) data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, ATTR_NAME: MOCK_NAME, fb.FILE_PATH: MOCK_FILE_PATH} @@ -224,16 +266,37 @@ async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file): fb.SERVICE_TEACH_FACE, service_data=data) await hass.async_block_till_done() + assert "AuthenticationError on facebox" in caplog.text - assert len(teach_events) == 2 - assert teach_events[1].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER - assert teach_events[1].data[ATTR_NAME] == MOCK_NAME - assert teach_events[1].data[fb.FILE_PATH] == MOCK_FILE_PATH - assert not teach_events[1].data['success'] - assert teach_events[1].data['message'] == MOCK_ERROR + # Now test the failed teaching. + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) + mock_req.post(url, status_code=HTTP_BAD_REQUEST, + text=MOCK_ERROR_NO_FACE) + data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, + ATTR_NAME: MOCK_NAME, + fb.FILE_PATH: MOCK_FILE_PATH} + await hass.services.async_call(ip.DOMAIN, + fb.SERVICE_TEACH_FACE, + service_data=data) + await hass.async_block_till_done() + assert MOCK_ERROR_NO_FACE in caplog.text + + # Now test connection error. + with requests_mock.Mocker() as mock_req: + url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT) + mock_req.post(url, exc=requests.exceptions.ConnectTimeout) + data = {ATTR_ENTITY_ID: VALID_ENTITY_ID, + ATTR_NAME: MOCK_NAME, + fb.FILE_PATH: MOCK_FILE_PATH} + await hass.services.async_call(ip.DOMAIN, + fb.SERVICE_TEACH_FACE, + service_data=data) + await hass.async_block_till_done() + assert "ConnectionError: Is facebox running?" in caplog.text -async def test_setup_platform_with_name(hass): +async def test_setup_platform_with_name(hass, mock_healthybox): """Setup platform with one entity and a name.""" named_entity_id = 'image_processing.{}'.format(MOCK_NAME)