Add facebox auth (#15439)
* Adds auth * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Remove TIMEOUT * Update test_facebox.py * fix lint * Update facebox.py * Update test_facebox.py * Update facebox.py * Adds check_box_health * Adds test auth * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Ups coverage * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py
This commit is contained in:
parent
47fa928425
commit
61721478f3
2 changed files with 198 additions and 96 deletions
|
@ -17,25 +17,29 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.components.image_processing import (
|
||||
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
|
||||
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
|
||||
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)
|
||||
from homeassistant.const import (
|
||||
CONF_IP_ADDRESS, CONF_PORT, CONF_PASSWORD, CONF_USERNAME,
|
||||
HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_BOUNDING_BOX = 'bounding_box'
|
||||
ATTR_CLASSIFIER = 'classifier'
|
||||
ATTR_IMAGE_ID = 'image_id'
|
||||
ATTR_ID = 'id'
|
||||
ATTR_MATCHED = 'matched'
|
||||
FACEBOX_NAME = 'name'
|
||||
CLASSIFIER = 'facebox'
|
||||
DATA_FACEBOX = 'facebox_classifiers'
|
||||
EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier'
|
||||
FILE_PATH = 'file_path'
|
||||
SERVICE_TEACH_FACE = 'facebox_teach_face'
|
||||
TIMEOUT = 9
|
||||
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||
vol.Required(CONF_IP_ADDRESS): cv.string,
|
||||
vol.Required(CONF_PORT): cv.port,
|
||||
vol.Optional(CONF_USERNAME): cv.string,
|
||||
vol.Optional(CONF_PASSWORD): cv.string,
|
||||
})
|
||||
|
||||
SERVICE_TEACH_SCHEMA = vol.Schema({
|
||||
|
@ -45,6 +49,26 @@ SERVICE_TEACH_SCHEMA = vol.Schema({
|
|||
})
|
||||
|
||||
|
||||
def check_box_health(url, username, password):
|
||||
"""Check the health of the classifier and return its id if healthy."""
|
||||
kwargs = {}
|
||||
if username:
|
||||
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
**kwargs
|
||||
)
|
||||
if response.status_code == HTTP_UNAUTHORIZED:
|
||||
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||
return None
|
||||
if response.status_code == HTTP_OK:
|
||||
return response.json()['hostname']
|
||||
except requests.exceptions.ConnectionError:
|
||||
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||
return None
|
||||
|
||||
|
||||
def encode_image(image):
|
||||
"""base64 encode an image stream."""
|
||||
base64_img = base64.b64encode(image).decode('ascii')
|
||||
|
@ -63,10 +87,10 @@ def parse_faces(api_faces):
|
|||
for entry in api_faces:
|
||||
face = {}
|
||||
if entry['matched']: # This data is only in matched faces.
|
||||
face[ATTR_NAME] = entry['name']
|
||||
face[FACEBOX_NAME] = entry['name']
|
||||
face[ATTR_IMAGE_ID] = entry['id']
|
||||
else: # Lets be explicit.
|
||||
face[ATTR_NAME] = None
|
||||
face[FACEBOX_NAME] = None
|
||||
face[ATTR_IMAGE_ID] = None
|
||||
face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2)
|
||||
face[ATTR_MATCHED] = entry['matched']
|
||||
|
@ -75,17 +99,46 @@ def parse_faces(api_faces):
|
|||
return known_faces
|
||||
|
||||
|
||||
def post_image(url, image):
|
||||
def post_image(url, image, username, password):
|
||||
"""Post an image to the classifier."""
|
||||
kwargs = {}
|
||||
if username:
|
||||
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"base64": encode_image(image)},
|
||||
timeout=TIMEOUT
|
||||
**kwargs
|
||||
)
|
||||
if response.status_code == HTTP_UNAUTHORIZED:
|
||||
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||
return None
|
||||
return response
|
||||
except requests.exceptions.ConnectionError:
|
||||
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||
return None
|
||||
|
||||
|
||||
def teach_file(url, name, file_path, username, password):
|
||||
"""Teach the classifier a name associated with a file."""
|
||||
kwargs = {}
|
||||
if username:
|
||||
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||
try:
|
||||
with open(file_path, 'rb') as open_file:
|
||||
response = requests.post(
|
||||
url,
|
||||
data={FACEBOX_NAME: name, ATTR_ID: file_path},
|
||||
files={'file': open_file},
|
||||
**kwargs
|
||||
)
|
||||
if response.status_code == HTTP_UNAUTHORIZED:
|
||||
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||
elif response.status_code == HTTP_BAD_REQUEST:
|
||||
_LOGGER.error("%s teaching of file %s failed with message:%s",
|
||||
CLASSIFIER, file_path, response.text)
|
||||
except requests.exceptions.ConnectionError:
|
||||
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||
|
||||
|
||||
def valid_file_path(file_path):
|
||||
|
@ -104,13 +157,20 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
|||
if DATA_FACEBOX not in hass.data:
|
||||
hass.data[DATA_FACEBOX] = []
|
||||
|
||||
ip_address = config[CONF_IP_ADDRESS]
|
||||
port = config[CONF_PORT]
|
||||
username = config.get(CONF_USERNAME)
|
||||
password = config.get(CONF_PASSWORD)
|
||||
url_health = "http://{}:{}/healthz".format(ip_address, port)
|
||||
hostname = check_box_health(url_health, username, password)
|
||||
if hostname is None:
|
||||
return
|
||||
|
||||
entities = []
|
||||
for camera in config[CONF_SOURCE]:
|
||||
facebox = FaceClassifyEntity(
|
||||
config[CONF_IP_ADDRESS],
|
||||
config[CONF_PORT],
|
||||
camera[CONF_ENTITY_ID],
|
||||
camera.get(CONF_NAME))
|
||||
ip_address, port, username, password, hostname,
|
||||
camera[CONF_ENTITY_ID], camera.get(CONF_NAME))
|
||||
entities.append(facebox)
|
||||
hass.data[DATA_FACEBOX].append(facebox)
|
||||
add_devices(entities)
|
||||
|
@ -129,33 +189,37 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
|||
classifier.teach(name, file_path)
|
||||
|
||||
hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_TEACH_FACE,
|
||||
service_handle,
|
||||
DOMAIN, SERVICE_TEACH_FACE, service_handle,
|
||||
schema=SERVICE_TEACH_SCHEMA)
|
||||
|
||||
|
||||
class FaceClassifyEntity(ImageProcessingFaceEntity):
|
||||
"""Perform a face classification."""
|
||||
|
||||
def __init__(self, ip, port, camera_entity, name=None):
|
||||
def __init__(self, ip_address, port, username, password, hostname,
|
||||
camera_entity, name=None):
|
||||
"""Init with the API key and model id."""
|
||||
super().__init__()
|
||||
self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
|
||||
self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER)
|
||||
self._url_check = "http://{}:{}/{}/check".format(
|
||||
ip_address, port, CLASSIFIER)
|
||||
self._url_teach = "http://{}:{}/{}/teach".format(
|
||||
ip_address, port, CLASSIFIER)
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._hostname = hostname
|
||||
self._camera = camera_entity
|
||||
if name:
|
||||
self._name = name
|
||||
else:
|
||||
camera_name = split_entity_id(camera_entity)[1]
|
||||
self._name = "{} {}".format(
|
||||
CLASSIFIER, camera_name)
|
||||
self._name = "{} {}".format(CLASSIFIER, camera_name)
|
||||
self._matched = {}
|
||||
|
||||
def process_image(self, image):
|
||||
"""Process an image."""
|
||||
response = post_image(self._url_check, image)
|
||||
if response is not None:
|
||||
response = post_image(
|
||||
self._url_check, image, self._username, self._password)
|
||||
if response:
|
||||
response_json = response.json()
|
||||
if response_json['success']:
|
||||
total_faces = response_json['facesCount']
|
||||
|
@ -173,34 +237,8 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
|
|||
if (not self.hass.config.is_allowed_path(file_path)
|
||||
or not valid_file_path(file_path)):
|
||||
return
|
||||
with open(file_path, 'rb') as open_file:
|
||||
response = requests.post(
|
||||
self._url_teach,
|
||||
data={ATTR_NAME: name, 'id': file_path},
|
||||
files={'file': open_file})
|
||||
|
||||
if response.status_code == 200:
|
||||
self.hass.bus.fire(
|
||||
EVENT_CLASSIFIER_TEACH, {
|
||||
ATTR_CLASSIFIER: CLASSIFIER,
|
||||
ATTR_NAME: name,
|
||||
FILE_PATH: file_path,
|
||||
'success': True,
|
||||
'message': None
|
||||
})
|
||||
|
||||
elif response.status_code == 400:
|
||||
_LOGGER.warning(
|
||||
"%s teaching of file %s failed with message:%s",
|
||||
CLASSIFIER, file_path, response.text)
|
||||
self.hass.bus.fire(
|
||||
EVENT_CLASSIFIER_TEACH, {
|
||||
ATTR_CLASSIFIER: CLASSIFIER,
|
||||
ATTR_NAME: name,
|
||||
FILE_PATH: file_path,
|
||||
'success': False,
|
||||
'message': response.text
|
||||
})
|
||||
teach_file(
|
||||
self._url_teach, name, file_path, self._username, self._password)
|
||||
|
||||
@property
|
||||
def camera_entity(self):
|
||||
|
@ -218,4 +256,5 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
|
|||
return {
|
||||
'matched_faces': self._matched,
|
||||
'total_matched_faces': len(self._matched),
|
||||
'hostname': self._hostname
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue