Upgrade to TensorFlow 2 (#38384)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Martin Hjelmare <marhje52@gmail.com> Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
7e34c2582f
commit
3546a82cfb
6 changed files with 115 additions and 48 deletions
|
@ -3,9 +3,11 @@ import io
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from PIL import Image, ImageDraw, UnidentifiedImageError
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.image_processing import (
|
||||
|
@ -16,16 +18,21 @@ from homeassistant.components.image_processing import (
|
|||
PLATFORM_SCHEMA,
|
||||
ImageProcessingEntity,
|
||||
)
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.core import split_entity_id
|
||||
from homeassistant.helpers import template
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util.pil import draw_box
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
|
||||
DOMAIN = "tensorflow"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_MATCHES = "matches"
|
||||
ATTR_SUMMARY = "summary"
|
||||
ATTR_TOTAL_MATCHES = "total_matches"
|
||||
ATTR_PROCESS_TIME = "process_time"
|
||||
|
||||
CONF_AREA = "area"
|
||||
CONF_BOTTOM = "bottom"
|
||||
|
@ -34,6 +41,7 @@ CONF_CATEGORY = "category"
|
|||
CONF_FILE_OUT = "file_out"
|
||||
CONF_GRAPH = "graph"
|
||||
CONF_LABELS = "labels"
|
||||
CONF_LABEL_OFFSET = "label_offset"
|
||||
CONF_LEFT = "left"
|
||||
CONF_MODEL = "model"
|
||||
CONF_MODEL_DIR = "model_dir"
|
||||
|
@ -58,12 +66,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
|
||||
vol.Required(CONF_MODEL): vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_GRAPH): cv.isfile,
|
||||
vol.Required(CONF_GRAPH): cv.isdir,
|
||||
vol.Optional(CONF_AREA): AREA_SCHEMA,
|
||||
vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
|
||||
cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
|
||||
),
|
||||
vol.Optional(CONF_LABELS): cv.isfile,
|
||||
vol.Optional(CONF_LABEL_OFFSET, default=1): int,
|
||||
vol.Optional(CONF_MODEL_DIR): cv.isdir,
|
||||
}
|
||||
),
|
||||
|
@ -71,17 +80,40 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
def get_model_detection_function(model):
|
||||
"""Get a tf.function for detection."""
|
||||
|
||||
@tf.function
|
||||
def detect_fn(image):
|
||||
"""Detect objects in image."""
|
||||
|
||||
image, shapes = model.preprocess(image)
|
||||
prediction_dict = model.predict(image, shapes)
|
||||
detections = model.postprocess(prediction_dict, shapes)
|
||||
|
||||
return detections
|
||||
|
||||
return detect_fn
|
||||
|
||||
|
||||
def setup_platform(hass, config, add_entities, discovery_info=None):
|
||||
"""Set up the TensorFlow image processing platform."""
|
||||
model_config = config.get(CONF_MODEL)
|
||||
model_config = config[CONF_MODEL]
|
||||
model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow")
|
||||
labels = model_config.get(CONF_LABELS) or hass.config.path(
|
||||
"tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt"
|
||||
)
|
||||
checkpoint = os.path.join(model_config[CONF_GRAPH], "checkpoint")
|
||||
pipeline_config = os.path.join(model_config[CONF_GRAPH], "pipeline.config")
|
||||
|
||||
# Make sure locations exist
|
||||
if not os.path.isdir(model_dir) or not os.path.exists(labels):
|
||||
_LOGGER.error("Unable to locate tensorflow models or label map")
|
||||
if (
|
||||
not os.path.isdir(model_dir)
|
||||
or not os.path.isdir(checkpoint)
|
||||
or not os.path.exists(pipeline_config)
|
||||
or not os.path.exists(labels)
|
||||
):
|
||||
_LOGGER.error("Unable to locate tensorflow model or label map")
|
||||
return
|
||||
|
||||
# append custom model path to sys.path
|
||||
|
@ -89,18 +121,17 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
|
||||
try:
|
||||
# Verify that the TensorFlow Object Detection API is pre-installed
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
# These imports shouldn't be moved to the top, because they depend on code from the model_dir.
|
||||
# (The model_dir is created during the manual setup process. See integration docs.)
|
||||
import tensorflow as tf # pylint: disable=import-outside-toplevel
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from object_detection.utils import label_map_util
|
||||
from object_detection.utils import config_util, label_map_util
|
||||
from object_detection.builders import model_builder
|
||||
except ImportError:
|
||||
_LOGGER.error(
|
||||
"No TensorFlow Object Detection library found! Install or compile "
|
||||
"for your system following instructions here: "
|
||||
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md"
|
||||
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#installation"
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -113,22 +144,45 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
"PIL at reduced resolution"
|
||||
)
|
||||
|
||||
# Set up Tensorflow graph, session, and label map to pass to processor
|
||||
# pylint: disable=no-member
|
||||
detection_graph = tf.Graph()
|
||||
with detection_graph.as_default():
|
||||
od_graph_def = tf.GraphDef()
|
||||
with tf.gfile.GFile(model_config.get(CONF_GRAPH), "rb") as fid:
|
||||
serialized_graph = fid.read()
|
||||
od_graph_def.ParseFromString(serialized_graph)
|
||||
tf.import_graph_def(od_graph_def, name="")
|
||||
hass.data[DOMAIN] = {CONF_MODEL: None}
|
||||
|
||||
session = tf.Session(graph=detection_graph)
|
||||
label_map = label_map_util.load_labelmap(labels)
|
||||
categories = label_map_util.convert_label_map_to_categories(
|
||||
label_map, max_num_classes=90, use_display_name=True
|
||||
def tensorflow_hass_start(_event):
|
||||
"""Set up TensorFlow model on hass start."""
|
||||
start = time.perf_counter()
|
||||
|
||||
# Load pipeline config and build a detection model
|
||||
pipeline_configs = config_util.get_configs_from_pipeline_file(pipeline_config)
|
||||
detection_model = model_builder.build(
|
||||
model_config=pipeline_configs["model"], is_training=False
|
||||
)
|
||||
|
||||
# Restore checkpoint
|
||||
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
|
||||
ckpt.restore(os.path.join(checkpoint, "ckpt-0")).expect_partial()
|
||||
|
||||
_LOGGER.debug(
|
||||
"Model checkpoint restore took %d seconds", time.perf_counter() - start
|
||||
)
|
||||
|
||||
model = get_model_detection_function(detection_model)
|
||||
|
||||
# Preload model cache with empty image tensor
|
||||
inp = np.zeros([2160, 3840, 3], dtype=np.uint8)
|
||||
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
|
||||
input_tensor = tf.convert_to_tensor(inp, dtype=tf.float32)
|
||||
# The model expects a batch of images, so add an axis with `tf.newaxis`.
|
||||
input_tensor = input_tensor[tf.newaxis, ...]
|
||||
# Run inference
|
||||
model(input_tensor)
|
||||
|
||||
_LOGGER.debug("Model load took %d seconds", time.perf_counter() - start)
|
||||
hass.data[DOMAIN][CONF_MODEL] = model
|
||||
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
|
||||
|
||||
category_index = label_map_util.create_category_index_from_labelmap(
|
||||
labels, use_display_name=True
|
||||
)
|
||||
category_index = label_map_util.create_category_index(categories)
|
||||
|
||||
entities = []
|
||||
|
||||
|
@ -138,8 +192,6 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
hass,
|
||||
camera[CONF_ENTITY_ID],
|
||||
camera.get(CONF_NAME),
|
||||
session,
|
||||
detection_graph,
|
||||
category_index,
|
||||
config,
|
||||
)
|
||||
|
@ -152,14 +204,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
"""Representation of an TensorFlow image processor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass,
|
||||
camera_entity,
|
||||
name,
|
||||
session,
|
||||
detection_graph,
|
||||
category_index,
|
||||
config,
|
||||
self, hass, camera_entity, name, category_index, config,
|
||||
):
|
||||
"""Initialize the TensorFlow entity."""
|
||||
model_config = config.get(CONF_MODEL)
|
||||
|
@ -169,13 +214,12 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
self._name = name
|
||||
else:
|
||||
self._name = "TensorFlow {}".format(split_entity_id(camera_entity)[1])
|
||||
self._session = session
|
||||
self._graph = detection_graph
|
||||
self._category_index = category_index
|
||||
self._min_confidence = config.get(CONF_CONFIDENCE)
|
||||
self._file_out = config.get(CONF_FILE_OUT)
|
||||
|
||||
# handle categories and specific detection areas
|
||||
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
|
||||
categories = model_config.get(CONF_CATEGORIES)
|
||||
self._include_categories = []
|
||||
self._category_areas = {}
|
||||
|
@ -212,6 +256,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
self._matches = {}
|
||||
self._total_matches = 0
|
||||
self._last_image = None
|
||||
self._process_time = 0
|
||||
|
||||
@property
|
||||
def camera_entity(self):
|
||||
|
@ -237,6 +282,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
category: len(values) for category, values in self._matches.items()
|
||||
},
|
||||
ATTR_TOTAL_MATCHES: self._total_matches,
|
||||
ATTR_PROCESS_TIME: self._process_time,
|
||||
}
|
||||
|
||||
def _save_image(self, image, matches, paths):
|
||||
|
@ -281,10 +327,16 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
|
||||
def process_image(self, image):
|
||||
"""Process the image."""
|
||||
model = self.hass.data[DOMAIN][CONF_MODEL]
|
||||
if not model:
|
||||
_LOGGER.debug("Model not yet ready.")
|
||||
return
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
import cv2 # pylint: disable=import-error, import-outside-toplevel
|
||||
|
||||
# pylint: disable=no-member
|
||||
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
||||
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
||||
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
||||
|
@ -303,15 +355,15 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
)
|
||||
inp_expanded = np.expand_dims(inp, axis=0)
|
||||
|
||||
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
|
||||
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
|
||||
scores = self._graph.get_tensor_by_name("detection_scores:0")
|
||||
classes = self._graph.get_tensor_by_name("detection_classes:0")
|
||||
boxes, scores, classes = self._session.run(
|
||||
[boxes, scores, classes], feed_dict={image_tensor: inp_expanded}
|
||||
)
|
||||
boxes, scores, classes = map(np.squeeze, [boxes, scores, classes])
|
||||
classes = classes.astype(int)
|
||||
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
|
||||
input_tensor = tf.convert_to_tensor(inp_expanded, dtype=tf.float32)
|
||||
|
||||
detections = model(input_tensor)
|
||||
boxes = detections["detection_boxes"][0].numpy()
|
||||
scores = detections["detection_scores"][0].numpy()
|
||||
classes = (
|
||||
detections["detection_classes"][0].numpy() + self._label_id_offset
|
||||
).astype(int)
|
||||
|
||||
matches = {}
|
||||
total_matches = 0
|
||||
|
@ -367,3 +419,4 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||
|
||||
self._matches = matches
|
||||
self._total_matches = total_matches
|
||||
self._process_time = time.perf_counter() - start
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue