From 542aef2fe1ceebdb20df1aeb255e242a0e444df7 Mon Sep 17 00:00:00 2001
From: Franck Nijhof <git@frenck.dev>
Date: Wed, 1 Dec 2021 00:38:45 +0100
Subject: [PATCH] Migrate switch device classes to StrEnum (#60658)

---
 homeassistant/components/demo/switch.py       |  6 ++--
 homeassistant/components/huawei_lte/switch.py |  8 ++---
 homeassistant/components/switch/__init__.py   | 36 +++++++++++++++----
 homeassistant/components/upcloud/__init__.py  |  3 --
 4 files changed, 34 insertions(+), 19 deletions(-)

diff --git a/homeassistant/components/demo/switch.py b/homeassistant/components/demo/switch.py
index dd969010bd7..09389f7c8cd 100644
--- a/homeassistant/components/demo/switch.py
+++ b/homeassistant/components/demo/switch.py
@@ -1,7 +1,7 @@
 """Demo platform that has two fake switches."""
 from __future__ import annotations
 
-from homeassistant.components.switch import SwitchEntity
+from homeassistant.components.switch import SwitchDeviceClass, SwitchEntity
 from homeassistant.const import DEVICE_DEFAULT_NAME
 from homeassistant.helpers.entity import DeviceInfo
 
@@ -19,7 +19,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
                 False,
                 "mdi:air-conditioner",
                 False,
-                device_class="outlet",
+                device_class=SwitchDeviceClass.OUTLET,
             ),
         ]
     )
@@ -42,7 +42,7 @@ class DemoSwitch(SwitchEntity):
         state: bool,
         icon: str | None,
         assumed: bool,
-        device_class: str | None = None,
+        device_class: SwitchDeviceClass | None = None,
     ) -> None:
         """Initialize the Demo switch."""
         self._attr_assumed_state = assumed
diff --git a/homeassistant/components/huawei_lte/switch.py b/homeassistant/components/huawei_lte/switch.py
index a4fd393346c..af2a382c4db 100644
--- a/homeassistant/components/huawei_lte/switch.py
+++ b/homeassistant/components/huawei_lte/switch.py
@@ -7,8 +7,8 @@ from typing import Any
 import attr
 
 from homeassistant.components.switch import (
-    DEVICE_CLASS_SWITCH,
     DOMAIN as SWITCH_DOMAIN,
+    SwitchDeviceClass,
     SwitchEntity,
 )
 from homeassistant.config_entries import ConfigEntry
@@ -43,6 +43,7 @@ class HuaweiLteBaseSwitch(HuaweiLteBaseEntity, SwitchEntity):
 
     key: str
     item: str
+    _attr_device_class = SwitchDeviceClass.SWITCH
     _raw_state: str | None = attr.ib(init=False, default=None)
 
     def _turn(self, state: bool) -> None:
@@ -56,11 +57,6 @@ class HuaweiLteBaseSwitch(HuaweiLteBaseEntity, SwitchEntity):
         """Turn switch off."""
         self._turn(state=False)
 
-    @property
-    def device_class(self) -> str:
-        """Return device class."""
-        return DEVICE_CLASS_SWITCH
-
     async def async_added_to_hass(self) -> None:
         """Subscribe to needed data on add."""
         await super().async_added_to_hass()
diff --git a/homeassistant/components/switch/__init__.py b/homeassistant/components/switch/__init__.py
index b023b819a53..fec5a58f414 100644
--- a/homeassistant/components/switch/__init__.py
+++ b/homeassistant/components/switch/__init__.py
@@ -24,6 +24,7 @@ from homeassistant.helpers.entity import ToggleEntity, ToggleEntityDescription
 from homeassistant.helpers.entity_component import EntityComponent
 from homeassistant.helpers.typing import ConfigType
 from homeassistant.loader import bind_hass
+from homeassistant.util.enum import StrEnum
 
 DOMAIN = "switch"
 SCAN_INTERVAL = timedelta(seconds=30)
@@ -40,16 +41,25 @@ PROP_TO_ATTR = {
     "today_energy_kwh": ATTR_TODAY_ENERGY_KWH,
 }
 
-DEVICE_CLASS_OUTLET = "outlet"
-DEVICE_CLASS_SWITCH = "switch"
-
-DEVICE_CLASSES = [DEVICE_CLASS_OUTLET, DEVICE_CLASS_SWITCH]
-
-DEVICE_CLASSES_SCHEMA = vol.All(vol.Lower, vol.In(DEVICE_CLASSES))
-
 _LOGGER = logging.getLogger(__name__)
 
 
+class SwitchDeviceClass(StrEnum):
+    """Device class for switches."""
+
+    OUTLET = "outlet"
+    SWITCH = "switch"
+
+
+DEVICE_CLASSES_SCHEMA = vol.All(vol.Lower, vol.Coerce(SwitchDeviceClass))
+
+# DEVICE_CLASS* below are deprecated as of 2021.12
+# use the SwitchDeviceClass enum instead.
+DEVICE_CLASSES = [cls.value for cls in SwitchDeviceClass]
+DEVICE_CLASS_OUTLET = SwitchDeviceClass.OUTLET.value
+DEVICE_CLASS_SWITCH = SwitchDeviceClass.SWITCH.value
+
+
 @bind_hass
 def is_on(hass: HomeAssistant, entity_id: str) -> bool:
     """Return if the switch is on based on the statemachine.
@@ -89,12 +99,15 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
 class SwitchEntityDescription(ToggleEntityDescription):
     """A class that describes switch entities."""
 
+    device_class: SwitchDeviceClass | str | None = None
+
 
 class SwitchEntity(ToggleEntity):
     """Base class for switch entities."""
 
     entity_description: SwitchEntityDescription
     _attr_current_power_w: float | None = None
+    _attr_device_class: SwitchDeviceClass | str | None
     _attr_today_energy_kwh: float | None = None
 
     @property
@@ -102,6 +115,15 @@ class SwitchEntity(ToggleEntity):
         """Return the current power usage in W."""
         return self._attr_current_power_w
 
+    @property
+    def device_class(self) -> SwitchDeviceClass | str | None:
+        """Return the class of this entity."""
+        if hasattr(self, "_attr_device_class"):
+            return self._attr_device_class
+        if hasattr(self, "entity_description"):
+            return self.entity_description.device_class
+        return None
+
     @property
     def today_energy_kwh(self) -> float | None:
         """Return the today total energy usage in kWh."""
diff --git a/homeassistant/components/upcloud/__init__.py b/homeassistant/components/upcloud/__init__.py
index 82d42e28589..2ecc6ec7522 100644
--- a/homeassistant/components/upcloud/__init__.py
+++ b/homeassistant/components/upcloud/__init__.py
@@ -47,7 +47,6 @@ CONF_SERVERS = "servers"
 DATA_UPCLOUD = "data_upcloud"
 
 DEFAULT_COMPONENT_NAME = "UpCloud {}"
-DEFAULT_COMPONENT_DEVICE_CLASS = "power"
 
 CONFIG_ENTRY_DOMAINS = {BINARY_SENSOR_DOMAIN, SWITCH_DOMAIN}
 
@@ -177,8 +176,6 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
 class UpCloudServerEntity(CoordinatorEntity):
     """Entity class for UpCloud servers."""
 
-    _attr_device_class = DEFAULT_COMPONENT_DEVICE_CLASS
-
     def __init__(
         self,
         coordinator: DataUpdateCoordinator[dict[str, upcloud_api.Server]],