Upgrade to TensorFlow 2 (#38384)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Jason Hunter 2020-08-07 02:56:28 -04:00 committed by GitHub
parent 7e34c2582f
commit 3546a82cfb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 115 additions and 48 deletions

View file

@ -89,5 +89,6 @@ jobs:
sed -i "s|# py_noaa|py_noaa|g" ${requirement_file} sed -i "s|# py_noaa|py_noaa|g" ${requirement_file}
sed -i "s|# bme680|bme680|g" ${requirement_file} sed -i "s|# bme680|bme680|g" ${requirement_file}
sed -i "s|# python-gammu|python-gammu|g" ${requirement_file} sed -i "s|# python-gammu|python-gammu|g" ${requirement_file}
sed -i "s|# tf-models-official|tf-models-official|g" ${requirement_file}
done done
displayName: 'Prepare requirements files for Home Assistant wheels' displayName: 'Prepare requirements files for Home Assistant wheels'

View file

@ -3,9 +3,11 @@ import io
import logging import logging
import os import os
import sys import sys
import time
from PIL import Image, ImageDraw, UnidentifiedImageError from PIL import Image, ImageDraw, UnidentifiedImageError
import numpy as np import numpy as np
import tensorflow as tf
import voluptuous as vol import voluptuous as vol
from homeassistant.components.image_processing import ( from homeassistant.components.image_processing import (
@ -16,16 +18,21 @@ from homeassistant.components.image_processing import (
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
ImageProcessingEntity, ImageProcessingEntity,
) )
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import split_entity_id from homeassistant.core import split_entity_id
from homeassistant.helpers import template from homeassistant.helpers import template
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.pil import draw_box from homeassistant.util.pil import draw_box
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
DOMAIN = "tensorflow"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_MATCHES = "matches" ATTR_MATCHES = "matches"
ATTR_SUMMARY = "summary" ATTR_SUMMARY = "summary"
ATTR_TOTAL_MATCHES = "total_matches" ATTR_TOTAL_MATCHES = "total_matches"
ATTR_PROCESS_TIME = "process_time"
CONF_AREA = "area" CONF_AREA = "area"
CONF_BOTTOM = "bottom" CONF_BOTTOM = "bottom"
@ -34,6 +41,7 @@ CONF_CATEGORY = "category"
CONF_FILE_OUT = "file_out" CONF_FILE_OUT = "file_out"
CONF_GRAPH = "graph" CONF_GRAPH = "graph"
CONF_LABELS = "labels" CONF_LABELS = "labels"
CONF_LABEL_OFFSET = "label_offset"
CONF_LEFT = "left" CONF_LEFT = "left"
CONF_MODEL = "model" CONF_MODEL = "model"
CONF_MODEL_DIR = "model_dir" CONF_MODEL_DIR = "model_dir"
@ -58,12 +66,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]), vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
vol.Required(CONF_MODEL): vol.Schema( vol.Required(CONF_MODEL): vol.Schema(
{ {
vol.Required(CONF_GRAPH): cv.isfile, vol.Required(CONF_GRAPH): cv.isdir,
vol.Optional(CONF_AREA): AREA_SCHEMA, vol.Optional(CONF_AREA): AREA_SCHEMA,
vol.Optional(CONF_CATEGORIES, default=[]): vol.All( vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)] cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
), ),
vol.Optional(CONF_LABELS): cv.isfile, vol.Optional(CONF_LABELS): cv.isfile,
vol.Optional(CONF_LABEL_OFFSET, default=1): int,
vol.Optional(CONF_MODEL_DIR): cv.isdir, vol.Optional(CONF_MODEL_DIR): cv.isdir,
} }
), ),
@ -71,17 +80,40 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
def get_model_detection_function(model):
"""Get a tf.function for detection."""
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections
return detect_fn
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
"""Set up the TensorFlow image processing platform.""" """Set up the TensorFlow image processing platform."""
model_config = config.get(CONF_MODEL) model_config = config[CONF_MODEL]
model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow") model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow")
labels = model_config.get(CONF_LABELS) or hass.config.path( labels = model_config.get(CONF_LABELS) or hass.config.path(
"tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt" "tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt"
) )
checkpoint = os.path.join(model_config[CONF_GRAPH], "checkpoint")
pipeline_config = os.path.join(model_config[CONF_GRAPH], "pipeline.config")
# Make sure locations exist # Make sure locations exist
if not os.path.isdir(model_dir) or not os.path.exists(labels): if (
_LOGGER.error("Unable to locate tensorflow models or label map") not os.path.isdir(model_dir)
or not os.path.isdir(checkpoint)
or not os.path.exists(pipeline_config)
or not os.path.exists(labels)
):
_LOGGER.error("Unable to locate tensorflow model or label map")
return return
# append custom model path to sys.path # append custom model path to sys.path
@ -89,18 +121,17 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
try: try:
# Verify that the TensorFlow Object Detection API is pre-installed # Verify that the TensorFlow Object Detection API is pre-installed
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# These imports shouldn't be moved to the top, because they depend on code from the model_dir. # These imports shouldn't be moved to the top, because they depend on code from the model_dir.
# (The model_dir is created during the manual setup process. See integration docs.) # (The model_dir is created during the manual setup process. See integration docs.)
import tensorflow as tf # pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from object_detection.utils import label_map_util from object_detection.utils import config_util, label_map_util
from object_detection.builders import model_builder
except ImportError: except ImportError:
_LOGGER.error( _LOGGER.error(
"No TensorFlow Object Detection library found! Install or compile " "No TensorFlow Object Detection library found! Install or compile "
"for your system following instructions here: " "for your system following instructions here: "
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md" "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#installation"
) )
return return
@ -113,22 +144,45 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
"PIL at reduced resolution" "PIL at reduced resolution"
) )
# Set up Tensorflow graph, session, and label map to pass to processor hass.data[DOMAIN] = {CONF_MODEL: None}
# pylint: disable=no-member
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(model_config.get(CONF_GRAPH), "rb") as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name="")
session = tf.Session(graph=detection_graph) def tensorflow_hass_start(_event):
label_map = label_map_util.load_labelmap(labels) """Set up TensorFlow model on hass start."""
categories = label_map_util.convert_label_map_to_categories( start = time.perf_counter()
label_map, max_num_classes=90, use_display_name=True
# Load pipeline config and build a detection model
pipeline_configs = config_util.get_configs_from_pipeline_file(pipeline_config)
detection_model = model_builder.build(
model_config=pipeline_configs["model"], is_training=False
)
# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(checkpoint, "ckpt-0")).expect_partial()
_LOGGER.debug(
"Model checkpoint restore took %d seconds", time.perf_counter() - start
)
model = get_model_detection_function(detection_model)
# Preload model cache with empty image tensor
inp = np.zeros([2160, 3840, 3], dtype=np.uint8)
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
input_tensor = tf.convert_to_tensor(inp, dtype=tf.float32)
# The model expects a batch of images, so add an axis with `tf.newaxis`.
input_tensor = input_tensor[tf.newaxis, ...]
# Run inference
model(input_tensor)
_LOGGER.debug("Model load took %d seconds", time.perf_counter() - start)
hass.data[DOMAIN][CONF_MODEL] = model
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
category_index = label_map_util.create_category_index_from_labelmap(
labels, use_display_name=True
) )
category_index = label_map_util.create_category_index(categories)
entities = [] entities = []
@ -138,8 +192,6 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass, hass,
camera[CONF_ENTITY_ID], camera[CONF_ENTITY_ID],
camera.get(CONF_NAME), camera.get(CONF_NAME),
session,
detection_graph,
category_index, category_index,
config, config,
) )
@ -152,14 +204,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
"""Representation of an TensorFlow image processor.""" """Representation of an TensorFlow image processor."""
def __init__( def __init__(
self, self, hass, camera_entity, name, category_index, config,
hass,
camera_entity,
name,
session,
detection_graph,
category_index,
config,
): ):
"""Initialize the TensorFlow entity.""" """Initialize the TensorFlow entity."""
model_config = config.get(CONF_MODEL) model_config = config.get(CONF_MODEL)
@ -169,13 +214,12 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
self._name = name self._name = name
else: else:
self._name = "TensorFlow {}".format(split_entity_id(camera_entity)[1]) self._name = "TensorFlow {}".format(split_entity_id(camera_entity)[1])
self._session = session
self._graph = detection_graph
self._category_index = category_index self._category_index = category_index
self._min_confidence = config.get(CONF_CONFIDENCE) self._min_confidence = config.get(CONF_CONFIDENCE)
self._file_out = config.get(CONF_FILE_OUT) self._file_out = config.get(CONF_FILE_OUT)
# handle categories and specific detection areas # handle categories and specific detection areas
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
categories = model_config.get(CONF_CATEGORIES) categories = model_config.get(CONF_CATEGORIES)
self._include_categories = [] self._include_categories = []
self._category_areas = {} self._category_areas = {}
@ -212,6 +256,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
self._matches = {} self._matches = {}
self._total_matches = 0 self._total_matches = 0
self._last_image = None self._last_image = None
self._process_time = 0
@property @property
def camera_entity(self): def camera_entity(self):
@ -237,6 +282,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
category: len(values) for category, values in self._matches.items() category: len(values) for category, values in self._matches.items()
}, },
ATTR_TOTAL_MATCHES: self._total_matches, ATTR_TOTAL_MATCHES: self._total_matches,
ATTR_PROCESS_TIME: self._process_time,
} }
def _save_image(self, image, matches, paths): def _save_image(self, image, matches, paths):
@ -281,10 +327,16 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
def process_image(self, image): def process_image(self, image):
"""Process the image.""" """Process the image."""
model = self.hass.data[DOMAIN][CONF_MODEL]
if not model:
_LOGGER.debug("Model not yet ready.")
return
start = time.perf_counter()
try: try:
import cv2 # pylint: disable=import-error, import-outside-toplevel import cv2 # pylint: disable=import-error, import-outside-toplevel
# pylint: disable=no-member
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED) img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3) inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
@ -303,15 +355,15 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
) )
inp_expanded = np.expand_dims(inp, axis=0) inp_expanded = np.expand_dims(inp, axis=0)
image_tensor = self._graph.get_tensor_by_name("image_tensor:0") # The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
boxes = self._graph.get_tensor_by_name("detection_boxes:0") input_tensor = tf.convert_to_tensor(inp_expanded, dtype=tf.float32)
scores = self._graph.get_tensor_by_name("detection_scores:0")
classes = self._graph.get_tensor_by_name("detection_classes:0") detections = model(input_tensor)
boxes, scores, classes = self._session.run( boxes = detections["detection_boxes"][0].numpy()
[boxes, scores, classes], feed_dict={image_tensor: inp_expanded} scores = detections["detection_scores"][0].numpy()
) classes = (
boxes, scores, classes = map(np.squeeze, [boxes, scores, classes]) detections["detection_classes"][0].numpy() + self._label_id_offset
classes = classes.astype(int) ).astype(int)
matches = {} matches = {}
total_matches = 0 total_matches = 0
@ -367,3 +419,4 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
self._matches = matches self._matches = matches
self._total_matches = total_matches self._total_matches = total_matches
self._process_time = time.perf_counter() - start

View file

@ -3,9 +3,12 @@
"name": "TensorFlow", "name": "TensorFlow",
"documentation": "https://www.home-assistant.io/integrations/tensorflow", "documentation": "https://www.home-assistant.io/integrations/tensorflow",
"requirements": [ "requirements": [
"tensorflow==1.13.2", "tensorflow==2.2.0",
"tf-slim==1.1.0",
"tf-models-official==2.2.1",
"pycocotools==2.0.1",
"numpy==1.19.1", "numpy==1.19.1",
"protobuf==3.6.1", "protobuf==3.12.2",
"pillow==7.1.2" "pillow==7.1.2"
], ],
"codeowners": [] "codeowners": []

View file

@ -5,7 +5,7 @@ ignore=tests
jobs=2 jobs=2
load-plugins=pylint_strict_informational load-plugins=pylint_strict_informational
persistent=no persistent=no
extension-pkg-whitelist=ciso8601 extension-pkg-whitelist=ciso8601,cv2
[BASIC] [BASIC]
good-names=id,i,j,k,ex,Run,_,fp,T,ev good-names=id,i,j,k,ex,Run,_,fp,T,ev

View file

@ -1120,7 +1120,7 @@ proliphix==0.4.1
prometheus_client==0.7.1 prometheus_client==0.7.1
# homeassistant.components.tensorflow # homeassistant.components.tensorflow
protobuf==3.6.1 protobuf==3.12.2
# homeassistant.components.proxmoxve # homeassistant.components.proxmoxve
proxmoxer==1.1.1 proxmoxer==1.1.1
@ -1261,6 +1261,9 @@ pychromecast==7.2.0
# homeassistant.components.cmus # homeassistant.components.cmus
pycmus==0.1.1 pycmus==0.1.1
# homeassistant.components.tensorflow
pycocotools==2.0.1
# homeassistant.components.comfoconnect # homeassistant.components.comfoconnect
pycomfoconnect==0.3 pycomfoconnect==0.3
@ -2098,7 +2101,7 @@ temescal==0.1
temperusb==1.5.3 temperusb==1.5.3
# homeassistant.components.tensorflow # homeassistant.components.tensorflow
# tensorflow==1.13.2 # tensorflow==2.2.0
# homeassistant.components.powerwall # homeassistant.components.powerwall
tesla-powerwall==0.2.12 tesla-powerwall==0.2.12
@ -2106,6 +2109,12 @@ tesla-powerwall==0.2.12
# homeassistant.components.tesla # homeassistant.components.tesla
teslajsonpy==0.10.1 teslajsonpy==0.10.1
# homeassistant.components.tensorflow
# tf-models-official==2.2.1
# homeassistant.components.tensorflow
tf-slim==1.1.0
# homeassistant.components.thermoworks_smoke # homeassistant.components.thermoworks_smoke
thermoworks_smoke==0.1.8 thermoworks_smoke==0.1.8

View file

@ -41,6 +41,7 @@ COMMENT_REQUIREMENTS = (
"RPi.GPIO", "RPi.GPIO",
"smbus-cffi", "smbus-cffi",
"tensorflow", "tensorflow",
"tf-models-official",
"VL53L1X2", "VL53L1X2",
) )