Refactor from_components for Select classes

Co-authored-by: Danny <1695103+Rapptz@users.noreply.github.com>
This commit is contained in:
Soheab_ 2023-10-01 03:21:29 +02:00 committed by GitHub
parent 5c5ccc4e82
commit 698363e76b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 29 deletions

View File

@ -216,6 +216,13 @@ class BaseSelect(Item[V]):
'max_values', 'max_values',
'disabled', 'disabled',
) )
__component_attributes__: Tuple[str, ...] = (
'custom_id',
'placeholder',
'min_values',
'max_values',
'disabled',
)
def __init__( def __init__(
self, self,
@ -336,11 +343,16 @@ class BaseSelect(Item[V]):
@classmethod @classmethod
def from_component(cls, component: SelectMenu) -> Self: def from_component(cls, component: SelectMenu) -> Self:
return cls( type_to_cls: Dict[ComponentType, Type[BaseSelect[Any]]] = {
**{k: getattr(component, k) for k in cls.__item_repr_attributes__}, ComponentType.string_select: Select,
custom_id=component.custom_id, ComponentType.user_select: UserSelect,
row=None, ComponentType.role_select: RoleSelect,
) ComponentType.channel_select: ChannelSelect,
ComponentType.mentionable_select: MentionableSelect,
}
constructor = type_to_cls.get(component.type, Select)
kwrgs = {key: getattr(component, key) for key in constructor.__component_attributes__}
return constructor(**kwrgs)
class Select(BaseSelect[V]): class Select(BaseSelect[V]):
@ -374,7 +386,7 @@ class Select(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
""" """
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('options',) __component_attributes__ = BaseSelect.__component_attributes__ + ('options',)
def __init__( def __init__(
self, self,
@ -525,7 +537,7 @@ class UserSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
""" """
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',)
def __init__( def __init__(
self, self,
@ -614,7 +626,7 @@ class RoleSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
""" """
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',)
def __init__( def __init__(
self, self,
@ -699,7 +711,7 @@ class MentionableSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
""" """
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ('default_values',) __component_attributes__ = BaseSelect.__component_attributes__ + ('default_values',)
def __init__( def __init__(
self, self,
@ -790,7 +802,7 @@ class ChannelSelect(BaseSelect[V]):
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
""" """
__item_repr_attributes__ = BaseSelect.__item_repr_attributes__ + ( __component_attributes__ = BaseSelect.__component_attributes__ + (
'channel_types', 'channel_types',
'default_values', 'default_values',
) )

View File

@ -34,7 +34,6 @@ import time
import os import os
from .item import Item, ItemCallbackType from .item import Item, ItemCallbackType
from .dynamic import DynamicItem from .dynamic import DynamicItem
from ..enums import ComponentType
from ..components import ( from ..components import (
Component, Component,
ActionRow as ActionRowComponent, ActionRow as ActionRowComponent,
@ -79,26 +78,10 @@ def _component_to_item(component: Component) -> Item:
return Button.from_component(component) return Button.from_component(component)
if isinstance(component, SelectComponent): if isinstance(component, SelectComponent):
if component.type is ComponentType.select: from .select import BaseSelect
from .select import Select
return Select.from_component(component) return BaseSelect.from_component(component)
elif component.type is ComponentType.user_select:
from .select import UserSelect
return UserSelect.from_component(component)
elif component.type is ComponentType.mentionable_select:
from .select import MentionableSelect
return MentionableSelect.from_component(component)
elif component.type is ComponentType.channel_select:
from .select import ChannelSelect
return ChannelSelect().from_component(component)
elif component.type is ComponentType.role_select:
from .select import RoleSelect
return RoleSelect.from_component(component)
return Item.from_component(component) return Item.from_component(component)