mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-21 08:17:47 +00:00
Clean up flag code significantly.
This also fixes the False setting bug.
This commit is contained in:
parent
9c6a121644
commit
f7687e0a68
168
discord/flags.py
168
discord/flags.py
@ -29,7 +29,7 @@ __all__ = (
|
||||
'MessageFlags',
|
||||
)
|
||||
|
||||
class _flag_descriptor:
|
||||
class flag_value:
|
||||
def __init__(self, func):
|
||||
self.flag = func(None)
|
||||
self.__doc__ = func.__doc__
|
||||
@ -40,19 +40,70 @@ class _flag_descriptor:
|
||||
def __set__(self, instance, value):
|
||||
instance._set_flag(self.flag, value)
|
||||
|
||||
def fill_with_flags(cls):
|
||||
cls.VALID_FLAGS = {
|
||||
name: value.flag
|
||||
for name, value in cls.__dict__.items()
|
||||
if isinstance(value, _flag_descriptor)
|
||||
}
|
||||
def fill_with_flags(*, inverted=False):
|
||||
def decorator(cls):
|
||||
cls.VALID_FLAGS = {
|
||||
name: value.flag
|
||||
for name, value in cls.__dict__.items()
|
||||
if isinstance(value, flag_value)
|
||||
}
|
||||
|
||||
max_bits = max(cls.VALID_FLAGS.values()).bit_length()
|
||||
cls.ALL_OFF_VALUE = -1 + (2 ** max_bits)
|
||||
return cls
|
||||
if inverted:
|
||||
max_bits = max(cls.VALID_FLAGS.values()).bit_length()
|
||||
cls.DEFAULT_VALUE = -1 + (2 ** max_bits)
|
||||
else:
|
||||
cls.DEFAULT_VALUE = 0
|
||||
|
||||
@fill_with_flags
|
||||
class SystemChannelFlags:
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
# n.b. flags must inherit from this and use the decorator above
|
||||
class BaseFlags:
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError('%r is not a valid flag name.' % key)
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value):
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__) and self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s value=%s>' % (self.__class__.__name__, self.value)
|
||||
|
||||
def __iter__(self):
|
||||
for name, value in self.__class__.__dict__.items():
|
||||
if isinstance(value, flag_value):
|
||||
yield (name, self._has_flag(value.flag))
|
||||
|
||||
def _has_flag(self, o):
|
||||
return (self.value & o) == o
|
||||
|
||||
def _set_flag(self, o, toggle):
|
||||
if toggle is True:
|
||||
self.value |= o
|
||||
elif toggle is False:
|
||||
self.value &= ~o
|
||||
else:
|
||||
raise TypeError('Value to set for %s must be a bool.' % self.__class__.__name__)
|
||||
|
||||
@fill_with_flags(inverted=True)
|
||||
class SystemChannelFlags(BaseFlags):
|
||||
r"""Wraps up a Discord system channel flag value.
|
||||
|
||||
Similar to :class:`Permissions`\, the properties provided are two way.
|
||||
@ -85,37 +136,7 @@ class SystemChannelFlags:
|
||||
representing the currently available flags. You should query
|
||||
flags via the properties rather than using this raw value.
|
||||
"""
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.value = self.ALL_OFF_VALUE
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError('%r is not a valid flag name.' % key)
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value):
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, SystemChannelFlags) and self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return '<SystemChannelFlags value=%s>' % self.value
|
||||
|
||||
def __iter__(self):
|
||||
for name, value in self.__class__.__dict__.items():
|
||||
if isinstance(value, _flag_descriptor):
|
||||
yield (name, self._has_flag(value.flag))
|
||||
__slots__ = ()
|
||||
|
||||
# For some reason the flags for system channels are "inverted"
|
||||
# ergo, if they're set then it means "suppress" (off in the GUI toggle)
|
||||
@ -133,19 +154,19 @@ class SystemChannelFlags:
|
||||
else:
|
||||
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
|
||||
|
||||
@_flag_descriptor
|
||||
@flag_value
|
||||
def join_notifications(self):
|
||||
""":class:`bool`: Returns ``True`` if the system channel is used for member join notifications."""
|
||||
return 1
|
||||
|
||||
@_flag_descriptor
|
||||
@flag_value
|
||||
def premium_subscriptions(self):
|
||||
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
|
||||
return 2
|
||||
|
||||
|
||||
@fill_with_flags
|
||||
class MessageFlags:
|
||||
@fill_with_flags()
|
||||
class MessageFlags(BaseFlags):
|
||||
r"""Wraps up a Discord Message flag value.
|
||||
|
||||
See :class:`SystemChannelFlags`.
|
||||
@ -173,65 +194,24 @@ class MessageFlags:
|
||||
representing the currently available flags. You should query
|
||||
flags via the properties rather than using this raw value.
|
||||
"""
|
||||
__slots__ = ('value',)
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.value = 0
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError('%r is not a valid flag name.' % key)
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
def _from_value(cls, value):
|
||||
self = cls.__new__(cls)
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, MessageFlags) and self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return '<MessageFlags value=%s>' % self.value
|
||||
|
||||
def __iter__(self):
|
||||
for name, value in self.__class__.__dict__.items():
|
||||
if isinstance(value, _flag_descriptor):
|
||||
yield (name, self._has_flag(value.flag))
|
||||
|
||||
def _has_flag(self, o):
|
||||
return (self.value & o) == o
|
||||
|
||||
def _set_flag(self, o, toggle):
|
||||
if toggle is True:
|
||||
self.value |= o
|
||||
elif toggle is False:
|
||||
self.value &= o
|
||||
else:
|
||||
raise TypeError('Value to set for MessageFlags must be a bool.')
|
||||
|
||||
@_flag_descriptor
|
||||
@flag_value
|
||||
def crossposted(self):
|
||||
""":class:`bool`: Returns ``True`` if the message is the original crossposted message."""
|
||||
return 1
|
||||
|
||||
@_flag_descriptor
|
||||
@flag_value
|
||||
def is_crossposted(self):
|
||||
""":class:`bool`: Returns ``True`` if the message was crossposted from another channel."""
|
||||
return 2
|
||||
|
||||
@_flag_descriptor
|
||||
@flag_value
|
||||
def suppress_embeds(self):
|
||||
""":class:`bool`: Returns ``True`` if the message's embeds have been suppressed."""
|
||||
return 4
|
||||
|
||||
@_flag_descriptor
|
||||
|
||||
@flag_value
|
||||
def source_message_deleted(self):
|
||||
""":class:`bool`: Returns ``True`` if the source message for this crosspost has been deleted."""
|
||||
return 8
|
||||
|
Loading…
x
Reference in New Issue
Block a user