Fix Facebox face data parsing (#14951)
* Adds parse_faces * Update facebox.py
This commit is contained in:
parent
e014a84215
commit
cccd0deb65
2 changed files with 65 additions and 19 deletions
|
@ -10,16 +10,22 @@ import logging
|
|||
import requests
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_NAME
|
||||
from homeassistant.core import split_entity_id
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.components.image_processing import (
|
||||
PLATFORM_SCHEMA, ImageProcessingFaceEntity, CONF_SOURCE, CONF_ENTITY_ID,
|
||||
CONF_NAME)
|
||||
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
|
||||
CONF_ENTITY_ID, CONF_NAME)
|
||||
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_BOUNDING_BOX = 'bounding_box'
|
||||
ATTR_IMAGE_ID = 'image_id'
|
||||
ATTR_MATCHED = 'matched'
|
||||
CLASSIFIER = 'facebox'
|
||||
TIMEOUT = 9
|
||||
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||
vol.Required(CONF_IP_ADDRESS): cv.string,
|
||||
|
@ -30,7 +36,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
|||
def encode_image(image):
|
||||
"""base64 encode an image stream."""
|
||||
base64_img = base64.b64encode(image).decode('ascii')
|
||||
return {"base64": base64_img}
|
||||
return base64_img
|
||||
|
||||
|
||||
def get_matched_faces(faces):
|
||||
|
@ -39,6 +45,24 @@ def get_matched_faces(faces):
|
|||
for face in faces if face['matched']}
|
||||
|
||||
|
||||
def parse_faces(api_faces):
|
||||
"""Parse the API face data into the format required."""
|
||||
known_faces = []
|
||||
for entry in api_faces:
|
||||
face = {}
|
||||
if entry['matched']: # This data is only in matched faces.
|
||||
face[ATTR_NAME] = entry['name']
|
||||
face[ATTR_IMAGE_ID] = entry['id']
|
||||
else: # Lets be explicit.
|
||||
face[ATTR_NAME] = None
|
||||
face[ATTR_IMAGE_ID] = None
|
||||
face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2)
|
||||
face[ATTR_MATCHED] = entry['matched']
|
||||
face[ATTR_BOUNDING_BOX] = entry['rect']
|
||||
known_faces.append(face)
|
||||
return known_faces
|
||||
|
||||
|
||||
def setup_platform(hass, config, add_devices, discovery_info=None):
|
||||
"""Set up the classifier."""
|
||||
entities = []
|
||||
|
@ -74,18 +98,18 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
|
|||
try:
|
||||
response = requests.post(
|
||||
self._url,
|
||||
json=encode_image(image),
|
||||
timeout=9
|
||||
json={"base64": encode_image(image)},
|
||||
timeout=TIMEOUT
|
||||
).json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||
response['success'] = False
|
||||
|
||||
if response['success']:
|
||||
faces = response['faces']
|
||||
total = response['facesCount']
|
||||
self.process_faces(faces, total)
|
||||
total_faces = response['facesCount']
|
||||
faces = parse_faces(response['faces'])
|
||||
self._matched = get_matched_faces(faces)
|
||||
self.process_faces(faces, total_faces)
|
||||
|
||||
else:
|
||||
self.total_faces = None
|
||||
|
|
|
@ -7,7 +7,7 @@ import requests_mock
|
|||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, CONF_FRIENDLY_NAME,
|
||||
ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME,
|
||||
CONF_IP_ADDRESS, CONF_PORT, STATE_UNKNOWN)
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.components.image_processing as ip
|
||||
|
@ -16,6 +16,7 @@ import homeassistant.components.image_processing.facebox as fb
|
|||
MOCK_IP = '192.168.0.1'
|
||||
MOCK_PORT = '8080'
|
||||
|
||||
# Mock data returned by the facebox API.
|
||||
MOCK_FACE = {'confidence': 0.5812028911604818,
|
||||
'id': 'john.jpg',
|
||||
'matched': True,
|
||||
|
@ -28,6 +29,20 @@ MOCK_JSON = {"facesCount": 1,
|
|||
"faces": [MOCK_FACE]
|
||||
}
|
||||
|
||||
# Faces data after parsing.
|
||||
PARSED_FACES = [{ATTR_NAME: 'John Lennon',
|
||||
fb.ATTR_IMAGE_ID: 'john.jpg',
|
||||
fb.ATTR_CONFIDENCE: 58.12,
|
||||
fb.ATTR_MATCHED: True,
|
||||
fb.ATTR_BOUNDING_BOX: {
|
||||
'height': 75,
|
||||
'left': 63,
|
||||
'top': 262,
|
||||
'width': 74},
|
||||
}]
|
||||
|
||||
MATCHED_FACES = {'John Lennon': 58.12}
|
||||
|
||||
VALID_ENTITY_ID = 'image_processing.facebox_demo_camera'
|
||||
VALID_CONFIG = {
|
||||
ip.DOMAIN: {
|
||||
|
@ -45,12 +60,14 @@ VALID_CONFIG = {
|
|||
|
||||
def test_encode_image():
|
||||
"""Test that binary data is encoded correctly."""
|
||||
assert fb.encode_image(b'test')["base64"] == 'dGVzdA=='
|
||||
assert fb.encode_image(b'test') == 'dGVzdA=='
|
||||
|
||||
|
||||
def test_get_matched_faces():
|
||||
"""Test that matched faces are parsed correctly."""
|
||||
assert fb.get_matched_faces([MOCK_FACE]) == {MOCK_FACE['name']: 0.58}
|
||||
def test_parse_faces():
|
||||
"""Test parsing of raw face data, and generation of matched_faces."""
|
||||
parsed_faces = fb.parse_faces(MOCK_JSON['faces'])
|
||||
assert parsed_faces == PARSED_FACES
|
||||
assert fb.get_matched_faces(parsed_faces) == MATCHED_FACES
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -92,16 +109,21 @@ async def test_process_image(hass, mock_image):
|
|||
|
||||
state = hass.states.get(VALID_ENTITY_ID)
|
||||
assert state.state == '1'
|
||||
assert state.attributes.get('matched_faces') == {MOCK_FACE['name']: 0.58}
|
||||
assert state.attributes.get('matched_faces') == MATCHED_FACES
|
||||
|
||||
MOCK_FACE[ATTR_ENTITY_ID] = VALID_ENTITY_ID # Update.
|
||||
assert state.attributes.get('faces') == [MOCK_FACE]
|
||||
PARSED_FACES[0][ATTR_ENTITY_ID] = VALID_ENTITY_ID # Update.
|
||||
assert state.attributes.get('faces') == PARSED_FACES
|
||||
assert state.attributes.get(CONF_FRIENDLY_NAME) == 'facebox demo_camera'
|
||||
|
||||
assert len(face_events) == 1
|
||||
assert face_events[0].data['name'] == MOCK_FACE['name']
|
||||
assert face_events[0].data['confidence'] == MOCK_FACE['confidence']
|
||||
assert face_events[0].data['entity_id'] == VALID_ENTITY_ID
|
||||
assert face_events[0].data[ATTR_NAME] == PARSED_FACES[0][ATTR_NAME]
|
||||
assert (face_events[0].data[fb.ATTR_CONFIDENCE]
|
||||
== PARSED_FACES[0][fb.ATTR_CONFIDENCE])
|
||||
assert face_events[0].data[ATTR_ENTITY_ID] == VALID_ENTITY_ID
|
||||
assert (face_events[0].data[fb.ATTR_IMAGE_ID] ==
|
||||
PARSED_FACES[0][fb.ATTR_IMAGE_ID])
|
||||
assert (face_events[0].data[fb.ATTR_BOUNDING_BOX] ==
|
||||
PARSED_FACES[0][fb.ATTR_BOUNDING_BOX])
|
||||
|
||||
|
||||
async def test_connection_error(hass, mock_image):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue