hass-core/homeassistant/scripts/db_migrator.py
Nathan Henrie aa079625d4 Don't overwrite the config directory (#2570)
Closes #2566

The `else` seems to have been an error and was overwriting a non-default config directory with the default location.
2016-07-19 21:51:38 -07:00

189 lines
6 KiB
Python

"""Script to convert an old-format home-assistant.db to a new format one."""
import argparse
import os.path
import sqlite3
import sys
try:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
except ImportError:
print('Fatal Error: SQLAlchemy is missing. Install it with '
'"pip3 install SQLAlchemy" before running this script')
sys.exit(1)
from homeassistant.components.recorder import models
import homeassistant.config as config_util
import homeassistant.util.dt as dt_util
def ts_to_dt(timestamp):
"""Turn a datetime into an integer for in the DB."""
if timestamp is None:
return None
return dt_util.utc_from_timestamp(timestamp)
# Based on code at
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
# pylint: disable=too-many-arguments
def print_progress(iteration, total, prefix='', suffix='', decimals=2,
bar_length=68):
"""Print progress bar.
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : number of decimals in percent complete (Int)
barLength - Optional : character length of bar (Int)
"""
filled_length = int(round(bar_length * iteration / float(total)))
percents = round(100.00 * (iteration / float(total)), decimals)
line = '#' * filled_length + '-' * (bar_length - filled_length)
sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, line,
percents, '%', suffix))
sys.stdout.flush()
if iteration == total:
print("\n")
def run(args):
"""The actual script body."""
# pylint: disable=too-many-locals,invalid-name,too-many-statements
parser = argparse.ArgumentParser(
description="Migrate legacy DB to SQLAlchemy format.")
parser.add_argument(
'-c', '--config',
metavar='path_to_config_dir',
default=config_util.get_default_config_dir(),
help="Directory that contains the Home Assistant configuration")
parser.add_argument(
'-a', '--append',
action='store_true',
default=False,
help="Append to existing new format SQLite database")
parser.add_argument(
'--uri',
type=str,
help="Connect to URI and import (implies --append)"
"eg: mysql://localhost/homeassistant")
parser.add_argument(
'--script',
choices=['db_migrator'])
args = parser.parse_args()
config_dir = os.path.join(os.getcwd(), args.config)
# Test if configuration directory exists
if not os.path.isdir(config_dir):
if config_dir != config_util.get_default_config_dir():
print(('Fatal Error: Specified configuration directory does '
'not exist {} ').format(config_dir))
return 1
src_db = '{}/home-assistant.db'.format(config_dir)
dst_db = '{}/home-assistant_v2.db'.format(config_dir)
if not os.path.exists(src_db):
print("Fatal Error: Old format database '{}' does not exist".format(
src_db))
return 1
if not args.uri and (os.path.exists(dst_db) and not args.append):
print("Fatal Error: New format database '{}' exists already - "
"Remove it or use --append".format(dst_db))
print("Note: --append must maintain an ID mapping and is much slower"
"and requires sufficient memory to track all event IDs")
return 1
conn = sqlite3.connect(src_db)
uri = args.uri or "sqlite:///{}".format(dst_db)
engine = create_engine(uri, echo=False)
models.Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine)
session = session_factory()
append = args.append or args.uri
c = conn.cursor()
c.execute("SELECT count(*) FROM recorder_runs")
num_rows = c.fetchone()[0]
print("Converting {} recorder_runs".format(num_rows))
c.close()
c = conn.cursor()
n = 0
for row in c.execute("SELECT * FROM recorder_runs"):
n += 1
session.add(models.RecorderRuns(
start=ts_to_dt(row[1]),
end=ts_to_dt(row[2]),
closed_incorrect=row[3],
created=ts_to_dt(row[4])
))
if n % 1000 == 0:
session.commit()
print_progress(n, num_rows)
print_progress(n, num_rows)
session.commit()
c.close()
c = conn.cursor()
c.execute("SELECT count(*) FROM events")
num_rows = c.fetchone()[0]
print("Converting {} events".format(num_rows))
c.close()
id_mapping = {}
c = conn.cursor()
n = 0
for row in c.execute("SELECT * FROM events"):
n += 1
o = models.Events(
event_type=row[1],
event_data=row[2],
origin=row[3],
created=ts_to_dt(row[4]),
time_fired=ts_to_dt(row[5]),
)
session.add(o)
if append:
session.flush()
id_mapping[row[0]] = o.event_id
if n % 1000 == 0:
session.commit()
print_progress(n, num_rows)
print_progress(n, num_rows)
session.commit()
c.close()
c = conn.cursor()
c.execute("SELECT count(*) FROM states")
num_rows = c.fetchone()[0]
print("Converting {} states".format(num_rows))
c.close()
c = conn.cursor()
n = 0
for row in c.execute("SELECT * FROM states"):
n += 1
session.add(models.States(
entity_id=row[1],
state=row[2],
attributes=row[3],
last_changed=ts_to_dt(row[4]),
last_updated=ts_to_dt(row[5]),
event_id=id_mapping.get(row[6], row[6]),
domain=row[7]
))
if n % 1000 == 0:
session.commit()
print_progress(n, num_rows)
print_progress(n, num_rows)
session.commit()
c.close()
return 0