mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-22 16:32:59 +00:00 
			
		
		
		
	Add typing for flags
This commit is contained in:
		| @@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |||||||
| DEALINGS IN THE SOFTWARE. | DEALINGS IN THE SOFTWARE. | ||||||
| """ | """ | ||||||
|  |  | ||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload | ||||||
|  |  | ||||||
| from .enums import UserFlags | from .enums import UserFlags | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = ( | ||||||
| @@ -32,17 +36,28 @@ __all__ = ( | |||||||
|     'MemberCacheFlags', |     'MemberCacheFlags', | ||||||
| ) | ) | ||||||
|  |  | ||||||
| class flag_value: | FV = TypeVar('FV', bound='flag_value') | ||||||
|     def __init__(self, func): | BF = TypeVar('BF', bound='BaseFlags') | ||||||
|  |  | ||||||
|  | class flag_value(Generic[BF]): | ||||||
|  |     def __init__(self, func: Callable[[Any], int]): | ||||||
|         self.flag = func(None) |         self.flag = func(None) | ||||||
|         self.__doc__ = func.__doc__ |         self.__doc__ = func.__doc__ | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |     @overload | ||||||
|  |     def __get__(self: FV, instance: None, owner: Type[BF]) -> FV: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     @overload | ||||||
|  |     def __get__(self, instance: BF, owner: Type[BF]) -> bool: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def __get__(self, instance: Optional[BF], owner: Type[BF]) -> Any: | ||||||
|         if instance is None: |         if instance is None: | ||||||
|             return self |             return self | ||||||
|         return instance._has_flag(self.flag) |         return instance._has_flag(self.flag) | ||||||
|  |  | ||||||
|     def __set__(self, instance, value): |     def __set__(self, instance: BF, value: bool) -> None: | ||||||
|         instance._set_flag(self.flag, value) |         instance._set_flag(self.flag, value) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
| @@ -51,8 +66,8 @@ class flag_value: | |||||||
| class alias_flag_value(flag_value): | class alias_flag_value(flag_value): | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
| def fill_with_flags(*, inverted=False): | def fill_with_flags(*, inverted: bool = False): | ||||||
|     def decorator(cls): |     def decorator(cls: Type[BF]): | ||||||
|         cls.VALID_FLAGS = { |         cls.VALID_FLAGS = { | ||||||
|             name: value.flag |             name: value.flag | ||||||
|             for name, value in cls.__dict__.items() |             for name, value in cls.__dict__.items() | ||||||
| @@ -70,9 +85,14 @@ def fill_with_flags(*, inverted=False): | |||||||
|  |  | ||||||
| # n.b. flags must inherit from this and use the decorator above | # n.b. flags must inherit from this and use the decorator above | ||||||
| class BaseFlags: | class BaseFlags: | ||||||
|  |     VALID_FLAGS: ClassVar[Dict[str, int]] | ||||||
|  |     DEFAULT_VALUE: ClassVar[int] | ||||||
|  |  | ||||||
|  |     value: int | ||||||
|  |  | ||||||
|     __slots__ = ('value',) |     __slots__ = ('value',) | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs: bool): | ||||||
|         self.value = self.DEFAULT_VALUE |         self.value = self.DEFAULT_VALUE | ||||||
|         for key, value in kwargs.items(): |         for key, value in kwargs.items(): | ||||||
|             if key not in self.VALID_FLAGS: |             if key not in self.VALID_FLAGS: | ||||||
| @@ -85,19 +105,19 @@ class BaseFlags: | |||||||
|         self.value = value |         self.value = value | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def __eq__(self, other): |     def __eq__(self, other: Any) -> bool: | ||||||
|         return isinstance(other, self.__class__) and self.value == other.value |         return isinstance(other, self.__class__) and self.value == other.value | ||||||
|  |  | ||||||
|     def __ne__(self, other): |     def __ne__(self, other: Any) -> bool: | ||||||
|         return not self.__eq__(other) |         return not self.__eq__(other) | ||||||
|  |  | ||||||
|     def __hash__(self): |     def __hash__(self) -> int: | ||||||
|         return hash(self.value) |         return hash(self.value) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self) -> str: | ||||||
|         return f'<{self.__class__.__name__} value={self.value}>' |         return f'<{self.__class__.__name__} value={self.value}>' | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self) -> Iterator[Tuple[str, bool]]: | ||||||
|         for name, value in self.__class__.__dict__.items(): |         for name, value in self.__class__.__dict__.items(): | ||||||
|             if isinstance(value, alias_flag_value): |             if isinstance(value, alias_flag_value): | ||||||
|                 continue |                 continue | ||||||
| @@ -105,10 +125,10 @@ class BaseFlags: | |||||||
|             if isinstance(value, flag_value): |             if isinstance(value, flag_value): | ||||||
|                 yield (name, self._has_flag(value.flag)) |                 yield (name, self._has_flag(value.flag)) | ||||||
|  |  | ||||||
|     def _has_flag(self, o): |     def _has_flag(self, o: int) -> bool: | ||||||
|         return (self.value & o) == o |         return (self.value & o) == o | ||||||
|  |  | ||||||
|     def _set_flag(self, o, toggle): |     def _set_flag(self, o: int, toggle: bool) -> None: | ||||||
|         if toggle is True: |         if toggle is True: | ||||||
|             self.value |= o |             self.value |= o | ||||||
|         elif toggle is False: |         elif toggle is False: | ||||||
| @@ -150,6 +170,7 @@ class SystemChannelFlags(BaseFlags): | |||||||
|         representing the currently available flags. You should query |         representing the currently available flags. You should query | ||||||
|         flags via the properties rather than using this raw value. |         flags via the properties rather than using this raw value. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     __slots__ = () |     __slots__ = () | ||||||
|  |  | ||||||
|     # For some reason the flags for system channels are "inverted" |     # For some reason the flags for system channels are "inverted" | ||||||
| @@ -157,10 +178,10 @@ class SystemChannelFlags(BaseFlags): | |||||||
|     # Since this is counter-intuitive from an API perspective and annoying |     # Since this is counter-intuitive from an API perspective and annoying | ||||||
|     # these will be inverted automatically |     # these will be inverted automatically | ||||||
|  |  | ||||||
|     def _has_flag(self, o): |     def _has_flag(self, o: int) -> bool: | ||||||
|         return (self.value & o) != o |         return (self.value & o) != o | ||||||
|  |  | ||||||
|     def _set_flag(self, o, toggle): |     def _set_flag(self, o: int, toggle: bool) -> None: | ||||||
|         if toggle is True: |         if toggle is True: | ||||||
|             self.value &= ~o |             self.value &= ~o | ||||||
|         elif toggle is False: |         elif toggle is False: | ||||||
| @@ -210,6 +231,7 @@ class MessageFlags(BaseFlags): | |||||||
|         representing the currently available flags. You should query |         representing the currently available flags. You should query | ||||||
|         flags via the properties rather than using this raw value. |         flags via the properties rather than using this raw value. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     __slots__ = () |     __slots__ = () | ||||||
|  |  | ||||||
|     @flag_value |     @flag_value | ||||||
| @@ -346,7 +368,7 @@ class PublicUserFlags(BaseFlags): | |||||||
|         """ |         """ | ||||||
|         return UserFlags.verified_bot_developer.value |         return UserFlags.verified_bot_developer.value | ||||||
|  |  | ||||||
|     def all(self): |     def all(self) -> List[UserFlags]: | ||||||
|         """List[:class:`UserFlags`]: Returns all public flags the user has.""" |         """List[:class:`UserFlags`]: Returns all public flags the user has.""" | ||||||
|         return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] |         return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] | ||||||
|  |  | ||||||
| @@ -393,7 +415,7 @@ class Intents(BaseFlags): | |||||||
|  |  | ||||||
|     __slots__ = () |     __slots__ = () | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs: bool): | ||||||
|         self.value = self.DEFAULT_VALUE |         self.value = self.DEFAULT_VALUE | ||||||
|         for key, value in kwargs.items(): |         for key, value in kwargs.items(): | ||||||
|             if key not in self.VALID_FLAGS: |             if key not in self.VALID_FLAGS: | ||||||
| @@ -401,7 +423,7 @@ class Intents(BaseFlags): | |||||||
|             setattr(self, key, value) |             setattr(self, key, value) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def all(cls): |     def all(cls: Type[Intents]) -> Intents: | ||||||
|         """A factory method that creates a :class:`Intents` with everything enabled.""" |         """A factory method that creates a :class:`Intents` with everything enabled.""" | ||||||
|         bits = max(cls.VALID_FLAGS.values()).bit_length() |         bits = max(cls.VALID_FLAGS.values()).bit_length() | ||||||
|         value = (1 << bits) - 1 |         value = (1 << bits) - 1 | ||||||
| @@ -410,14 +432,14 @@ class Intents(BaseFlags): | |||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def none(cls): |     def none(cls: Type[Intents]) -> Intents: | ||||||
|         """A factory method that creates a :class:`Intents` with everything disabled.""" |         """A factory method that creates a :class:`Intents` with everything disabled.""" | ||||||
|         self = cls.__new__(cls) |         self = cls.__new__(cls) | ||||||
|         self.value = self.DEFAULT_VALUE |         self.value = self.DEFAULT_VALUE | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def default(cls): |     def default(cls: Type[Intents]) -> Intents: | ||||||
|         """A factory method that creates a :class:`Intents` with everything enabled |         """A factory method that creates a :class:`Intents` with everything enabled | ||||||
|         except :attr:`presences` and :attr:`members`. |         except :attr:`presences` and :attr:`members`. | ||||||
|         """ |         """ | ||||||
| @@ -825,7 +847,7 @@ class MemberCacheFlags(BaseFlags): | |||||||
|  |  | ||||||
|     __slots__ = () |     __slots__ = () | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs: bool): | ||||||
|         bits = max(self.VALID_FLAGS.values()).bit_length() |         bits = max(self.VALID_FLAGS.values()).bit_length() | ||||||
|         self.value = (1 << bits) - 1 |         self.value = (1 << bits) - 1 | ||||||
|         for key, value in kwargs.items(): |         for key, value in kwargs.items(): | ||||||
| @@ -834,7 +856,7 @@ class MemberCacheFlags(BaseFlags): | |||||||
|             setattr(self, key, value) |             setattr(self, key, value) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def all(cls): |     def all(cls: Type[MemberCacheFlags]) -> MemberCacheFlags: | ||||||
|         """A factory method that creates a :class:`MemberCacheFlags` with everything enabled.""" |         """A factory method that creates a :class:`MemberCacheFlags` with everything enabled.""" | ||||||
|         bits = max(cls.VALID_FLAGS.values()).bit_length() |         bits = max(cls.VALID_FLAGS.values()).bit_length() | ||||||
|         value = (1 << bits) - 1 |         value = (1 << bits) - 1 | ||||||
| @@ -843,7 +865,7 @@ class MemberCacheFlags(BaseFlags): | |||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def none(cls): |     def none(cls: Type[MemberCacheFlags]) -> MemberCacheFlags: | ||||||
|         """A factory method that creates a :class:`MemberCacheFlags` with everything disabled.""" |         """A factory method that creates a :class:`MemberCacheFlags` with everything disabled.""" | ||||||
|         self = cls.__new__(cls) |         self = cls.__new__(cls) | ||||||
|         self.value = self.DEFAULT_VALUE |         self.value = self.DEFAULT_VALUE | ||||||
| @@ -886,7 +908,7 @@ class MemberCacheFlags(BaseFlags): | |||||||
|         return 4 |         return 4 | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def from_intents(cls, intents): |     def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags: | ||||||
|         """A factory method that creates a :class:`MemberCacheFlags` based on |         """A factory method that creates a :class:`MemberCacheFlags` based on | ||||||
|         the currently selected :class:`Intents`. |         the currently selected :class:`Intents`. | ||||||
|  |  | ||||||
| @@ -914,7 +936,7 @@ class MemberCacheFlags(BaseFlags): | |||||||
|  |  | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|     def _verify_intents(self, intents): |     def _verify_intents(self, intents: Intents): | ||||||
|         if self.online and not intents.presences: |         if self.online and not intents.presences: | ||||||
|             raise ValueError('MemberCacheFlags.online requires Intents.presences enabled') |             raise ValueError('MemberCacheFlags.online requires Intents.presences enabled') | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user