Run black on the repository, with the default configuration. #43

Closed
paris-ci wants to merge 1 commits from black into 2.0
107 changed files with 8671 additions and 5258 deletions

View File

@ -9,13 +9,13 @@ A basic wrapper for the Discord API.
"""
__title__ = 'discord'
__author__ = 'Rapptz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2015-present Rapptz'
__version__ = '2.0.0a'
__title__ = "discord"
__author__ = "Rapptz"
__license__ = "MIT"
__copyright__ = "Copyright 2015-present Rapptz"
__version__ = "2.0.0a"
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
import logging
from typing import NamedTuple, Literal
@ -69,6 +69,8 @@ class VersionInfo(NamedTuple):
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
version_info: VersionInfo = VersionInfo(
major=2, minor=0, micro=0, releaselevel="alpha", serial=0
)
logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@ -31,26 +31,37 @@ import pkg_resources
import aiohttp
import platform
def show_version():
entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
entries.append(
"- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(
sys.version_info
)
)
version_info = discord.version_info
entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info))
if version_info.releaselevel != 'final':
pkg = pkg_resources.get_distribution('discord.py')
entries.append(
"- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(
version_info
)
)
if version_info.releaselevel != "final":
pkg = pkg_resources.get_distribution("discord.py")
if pkg:
entries.append(f' - discord.py pkg_resources: v{pkg.version}')
entries.append(f" - discord.py pkg_resources: v{pkg.version}")
entries.append(f'- aiohttp v{aiohttp.__version__}')
entries.append(f"- aiohttp v{aiohttp.__version__}")
uname = platform.uname()
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries))
entries.append("- system info: {0.system} {0.release} {0.version}".format(uname))
print("\n".join(entries))
def core(parser, args):
if args.version:
show_version()
_bot_template = """#!/usr/bin/env python3
from discord.ext import commands
@ -120,7 +131,7 @@ def setup(bot):
bot.add_cog({name}(bot))
'''
_cog_extras = '''
_cog_extras = """
def cog_unload(self):
# clean up logic goes here
pass
@ -149,22 +160,22 @@ _cog_extras = '''
# called after a command is called here
pass
'''
"""
# certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent
_base_table = {
'<': '-',
'>': '-',
':': '-',
'"': '-',
"<": "-",
">": "-",
":": "-",
'"': "-",
# '/': '-', these are fine
# '\\': '-',
'|': '-',
'?': '-',
'*': '-',
"|": "-",
"?": "-",
"*": "-",
}
# NUL (0) and 1-31 are disallowed
@ -172,21 +183,45 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path):
return name
if sys.platform == 'win32':
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9')
if sys.platform == "win32":
forbidden = (
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
)
if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one')
parser.error("invalid directory name given, use a different one")
name = name.translate(_translation_table)
if replace_spaces:
name = name.replace(' ', '-')
name = name.replace(" ", "-")
return Path(name)
def newbot(parser, args):
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
@ -195,106 +230,147 @@ def newbot(parser, args):
try:
new_directory.mkdir(exist_ok=True, parents=True)
except OSError as exc:
parser.error(f'could not create our bot directory ({exc})')
parser.error(f"could not create our bot directory ({exc})")
cogs = new_directory / 'cogs'
cogs = new_directory / "cogs"
try:
cogs.mkdir(exist_ok=True)
init = cogs / '__init__.py'
init = cogs / "__init__.py"
init.touch()
except OSError as exc:
print(f'warning: could not create cogs directory ({exc})')
print(f"warning: could not create cogs directory ({exc})")
try:
with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp:
with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp:
fp.write('token = "place your token here"\ncogs = []\n')
except OSError as exc:
parser.error(f'could not create config file ({exc})')
parser.error(f"could not create config file ({exc})")
try:
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot'
with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp:
base = "Bot" if not args.sharded else "AutoShardedBot"
fp.write(_bot_template.format(base=base, prefix=args.prefix))
except OSError as exc:
parser.error(f'could not create bot file ({exc})')
parser.error(f"could not create bot file ({exc})")
if not args.no_git:
try:
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp:
fp.write(_gitignore_template)
except OSError as exc:
print(f'warning: could not create .gitignore file ({exc})')
print(f"warning: could not create .gitignore file ({exc})")
print("successfully made bot at", new_directory)
print('successfully made bot at', new_directory)
def newcog(parser, args):
cog_dir = to_path(parser, args.directory)
try:
cog_dir.mkdir(exist_ok=True)
except OSError as exc:
print(f'warning: could not create cogs directory ({exc})')
print(f"warning: could not create cogs directory ({exc})")
directory = cog_dir / to_path(parser, args.name)
directory = directory.with_suffix('.py')
directory = directory.with_suffix(".py")
try:
with open(str(directory), 'w', encoding='utf-8') as fp:
attrs = ''
extra = _cog_extras if args.full else ''
with open(str(directory), "w", encoding="utf-8") as fp:
attrs = ""
extra = _cog_extras if args.full else ""
if args.class_name:
name = args.class_name
else:
name = str(directory.stem)
if '-' in name or '_' in name:
translation = str.maketrans('-_', ' ')
name = name.translate(translation).title().replace(' ', '')
if "-" in name or "_" in name:
translation = str.maketrans("-_", " ")
name = name.translate(translation).title().replace(" ", "")
else:
name = name.title()
if args.display_name:
attrs += f', name="{args.display_name}"'
if args.hide_commands:
attrs += ', command_attrs=dict(hidden=True)'
attrs += ", command_attrs=dict(hidden=True)"
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc:
parser.error(f'could not create cog file ({exc})')
parser.error(f"could not create cog file ({exc})")
else:
print('successfully made cog at', directory)
print("successfully made cog at", directory)
def add_newbot_args(subparser):
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser = subparser.add_parser(
"newbot", help="creates a command bot project quickly"
)
parser.set_defaults(func=newbot)
parser.add_argument('name', help='the bot project name')
parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd())
parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='<prefix>')
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true')
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
parser.add_argument("name", help="the bot project name")
parser.add_argument(
"directory",
help="the directory to place it in (default: .)",
nargs="?",
default=Path.cwd(),
)
parser.add_argument(
"--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>"
)
parser.add_argument(
"--sharded", help="whether to use AutoShardedBot", action="store_true"
)
parser.add_argument(
"--no-git",
help="do not create a .gitignore file",
action="store_true",
dest="no_git",
)
def add_newcog_args(subparser):
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser = subparser.add_parser("newcog", help="creates a new cog template quickly")
parser.set_defaults(func=newcog)
parser.add_argument('name', help='the cog name')
parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs'))
parser.add_argument('--class-name', help='the class name of the cog (default: <name>)', dest='class_name')
parser.add_argument('--display-name', help='the cog name (default: <name>)')
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true')
parser.add_argument('--full', help='add all special methods as well', action='store_true')
parser.add_argument("name", help="the cog name")
parser.add_argument(
"directory",
help="the directory to place it in (default: cogs)",
nargs="?",
default=Path("cogs"),
)
parser.add_argument(
"--class-name",
help="the class name of the cog (default: <name>)",
dest="class_name",
)
parser.add_argument("--display-name", help="the cog name (default: <name>)")
parser.add_argument(
"--hide-commands",
help="whether to hide all commands in the cog",
action="store_true",
)
parser.add_argument(
"--full", help="add all special methods as well", action="store_true"
)
def parse_args():
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser = argparse.ArgumentParser(
prog="discord", description="Tools for helping with discord.py"
)
parser.add_argument(
"-v", "--version", action="store_true", help="shows the library version"
)
parser.set_defaults(func=core)
subparser = parser.add_subparsers(dest='subcommand', title='subcommands')
subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
add_newbot_args(subparser)
add_newcog_args(subparser)
return parser, parser.parse_args()
def main():
parser, args = parse_args()
args.func(parser, args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem
from . import utils
__all__ = (
'Snowflake',
'User',
'PrivateChannel',
'GuildChannel',
'Messageable',
'Connectable',
"Snowflake",
"User",
"PrivateChannel",
"GuildChannel",
"Messageable",
"Connectable",
)
T = TypeVar('T', bound=VoiceProtocol)
T = TypeVar("T", bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
@ -89,7 +89,9 @@ if TYPE_CHECKING:
OverwriteType,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
PartialMessageableChannel = Union[
TextChannel, Thread, DMChannel, PartialMessageable
]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
@ -98,7 +100,7 @@ MISSING = utils.MISSING
class _Undefined:
def __repr__(self) -> str:
return 'see-below'
return "see-below"
_undefined: Any = _Undefined()
@ -189,23 +191,23 @@ class PrivateChannel(Snowflake, Protocol):
class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type')
__slots__ = ("id", "allow", "deny", "type")
ROLE = 0
MEMBER = 1
def __init__(self, data: PermissionOverwritePayload):
self.id: int = int(data['id'])
self.allow: int = int(data.get('allow', 0))
self.deny: int = int(data.get('deny', 0))
self.type: OverwriteType = data['type']
self.id: int = int(data["id"])
self.allow: int = int(data.get("allow", 0))
self.deny: int = int(data.get("deny", 0))
self.type: OverwriteType = data["type"]
def _asdict(self) -> PermissionOverwritePayload:
return {
'id': self.id,
'allow': str(self.allow),
'deny': str(self.deny),
'type': self.type,
"id": self.id,
"allow": str(self.allow),
"deny": str(self.deny),
"type": self.type,
}
def is_role(self) -> bool:
@ -215,7 +217,7 @@ class _Overwrites:
return self.type == 1
GCH = TypeVar('GCH', bound='GuildChannel')
GCH = TypeVar("GCH", bound="GuildChannel")
class GuildChannel:
@ -254,7 +256,9 @@ class GuildChannel:
if TYPE_CHECKING:
def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]):
def __init__(
self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]
):
...
def __str__(self) -> str:
@ -276,11 +280,13 @@ class GuildChannel:
reason: Optional[str],
) -> None:
if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.')
raise InvalidArgument("Channel position cannot be less than 0.")
http = self._state.http
bucket = self._sorting_bucket
channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]
channels: List[GuildChannel] = [
c for c in self.guild.channels if c._sorting_bucket == bucket
]
channels.sort(key=lambda c: c.position)
@ -291,106 +297,124 @@ class GuildChannel:
# not there somehow lol
return
else:
index = next((i for i, c in enumerate(channels) if c.position >= position), len(channels))
index = next(
(i for i, c in enumerate(channels) if c.position >= position),
len(channels),
)
# add ourselves at our designated position
channels.insert(index, self)
payload = []
for index, c in enumerate(channels):
d: Dict[str, Any] = {'id': c.id, 'position': index}
d: Dict[str, Any] = {"id": c.id, "position": index}
if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
async def _edit(
self, options: Dict[str, Any], reason: Optional[str]
) -> Optional[ChannelPayload]:
try:
parent = options.pop('category')
parent = options.pop("category")
except KeyError:
parent_id = _undefined
else:
parent_id = parent and parent.id
try:
options['rate_limit_per_user'] = options.pop('slowmode_delay')
options["rate_limit_per_user"] = options.pop("slowmode_delay")
except KeyError:
pass
try:
rtc_region = options.pop('rtc_region')
rtc_region = options.pop("rtc_region")
except KeyError:
pass
else:
options['rtc_region'] = None if rtc_region is None else str(rtc_region)
options["rtc_region"] = None if rtc_region is None else str(rtc_region)
try:
video_quality_mode = options.pop('video_quality_mode')
video_quality_mode = options.pop("video_quality_mode")
except KeyError:
pass
else:
options['video_quality_mode'] = int(video_quality_mode)
options["video_quality_mode"] = int(video_quality_mode)
lock_permissions = options.pop('sync_permissions', False)
lock_permissions = options.pop("sync_permissions", False)
try:
position = options.pop('position')
position = options.pop("position")
except KeyError:
if parent_id is not _undefined:
if lock_permissions:
category = self.guild.get_channel(parent_id)
if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id
options["permission_overwrites"] = [
c._asdict() for c in category._overwrites
]
options["parent_id"] = parent_id
elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id)
if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options["permission_overwrites"] = [
c._asdict() for c in category._overwrites
]
else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
await self._move(
position,
parent_id=parent_id,
lock_permissions=lock_permissions,
reason=reason,
)
overwrites = options.get('overwrites', None)
overwrites = options.get("overwrites", None)
if overwrites is not None:
perms = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}')
raise InvalidArgument(
f"Expected PermissionOverwrite received {perm.__class__.__name__}"
)
allow, deny = perm.pair()
payload = {
'allow': allow.value,
'deny': deny.value,
'id': target.id,
"allow": allow.value,
"deny": deny.value,
"id": target.id,
}
if isinstance(target, Role):
payload['type'] = _Overwrites.ROLE
payload["type"] = _Overwrites.ROLE
else:
payload['type'] = _Overwrites.MEMBER
payload["type"] = _Overwrites.MEMBER
perms.append(payload)
options['permission_overwrites'] = perms
options["permission_overwrites"] = perms
try:
ch_type = options['type']
ch_type = options["type"]
except KeyError:
pass
else:
if not isinstance(ch_type, ChannelType):
raise InvalidArgument('type field must be of type ChannelType')
options['type'] = ch_type.value
raise InvalidArgument("type field must be of type ChannelType")
options["type"] = ch_type.value
if options:
return await self._state.http.edit_channel(self.id, reason=reason, **options)
return await self._state.http.edit_channel(
self.id, reason=reason, **options
)
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
self._overwrites = []
everyone_index = 0
everyone_id = self.guild.id
for index, overridden in enumerate(data.get('permission_overwrites', [])):
for index, overridden in enumerate(data.get("permission_overwrites", [])):
overwrite = _Overwrites(overridden)
self._overwrites.append(overwrite)
@ -429,7 +453,7 @@ class GuildChannel:
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime:
@ -589,7 +613,9 @@ class GuildChannel:
try:
maybe_everyone = self._overwrites[0]
if maybe_everyone.id == self.guild.id:
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
base.handle_overwrite(
allow=maybe_everyone.allow, deny=maybe_everyone.deny
)
except IndexError:
pass
@ -620,7 +646,9 @@ class GuildChannel:
try:
maybe_everyone = self._overwrites[0]
if maybe_everyone.id == self.guild.id:
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
base.handle_overwrite(
allow=maybe_everyone.allow, deny=maybe_everyone.deny
)
remaining_overwrites = self._overwrites[1:]
else:
remaining_overwrites = self._overwrites
@ -703,7 +731,9 @@ class GuildChannel:
) -> None:
...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
async def set_permissions(
self, target, *, overwrite=_undefined, reason=None, **permissions
):
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@ -779,18 +809,18 @@ class GuildChannel:
elif isinstance(target, Role):
perm_type = _Overwrites.ROLE
else:
raise InvalidArgument('target parameter must be either Member or Role')
raise InvalidArgument("target parameter must be either Member or Role")
if overwrite is _undefined:
if len(permissions) == 0:
raise InvalidArgument('No overwrite provided.')
raise InvalidArgument("No overwrite provided.")
try:
overwrite = PermissionOverwrite(**permissions)
except (ValueError, TypeError):
raise InvalidArgument('Invalid permissions given to keyword arguments.')
raise InvalidArgument("Invalid permissions given to keyword arguments.")
else:
if len(permissions) > 0:
raise InvalidArgument('Cannot mix overwrite and keyword arguments.')
raise InvalidArgument("Cannot mix overwrite and keyword arguments.")
# TODO: wait for event
@ -798,9 +828,11 @@ class GuildChannel:
await http.delete_channel_permissions(self.id, target.id, reason=reason)
elif isinstance(overwrite, PermissionOverwrite):
(allow, deny) = overwrite.pair()
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
await http.edit_channel_permissions(
self.id, target.id, allow.value, deny.value, perm_type, reason=reason
)
else:
raise InvalidArgument('Invalid overwrite type provided.')
raise InvalidArgument("Invalid overwrite type provided.")
async def _clone_impl(
self: GCH,
@ -809,19 +841,23 @@ class GuildChannel:
name: Optional[str] = None,
reason: Optional[str] = None,
) -> GCH:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name
base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites]
base_attrs["parent_id"] = self.category_id
base_attrs["name"] = name or self.name
guild_id = self.guild.id
cls = self.__class__
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
data = await self._state.http.create_channel(
guild_id, self.type.value, reason=reason, **base_attrs
)
obj = cls(state=self._state, guild=self.guild, data=data)
# temporarily add it to the cache
self.guild._channels[obj.id] = obj # type: ignore
return obj
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
async def clone(
self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None
) -> GCH:
"""|coro|
Clones this channel. This creates a channel with the same properties
@ -964,14 +1000,16 @@ class GuildChannel:
if not kwargs:
return
beginning, end = kwargs.get('beginning'), kwargs.get('end')
before, after = kwargs.get('before'), kwargs.get('after')
offset = kwargs.get('offset', 0)
beginning, end = kwargs.get("beginning"), kwargs.get("end")
before, after = kwargs.get("before"), kwargs.get("after")
offset = kwargs.get("offset", 0)
if sum(bool(a) for a in (beginning, end, before, after)) > 1:
raise InvalidArgument('Only one of [before, after, end, beginning] can be used.')
raise InvalidArgument(
"Only one of [before, after, end, beginning] can be used."
)
bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING)
parent_id = kwargs.get("category", MISSING)
# fmt: off
channels: List[GuildChannel]
if parent_id not in (MISSING, None):
@ -1008,22 +1046,26 @@ class GuildChannel:
elif before:
index = next((i for i, c in enumerate(channels) if c.id == before.id), None)
elif after:
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)
index = next(
(i + 1 for i, c in enumerate(channels) if c.id == after.id), None
)
if index is None:
raise InvalidArgument('Could not resolve appropriate move position')
raise InvalidArgument("Could not resolve appropriate move position")
channels.insert(max((index + offset), 0), self)
payload = []
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
lock_permissions = kwargs.get("sync_permissions", False)
reason = kwargs.get("reason")
for index, channel in enumerate(channels):
d = {'id': channel.id, 'position': index}
d = {"id": channel.id, "position": index}
if parent_id is not MISSING and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
await self._state.http.bulk_channel_update(
self.guild.id, payload, reason=reason
)
async def create_invite(
self,
@ -1126,7 +1168,10 @@ class GuildChannel:
state = self._state
data = await state.http.invites_from_channel(self.id)
guild = self.guild
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]
return [
Invite(state=state, data=invite, channel=self, guild=guild)
for invite in data
]
class Messageable:
@ -1332,14 +1377,18 @@ class Messageable:
content = str(content) if content is not None else None
if embed is not None and embeds is not None:
raise InvalidArgument('cannot pass both embed and embeds parameter to send()')
raise InvalidArgument(
"cannot pass both embed and embeds parameter to send()"
)
if embed is not None:
embed = embed.to_dict()
elif embeds is not None:
if len(embeds) > 10:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
raise InvalidArgument(
"embeds parameter must be a list of up to 10 elements"
)
embeds = [embed.to_dict() for embed in embeds]
if stickers is not None:
@ -1347,36 +1396,44 @@ class Messageable:
if allowed_mentions is not None:
if state.allowed_mentions is not None:
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
allowed_mentions = state.allowed_mentions.merge(
allowed_mentions
).to_dict()
else:
allowed_mentions = allowed_mentions.to_dict()
else:
allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict()
allowed_mentions = (
state.allowed_mentions and state.allowed_mentions.to_dict()
)
if mention_author is not None:
allowed_mentions = allowed_mentions or AllowedMentions().to_dict()
allowed_mentions['replied_user'] = bool(mention_author)
allowed_mentions["replied_user"] = bool(mention_author)
if reference is not None:
try:
reference = reference.to_message_reference_dict()
except AttributeError:
raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None
raise InvalidArgument(
"reference parameter must be Message, MessageReference, or PartialMessage"
) from None
if view:
if not hasattr(view, '__discord_ui_view__'):
raise InvalidArgument(f'view parameter must be View not {view.__class__!r}')
if not hasattr(view, "__discord_ui_view__"):
raise InvalidArgument(
f"view parameter must be View not {view.__class__!r}"
)
components = view.to_components()
else:
components = None
if file is not None and files is not None:
raise InvalidArgument('cannot pass both file and files parameter to send()')
raise InvalidArgument("cannot pass both file and files parameter to send()")
if file is not None:
if not isinstance(file, File):
raise InvalidArgument('file parameter must be File')
raise InvalidArgument("file parameter must be File")
try:
data = await state.http.send_files(
@ -1397,9 +1454,11 @@ class Messageable:
elif files is not None:
if len(files) > 10:
raise InvalidArgument('files parameter must be a list of up to 10 elements')
raise InvalidArgument(
"files parameter must be a list of up to 10 elements"
)
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument('files parameter must be a list of File')
raise InvalidArgument("files parameter must be a list of File")
try:
data = await state.http.send_files(
@ -1594,7 +1653,14 @@ class Messageable:
:class:`~discord.Message`
The message with the message data parsed.
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
return HistoryIterator(
self,
limit=limit,
before=before,
after=after,
around=around,
oldest_first=oldest_first,
)
class Connectable(Protocol):
@ -1666,13 +1732,13 @@ class Connectable(Protocol):
state = self._state
if state._get_voice_client(key_id):
raise ClientException('Already connected to a voice channel.')
raise ClientException("Already connected to a voice channel.")
client = state._get_client()
voice = cls(client, self)
if not isinstance(voice, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class.')
raise TypeError("Type must meet VoiceProtocol abstract base class.")
state._add_voice_client(key_id, voice)

View File

@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji
from .utils import _get_as_snowflake
__all__ = (
'BaseActivity',
'Activity',
'Streaming',
'Game',
'Spotify',
'CustomActivity',
"BaseActivity",
"Activity",
"Streaming",
"Game",
"Spotify",
"CustomActivity",
)
"""If curious, this is the current schema for an activity.
@ -119,10 +119,10 @@ class BaseActivity:
.. versionadded:: 1.3
"""
__slots__ = ('_created_at',)
__slots__ = ("_created_at",)
def __init__(self, **kwargs):
self._created_at: Optional[float] = kwargs.pop('created_at', None)
self._created_at: Optional[float] = kwargs.pop("created_at", None)
@property
def created_at(self) -> Optional[datetime.datetime]:
@ -131,7 +131,9 @@ class BaseActivity:
.. versionadded:: 1.3
"""
if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._created_at / 1000, tz=datetime.timezone.utc
)
def to_dict(self) -> ActivityPayload:
raise NotImplementedError
@ -199,58 +201,62 @@ class Activity(BaseActivity):
"""
__slots__ = (
'state',
'details',
'_created_at',
'timestamps',
'assets',
'party',
'flags',
'sync_id',
'session_id',
'type',
'name',
'url',
'application_id',
'emoji',
'buttons',
"state",
"details",
"_created_at",
"timestamps",
"assets",
"party",
"flags",
"sync_id",
"session_id",
"type",
"name",
"url",
"application_id",
"emoji",
"buttons",
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {})
self.assets: ActivityAssets = kwargs.pop('assets', {})
self.party: ActivityParty = kwargs.pop('party', {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id')
self.name: Optional[str] = kwargs.pop('name', None)
self.url: Optional[str] = kwargs.pop('url', None)
self.flags: int = kwargs.pop('flags', 0)
self.sync_id: Optional[str] = kwargs.pop('sync_id', None)
self.session_id: Optional[str] = kwargs.pop('session_id', None)
self.buttons: List[ActivityButton] = kwargs.pop('buttons', [])
self.state: Optional[str] = kwargs.pop("state", None)
self.details: Optional[str] = kwargs.pop("details", None)
self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {})
self.assets: ActivityAssets = kwargs.pop("assets", {})
self.party: ActivityParty = kwargs.pop("party", {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id")
self.name: Optional[str] = kwargs.pop("name", None)
self.url: Optional[str] = kwargs.pop("url", None)
self.flags: int = kwargs.pop("flags", 0)
self.sync_id: Optional[str] = kwargs.pop("sync_id", None)
self.session_id: Optional[str] = kwargs.pop("session_id", None)
self.buttons: List[ActivityButton] = kwargs.pop("buttons", [])
activity_type = kwargs.pop('type', -1)
activity_type = kwargs.pop("type", -1)
self.type: ActivityType = (
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
activity_type
if isinstance(activity_type, ActivityType)
else try_enum(ActivityType, activity_type)
)
emoji = kwargs.pop('emoji', None)
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
emoji = kwargs.pop("emoji", None)
self.emoji: Optional[PartialEmoji] = (
PartialEmoji.from_dict(emoji) if emoji is not None else None
)
def __repr__(self) -> str:
attrs = (
('type', self.type),
('name', self.name),
('url', self.url),
('details', self.details),
('application_id', self.application_id),
('session_id', self.session_id),
('emoji', self.emoji),
("type", self.type),
("name", self.name),
("url", self.url),
("details", self.details),
("application_id", self.application_id),
("session_id", self.session_id),
("emoji", self.emoji),
)
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Activity {inner}>'
inner = " ".join("%s=%r" % t for t in attrs)
return f"<Activity {inner}>"
def to_dict(self) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
@ -263,16 +269,16 @@ class Activity(BaseActivity):
continue
ret[attr] = value
ret['type'] = int(self.type)
ret["type"] = int(self.type)
if self.emoji:
ret['emoji'] = self.emoji.to_dict()
ret["emoji"] = self.emoji.to_dict()
return ret
@property
def start(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['start'] / 1000
timestamp = self.timestamps["start"] / 1000
except KeyError:
return None
else:
@ -282,7 +288,7 @@ class Activity(BaseActivity):
def end(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['end'] / 1000
timestamp = self.timestamps["end"] / 1000
except KeyError:
return None
else:
@ -295,11 +301,11 @@ class Activity(BaseActivity):
return None
try:
large_image = self.assets['large_image']
large_image = self.assets["large_image"]
except KeyError:
return None
else:
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png'
return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png"
@property
def small_image_url(self) -> Optional[str]:
@ -308,21 +314,21 @@ class Activity(BaseActivity):
return None
try:
small_image = self.assets['small_image']
small_image = self.assets["small_image"]
except KeyError:
return None
else:
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png'
return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png"
@property
def large_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
return self.assets.get('large_text', None)
return self.assets.get("large_text", None)
@property
def small_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
return self.assets.get('small_text', None)
return self.assets.get("small_text", None)
class Game(BaseActivity):
@ -359,20 +365,20 @@ class Game(BaseActivity):
The game's name.
"""
__slots__ = ('name', '_end', '_start')
__slots__ = ("name", "_end", "_start")
def __init__(self, name: str, **extra):
super().__init__(**extra)
self.name: str = name
try:
timestamps: ActivityTimestamps = extra['timestamps']
timestamps: ActivityTimestamps = extra["timestamps"]
except KeyError:
self._start = 0
self._end = 0
else:
self._start = timestamps.get('start', 0)
self._end = timestamps.get('end', 0)
self._start = timestamps.get("start", 0)
self._end = timestamps.get("end", 0)
@property
def type(self) -> ActivityType:
@ -386,29 +392,33 @@ class Game(BaseActivity):
def start(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable."""
if self._start:
return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._start / 1000, tz=datetime.timezone.utc
)
return None
@property
def end(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable."""
if self._end:
return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._end / 1000, tz=datetime.timezone.utc
)
return None
def __str__(self) -> str:
return str(self.name)
def __repr__(self) -> str:
return f'<Game name={self.name!r}>'
return f"<Game name={self.name!r}>"
def to_dict(self) -> Dict[str, Any]:
timestamps: Dict[str, Any] = {}
if self._start:
timestamps['start'] = self._start
timestamps["start"] = self._start
if self._end:
timestamps['end'] = self._end
timestamps["end"] = self._end
# fmt: off
return {
@ -473,16 +483,16 @@ class Streaming(BaseActivity):
A dictionary comprising of similar keys than those in :attr:`Activity.assets`.
"""
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
__slots__ = ("platform", "name", "game", "url", "details", "assets")
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
super().__init__(**extra)
self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name)
self.game: Optional[str] = extra.pop('state', None)
self.name: Optional[str] = extra.pop("details", name)
self.game: Optional[str] = extra.pop("state", None)
self.url: str = url
self.details: Optional[str] = extra.pop('details', self.name) # compatibility
self.assets: ActivityAssets = extra.pop('assets', {})
self.details: Optional[str] = extra.pop("details", self.name) # compatibility
self.assets: ActivityAssets = extra.pop("assets", {})
@property
def type(self) -> ActivityType:
@ -496,7 +506,7 @@ class Streaming(BaseActivity):
return str(self.name)
def __repr__(self) -> str:
return f'<Streaming name={self.name!r}>'
return f"<Streaming name={self.name!r}>"
@property
def twitch_name(self):
@ -507,11 +517,11 @@ class Streaming(BaseActivity):
"""
try:
name = self.assets['large_image']
name = self.assets["large_image"]
except KeyError:
return None
else:
return name[7:] if name[:7] == 'twitch:' else None
return name[7:] if name[:7] == "twitch:" else None
def to_dict(self) -> Dict[str, Any]:
# fmt: off
@ -523,11 +533,15 @@ class Streaming(BaseActivity):
}
# fmt: on
if self.details:
ret['details'] = self.details
ret["details"] = self.details
return ret
def __eq__(self, other: Any) -> bool:
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
return (
isinstance(other, Streaming)
and other.name == self.name
and other.url == self.url
)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
@ -559,17 +573,26 @@ class Spotify:
Returns the string 'Spotify'.
"""
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
__slots__ = (
"_state",
"_details",
"_timestamps",
"_assets",
"_party",
"_sync_id",
"_session_id",
"_created_at",
)
def __init__(self, **data):
self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id')
self._session_id: str = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None)
self._state: str = data.pop("state", "")
self._details: str = data.pop("details", "")
self._timestamps: Dict[str, int] = data.pop("timestamps", {})
self._assets: ActivityAssets = data.pop("assets", {})
self._party: ActivityParty = data.pop("party", {})
self._sync_id: str = data.pop("sync_id")
self._session_id: str = data.pop("session_id")
self._created_at: Optional[float] = data.pop("created_at", None)
@property
def type(self) -> ActivityType:
@ -586,7 +609,9 @@ class Spotify:
.. versionadded:: 1.3
"""
if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._created_at / 1000, tz=datetime.timezone.utc
)
@property
def colour(self) -> Colour:
@ -604,21 +629,21 @@ class Spotify:
def to_dict(self) -> Dict[str, Any]:
return {
'flags': 48, # SYNC | PLAY
'name': 'Spotify',
'assets': self._assets,
'party': self._party,
'sync_id': self._sync_id,
'session_id': self._session_id,
'timestamps': self._timestamps,
'details': self._details,
'state': self._state,
"flags": 48, # SYNC | PLAY
"name": "Spotify",
"assets": self._assets,
"party": self._party,
"sync_id": self._sync_id,
"session_id": self._session_id,
"timestamps": self._timestamps,
"details": self._details,
"state": self._state,
}
@property
def name(self) -> str:
""":class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify'
return "Spotify"
def __eq__(self, other: Any) -> bool:
return (
@ -635,10 +660,10 @@ class Spotify:
return hash(self._session_id)
def __str__(self) -> str:
return 'Spotify'
return "Spotify"
def __repr__(self) -> str:
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>'
return f"<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>"
@property
def title(self) -> str:
@ -648,7 +673,7 @@ class Spotify:
@property
def artists(self) -> List[str]:
"""List[:class:`str`]: The artists of the song being played."""
return self._state.split('; ')
return self._state.split("; ")
@property
def artist(self) -> str:
@ -662,16 +687,16 @@ class Spotify:
@property
def album(self) -> str:
""":class:`str`: The album that the song being played belongs to."""
return self._assets.get('large_text', '')
return self._assets.get("large_text", "")
@property
def album_cover_url(self) -> str:
""":class:`str`: The album cover image URL from Spotify's CDN."""
large_image = self._assets.get('large_image', '')
if large_image[:8] != 'spotify:':
return ''
large_image = self._assets.get("large_image", "")
if large_image[:8] != "spotify:":
return ""
album_image_id = large_image[8:]
return 'https://i.scdn.co/image/' + album_image_id
return "https://i.scdn.co/image/" + album_image_id
@property
def track_id(self) -> str:
@ -684,17 +709,21 @@ class Spotify:
.. versionadded:: 2.0
"""
return f'https://open.spotify.com/track/{self.track_id}'
return f"https://open.spotify.com/track/{self.track_id}"
@property
def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._timestamps["start"] / 1000, tz=datetime.timezone.utc
)
@property
def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(
self._timestamps["end"] / 1000, tz=datetime.timezone.utc
)
@property
def duration(self) -> datetime.timedelta:
@ -704,7 +733,7 @@ class Spotify:
@property
def party_id(self) -> str:
""":class:`str`: The party ID of the listening party."""
return self._party.get('id', '')
return self._party.get("id", "")
class CustomActivity(BaseActivity):
@ -738,13 +767,15 @@ class CustomActivity(BaseActivity):
The emoji to pass to the activity, if any.
"""
__slots__ = ('name', 'emoji', 'state')
__slots__ = ("name", "emoji", "state")
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
def __init__(
self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any
):
super().__init__(**extra)
self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None)
if self.name == 'Custom Status':
self.state: Optional[str] = extra.pop("state", None)
if self.name == "Custom Status":
self.name = self.state
self.emoji: Optional[PartialEmoji]
@ -757,7 +788,9 @@ class CustomActivity(BaseActivity):
elif isinstance(emoji, PartialEmoji):
self.emoji = emoji
else:
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.')
raise TypeError(
f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead."
)
@property
def type(self) -> ActivityType:
@ -770,22 +803,26 @@ class CustomActivity(BaseActivity):
def to_dict(self) -> Dict[str, Any]:
if self.name == self.state:
o = {
'type': ActivityType.custom.value,
'state': self.name,
'name': 'Custom Status',
"type": ActivityType.custom.value,
"state": self.name,
"name": "Custom Status",
}
else:
o = {
'type': ActivityType.custom.value,
'name': self.name,
"type": ActivityType.custom.value,
"name": self.name,
}
if self.emoji:
o['emoji'] = self.emoji.to_dict()
o["emoji"] = self.emoji.to_dict()
return o
def __eq__(self, other: Any) -> bool:
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
return (
isinstance(other, CustomActivity)
and other.name == self.name
and other.emoji == self.emoji
)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
@ -796,47 +833,54 @@ class CustomActivity(BaseActivity):
def __str__(self) -> str:
if self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return f"{self.emoji} {self.name}"
return str(self.emoji)
else:
return str(self.name)
def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
return f"<CustomActivity name={self.name!r} emoji={self.emoji!r}>"
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data:
return None
game_type = try_enum(ActivityType, data.get('type', -1))
game_type = try_enum(ActivityType, data.get("type", -1))
if game_type is ActivityType.playing:
if 'application_id' in data or 'session_id' in data:
if "application_id" in data or "session_id" in data:
return Activity(**data)
return Game(**data)
elif game_type is ActivityType.custom:
try:
name = data.pop('name')
name = data.pop("name")
except KeyError:
return Activity(**data)
else:
# we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore
return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming:
if 'url' in data:
if "url" in data:
# the url won't be None here
return Streaming(**data) # type: ignore
return Streaming(**data) # type: ignore
return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
elif (
game_type is ActivityType.listening
and "sync_id" in data
and "session_id" in data
):
return Spotify(**data)
return Activity(**data)

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
from .state import ConnectionState
__all__ = (
'AppInfo',
'PartialAppInfo',
"AppInfo",
"PartialAppInfo",
)
@ -115,58 +115,60 @@ class AppInfo:
"""
__slots__ = (
'_state',
'description',
'id',
'name',
'rpc_origins',
'bot_public',
'bot_require_code_grant',
'owner',
'_icon',
'summary',
'verify_key',
'team',
'guild_id',
'primary_sku_id',
'slug',
'_cover_image',
'terms_of_service_url',
'privacy_policy_url',
"_state",
"description",
"id",
"name",
"rpc_origins",
"bot_public",
"bot_require_code_grant",
"owner",
"_icon",
"summary",
"verify_key",
"team",
"guild_id",
"primary_sku_id",
"slug",
"_cover_image",
"terms_of_service_url",
"privacy_policy_url",
)
def __init__(self, state: ConnectionState, data: AppInfoPayload):
from .team import Team
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self._icon: Optional[str] = data['icon']
self.rpc_origins: List[str] = data['rpc_origins']
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.create_user(data['owner'])
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self._icon: Optional[str] = data["icon"]
self.rpc_origins: List[str] = data["rpc_origins"]
self.bot_public: bool = data["bot_public"]
self.bot_require_code_grant: bool = data["bot_require_code_grant"]
self.owner: User = state.create_user(data["owner"])
team: Optional[TeamPayload] = data.get('team')
team: Optional[TeamPayload] = data.get("team")
self.team: Optional[Team] = Team(state, team) if team else None
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id')
self.slug: Optional[str] = data.get('slug')
self._cover_image: Optional[str] = data.get('cover_image')
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(
data, "primary_sku_id"
)
self.slug: Optional[str] = data.get("slug")
self._cover_image: Optional[str] = data.get("cover_image")
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'description={self.description!r} public={self.bot_public} '
f'owner={self.owner!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"description={self.description!r} public={self.bot_public} "
f"owner={self.owner!r}>"
)
@property
@ -174,7 +176,7 @@ class AppInfo:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
return Asset._from_icon(self._state, self.id, self._icon, path="app")
@property
def cover_image(self) -> Optional[Asset]:
@ -195,6 +197,7 @@ class AppInfo:
"""
return self._state._get_guild(self.guild_id)
class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
@ -222,26 +225,37 @@ class PartialAppInfo:
The application's privacy policy URL, if set.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon')
__slots__ = (
"_state",
"id",
"name",
"description",
"rpc_origins",
"summary",
"verify_key",
"terms_of_service_url",
"privacy_policy_url",
"_icon",
)
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data.get('icon')
self.description: str = data['description']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data.get("icon")
self.description: str = data["description"]
self.rpc_origins: Optional[List[str]] = data.get("rpc_origins")
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>"
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
return Asset._from_icon(self._state, self.id, self._icon, path="app")

View File

@ -33,13 +33,11 @@ from . import utils
import yarl
__all__ = (
'Asset',
)
__all__ = ("Asset",)
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
MISSING = utils.MISSING
class AssetMixin:
url: str
_state: Optional[Any]
@ -71,11 +70,16 @@ class AssetMixin:
The content of the asset.
"""
if self._state is None:
raise DiscordException('Invalid state (no ConnectionState provided)')
raise DiscordException("Invalid state (no ConnectionState provided)")
return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
async def save(
self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
*,
seek_begin: bool = True,
) -> int:
"""|coro|
Saves this asset into a file-like object.
@ -112,7 +116,7 @@ class AssetMixin:
fp.seek(0)
return written
else:
with open(fp, 'wb') as f:
with open(fp, "wb") as f:
return f.write(data)
@ -143,13 +147,13 @@ class Asset(AssetMixin):
"""
__slots__: Tuple[str, ...] = (
'_state',
'_url',
'_animated',
'_key',
"_state",
"_url",
"_animated",
"_key",
)
BASE = 'https://cdn.discordapp.com'
BASE = "https://cdn.discordapp.com"
def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
@ -161,26 +165,28 @@ class Asset(AssetMixin):
def _from_default_avatar(cls, state, index: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/embed/avatars/{index}.png',
url=f"{cls.BASE}/embed/avatars/{index}.png",
key=str(index),
animated=False,
)
@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024',
url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024",
key=avatar,
animated=animated,
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
def _from_guild_avatar(
cls, state, guild_id: int, member_id: int, avatar: str
) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
@ -192,7 +198,7 @@ class Asset(AssetMixin):
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
key=icon_hash,
animated=False,
)
@ -201,7 +207,7 @@ class Asset(AssetMixin):
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
key=cover_image_hash,
animated=False,
)
@ -210,18 +216,18 @@ class Asset(AssetMixin):
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024',
url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
key=image,
animated=False,
)
@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png'
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024',
url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024",
key=icon_hash,
animated=animated,
)
@ -230,20 +236,20 @@ class Asset(AssetMixin):
def _from_sticker_banner(cls, state, banner: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
key=str(banner),
animated=False,
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
animated = banner_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512",
key=banner_hash,
animated=animated
animated=animated,
)
def __str__(self) -> str:
@ -253,8 +259,8 @@ class Asset(AssetMixin):
return len(self._url)
def __repr__(self):
shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>'
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"
def __eq__(self, other):
return isinstance(other, Asset) and self._url == other._url
@ -312,21 +318,27 @@ class Asset(AssetMixin):
if format is not MISSING:
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
url = url.with_path(f'{path}.{format}')
raise InvalidArgument(
f"format must be one of {VALID_ASSET_FORMATS}"
)
url = url.with_path(f"{path}.{format}")
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')
raise InvalidArgument(
f"format must be one of {VALID_STATIC_FORMATS}"
)
url = url.with_path(f"{path}.{format}")
if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{static_format}')
raise InvalidArgument(
f"static_format must be one of {VALID_STATIC_FORMATS}"
)
url = url.with_path(f"{path}.{static_format}")
if size is not MISSING:
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = url.with_query(size=size)
else:
url = url.with_query(url.raw_query_string)
@ -353,7 +365,7 @@ class Asset(AssetMixin):
The new updated asset.
"""
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
@ -379,14 +391,14 @@ class Asset(AssetMixin):
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
else:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:

View File

@ -24,7 +24,20 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from . import enums, utils
from .asset import Asset
@ -35,9 +48,9 @@ from .object import Object
from .permissions import PermissionOverwrite, Permissions
__all__ = (
'AuditLogDiff',
'AuditLogChanges',
'AuditLogEntry',
"AuditLogDiff",
"AuditLogChanges",
"AuditLogEntry",
)
@ -74,18 +87,25 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
return int(data)
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
def _transform_channel(
entry: AuditLogEntry, data: Optional[Snowflake]
) -> Optional[Union[abc.GuildChannel, Object]]:
if data is None:
return None
return entry.guild.get_channel(int(data)) or Object(id=data)
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
def _transform_member_id(
entry: AuditLogEntry, data: Optional[Snowflake]
) -> Union[Member, User, None]:
if data is None:
return None
return entry._get_member(int(data))
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
def _transform_guild_id(
entry: AuditLogEntry, data: Optional[Snowflake]
) -> Optional[Guild]:
if data is None:
return None
return entry._state._get_guild(data)
@ -96,16 +116,16 @@ def _transform_overwrites(
) -> List[Tuple[Object, PermissionOverwrite]]:
overwrites = []
for elem in data:
allow = Permissions(int(elem['allow']))
deny = Permissions(int(elem['deny']))
allow = Permissions(int(elem["allow"]))
deny = Permissions(int(elem["deny"]))
ow = PermissionOverwrite.from_pair(allow, deny)
ow_type = elem['type']
ow_id = int(elem['id'])
ow_type = elem["type"]
ow_id = int(elem["id"])
target = None
if ow_type == '0':
if ow_type == "0":
target = entry.guild.get_role(ow_id)
elif ow_type == '1':
elif ow_type == "1":
target = entry._get_member(ow_id)
if target is None:
@ -128,7 +148,9 @@ def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Ass
return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore
def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]:
def _guild_hash_transformer(
path: str,
) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]:
def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]:
if data is None:
return None
@ -137,7 +159,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
return _transform
T = TypeVar('T', bound=enums.Enum)
T = TypeVar("T", bound=enums.Enum)
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
@ -146,12 +168,16 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'):
def _transform_type(
entry: AuditLogEntry, data: Union[int]
) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith("sticker_"):
return enums.try_enum(enums.StickerType, data)
else:
return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff:
def __len__(self) -> int:
return len(self.__dict__)
@ -160,8 +186,8 @@ class AuditLogDiff:
yield from self.__dict__.items()
def __repr__(self) -> str:
values = ' '.join('%s=%r' % item for item in self.__dict__.items())
return f'<AuditLogDiff {values}>'
values = " ".join("%s=%r" % item for item in self.__dict__.items())
return f"<AuditLogDiff {values}>"
if TYPE_CHECKING:
@ -217,14 +243,14 @@ class AuditLogChanges:
self.after = AuditLogDiff()
for elem in data:
attr = elem['key']
attr = elem["key"]
# special cases for role add/remove
if attr == '$add':
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore
if attr == "$add":
self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore
continue
elif attr == '$remove':
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore
elif attr == "$remove":
self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore
continue
try:
@ -238,7 +264,7 @@ class AuditLogChanges:
transformer: Optional[Transformer]
try:
before = elem['old_value']
before = elem["old_value"]
except KeyError:
before = None
else:
@ -248,7 +274,7 @@ class AuditLogChanges:
setattr(self.before, attr, before)
try:
after = elem['new_value']
after = elem["new_value"]
except KeyError:
after = None
else:
@ -258,34 +284,40 @@ class AuditLogChanges:
setattr(self.after, attr, after)
# add an alias
if hasattr(self.after, 'colour'):
if hasattr(self.after, "colour"):
self.after.color = self.after.colour
self.before.color = self.before.colour
if hasattr(self.after, 'expire_behavior'):
if hasattr(self.after, "expire_behavior"):
self.after.expire_behaviour = self.after.expire_behavior
self.before.expire_behaviour = self.before.expire_behavior
def __repr__(self) -> str:
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>'
return f"<AuditLogChanges before={self.before!r} after={self.after!r}>"
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None:
if not hasattr(first, 'roles'):
setattr(first, 'roles', [])
def _handle_role(
self,
first: AuditLogDiff,
second: AuditLogDiff,
entry: AuditLogEntry,
elem: List[RolePayload],
) -> None:
if not hasattr(first, "roles"):
setattr(first, "roles", [])
data = []
g: Guild = entry.guild # type: ignore
for e in elem:
role_id = int(e['id'])
role_id = int(e["id"])
role = g.get_role(role_id)
if role is None:
role = Object(id=role_id)
role.name = e['name'] # type: ignore
role.name = e["name"] # type: ignore
data.append(role)
setattr(second, 'roles', data)
setattr(second, "roles", data)
class _AuditLogProxyMemberPrune:
@ -358,63 +390,81 @@ class AuditLogEntry(Hashable):
which actions have this field filled out.
"""
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
def __init__(
self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild
):
self._state = guild._state
self.guild = guild
self._users = users
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id'])
self.action = enums.try_enum(enums.AuditLogAction, data["action_type"])
self.id = int(data["id"])
# this key is technically not usually present
self.reason = data.get('reason')
self.extra = data.get('options')
self.reason = data.get("reason")
self.extra = data.get("options")
if isinstance(self.action, enums.AuditLogAction) and self.extra:
if self.action is enums.AuditLogAction.member_prune:
# member prune has two keys with useful information
self.extra: _AuditLogProxyMemberPrune = type(
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()}
"_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()}
)()
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
channel_id = int(self.extra['channel_id'])
elif (
self.action is enums.AuditLogAction.member_move
or self.action is enums.AuditLogAction.message_delete
):
channel_id = int(self.extra["channel_id"])
elems = {
'count': int(self.extra['count']),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
"count": int(self.extra["count"]),
"channel": self.guild.get_channel(channel_id)
or Object(id=channel_id),
}
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)()
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type(
"_AuditLogProxy", (), elems
)()
elif self.action is enums.AuditLogAction.member_disconnect:
# The member disconnect action has a dict with some information
elems = {
'count': int(self.extra['count']),
"count": int(self.extra["count"]),
}
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)()
elif self.action.name.endswith('pin'):
self.extra: _AuditLogProxyMemberDisconnect = type(
"_AuditLogProxy", (), elems
)()
elif self.action.name.endswith("pin"):
# the pin actions have a dict with some information
channel_id = int(self.extra['channel_id'])
channel_id = int(self.extra["channel_id"])
elems = {
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
'message_id': int(self.extra['message_id']),
"channel": self.guild.get_channel(channel_id)
or Object(id=channel_id),
"message_id": int(self.extra["message_id"]),
}
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith('overwrite_'):
self.extra: _AuditLogProxyPinAction = type(
"_AuditLogProxy", (), elems
)()
elif self.action.name.startswith("overwrite_"):
# the overwrite_ actions have a dict with some information
instance_id = int(self.extra['id'])
the_type = self.extra.get('type')
if the_type == '1':
instance_id = int(self.extra["id"])
the_type = self.extra.get("type")
if the_type == "1":
self.extra = self._get_member(instance_id)
elif the_type == '0':
elif the_type == "0":
role = self.guild.get_role(instance_id)
if role is None:
role = Object(id=instance_id)
role.name = self.extra.get('role_name') # type: ignore
role.name = self.extra.get("role_name") # type: ignore
self.extra: Role = role
elif self.action.name.startswith('stage_instance'):
channel_id = int(self.extra['channel_id'])
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith("stage_instance"):
channel_id = int(self.extra["channel_id"])
elems = {
"channel": self.guild.get_channel(channel_id)
or Object(id=channel_id)
}
self.extra: _AuditLogProxyStageInstanceAction = type(
"_AuditLogProxy", (), elems
)()
# fmt: off
self.extra: Union[
@ -433,16 +483,16 @@ class AuditLogEntry(Hashable):
# where new_value and old_value are not guaranteed to be there depending
# on the action type, so let's just fetch it for now and only turn it
# into meaningful data when requested
self._changes = data.get('changes', [])
self._changes = data.get("changes", [])
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore
self._target_id = utils._get_as_snowflake(data, 'target_id')
self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
self._target_id = utils._get_as_snowflake(data, "target_id")
def _get_member(self, user_id: int) -> Union[Member, User, None]:
return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str:
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>'
return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
@utils.cached_property
def created_at(self) -> datetime.datetime:
@ -450,9 +500,24 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id)
@utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]:
def target(
self,
) -> Union[
Guild,
abc.GuildChannel,
Member,
User,
Role,
Invite,
Emoji,
StageInstance,
GuildSticker,
Thread,
Object,
None,
]:
try:
converter = getattr(self, '_convert_target_' + self.action.target_type)
converter = getattr(self, "_convert_target_" + self.action.target_type)
except AttributeError:
return Object(id=self._target_id)
else:
@ -483,7 +548,9 @@ class AuditLogEntry(Hashable):
def _convert_target_guild(self, target_id: int) -> Guild:
return self.guild
def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]:
def _convert_target_channel(
self, target_id: int
) -> Union[abc.GuildChannel, Object]:
return self.guild.get_channel(target_id) or Object(id=target_id)
def _convert_target_user(self, target_id: int) -> Union[Member, User, None]:
@ -495,14 +562,18 @@ class AuditLogEntry(Hashable):
def _convert_target_invite(self, target_id: int) -> Invite:
# invites have target_id set to null
# so figure out which change has the full invite data
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
changeset = (
self.before
if self.action is enums.AuditLogAction.invite_delete
else self.after
)
fake_payload = {
'max_age': changeset.max_age,
'max_uses': changeset.max_uses,
'code': changeset.code,
'temporary': changeset.temporary,
'uses': changeset.uses,
"max_age": changeset.max_age,
"max_uses": changeset.max_uses,
"code": changeset.code,
"temporary": changeset.temporary,
"uses": changeset.uses,
}
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore
@ -518,7 +589,9 @@ class AuditLogEntry(Hashable):
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
def _convert_target_stage_instance(
self, target_id: int
) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:

View File

@ -29,11 +29,10 @@ import time
import random
from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar('T', bool, Literal[True], Literal[False])
T = TypeVar("T", bool, Literal[True], Literal[False])
__all__ = ("ExponentialBackoff",)
__all__ = (
'ExponentialBackoff',
)
class ExponentialBackoff(Generic[T]):
"""An implementation of the exponential backoff algorithm
@ -69,7 +68,7 @@ class ExponentialBackoff(Generic[T]):
rand = random.Random()
rand.seed()
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
@overload
def delay(self: ExponentialBackoff[Literal[False]]) -> float:

View File

@ -45,7 +45,13 @@ import datetime
import discord.abc
from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode
from .enums import (
ChannelType,
StagePrivacyLevel,
try_enum,
VoiceRegion,
VideoQualityMode,
)
from .mixins import Hashable
from .object import Object
from . import utils
@ -57,14 +63,14 @@ from .threads import Thread
from .iterators import ArchivedThreadIterator
__all__ = (
'TextChannel',
'VoiceChannel',
'StageChannel',
'DMChannel',
'CategoryChannel',
'StoreChannel',
'GroupChannel',
'PartialMessageable',
"TextChannel",
"VoiceChannel",
"StageChannel",
"DMChannel",
"CategoryChannel",
"StoreChannel",
"GroupChannel",
"PartialMessageable",
)
if TYPE_CHECKING:
@ -155,51 +161,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'topic',
'_state',
'nsfw',
'category_id',
'position',
'slowmode_delay',
'_overwrites',
'_type',
'last_message_id',
'default_auto_archive_duration',
"name",
"id",
"guild",
"topic",
"_state",
"nsfw",
"category_id",
"position",
"slowmode_delay",
"_overwrites",
"_type",
"last_message_id",
"default_auto_archive_duration",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
def __init__(
self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self._type: int = data['type']
self.id: int = int(data["id"])
self._type: int = data["type"]
self._update(guild, data)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('position', self.position),
('nsfw', self.nsfw),
('news', self.is_news()),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("position", self.position),
("nsfw", self.nsfw),
("news", self.is_news()),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: TextChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.topic: Optional[str] = data.get('topic')
self.position: int = data['position']
self.nsfw: bool = data.get('nsfw', False)
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.topic: Optional[str] = data.get("topic")
self.position: int = data["position"]
self.nsfw: bool = data.get("nsfw", False)
# Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
self._type: int = data.get('type', self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self.slowmode_delay: int = data.get("rate_limit_per_user", 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get(
"default_auto_archive_duration", 1440
)
self._type: int = data.get("type", self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(
data, "last_message_id"
)
self._fill_overwrites(data)
async def _get_channel(self):
@ -234,7 +246,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. versionadded:: 2.0
"""
return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id]
return [
thread
for thread in self.guild._threads.values()
if thread.parent_id == self.id
]
def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the channel is NSFW."""
@ -263,7 +279,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Optional[:class:`Message`]
The last message in this channel or ``None`` if not found.
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None
return (
self._state._get_message(self.last_message_id)
if self.last_message_id
else None
)
@overload
async def edit(
@ -359,9 +379,17 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
async def clone(
self, *, name: Optional[str] = None, reason: Optional[str] = None
) -> TextChannel:
return await self._clone_impl(
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
{
"topic": self.topic,
"nsfw": self.nsfw,
"rate_limit_per_user": self.slowmode_delay,
},
name=name,
reason=reason,
)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@ -408,7 +436,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
@ -483,11 +511,19 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
if check is MISSING:
check = lambda m: True
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
iterator = self.history(
limit=limit,
before=before,
after=after,
oldest_first=oldest_first,
around=around,
)
ret: List[Message] = []
count = 0
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
minimum_time = (
int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
)
strategy = self.delete_messages if bulk else _single_delete_strategy
async for message in iterator:
@ -548,7 +584,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
data = await self._state.http.channel_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]
async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook:
async def create_webhook(
self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None
) -> Webhook:
"""|coro|
Creates a webhook for this channel.
@ -586,10 +624,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
if avatar is not None:
avatar = utils._bytes_to_base64_data(avatar) # type: ignore
data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason)
data = await self._state.http.create_webhook(
self.id, name=str(name), avatar=avatar, reason=reason
)
return Webhook.from_state(data, state=self._state)
async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook:
async def follow(
self, *, destination: TextChannel, reason: Optional[str] = None
) -> Webhook:
"""
Follows a channel using a webhook.
@ -625,14 +667,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
if not self.is_news():
raise ClientException('The channel must be a news channel.')
raise ClientException("The channel must be a news channel.")
if not isinstance(destination, TextChannel):
raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}')
raise InvalidArgument(
f"Expected TextChannel received {destination.__class__.__name__}"
)
from .webhook import Webhook
data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason)
data = await self._state.http.follow_webhook(
self.id, webhook_channel_id=destination.id, reason=reason
)
return Webhook._as_follower(data, channel=destination, user=self._state.user)
def get_partial_message(self, message_id: int, /) -> PartialMessage:
@ -731,7 +777,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
data = await self._state.http.start_thread_without_message(
self.id,
name=name,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
auto_archive_duration=auto_archive_duration
or self.default_auto_archive_duration,
type=type.value,
reason=reason,
)
@ -740,7 +787,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.id,
message.id,
name=name,
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
auto_archive_duration=auto_archive_duration
or self.default_auto_archive_duration,
reason=reason,
)
@ -787,45 +835,64 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
:class:`Thread`
The archived threads.
"""
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
return ArchivedThreadIterator(
self.id,
self.guild,
limit=limit,
joined=joined,
private=private,
before=before,
)
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
__slots__ = (
'name',
'id',
'guild',
'bitrate',
'user_limit',
'_state',
'position',
'_overwrites',
'category_id',
'rtc_region',
'video_quality_mode',
"name",
"id",
"guild",
"bitrate",
"user_limit",
"_state",
"position",
"_overwrites",
"category_id",
"rtc_region",
"video_quality_mode",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]):
def __init__(
self,
*,
state: ConnectionState,
guild: Guild,
data: Union[VoiceChannelPayload, StageChannelPayload],
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def _get_voice_client_key(self) -> Tuple[int, str]:
return self.guild.id, 'guild_id'
return self.guild.id, "guild_id"
def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
def _update(
self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]
) -> None:
self.guild = guild
self.name: str = data['name']
rtc = data.get('rtc_region')
self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.position: int = data['position']
self.bitrate: int = data.get('bitrate')
self.user_limit: int = data.get('user_limit')
self.name: str = data["name"]
rtc = data.get("rtc_region")
self.rtc_region: Optional[VoiceRegion] = (
try_enum(VoiceRegion, rtc) if rtc is not None else None
)
self.video_quality_mode: VideoQualityMode = try_enum(
VideoQualityMode, data.get("video_quality_mode", 1)
)
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data["position"]
self.bitrate: int = data.get("bitrate")
self.user_limit: int = data.get("user_limit")
self._fill_overwrites(data)
@property
@ -933,17 +1000,17 @@ class VoiceChannel(VocalGuildChannel):
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("rtc_region", self.rtc_region),
("position", self.position),
("bitrate", self.bitrate),
("video_quality_mode", self.video_quality_mode),
("user_limit", self.user_limit),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
@property
def type(self) -> ChannelType:
@ -951,8 +1018,14 @@ class VoiceChannel(VocalGuildChannel):
return ChannelType.voice
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason)
async def clone(
self, *, name: Optional[str] = None, reason: Optional[str] = None
) -> VoiceChannel:
return await self._clone_impl(
{"bitrate": self.bitrate, "user_limit": self.user_limit},
name=name,
reason=reason,
)
@overload
async def edit(
@ -1093,31 +1166,35 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
__slots__ = ('topic',)
__slots__ = ("topic",)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('topic', self.topic),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("topic", self.topic),
("rtc_region", self.rtc_region),
("position", self.position),
("bitrate", self.bitrate),
("video_quality_mode", self.video_quality_mode),
("user_limit", self.user_limit),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data)
self.topic = data.get('topic')
self.topic = data.get("topic")
@property
def requesting_to_speak(self) -> List[Member]:
"""List[:class:`Member`]: A list of members who are requesting to speak in the stage channel."""
return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None]
return [
member
for member in self.members
if member.voice and member.voice.requested_to_speak_at is not None
]
@property
def speakers(self) -> List[Member]:
@ -1128,7 +1205,9 @@ class StageChannel(VocalGuildChannel):
return [
member
for member in self.members
if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None
if member.voice
and not member.voice.suppress
and member.voice.requested_to_speak_at is None
]
@property
@ -1137,7 +1216,9 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
return [member for member in self.members if member.voice and member.voice.suppress]
return [
member for member in self.members if member.voice and member.voice.suppress
]
@property
def moderators(self) -> List[Member]:
@ -1146,7 +1227,11 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
required_permissions = Permissions.stage_moderator()
return [member for member in self.members if self.permissions_for(member) >= required_permissions]
return [
member
for member in self.members
if self.permissions_for(member) >= required_permissions
]
@property
def type(self) -> ChannelType:
@ -1154,7 +1239,9 @@ class StageChannel(VocalGuildChannel):
return ChannelType.stage_voice
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel:
async def clone(
self, *, name: Optional[str] = None, reason: Optional[str] = None
) -> StageChannel:
return await self._clone_impl({}, name=name, reason=reason)
@property
@ -1166,7 +1253,11 @@ class StageChannel(VocalGuildChannel):
return utils.get(self.guild.stage_instances, channel_id=self.id)
async def create_instance(
self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
self,
*,
topic: str,
privacy_level: StagePrivacyLevel = MISSING,
reason: Optional[str] = None,
) -> StageInstance:
"""|coro|
@ -1201,13 +1292,15 @@ class StageChannel(VocalGuildChannel):
The newly created stage instance.
"""
payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic}
payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic}
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
raise InvalidArgument(
"privacy_level field must be of type PrivacyLevel"
)
payload['privacy_level'] = privacy_level.value
payload["privacy_level"] = privacy_level.value
data = await self._state.http.create_stage_instance(**payload, reason=reason)
return StageInstance(guild=self.guild, state=self._state, data=data)
@ -1361,22 +1454,33 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
"""
__slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id')
__slots__ = (
"name",
"id",
"guild",
"nsfw",
"_state",
"position",
"_overwrites",
"category_id",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
def __init__(
self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def __repr__(self) -> str:
return f'<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
return f"<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.nsfw: bool = data.get('nsfw', False)
self.position: int = data['position']
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.nsfw: bool = data.get("nsfw", False)
self.position: int = data["position"]
self._fill_overwrites(data)
@property
@ -1393,8 +1497,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.nsfw
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
async def clone(
self, *, name: Optional[str] = None, reason: Optional[str] = None
) -> CategoryChannel:
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload
async def edit(
@ -1463,7 +1569,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs):
kwargs.pop('category', None)
kwargs.pop("category", None)
await super().move(**kwargs)
@property
@ -1483,14 +1589,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@property
def text_channels(self) -> List[TextChannel]:
"""List[:class:`TextChannel`]: Returns the text channels that are under this category."""
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)]
ret = [
c
for c in self.guild.channels
if c.category_id == self.id and isinstance(c, TextChannel)
]
ret.sort(key=lambda c: (c.position, c.id))
return ret
@property
def voice_channels(self) -> List[VoiceChannel]:
"""List[:class:`VoiceChannel`]: Returns the voice channels that are under this category."""
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)]
ret = [
c
for c in self.guild.channels
if c.category_id == self.id and isinstance(c, VoiceChannel)
]
ret.sort(key=lambda c: (c.position, c.id))
return ret
@ -1500,7 +1614,11 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
.. versionadded:: 1.7
"""
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)]
ret = [
c
for c in self.guild.channels
if c.category_id == self.id and isinstance(c, StageChannel)
]
ret.sort(key=lambda c: (c.position, c.id))
return ret
@ -1590,30 +1708,32 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'_state',
'nsfw',
'category_id',
'position',
'_overwrites',
"name",
"id",
"guild",
"_state",
"nsfw",
"category_id",
"position",
"_overwrites",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload):
def __init__(
self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def __repr__(self) -> str:
return f'<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
return f"<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: StoreChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.position: int = data['position']
self.nsfw: bool = data.get('nsfw', False)
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data["position"]
self.nsfw: bool = data.get("nsfw", False)
self._fill_overwrites(data)
@property
@ -1639,8 +1759,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
return self.nsfw
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
async def clone(
self, *, name: Optional[str] = None, reason: Optional[str] = None
) -> StoreChannel:
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload
async def edit(
@ -1716,7 +1838,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
DMC = TypeVar('DMC', bound='DMChannel')
DMC = TypeVar("DMC", bound="DMChannel")
class DMChannel(discord.abc.Messageable, Hashable):
@ -1756,24 +1878,26 @@ class DMChannel(discord.abc.Messageable, Hashable):
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', '_state')
__slots__ = ("id", "recipient", "me", "_state")
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
def __init__(
self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload
):
self._state: ConnectionState = state
self.recipient: Optional[User] = state.store_user(data['recipients'][0])
self.recipient: Optional[User] = state.store_user(data["recipients"][0])
self.me: ClientUser = me
self.id: int = int(data['id'])
self.id: int = int(data["id"])
async def _get_channel(self):
return self
def __str__(self) -> str:
if self.recipient:
return f'Direct Message with {self.recipient}'
return 'Direct Message with Unknown User'
return f"Direct Message with {self.recipient}"
return "Direct Message with Unknown User"
def __repr__(self) -> str:
return f'<DMChannel id={self.id} recipient={self.recipient!r}>'
return f"<DMChannel id={self.id} recipient={self.recipient!r}>"
@classmethod
def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC:
@ -1892,19 +2016,32 @@ class GroupChannel(discord.abc.Messageable, Hashable):
The group channel's name if provided.
"""
__slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state')
__slots__ = (
"id",
"recipients",
"owner_id",
"owner",
"_icon",
"name",
"me",
"_state",
)
def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
def __init__(
self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self.me: ClientUser = me
self._update_group(data)
def _update_group(self, data: GroupChannelPayload) -> None:
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id')
self._icon: Optional[str] = data.get('icon')
self.name: Optional[str] = data.get('name')
self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id")
self._icon: Optional[str] = data.get("icon")
self.name: Optional[str] = data.get("name")
self.recipients: List[User] = [
self._state.store_user(u) for u in data.get("recipients", [])
]
self.owner: Optional[BaseUser]
if self.owner_id == self.me.id:
@ -1920,12 +2057,12 @@ class GroupChannel(discord.abc.Messageable, Hashable):
return self.name
if len(self.recipients) == 0:
return 'Unnamed'
return "Unnamed"
return ', '.join(map(lambda x: x.name, self.recipients))
return ", ".join(map(lambda x: x.name, self.recipients))
def __repr__(self) -> str:
return f'<GroupChannel id={self.id} name={self.name!r}>'
return f"<GroupChannel id={self.id} name={self.name!r}>"
@property
def type(self) -> ChannelType:
@ -1937,7 +2074,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
"""Optional[:class:`Asset`]: Returns the channel's icon asset if available."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='channel')
return Asset._from_icon(self._state, self.id, self._icon, path="channel")
@property
def created_at(self) -> datetime.datetime:
@ -2032,7 +2169,9 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
The channel type associated with this partial messageable, if given.
"""
def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None):
def __init__(
self, state: ConnectionState, id: int, type: Optional[ChannelType] = None
):
self._state: ConnectionState = state
self._channel: Object = Object(id=id)
self.id: int = id
@ -2093,13 +2232,21 @@ def _channel_factory(channel_type: int):
def _threaded_channel_factory(channel_type: int):
cls, value = _channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
if value in (
ChannelType.private_thread,
ChannelType.public_thread,
ChannelType.news_thread,
):
return Thread, value
return cls, value
def _threaded_guild_channel_factory(channel_type: int):
cls, value = _guild_channel_factory(channel_type)
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
if value in (
ChannelType.private_thread,
ChannelType.public_thread,
ChannelType.news_thread,
):
return Thread, value
return cls, value

View File

@ -29,7 +29,20 @@ import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
)
import aiohttp
@ -69,46 +82,49 @@ if TYPE_CHECKING:
from .member import Member
from .voice_client import VoiceProtocol
__all__ = (
'Client',
)
__all__ = ("Client",)
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
_log.info("Cleaning up after %d tasks.", len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
_log.info("All tasks finished cancelling.")
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task
})
loop.call_exception_handler(
{
"message": "Unhandled exception during Client.run shutdown.",
"exception": task.exception(),
"task": task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
_log.info("Closing the event loop.")
loop.close()
class Client:
r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
@ -199,6 +215,7 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
@ -210,26 +227,34 @@ class Client:
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
self.loop: asyncio.AbstractEventLoop = (
asyncio.get_event_loop() if loop is None else loop
)
self._listeners: Dict[
str, List[Tuple[asyncio.Future, Callable[..., bool]]]
] = {}
self.shard_id: Optional[int] = options.get("shard_id")
self.shard_count: Optional[int] = options.get("shard_count")
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None)
proxy: Optional[str] = options.pop("proxy", None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None)
unsync_clock: bool = options.pop("assume_unsync_clock", True)
self.http: HTTPClient = HTTPClient(
connector,
proxy=proxy,
proxy_auth=proxy_auth,
unsync_clock=unsync_clock,
loop=self.loop,
)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready
}
self._handlers: Dict[str, Callable] = {"ready": self._handle_ready}
self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook
"before_identify": self._call_before_identify_hook
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._enable_debug_events: bool = options.pop("enable_debug_events", False)
self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
@ -243,12 +268,20 @@ class Client:
# internals
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
def _get_websocket(
self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None
) -> DiscordWebSocket:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
return ConnectionState(
dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
def _handle_ready(self) -> None:
self._ready.set()
@ -260,7 +293,7 @@ class Client:
This could be referred to as the Discord WebSocket protocol latency.
"""
ws = self.ws
return float('nan') if not ws else ws.latency
return float("nan") if not ws else ws.latency
def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
@ -331,7 +364,7 @@ class Client:
If this is not passed via ``__init__`` then this is retrieved
through the gateway when an event contains the data. Usually
after :func:`~discord.on_connect` is called.
.. versionadded:: 2.0
"""
return self._connection.application_id
@ -348,7 +381,13 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
async def _run_event(
self,
coro: Callable[..., Coroutine[Any, Any, Any]],
event_name: str,
*args: Any,
**kwargs: Any,
) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@ -359,14 +398,20 @@ class Client:
except asyncio.CancelledError:
pass
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
def _schedule_event(
self,
coro: Callable[..., Coroutine[Any, Any, Any]],
event_name: str,
*args: Any,
**kwargs: Any,
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
return asyncio.create_task(wrapped, name=f"discord.py: {event_name}")
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s', event)
method = 'on_' + event
_log.debug("Dispatching event %s", event)
method = "on_" + event
listeners = self._listeners.get(event)
if listeners:
@ -413,18 +458,22 @@ class Client:
overridden to have a different implementation.
Check :func:`~discord.on_error` for more details.
"""
print(f'Ignoring exception in {event_method}', file=sys.stderr)
print(f"Ignoring exception in {event_method}", file=sys.stderr)
traceback.print_exc()
# hooks
async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
async def _call_before_identify_hook(
self, shard_id: Optional[int], *, initial: bool = False
) -> None:
# This hook is an internal hook that actually calls the public one.
# It allows the library to have its own hook without stepping on the
# toes of those who need to override their own hook.
await self.before_identify_hook(shard_id, initial=initial)
async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
async def before_identify_hook(
self, shard_id: Optional[int], *, initial: bool = False
) -> None:
"""|coro|
A hook that is called before IDENTIFYing a session. This is useful
@ -470,7 +519,7 @@ class Client:
passing status code.
"""
_log.info('logging in using static token')
_log.info("logging in using static token")
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@ -502,29 +551,35 @@ class Client:
backoff = ExponentialBackoff()
ws_params = {
'initial': True,
'shard_id': self.shard_id,
"initial": True,
"shard_id": self.shard_id,
}
while not self.is_closed():
try:
coro = DiscordWebSocket.from_client(self, **ws_params)
self.ws = await asyncio.wait_for(coro, timeout=60.0)
ws_params['initial'] = False
ws_params["initial"] = False
while True:
await self.ws.poll_event()
except ReconnectWebSocket as e:
_log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect')
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
_log.info("Got a request to %s the websocket.", e.op)
self.dispatch("disconnect")
ws_params.update(
sequence=self.ws.sequence,
resume=e.resume,
session=self.ws.session_id,
)
continue
except (OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError) as exc:
except (
OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect')
self.dispatch("disconnect")
if not reconnect:
await self.close()
if isinstance(exc, ConnectionClosed) and exc.code == 1000:
@ -537,7 +592,12 @@ class Client:
# If we get connection reset by peer then try to RESUME
if isinstance(exc, OSError) and exc.errno in (54, 10054):
ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id)
ws_params.update(
sequence=self.ws.sequence,
initial=False,
resume=True,
session=self.ws.session_id,
)
continue
# We should only get this when an unhandled close code happens,
@ -557,7 +617,9 @@ class Client:
# Always try to RESUME the connection
# If the connection is not RESUME-able then the gateway will invalidate the session.
# This is apparently what the official Discord client does.
ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
ws_params.update(
sequence=self.ws.sequence, resume=True, session=self.ws.session_id
)
async def close(self) -> None:
"""|coro|
@ -654,10 +716,10 @@ class Client:
try:
loop.run_forever()
except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.')
_log.info("Received signal to terminate bot and event loop.")
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.')
_log.info("Cleaning up tasks.")
_cleanup_loop(loop)
if not future.cancelled():
@ -686,10 +748,10 @@ class Client:
self._connection._activity = None
elif isinstance(value, BaseActivity):
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
self._connection._activity = value.to_dict() # type: ignore
self._connection._activity = value.to_dict() # type: ignore
else:
raise TypeError('activity must derive from BaseActivity.')
raise TypeError("activity must derive from BaseActivity.")
@property
def status(self):
""":class:`.Status`:
@ -704,11 +766,11 @@ class Client:
@status.setter
def status(self, value):
if value is Status.offline:
self._connection._status = 'invisible'
self._connection._status = "invisible"
elif isinstance(value, Status):
self._connection._status = str(value)
else:
raise TypeError('status must derive from Status.')
raise TypeError("status must derive from Status.")
@property
def allowed_mentions(self) -> Optional[AllowedMentions]:
@ -723,7 +785,9 @@ class Client:
if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
else:
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
raise TypeError(
f"allowed_mentions must be AllowedMentions not {value.__class__!r}"
)
@property
def intents(self) -> Intents:
@ -740,7 +804,9 @@ class Client:
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
return list(self._connection._users.values())
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
def get_channel(
self, id: int, /
) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
"""Returns a channel or thread with the given ID.
Parameters
@ -755,12 +821,14 @@ class Client:
"""
return self._connection.get_channel(id)
def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable:
def get_partial_messageable(
self, id: int, *, type: Optional[ChannelType] = None
) -> PartialMessageable:
"""Returns a partial messageable with the given channel ID.
This is useful if you have a channel_id but don't want to do an API call
to send messages to it.
.. versionadded:: 2.0
Parameters
@ -1001,8 +1069,10 @@ class Client:
future = self.loop.create_future()
if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
@ -1040,10 +1110,10 @@ class Client:
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('event registered must be a coroutine function')
raise TypeError("event registered must be a coroutine function")
setattr(self, coro.__name__, coro)
_log.debug('%s has successfully been registered as an event', coro.__name__)
_log.debug("%s has successfully been registered as an event", coro.__name__)
return coro
async def change_presence(
@ -1082,10 +1152,10 @@ class Client:
"""
if status is None:
status_str = 'online'
status_str = "online"
status = Status.online
elif status is Status.offline:
status_str = 'invisible'
status_str = "invisible"
status = Status.offline
else:
status_str = str(status)
@ -1111,7 +1181,7 @@ class Client:
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
after: SnowflakeTime = None,
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1191,7 +1261,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int, /) -> Guild:
"""|coro|
@ -1277,7 +1347,9 @@ class Client:
region_value = str(region)
if code:
data = await self.http.create_from_template(code, name, region_value, icon_base64)
data = await self.http.create_from_template(
code, name, region_value, icon_base64
)
else:
data = await self.http.create_guild(name, region_value, icon_base64)
return Guild(data=data, state=self._connection)
@ -1307,12 +1379,18 @@ class Client:
The stage instance from the stage channel ID.
"""
data = await self.http.get_stage_instance(channel_id)
guild = self.get_guild(int(data['guild_id']))
guild = self.get_guild(int(data["guild_id"]))
return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore
# Invite management
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
async def fetch_invite(
self,
url: Union[Invite, str],
*,
with_counts: bool = True,
with_expiration: bool = True,
) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1351,7 +1429,9 @@ class Client:
"""
invite_id = utils.resolve_invite(url)
data = await self.http.get_invite(invite_id, with_counts=with_counts, with_expiration=with_expiration)
data = await self.http.get_invite(
invite_id, with_counts=with_counts, with_expiration=with_expiration
)
return Invite.from_incomplete(state=self._connection, data=data)
async def delete_invite(self, invite: Union[Invite, str]) -> None:
@ -1428,8 +1508,8 @@ class Client:
The bot's application information.
"""
data = await self.http.application_info()
if 'rpc_origins' not in data:
data['rpc_origins'] = None
if "rpc_origins" not in data:
data["rpc_origins"] = None
return AppInfo(self._connection, data)
async def fetch_user(self, user_id: int, /) -> User:
@ -1463,7 +1543,9 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]:
async def fetch_channel(
self, channel_id: int, /
) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro|
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
@ -1492,19 +1574,21 @@ class Client:
"""
data = await self.http.get_channel(channel_id)
factory, ch_type = _threaded_channel_factory(data['type'])
factory, ch_type = _threaded_channel_factory(data["type"])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
raise InvalidData(
"Unknown channel type {type} for channel ID {id}.".format_map(data)
)
if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
else:
# the factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore
guild_id = int(data["guild_id"]) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id)
# GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel
@ -1530,7 +1614,9 @@ class Client:
data = await self.http.get_webhook(webhook_id)
return Webhook.from_state(data, state=self._connection)
async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]:
async def fetch_sticker(
self, sticker_id: int, /
) -> Union[StandardSticker, GuildSticker]:
"""|coro|
Retrieves a :class:`.Sticker` with the specified ID.
@ -1550,8 +1636,8 @@ class Client:
The sticker you requested.
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro|
@ -1571,7 +1657,10 @@ class Client:
All available premium sticker packs.
"""
data = await self.http.list_premium_sticker_packs()
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
return [
StickerPack(state=self._connection, data=pack)
for pack in data["sticker_packs"]
]
async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
@ -1606,7 +1695,7 @@ class Client:
This method should be used for when a view is comprised of components
that last longer than the lifecycle of the program.
.. versionadded:: 2.0
Parameters
@ -1628,17 +1717,19 @@ class Client:
"""
if not isinstance(view, View):
raise TypeError(f'expected an instance of View not {view.__class__!r}')
raise TypeError(f"expected an instance of View not {view.__class__!r}")
if not view.is_persistent():
raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout')
raise ValueError(
"View is not persistent. Items need to have a custom_id set and View must have no timeout"
)
self._connection.store_view(view, message_id)
@property
def persistent_views(self) -> Sequence[View]:
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client.
.. versionadded:: 2.0
"""
return self._connection.persistent_views

View File

@ -35,11 +35,11 @@ from typing import (
)
__all__ = (
'Colour',
'Color',
"Colour",
"Color",
)
CT = TypeVar('CT', bound='Colour')
CT = TypeVar("CT", bound="Colour")
class Colour:
@ -76,16 +76,18 @@ class Colour:
The raw integer colour value.
"""
__slots__ = ('value',)
__slots__ = ("value",)
def __init__(self, value: int):
if not isinstance(value, int):
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
raise TypeError(
f"Expected int parameter, received {value.__class__.__name__} instead."
)
self.value: int = value
def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xff
return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool:
return isinstance(other, Colour) and self.value == other.value
@ -94,13 +96,13 @@ class Colour:
return not self.__eq__(other)
def __str__(self) -> str:
return f'#{self.value:0>6x}'
return f"#{self.value:0>6x}"
def __int__(self) -> int:
return self.value
def __repr__(self) -> str:
return f'<Colour value={self.value}>'
return f"<Colour value={self.value}>"
def __hash__(self) -> int:
return hash(self.value)
@ -141,7 +143,11 @@ class Colour:
return cls(0)
@classmethod
def random(cls: Type[CT], *, seed: Optional[Union[int, str, float, bytes, bytearray]] = None) -> CT:
def random(
cls: Type[CT],
*,
seed: Optional[Union[int, str, float, bytes, bytearray]] = None,
) -> CT:
"""A factory method that returns a :class:`Colour` with a random hue.
.. note::
@ -164,12 +170,12 @@ class Colour:
@classmethod
def teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
return cls(0x1abc9c)
return cls(0x1ABC9C)
@classmethod
def dark_teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806a)
return cls(0x11806A)
@classmethod
def brand_green(cls: Type[CT]) -> CT:
@ -182,17 +188,17 @@ class Colour:
@classmethod
def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
return cls(0x2ecc71)
return cls(0x2ECC71)
@classmethod
def dark_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
return cls(0x1f8b4c)
return cls(0x1F8B4C)
@classmethod
def blue(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
return cls(0x3498db)
return cls(0x3498DB)
@classmethod
def dark_blue(cls: Type[CT]) -> CT:
@ -202,42 +208,42 @@ class Colour:
@classmethod
def purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
return cls(0x9b59b6)
return cls(0x9B59B6)
@classmethod
def dark_purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
return cls(0x71368a)
return cls(0x71368A)
@classmethod
def magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
return cls(0xe91e63)
return cls(0xE91E63)
@classmethod
def dark_magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
return cls(0xad1457)
return cls(0xAD1457)
@classmethod
def gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
return cls(0xf1c40f)
return cls(0xF1C40F)
@classmethod
def dark_gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
return cls(0xc27c0e)
return cls(0xC27C0E)
@classmethod
def orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
return cls(0xe67e22)
return cls(0xE67E22)
@classmethod
def dark_orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xa84300)
return cls(0xA84300)
@classmethod
def brand_red(cls: Type[CT]) -> CT:
@ -250,45 +256,45 @@ class Colour:
@classmethod
def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c)
return cls(0xE74C3C)
@classmethod
def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
return cls(0x992d22)
return cls(0x992D22)
@classmethod
def lighter_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95a5a6)
return cls(0x95A5A6)
lighter_gray = lighter_grey
@classmethod
def dark_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607d8b)
return cls(0x607D8B)
dark_gray = dark_grey
@classmethod
def light_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979c9f)
return cls(0x979C9F)
light_gray = light_grey
@classmethod
def darker_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546e7a)
return cls(0x546E7A)
darker_gray = darker_grey
@classmethod
def og_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
return cls(0x7289da)
return cls(0x7289DA)
@classmethod
def blurple(cls: Type[CT]) -> CT:
@ -298,7 +304,7 @@ class Colour:
@classmethod
def greyple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
return cls(0x99aab5)
return cls(0x99AAB5)
@classmethod
def dark_theme(cls: Type[CT]) -> CT:
@ -324,7 +330,7 @@ class Colour:
.. versionadded:: 2.0
"""
return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.

View File

@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
)
from .enums import try_enum, ComponentType, ButtonStyle
from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag
@ -41,14 +52,14 @@ if TYPE_CHECKING:
__all__ = (
'Component',
'ActionRow',
'Button',
'SelectMenu',
'SelectOption',
"Component",
"ActionRow",
"Button",
"SelectMenu",
"SelectOption",
)
C = TypeVar('C', bound='Component')
C = TypeVar("C", bound="Component")
class Component:
@ -70,14 +81,14 @@ class Component:
The type of component.
"""
__slots__: Tuple[str, ...] = ('type',)
__slots__: Tuple[str, ...] = ("type",)
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__)
return f"<{self.__class__.__name__} {attrs}>"
@classmethod
def _raw_construct(cls: Type[C], **kwargs) -> C:
@ -112,18 +123,20 @@ class ActionRow(Component):
The children components that this holds, if any.
"""
__slots__: Tuple[str, ...] = ('children',)
__slots__: Tuple[str, ...] = ("children",)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.children: List[Component] = [
_component_factory(d) for d in data.get("components", [])
]
def to_dict(self) -> ActionRowPayload:
return {
'type': int(self.type),
'components': [child.to_dict() for child in self.children],
"type": int(self.type),
"components": [child.to_dict() for child in self.children],
} # type: ignore
@ -157,44 +170,44 @@ class Button(Component):
"""
__slots__: Tuple[str, ...] = (
'style',
'custom_id',
'url',
'disabled',
'label',
'emoji',
"style",
"custom_id",
"url",
"disabled",
"label",
"emoji",
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
self.disabled: bool = data.get('disabled', False)
self.label: Optional[str] = data.get('label')
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
self.custom_id: Optional[str] = data.get("custom_id")
self.url: Optional[str] = data.get("url")
self.disabled: bool = data.get("disabled", False)
self.label: Optional[str] = data.get("label")
self.emoji: Optional[PartialEmoji]
try:
self.emoji = PartialEmoji.from_dict(data['emoji'])
self.emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError:
self.emoji = None
def to_dict(self) -> ButtonComponentPayload:
payload = {
'type': 2,
'style': int(self.style),
'label': self.label,
'disabled': self.disabled,
"type": 2,
"style": int(self.style),
"label": self.label,
"disabled": self.disabled,
}
if self.custom_id:
payload['custom_id'] = self.custom_id
payload["custom_id"] = self.custom_id
if self.url:
payload['url'] = self.url
payload["url"] = self.url
if self.emoji:
payload['emoji'] = self.emoji.to_dict()
payload["emoji"] = self.emoji.to_dict()
return payload # type: ignore
@ -231,37 +244,39 @@ class SelectMenu(Component):
"""
__slots__: Tuple[str, ...] = (
'custom_id',
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
"custom_id",
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
self.max_values: int = data.get('max_values', 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
self.custom_id: str = data["custom_id"]
self.placeholder: Optional[str] = data.get("placeholder")
self.min_values: int = data.get("min_values", 1)
self.max_values: int = data.get("max_values", 1)
self.options: List[SelectOption] = [
SelectOption.from_dict(option) for option in data.get("options", [])
]
self.disabled: bool = data.get("disabled", False)
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
'type': self.type.value,
'custom_id': self.custom_id,
'min_values': self.min_values,
'max_values': self.max_values,
'options': [op.to_dict() for op in self.options],
'disabled': self.disabled,
"type": self.type.value,
"custom_id": self.custom_id,
"min_values": self.min_values,
"max_values": self.max_values,
"options": [op.to_dict() for op in self.options],
"disabled": self.disabled,
}
if self.placeholder:
payload['placeholder'] = self.placeholder
payload["placeholder"] = self.placeholder
return payload
@ -292,11 +307,11 @@ class SelectOption:
"""
__slots__: Tuple[str, ...] = (
'label',
'value',
'description',
'emoji',
'default',
"label",
"value",
"description",
"emoji",
"default",
)
def __init__(
@ -318,60 +333,62 @@ class SelectOption:
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
raise TypeError(
f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}"
)
self.emoji = emoji
self.default = default
def __repr__(self) -> str:
return (
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} '
f'emoji={self.emoji!r} default={self.default!r}>'
f"<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} "
f"emoji={self.emoji!r} default={self.default!r}>"
)
def __str__(self) -> str:
if self.emoji:
base = f'{self.emoji} {self.label}'
base = f"{self.emoji} {self.label}"
else:
base = self.label
if self.description:
return f'{base}\n{self.description}'
return f"{base}\n{self.description}"
return base
@classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try:
emoji = PartialEmoji.from_dict(data['emoji'])
emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError:
emoji = None
return cls(
label=data['label'],
value=data['value'],
description=data.get('description'),
label=data["label"],
value=data["value"],
description=data.get("description"),
emoji=emoji,
default=data.get('default', False),
default=data.get("default", False),
)
def to_dict(self) -> SelectOptionPayload:
payload: SelectOptionPayload = {
'label': self.label,
'value': self.value,
'default': self.default,
"label": self.label,
"value": self.value,
"default": self.default,
}
if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore
payload["emoji"] = self.emoji.to_dict() # type: ignore
if self.description:
payload['description'] = self.description
payload["description"] = self.description
return payload
def _component_factory(data: ComponentPayload) -> Component:
component_type = data['type']
component_type = data["type"]
if component_type == 1:
return ActionRow(data)
elif component_type == 2:

View File

@ -32,11 +32,10 @@ if TYPE_CHECKING:
from types import TracebackType
TypingT = TypeVar('TypingT', bound='Typing')
TypingT = TypeVar("TypingT", bound="Typing")
__all__ = ("Typing",)
__all__ = (
'Typing',
)
def _typing_done_callback(fut: asyncio.Future) -> None:
# just retrieve any exception and call it a day
@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
except (asyncio.CancelledError, Exception):
pass
class Typing:
def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
@ -67,7 +67,8 @@ class Typing:
self.task.add_done_callback(_typing_done_callback)
return self
def __exit__(self,
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
@ -79,7 +80,8 @@ class Typing:
await channel._state.http.send_typing(channel.id)
return self.__enter__()
async def __aexit__(self,
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],

View File

@ -25,14 +25,23 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from typing import (
Any,
Dict,
Final,
List,
Mapping,
Protocol,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
from . import utils
from .colour import Colour
__all__ = (
'Embed',
)
__all__ = ("Embed",)
class _EmptyEmbed:
@ -40,7 +49,7 @@ class _EmptyEmbed:
return False
def __repr__(self) -> str:
return 'Embed.Empty'
return "Embed.Empty"
def __len__(self) -> int:
return 0
@ -57,51 +66,47 @@ class EmbedProxy:
return len(self.__dict__)
def __repr__(self) -> str:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
return f'EmbedProxy({inner})'
inner = ", ".join(
(f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_"))
)
return f"EmbedProxy({inner})"
def __getattr__(self, attr: str) -> _EmptyEmbed:
return EmptyEmbed
E = TypeVar('E', bound='Embed')
E = TypeVar("E", bound="Embed")
if TYPE_CHECKING:
from discord.types.embed import Embed as EmbedData, EmbedType
T = TypeVar('T')
T = TypeVar("T")
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str]
value: MaybeEmpty[str]
inline: bool
class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
@ -163,33 +168,33 @@ class Embed:
"""
__slots__ = (
'title',
'url',
'type',
'_timestamp',
'_colour',
'_footer',
'_image',
'_thumbnail',
'_video',
'_provider',
'_author',
'_fields',
'description',
"title",
"url",
"type",
"_timestamp",
"_colour",
"_footer",
"_image",
"_thumbnail",
"_video",
"_provider",
"_author",
"_fields",
"description",
)
Empty: Final = EmptyEmbed
def __init__(
self,
*,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
self,
*,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = "rich",
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
):
self.colour = colour if colour is not EmptyEmbed else color
@ -231,10 +236,10 @@ class Embed:
# fill in the basic fields
self.title = data.get('title', EmptyEmbed)
self.type = data.get('type', EmptyEmbed)
self.description = data.get('description', EmptyEmbed)
self.url = data.get('url', EmptyEmbed)
self.title = data.get("title", EmptyEmbed)
self.type = data.get("type", EmptyEmbed)
self.description = data.get("description", EmptyEmbed)
self.url = data.get("url", EmptyEmbed)
if self.title is not EmptyEmbed:
self.title = str(self.title)
@ -248,22 +253,30 @@ class Embed:
# try to fill in the more rich fields
try:
self._colour = Colour(value=data['color'])
self._colour = Colour(value=data["color"])
except KeyError:
pass
try:
self._timestamp = utils.parse_time(data['timestamp'])
self._timestamp = utils.parse_time(data["timestamp"])
except KeyError:
pass
for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'):
for attr in (
"thumbnail",
"video",
"provider",
"author",
"fields",
"image",
"footer",
):
try:
value = data[attr]
except KeyError:
continue
else:
setattr(self, '_' + attr, value)
setattr(self, "_" + attr, value)
return self
@ -273,11 +286,11 @@ class Embed:
def __len__(self) -> int:
total = len(self.title) + len(self.description)
for field in getattr(self, '_fields', []):
total += len(field['name']) + len(field['value'])
for field in getattr(self, "_fields", []):
total += len(field["name"]) + len(field["value"])
try:
footer_text = self._footer['text']
footer_text = self._footer["text"]
except (AttributeError, KeyError):
pass
else:
@ -288,7 +301,7 @@ class Embed:
except AttributeError:
pass
else:
total += len(author['name'])
total += len(author["name"])
return total
@ -312,7 +325,7 @@ class Embed:
@property
def colour(self) -> MaybeEmpty[Colour]:
return getattr(self, '_colour', EmptyEmbed)
return getattr(self, "_colour", EmptyEmbed)
@colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore
@ -321,13 +334,15 @@ class Embed:
elif isinstance(value, int):
self._colour = Colour(value=value)
else:
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.')
raise TypeError(
f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead."
)
color = colour
@property
def timestamp(self) -> MaybeEmpty[datetime.datetime]:
return getattr(self, '_timestamp', EmptyEmbed)
return getattr(self, "_timestamp", EmptyEmbed)
@timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
@ -338,7 +353,9 @@ class Embed:
elif isinstance(value, _EmptyEmbed):
self._timestamp = value
else:
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead")
raise TypeError(
f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead"
)
@property
def footer(self) -> _EmbedFooterProxy:
@ -348,9 +365,14 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
return EmbedProxy(getattr(self, "_footer", {})) # type: ignore
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
def set_footer(
self: E,
*,
text: MaybeEmpty[Any] = EmptyEmbed,
icon_url: MaybeEmpty[Any] = EmptyEmbed,
) -> E:
"""Sets the footer for the embed content.
This function returns the class instance to allow for fluent-style
@ -366,10 +388,10 @@ class Embed:
self._footer = {}
if text is not EmptyEmbed:
self._footer['text'] = str(text)
self._footer["text"] = str(text)
if icon_url is not EmptyEmbed:
self._footer['icon_url'] = str(icon_url)
self._footer["icon_url"] = str(icon_url)
return self
@ -401,12 +423,12 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
return EmbedProxy(getattr(self, "_image", {})) # type: ignore
@image.setter
def image(self: E, *, url: Any):
self._image = {
'url': str(url),
"url": str(url),
}
@image.deleter
@ -451,15 +473,14 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore
@thumbnail.setter
def thumbnail(self: E, *, url: Any):
"""Sets the thumbnail for the embed content.
"""
"""Sets the thumbnail for the embed content."""
self._thumbnail = {
'url': str(url),
"url": str(url),
}
return
@ -504,7 +525,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
return EmbedProxy(getattr(self, "_video", {})) # type: ignore
@property
def provider(self) -> _EmbedProviderProxy:
@ -514,7 +535,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
return EmbedProxy(getattr(self, "_provider", {})) # type: ignore
@property
def author(self) -> _EmbedAuthorProxy:
@ -524,9 +545,15 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_author', {})) # type: ignore
return EmbedProxy(getattr(self, "_author", {})) # type: ignore
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
def set_author(
self: E,
*,
name: Any,
url: MaybeEmpty[Any] = EmptyEmbed,
icon_url: MaybeEmpty[Any] = EmptyEmbed,
) -> E:
"""Sets the author for the embed content.
This function returns the class instance to allow for fluent-style
@ -543,14 +570,14 @@ class Embed:
"""
self._author = {
'name': str(name),
"name": str(name),
}
if url is not EmptyEmbed:
self._author['url'] = str(url)
self._author["url"] = str(url)
if icon_url is not EmptyEmbed:
self._author['icon_url'] = str(icon_url)
self._author["icon_url"] = str(icon_url)
return self
@ -577,7 +604,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore
return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E:
"""Adds a field to the embed object.
@ -596,9 +623,9 @@ class Embed:
"""
field = {
'inline': inline,
'name': str(name),
'value': str(value),
"inline": inline,
"name": str(name),
"value": str(value),
}
try:
@ -608,7 +635,9 @@ class Embed:
return self
def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E:
def insert_field_at(
self: E, index: int, *, name: Any, value: Any, inline: bool = True
) -> E:
"""Inserts a field before a specified index to the embed.
This function returns the class instance to allow for fluent-style
@ -629,9 +658,9 @@ class Embed:
"""
field = {
'inline': inline,
'name': str(name),
'value': str(value),
"inline": inline,
"name": str(name),
"value": str(value),
}
try:
@ -669,7 +698,9 @@ class Embed:
except (AttributeError, IndexError):
pass
def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E:
def set_field_at(
self: E, index: int, *, name: Any, value: Any, inline: bool = True
) -> E:
"""Modifies a field to the embed object.
The index must point to a valid pre-existing field.
@ -697,11 +728,11 @@ class Embed:
try:
field = self._fields[index]
except (TypeError, IndexError, AttributeError):
raise IndexError('field index out of range')
raise IndexError("field index out of range")
field['name'] = str(name)
field['value'] = str(value)
field['inline'] = inline
field["name"] = str(name)
field["value"] = str(value)
field["inline"] = inline
return self
def to_dict(self) -> EmbedData:
@ -719,35 +750,39 @@ class Embed:
# deal with basic convenience wrappers
try:
colour = result.pop('colour')
colour = result.pop("colour")
except KeyError:
pass
else:
if colour:
result['color'] = colour.value
result["color"] = colour.value
try:
timestamp = result.pop('timestamp')
timestamp = result.pop("timestamp")
except KeyError:
pass
else:
if timestamp:
if timestamp.tzinfo:
result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
result["timestamp"] = timestamp.astimezone(
tz=datetime.timezone.utc
).isoformat()
else:
result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
result["timestamp"] = timestamp.replace(
tzinfo=datetime.timezone.utc
).isoformat()
# add in the non raw attribute ones
if self.type:
result['type'] = self.type
result["type"] = self.type
if self.description:
result['description'] = self.description
result["description"] = self.description
if self.url:
result['url'] = self.url
result["url"] = self.url
if self.title:
result['title'] = self.title
result["title"] = self.title
return result # type: ignore

View File

@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User
__all__ = (
'Emoji',
)
__all__ = ("Emoji",)
if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload
@ -98,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin):
"""
__slots__: Tuple[str, ...] = (
'require_colons',
'animated',
'managed',
'id',
'name',
'_roles',
'guild_id',
'_state',
'user',
'available',
"require_colons",
"animated",
"managed",
"id",
"name",
"_roles",
"guild_id",
"_state",
"user",
"available",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
@ -116,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin):
self._from_data(data)
def _from_data(self, emoji: EmojiPayload):
self.require_colons: bool = emoji.get('require_colons', False)
self.managed: bool = emoji.get('managed', False)
self.id: int = int(emoji['id']) # type: ignore
self.name: str = emoji['name'] # type: ignore
self.animated: bool = emoji.get('animated', False)
self.available: bool = emoji.get('available', True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', [])))
user = emoji.get('user')
self.require_colons: bool = emoji.get("require_colons", False)
self.managed: bool = emoji.get("managed", False)
self.id: int = int(emoji["id"]) # type: ignore
self.name: str = emoji["name"] # type: ignore
self.animated: bool = emoji.get("animated", False)
self.available: bool = emoji.get("available", True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", [])))
user = emoji.get("user")
self.user: Optional[User] = User(state=self._state, data=user) if user else None
def _to_partial(self) -> PartialEmoji:
@ -131,21 +129,21 @@ class Emoji(_EmojiTag, AssetMixin):
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for attr in self.__slots__:
if attr[0] != '_':
if attr[0] != "_":
value = getattr(self, attr, None)
if value is not None:
yield (attr, value)
def __str__(self) -> str:
if self.animated:
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
return f"<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>"
def __eq__(self, other: Any) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id
@ -164,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin):
@property
def url(self) -> str:
""":class:`str`: Returns the URL of the emoji."""
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
@property
def roles(self) -> List[Role]:
@ -217,9 +215,17 @@ class Emoji(_EmojiTag, AssetMixin):
An error occurred deleting the emoji.
"""
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
await self._state.http.delete_custom_emoji(
self.guild.id, self.id, reason=reason
)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
async def edit(
self,
*,
name: str = MISSING,
roles: List[Snowflake] = MISSING,
reason: Optional[str] = None,
) -> Emoji:
r"""|coro|
Edits the custom emoji.
@ -254,9 +260,11 @@ class Emoji(_EmojiTag, AssetMixin):
payload = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if roles is not MISSING:
payload['roles'] = [role.id for role in roles]
payload["roles"] = [role.id for role in roles]
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
data = await self._state.http.edit_custom_emoji(
self.guild.id, self.id, payload=payload, reason=reason
)
return Emoji(guild=self.guild, data=data, state=self._state)

View File

@ -27,50 +27,65 @@ from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
__all__ = (
'Enum',
'ChannelType',
'MessageType',
'VoiceRegion',
'SpeakingState',
'VerificationLevel',
'ContentFilter',
'Status',
'DefaultAvatar',
'AuditLogAction',
'AuditLogActionCategory',
'UserFlags',
'ActivityType',
'NotificationLevel',
'TeamMembershipState',
'WebhookType',
'ExpireBehaviour',
'ExpireBehavior',
'StickerType',
'StickerFormatType',
'InviteTarget',
'VideoQualityMode',
'ComponentType',
'ButtonStyle',
'StagePrivacyLevel',
'InteractionType',
'InteractionResponseType',
'NSFWLevel',
"Enum",
"ChannelType",
"MessageType",
"VoiceRegion",
"SpeakingState",
"VerificationLevel",
"ContentFilter",
"Status",
"DefaultAvatar",
"AuditLogAction",
"AuditLogActionCategory",
"UserFlags",
"ActivityType",
"NotificationLevel",
"TeamMembershipState",
"WebhookType",
"ExpireBehaviour",
"ExpireBehavior",
"StickerType",
"StickerFormatType",
"InviteTarget",
"VideoQualityMode",
"ComponentType",
"ButtonStyle",
"StagePrivacyLevel",
"InteractionType",
"InteractionResponseType",
"NSFWLevel",
)
def _create_value_cls(name, comparable):
cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}'
cls = namedtuple("_EnumValue_" + name, "name value")
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
cls.__le__ = (
lambda self, other: isinstance(other, self.__class__)
and self.value <= other.value
)
cls.__ge__ = (
lambda self, other: isinstance(other, self.__class__)
and self.value >= other.value
)
cls.__lt__ = (
lambda self, other: isinstance(other, self.__class__)
and self.value < other.value
)
cls.__gt__ = (
lambda self, other: isinstance(other, self.__class__)
and self.value > other.value
)
return cls
def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)
class EnumMeta(type):
@ -88,7 +103,7 @@ class EnumMeta(type):
value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == '_' and not is_descriptor:
if key[0] == "_" and not is_descriptor:
continue
# Special case classmethod to just pass through
@ -110,10 +125,10 @@ class EnumMeta(type):
member_mapping[key] = new_value
attrs[key] = new_value
attrs['_enum_value_map_'] = value_mapping
attrs['_enum_member_map_'] = member_mapping
attrs['_enum_member_names_'] = member_names
attrs['_enum_value_cls_'] = value_cls
attrs["_enum_value_map_"] = value_mapping
attrs["_enum_member_map_"] = member_mapping
attrs["_enum_member_names_"] = member_names
attrs["_enum_value_cls_"] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore
return actual_cls
@ -122,13 +137,15 @@ class EnumMeta(type):
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
def __reversed__(cls):
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
return (
cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)
)
def __len__(cls):
return len(cls._enum_member_names_)
def __repr__(cls):
return f'<enum {cls.__name__}>'
return f"<enum {cls.__name__}>"
@property
def __members__(cls):
@ -144,10 +161,10 @@ class EnumMeta(type):
return cls._enum_member_map_[key]
def __setattr__(cls, name, value):
raise TypeError('Enums are immutable.')
raise TypeError("Enums are immutable.")
def __delattr__(cls, attr):
raise TypeError('Enums are immutable')
raise TypeError("Enums are immutable")
def __instancecheck__(self, instance):
# isinstance(x, Y)
@ -215,29 +232,29 @@ class MessageType(Enum):
class VoiceRegion(Enum):
us_west = 'us-west'
us_east = 'us-east'
us_south = 'us-south'
us_central = 'us-central'
eu_west = 'eu-west'
eu_central = 'eu-central'
singapore = 'singapore'
london = 'london'
sydney = 'sydney'
amsterdam = 'amsterdam'
frankfurt = 'frankfurt'
brazil = 'brazil'
hongkong = 'hongkong'
russia = 'russia'
japan = 'japan'
southafrica = 'southafrica'
south_korea = 'south-korea'
india = 'india'
europe = 'europe'
dubai = 'dubai'
vip_us_east = 'vip-us-east'
vip_us_west = 'vip-us-west'
vip_amsterdam = 'vip-amsterdam'
us_west = "us-west"
us_east = "us-east"
us_south = "us-south"
us_central = "us-central"
eu_west = "eu-west"
eu_central = "eu-central"
singapore = "singapore"
london = "london"
sydney = "sydney"
amsterdam = "amsterdam"
frankfurt = "frankfurt"
brazil = "brazil"
hongkong = "hongkong"
russia = "russia"
japan = "japan"
southafrica = "southafrica"
south_korea = "south-korea"
india = "india"
europe = "europe"
dubai = "dubai"
vip_us_east = "vip-us-east"
vip_us_west = "vip-us-west"
vip_amsterdam = "vip-amsterdam"
def __str__(self):
return self.value
@ -277,12 +294,12 @@ class ContentFilter(Enum, comparable=True):
class Status(Enum):
online = 'online'
offline = 'offline'
idle = 'idle'
dnd = 'dnd'
do_not_disturb = 'dnd'
invisible = 'invisible'
online = "online"
offline = "offline"
idle = "idle"
dnd = "dnd"
do_not_disturb = "dnd"
invisible = "invisible"
def __str__(self):
return self.value
@ -415,33 +432,33 @@ class AuditLogAction(Enum):
def target_type(self) -> Optional[str]:
v = self.value
if v == -1:
return 'all'
return "all"
elif v < 10:
return 'guild'
return "guild"
elif v < 20:
return 'channel'
return "channel"
elif v < 30:
return 'user'
return "user"
elif v < 40:
return 'role'
return "role"
elif v < 50:
return 'invite'
return "invite"
elif v < 60:
return 'webhook'
return "webhook"
elif v < 70:
return 'emoji'
return "emoji"
elif v == 73:
return 'channel'
return "channel"
elif v < 80:
return 'message'
return "message"
elif v < 83:
return 'integration'
return "integration"
elif v < 90:
return 'stage_instance'
return "stage_instance"
elif v < 93:
return 'sticker'
return "sticker"
elif v < 113:
return 'thread'
return "thread"
class UserFlags(Enum):
@ -589,12 +606,12 @@ class NSFWLevel(Enum, comparable=True):
age_restricted = 3
T = TypeVar('T')
T = TypeVar("T")
def create_unknown_value(cls: Type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore
name = f'unknown_{val}'
name = f"unknown_{val}"
return value_cls(name=name, value=val)

View File

@ -38,20 +38,20 @@ if TYPE_CHECKING:
from .interactions import Interaction
__all__ = (
'DiscordException',
'ClientException',
'NoMoreItems',
'GatewayNotFound',
'HTTPException',
'Forbidden',
'NotFound',
'DiscordServerError',
'InvalidData',
'InvalidArgument',
'LoginFailure',
'ConnectionClosed',
'PrivilegedIntentsRequired',
'InteractionResponded',
"DiscordException",
"ClientException",
"NoMoreItems",
"GatewayNotFound",
"HTTPException",
"Forbidden",
"NotFound",
"DiscordServerError",
"InvalidData",
"InvalidArgument",
"LoginFailure",
"ConnectionClosed",
"PrivilegedIntentsRequired",
"InteractionResponded",
)
@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException):
"""An exception that is raised when the gateway for Discord could not be found"""
def __init__(self):
message = 'The gateway to connect to discord was not found.'
message = "The gateway to connect to discord was not found."
super().__init__(message)
def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]:
def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]:
items: List[Tuple[str, str]] = []
for k, v in d.items():
new_key = key + '.' + k if key else k
new_key = key + "." + k if key else k
if isinstance(v, dict):
try:
_errors: List[Dict[str, Any]] = v['_errors']
_errors: List[Dict[str, Any]] = v["_errors"]
except KeyError:
items.extend(_flatten_error_dict(v, new_key).items())
else:
items.append((new_key, ' '.join(x.get('message', '') for x in _errors)))
items.append((new_key, " ".join(x.get("message", "") for x in _errors)))
else:
items.append((new_key, v))
@ -123,28 +123,30 @@ class HTTPException(DiscordException):
The Discord specific error code for the failure.
"""
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
def __init__(
self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]
):
self.response: _ResponseType = response
self.status: int = response.status # type: ignore
self.code: int
self.text: str
if isinstance(message, dict):
self.code = message.get('code', 0)
base = message.get('message', '')
errors = message.get('errors')
self.code = message.get("code", 0)
base = message.get("message", "")
errors = message.get("errors")
if errors:
errors = _flatten_error_dict(errors)
helpful = '\n'.join('In %s: %s' % t for t in errors.items())
self.text = base + '\n' + helpful
helpful = "\n".join("In %s: %s" % t for t in errors.items())
self.text = base + "\n" + helpful
else:
self.text = base
else:
self.text = message or ''
self.text = message or ""
self.code = 0
fmt = '{0.status} {0.reason} (error code: {1})'
fmt = "{0.status} {0.reason} (error code: {1})"
if len(self.text):
fmt += ': {2}'
fmt += ": {2}"
super().__init__(fmt.format(self.response, self.code, self.text))
@ -221,14 +223,20 @@ class ConnectionClosed(ClientException):
The shard ID that got closed if applicable.
"""
def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None):
def __init__(
self,
socket: ClientWebSocketResponse,
*,
shard_id: Optional[int],
code: Optional[int] = None,
):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1
# aiohttp doesn't seem to consistently provide close reason
self.reason: str = ''
self.reason: str = ""
self.shard_id: Optional[int] = shard_id
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}')
super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}")
class PrivilegedIntentsRequired(ClientException):
@ -250,10 +258,10 @@ class PrivilegedIntentsRequired(ClientException):
def __init__(self, shard_id: Optional[int]):
self.shard_id: Optional[int] = shard_id
msg = (
'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the '
'developer portal. It is recommended to go to https://discord.com/developers/applications/ '
'and explicitly enable the privileged intents within your application\'s page. If this is not '
'possible, then consider disabling the privileged intents instead.'
"Shard ID %s is requesting privileged intents that have not been explicitly enabled in the "
"developer portal. It is recommended to go to https://discord.com/developers/applications/ "
"and explicitly enable the privileged intents within your application's page. If this is not "
"possible, then consider disabling the privileged intents instead."
)
super().__init__(msg % shard_id)
@ -274,4 +282,4 @@ class InteractionResponded(ClientException):
def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction
super().__init__('This interaction has already been responded to before')
super().__init__("This interaction has already been responded to before")

View File

@ -31,15 +31,23 @@ if TYPE_CHECKING:
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
Check = Union[
Callable[["Cog", "Context[Any]"], MaybeCoro[bool]],
Callable[["Context[Any]"], MaybeCoro[bool]],
]
Hook = Union[
Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]
]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]],
Callable[["Context[Any]", "CommandError"], Coro[Any]],
]
# This is merely a tag type to avoid circular import issues.

View File

@ -33,7 +33,18 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord
@ -54,17 +65,18 @@ if TYPE_CHECKING:
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
'Bot',
'AutoShardedBot',
"when_mentioned",
"when_mentioned_or",
"Bot",
"AutoShardedBot",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@ -72,9 +84,12 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
def when_mentioned_or(
*prefixes: str,
) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -103,6 +118,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
----------
:func:`.when_mentioned`
"""
def inner(bot, msg):
r = list(prefixes)
r = when_mentioned(bot, msg) + r
@ -110,17 +126,29 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
def __repr__(self):
return '<default-help-command>'
return "<default-help-command>"
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
def __init__(
self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
**options,
):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
@ -131,16 +159,20 @@ class BotBase(GroupMixin):
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
if self.owner_ids and not isinstance(
self.owner_ids, collections.abc.Collection
):
raise TypeError(
f"owner_ids must be a collection not {self.owner_ids.__class__!r}"
)
if help_command is _default:
self.help_command = DefaultHelpCommand()
@ -152,7 +184,7 @@ class BotBase(GroupMixin):
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
ev = "on_" + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
@ -172,7 +204,9 @@ class BotBase(GroupMixin):
await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(
self, context: Context, exception: errors.CommandError
) -> None:
"""|coro|
The default command error handler provided by the bot.
@ -182,7 +216,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get('on_command_error', None):
if self.extra_events.get("on_command_error", None):
return
command = context.command
@ -193,8 +227,10 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler():
return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
# global check registration
@ -380,7 +416,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@ -413,7 +449,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@ -445,7 +481,7 @@ class BotBase(GroupMixin):
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
raise TypeError("Listeners must be coroutines")
if name in self.extra_events:
self.extra_events[name].append(func)
@ -541,14 +577,14 @@ class BotBase(GroupMixin):
"""
if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog')
raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name)
cog = cog._inject(self)
@ -628,7 +664,9 @@ class BotBase(GroupMixin):
for event_list in self.extra_events.copy().values():
remove = []
for index, event in enumerate(event_list):
if event.__module__ is not None and _is_submodule(name, event.__module__):
if event.__module__ is not None and _is_submodule(
name, event.__module__
):
remove.append(index)
for index in reversed(remove):
@ -636,7 +674,7 @@ class BotBase(GroupMixin):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
func = getattr(lib, "teardown")
except AttributeError:
pass
else:
@ -652,7 +690,9 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
def _load_from_module_spec(
self, spec: importlib.machinery.ModuleSpec, key: str
) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
@ -663,7 +703,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, 'setup')
setup = getattr(lib, "setup")
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
@ -850,7 +890,7 @@ class BotBase(GroupMixin):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
raise TypeError("help_command must be a subclass of HelpCommand")
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
@ -893,11 +933,15 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}")
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
raise ValueError(
"Iterable command_prefix must contain at least one prefix"
)
return ret
@ -954,14 +998,18 @@ class BotBase(GroupMixin):
except TypeError:
if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}")
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}")
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen
raise
@ -988,19 +1036,19 @@ class BotBase(GroupMixin):
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch('command', ctx)
self.dispatch("command", ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure('The global check once functions failed.')
raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch('command_completion', ctx)
self.dispatch("command_completion", ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None:
"""|coro|
@ -1033,6 +1081,7 @@ class BotBase(GroupMixin):
async def on_message(self, message):
await self.process_commands(message)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@ -1103,10 +1152,13 @@ class Bot(BotBase, discord.Client):
.. versionadded:: 1.7
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.
"""
pass

View File

@ -26,7 +26,19 @@ from __future__ import annotations
import inspect
import discord.utils
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generator,
List,
Optional,
TYPE_CHECKING,
Tuple,
TypeVar,
Type,
)
from ._types import _BaseCommand
@ -36,15 +48,16 @@ if TYPE_CHECKING:
from .core import Command
__all__ = (
'CogMeta',
'Cog',
"CogMeta",
"Cog",
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@ -104,6 +117,7 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
@ -111,17 +125,17 @@ class CogMeta(type):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
description = kwargs.pop('description', None)
description = kwargs.pop("description", None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
@ -136,21 +150,25 @@ class CogMeta(type):
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
if elem.startswith(('cog_', 'bot_')):
raise TypeError(
f"Command in method {base}.{elem!r} must not be staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
getattr(value, "__cog_listener__")
except AttributeError:
continue
else:
if elem.startswith(('cog_', 'bot_')):
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_commands__ = list(
commands.values()
) # this will be copied in Cog.__new__
listeners_as_list = []
for listener in listeners.values():
@ -169,10 +187,12 @@ class CogMeta(type):
def qualified_name(cls) -> str:
return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
@ -183,6 +203,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
@ -199,10 +220,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
@ -255,6 +273,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
@ -269,12 +288,15 @@ class Cog(metaclass=CogMeta):
List[Tuple[:class:`str`, :ref:`coroutine <coroutine>`]]
The listeners defined in this cog.
"""
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
return [
(name, getattr(self, method_name))
for name, method_name in self.__cog_listeners__
]
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
return getattr(method.__func__, "__cog_special_method__", method)
@classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
@ -296,14 +318,16 @@ class Cog(metaclass=CogMeta):
"""
if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
raise TypeError(
f"Cog.listener expected str but received {name.__class__.__name__!r} instead."
)
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
@ -315,6 +339,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
@ -322,7 +347,7 @@ class Cog(metaclass=CogMeta):
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
@_cog_special_method
def cog_unload(self) -> None:

View File

@ -49,21 +49,19 @@ if TYPE_CHECKING:
from .help import HelpCommand
from .view import StringView
__all__ = (
'Context',
)
__all__ = ("Context",)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]):
@ -122,7 +120,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
or invoked.
"""
def __init__(self,
def __init__(
self,
*,
message: Message,
bot: BotT,
@ -153,7 +152,9 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
async def invoke(
self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs
) -> T:
r"""|coro|
Calls a command with the arguments given.
@ -219,7 +220,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command
view = self.view
if cmd is None:
raise ValueError('This context is not valid.')
raise ValueError("This context is not valid.")
# some state to revert to when we're done
index, previous = view.index, view.previous
@ -230,10 +231,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '')
view.index = len(self.prefix or "")
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
@ -263,7 +264,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
if self.prefix is None:
return ''
return ""
user = self.me
# this breaks if the prefix mention is not the bot itself but I
@ -271,7 +272,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
@property
def cog(self) -> Optional[Cog]:
@ -381,7 +382,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, '__cog_commands__'):
if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):

View File

@ -52,32 +52,32 @@ if TYPE_CHECKING:
__all__ = (
'Converter',
'ObjectConverter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'clean_content',
'Greedy',
'run_converters',
"Converter",
"ObjectConverter",
"MemberConverter",
"UserConverter",
"MessageConverter",
"PartialMessageConverter",
"TextChannelConverter",
"InviteConverter",
"GuildConverter",
"RoleConverter",
"GameConverter",
"ColourConverter",
"ColorConverter",
"VoiceChannelConverter",
"StageChannelConverter",
"EmojiConverter",
"PartialEmojiConverter",
"CategoryChannelConverter",
"IDConverter",
"StoreChannelConverter",
"ThreadConverter",
"GuildChannelConverter",
"GuildStickerConverter",
"clean_content",
"Greedy",
"run_converters",
)
@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar("TT", bound=discord.Thread)
@runtime_checkable
@ -132,10 +132,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError('Derived classes need to implement this.')
raise NotImplementedError("Derived classes need to implement this.")
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
class IDConverter(Converter[T_co]):
@ -158,7 +158,9 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument
)
if match is None:
raise ObjectNotFound(argument)
@ -192,13 +194,17 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
return discord.utils.get(
members, name=username, discriminator=discriminator
)
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
return discord.utils.find(
lambda m: m.name == argument or m.nick == argument, members
)
async def query_member_by_id(self, bot, guild, user_id):
ws = bot._get_websocket(shard_id=guild.shard_id)
@ -223,7 +229,9 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
guild = ctx.guild
result = None
user_id = None
@ -232,13 +240,15 @@ class MemberConverter(IDConverter[discord.Member]):
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, 'get_member_named', argument)
result = _get_from_guilds(bot, "get_member_named", argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = guild.get_member(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
else:
result = _get_from_guilds(bot, 'get_member', user_id)
result = _get_from_guilds(bot, "get_member", user_id)
if result is None:
if guild is None:
@ -276,13 +286,17 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@!?([0-9]{15,20})>$", argument
)
result = None
state = ctx._state
if match is not None:
user_id = int(match.group(1))
result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id)
result = ctx.bot.get_user(user_id) or _utils_get(
ctx.message.mentions, id=user_id
)
if result is None:
try:
result = await ctx.bot.fetch_user(user_id)
@ -294,12 +308,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
if arg[0] == "@":
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
if len(arg) > 5 and arg[-5] == "#":
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
@ -330,29 +344,33 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
id_regex = re.compile(
r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$"
)
link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?P<guild_id>[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data["message_id"])
guild_id = data.get("guild_id")
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me':
elif guild_id == "@me":
guild_id = None
else:
guild_id = int(guild_id)
return guild_id, message_id, channel_id
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
def _resolve_channel(
ctx, guild_id, channel_id
) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
@ -386,7 +404,9 @@ class MessageConverter(IDConverter[discord.Message]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(
ctx, argument
)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
@ -417,13 +437,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
return self._resolve_channel(
ctx, argument, "channels", discord.abc.GuildChannel
)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(
ctx: Context, argument: str, attribute: str, type: Type[CT]
) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@ -443,7 +469,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
result = _get_from_guilds(bot, "get_channel", channel_id)
if not isinstance(result, type):
raise ChannelNotFound(argument)
@ -451,10 +477,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(
ctx: Context, argument: str, attribute: str, type: Type[TT]
) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(
r"<#([0-9]{15,20})>$", argument
)
result = None
guild = ctx.guild
@ -491,7 +521,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "text_channels", discord.TextChannel
)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@ -511,7 +543,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "voice_channels", discord.VoiceChannel
)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@ -530,7 +564,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "stage_channels", discord.StageChannel
)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@ -550,7 +586,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "categories", discord.CategoryChannel
)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@ -569,7 +607,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
return GuildChannelConverter._resolve_channel(
ctx, argument, "channels", discord.StoreChannel
)
class ThreadConverter(IDConverter[discord.Thread]):
@ -587,7 +627,9 @@ class ThreadConverter(IDConverter[discord.Thread]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
return GuildChannelConverter._resolve_thread(
ctx, argument, "threads", discord.Thread
)
class ColourConverter(Converter[discord.Colour]):
@ -616,10 +658,12 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
RGB_REGEX = re.compile(
r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)"
)
def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
@ -630,7 +674,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == '%':
if number[-1] == "%":
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
@ -646,29 +690,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group("b"))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == '#':
if argument[0] == "#":
return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x':
if argument[0:2] == "0x":
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith('#'):
if rest.startswith("#"):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == 'rgb':
if arg[0:3] == "rgb":
return self.parse_rgb(arg)
arg = arg.replace(' ', '_')
arg = arg.replace(" ", "_")
method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
@ -697,7 +741,9 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<@&([0-9]{15,20})>$", argument
)
if match:
result = guild.get_role(int(match.group(1)))
else:
@ -776,7 +822,9 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(
r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument
)
result = None
bot = ctx.bot
guild = ctx.guild
@ -810,7 +858,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
if match:
emoji_animated = bool(match.group(1))
@ -818,7 +866,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
ctx.bot._connection,
animated=emoji_animated,
name=emoji_name,
id=emoji_id,
)
raise PartialEmojiConversionFailure(argument)
@ -903,37 +954,41 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
return (
f"@{m.display_name if self.use_nicknames else m.name}"
if m
else "@deleted-user"
)
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role'
return f"@{r.name}" if r else "@deleted-role"
else:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user'
return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
return '@deleted-role'
return "@deleted-role"
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel'
return f"#{c.name}" if c else "#deleted-channel"
else:
def resolve_channel(id: int) -> str:
return f'<#{id}>'
return f"<#{id}>"
transforms = {
'@': resolve_member,
'@!': resolve_member,
'#': resolve_channel,
'@&': resolve_role,
"@": resolve_member,
"@!": resolve_member,
"#": resolve_channel,
"@&": resolve_role,
}
def repl(match: re.Match) -> str:
@ -942,7 +997,7 @@ class clean_content(Converter[str]):
transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
@ -974,42 +1029,46 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',)
__slots__ = ("converter",)
def __init__(self, *, converter: T):
self.converter = converter
def __repr__(self):
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
converter = getattr(self.converter, "__name__", repr(self.converter))
return f"Greedy[{converter}]"
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
raise TypeError("Greedy[...] only takes a single argument")
converter = params[0]
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
if not (
callable(converter)
or isinstance(converter, Converter)
or origin is not None
):
raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
raise TypeError(f"Greedy[{converter!r}] is invalid.")
return cls(converter=converter)
def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
return False
else:
raise BadBoolArgument(lowered)
@ -1056,7 +1115,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
if converter is bool:
return _convert_to_bool(argument)
@ -1065,7 +1126,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
if module is not None and (
module.startswith("discord.") and not module.endswith("converter")
):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@ -1091,10 +1154,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
name = converter.__class__.__name__
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
raise BadArgument(
f'Converting to "{name}" failed for parameter "{param.name}".'
) from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(
ctx: Context, converter, argument: str, param: inspect.Parameter
):
"""|coro|
Runs converters for a given converter, argument, and parameter.
@ -1124,7 +1191,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any
The resulting conversion.
"""
origin = getattr(converter, '__origin__', None)
origin = getattr(converter, "__origin__", None)
if origin is Union:
errors = []

View File

@ -38,24 +38,25 @@ if TYPE_CHECKING:
from ...message import Message
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
role = 6
def get_key(self, msg: Message) -> Any:
if self is BucketType.user:
@ -90,7 +91,7 @@ class Cooldown:
The length of the cooldown period in seconds.
"""
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
@ -190,7 +191,8 @@ class Cooldown:
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
@ -199,7 +201,7 @@ class CooldownMapping:
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
@ -252,16 +254,16 @@ class CooldownMapping:
return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
def update_rate_limit(
self, message: Message, current: Optional[float] = None
) -> Optional[float]:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.
@ -291,7 +294,7 @@ class _Semaphore:
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
@ -299,7 +302,7 @@ class _Semaphore:
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
@ -337,8 +340,9 @@ class _Semaphore:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
@ -347,16 +351,20 @@ class MaxConcurrency:
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
raise TypeError(
f"max_concurrency 'per' must be of type BucketType not {type(per)!r}"
)
def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
return (
f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
)
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@ -41,65 +41,66 @@ if TYPE_CHECKING:
__all__ = (
'CommandError',
'MissingRequiredArgument',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
'CheckFailure',
'CheckAnyFailure',
'CommandNotFound',
'DisabledCommand',
'CommandInvokeError',
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner',
'MessageNotFound',
'ObjectNotFound',
'MemberNotFound',
'GuildNotFound',
'UserNotFound',
'ChannelNotFound',
'ThreadNotFound',
'ChannelNotReadable',
'BadColourArgument',
'BadColorArgument',
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
'BotMissingRole',
'MissingAnyRole',
'BotMissingAnyRole',
'MissingPermissions',
'BotMissingPermissions',
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
'ExpectedClosingQuoteError',
'ExtensionError',
'ExtensionAlreadyLoaded',
'ExtensionNotLoaded',
'NoEntryPointError',
'ExtensionFailed',
'ExtensionNotFound',
'CommandRegistrationError',
'FlagError',
'BadFlagArgument',
'MissingFlagArgument',
'TooManyFlags',
'MissingRequiredFlag',
"CommandError",
"MissingRequiredArgument",
"BadArgument",
"PrivateMessageOnly",
"NoPrivateMessage",
"CheckFailure",
"CheckAnyFailure",
"CommandNotFound",
"DisabledCommand",
"CommandInvokeError",
"TooManyArguments",
"UserInputError",
"CommandOnCooldown",
"MaxConcurrencyReached",
"NotOwner",
"MessageNotFound",
"ObjectNotFound",
"MemberNotFound",
"GuildNotFound",
"UserNotFound",
"ChannelNotFound",
"ThreadNotFound",
"ChannelNotReadable",
"BadColourArgument",
"BadColorArgument",
"RoleNotFound",
"BadInviteArgument",
"EmojiNotFound",
"GuildStickerNotFound",
"PartialEmojiConversionFailure",
"BadBoolArgument",
"MissingRole",
"BotMissingRole",
"MissingAnyRole",
"BotMissingAnyRole",
"MissingPermissions",
"BotMissingPermissions",
"NSFWChannelRequired",
"ConversionError",
"BadUnionArgument",
"BadLiteralArgument",
"ArgumentParsingError",
"UnexpectedQuoteError",
"InvalidEndOfQuotedStringError",
"ExpectedClosingQuoteError",
"ExtensionError",
"ExtensionAlreadyLoaded",
"ExtensionNotLoaded",
"NoEntryPointError",
"ExtensionFailed",
"ExtensionNotFound",
"CommandRegistrationError",
"FlagError",
"BadFlagArgument",
"MissingFlagArgument",
"TooManyFlags",
"MissingRequiredFlag",
)
class CommandError(DiscordException):
r"""The base exception type for all command related errors.
@ -109,14 +110,18 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`.
"""
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace(
"@here", "@\u200bhere"
)
super().__init__(m, *args)
else:
super().__init__(*args)
class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError.
@ -130,18 +135,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
self.original: Exception = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
regarding user input.
This inherits from :exc:`CommandError`.
"""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked
but no command under that name is found.
@ -151,8 +160,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`.
"""
pass
class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter
that is required is not encountered.
@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
super().__init__(f"{param.name} is a required argument that is missing.")
class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
@ -174,23 +187,29 @@ class TooManyArguments(UserInputError):
This inherits from :exc:`UserInputError`
"""
pass
class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command.
This inherits from :exc:`UserInputError`
"""
pass
class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError`
"""
pass
class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
@ -206,10 +225,13 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(
self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]
) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.')
super().__init__("You do not have permission to run this command.")
class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private
@ -217,8 +239,12 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.')
super().__init__(
message or "This command can only be used in private messages."
)
class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message
@ -228,15 +254,18 @@ class NoPrivateMessage(CheckFailure):
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.')
super().__init__(message or "This command cannot be used in private messages.")
class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure`
"""
pass
class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format
of an ID or a mention.
@ -250,9 +279,11 @@ class ObjectNotFound(BadArgument):
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's
@ -267,10 +298,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache.
@ -283,10 +316,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's
cache.
@ -300,10 +335,12 @@ class UserNotFound(BadArgument):
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel.
@ -316,10 +353,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages
in the channel.
@ -333,10 +372,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable
"""
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel.
@ -349,10 +390,12 @@ class ChannelNotFound(BadArgument):
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread.
@ -365,10 +408,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid.
@ -381,12 +426,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role.
@ -399,10 +447,12 @@ class RoleNotFound(BadArgument):
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired.
@ -410,10 +460,12 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji.
@ -426,10 +478,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct
format.
@ -443,10 +497,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
@ -459,10 +515,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@ -475,17 +533,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option')
super().__init__(f"{argument} is not a recognised boolean option")
class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError`
"""
pass
class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception.
@ -497,9 +559,11 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e: Exception) -> None:
self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown.
@ -516,11 +580,15 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
def __init__(
self, cooldown: Cooldown, retry_after: float, type: BucketType
) -> None:
self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
@ -539,10 +607,13 @@ class MaxConcurrencyReached(CommandError):
self.number: int = number
self.per: BucketType = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
suffix = "per %s" % name if per.name != "default" else "globally"
plural = "%s times %s" if number > 1 else "%s time %s"
fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
super().__init__(
f"Too many people are using this command. It can only be used {fmt} concurrently."
)
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.
@ -557,11 +628,13 @@ class MissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.'
message = f"Role {missing_role!r} is required to run this command."
super().__init__(message)
class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command.
@ -575,11 +648,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command'
message = f"Bot requires the role {missing_role!r} to run this command"
super().__init__(message)
class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of
the roles specified to run a command.
@ -594,15 +669,16 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"You are missing at least one of the required roles: {fmt}"
super().__init__(message)
@ -623,19 +699,21 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message)
class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting.
@ -648,9 +726,13 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled.
"""
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
super().__init__(
f"Channel '{channel}' needs to be NSFW for this command to work."
)
class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a
@ -663,18 +745,23 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [
perm.replace("_", " ").replace("guild", "server").title()
for perm in missing_permissions
]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'You are missing {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"You are missing {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a
command.
@ -686,18 +773,23 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [
perm.replace("_", " ").replace("guild", "server").title()
for perm in missing_permissions
]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'Bot requires {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"Bot requires {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all
its associated types.
@ -713,7 +805,10 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
def __init__(
self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]
) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.errors: List[CommandError] = errors
@ -722,18 +817,19 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if hasattr(x, '__origin__'):
if hasattr(x, "__origin__"):
return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
@ -751,19 +847,23 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
def __init__(
self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]
) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
self.errors: List[CommandError] = errors
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.
@ -772,8 +872,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for
i18n purposes.
"""
pass
class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
@ -784,9 +886,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote: str) -> None:
self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string
@ -799,9 +903,13 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char: str) -> None:
self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}')
super().__init__(
f"Expected space after closing quotation but received {char!r}"
)
class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found.
@ -816,7 +924,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.')
super().__init__(f"Expected closing {close_quote}.")
class ExtensionError(DiscordException):
"""Base exception for extension related errors.
@ -828,37 +937,47 @@ class ExtensionError(DiscordException):
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name
message = message or f'Extension {name!r} had an error.'
message = message or f"Extension {name!r} had an error."
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace(
"@here", "@\u200bhere"
)
super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name)
super().__init__(f"Extension {name!r} is already loaded.", name=name)
class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@ -872,11 +991,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found.
@ -890,10 +1011,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str`
The extension that had the error.
"""
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
msg = f"Extension {name!r} could not be loaded."
super().__init__(msg, name=name)
class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added
because the name is already taken by a different command.
@ -909,11 +1032,13 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name
self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.')
type_ = "alias" if alias_conflict else "command"
super().__init__(f"The {type_} {name} is already an existing command or alias.")
class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors.
@ -922,8 +1047,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0
"""
pass
class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values.
@ -938,10 +1065,14 @@ class TooManyFlags(FlagError):
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
super().__init__(
f"Too many flag values, expected {flag.max_args} but received {len(values)}."
)
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
@ -955,6 +1086,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
@ -962,7 +1094,8 @@ class BadFlagArgument(FlagError):
except AttributeError:
name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given.
@ -976,9 +1109,11 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
super().__init__(f"Flag {flag.name!r} is required and missing")
class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value.
@ -992,6 +1127,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')
super().__init__(f"Flag {flag.name!r} does not have an argument")

View File

@ -59,9 +59,9 @@ import sys
import re
__all__ = (
'Flag',
'flag',
'FlagConverter',
"Flag",
"flag",
"FlagConverter",
)
@ -143,25 +143,35 @@ def flag(
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
return Flag(
name=name,
aliases=aliases,
default=default,
max_args=max_args,
override=override,
)
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError('flag names should not be empty')
raise ValueError("flag names should not be empty")
for ch in name:
if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
raise ValueError(
f"flag name {name!r} cannot have any of {forbidden!r} within them"
)
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
def get_flags(
namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]
) -> Dict[str, Flag]:
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
@ -176,9 +186,15 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.name is MISSING:
flag.name = name
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
annotation = flag.annotation = resolve_annotation(
flag.annotation, globals, locals, cache
)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default
if flag.aliases is MISSING:
@ -229,7 +245,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
raise TypeError(
f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag"
)
if flag.override is MISSING:
flag.override = False
@ -237,7 +255,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
raise TypeError(
f"{flag.name!r} flag conflicts with previous flag or alias."
)
else:
names.add(name)
@ -245,7 +265,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
raise TypeError(
f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias."
)
else:
names.add(alias)
@ -274,10 +296,10 @@ class FlagsMeta(type):
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs['__commands_is_flag__'] = True
attrs["__commands_is_flag__"] = True
try:
global_ns = sys.modules[attrs['__module__']].__dict__
global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError:
global_ns = {}
@ -296,26 +318,32 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
attrs["__commands_flag_case_insensitive__"] = base.__dict__[
"__commands_flag_case_insensitive__"
]
if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
attrs["__commands_flag_delimiter__"] = base.__dict__[
"__commands_flag_delimiter__"
]
if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
attrs["__commands_flag_prefix__"] = base.__dict__[
"__commands_flag_prefix__"
]
if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive
attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter
attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix
attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
@ -330,23 +358,30 @@ class FlagsMeta(type):
regex_flags = 0
if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
aliases = {
key.casefold(): value.casefold() for key, value in aliases.items()
}
regex_flags = re.IGNORECASE
keys = list(re.escape(k) for k in flags)
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
joined = "|".join(keys)
pattern = re.compile(
f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})",
regex_flags,
)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
async def tuple_convert_all(
ctx: Context, argument: str, flag: Flag, converter: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -371,7 +406,9 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
async def tuple_convert_flag(
ctx: Context, argument: str, flag: Flag, converters: Any
) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -409,9 +446,13 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
else:
if origin is tuple:
if annotation.__args__[-1] is Ellipsis:
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
return await tuple_convert_all(
ctx, argument, flag, annotation.__args__[0]
)
else:
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
return await tuple_convert_flag(
ctx, argument, flag, annotation.__args__
)
elif origin is list:
# typing.List[x]
annotation = annotation.__args__[0]
@ -432,7 +473,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
F = TypeVar("F", bound="FlagConverter")
class FlagConverter(metaclass=FlagsMeta):
@ -493,8 +534,13 @@ class FlagConverter(metaclass=FlagsMeta):
return self
def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
pairs = " ".join(
[
f"{flag.attribute}={getattr(self, flag.attribute)!r}"
for flag in self.get_flags().values()
]
)
return f"<{self.__class__.__name__} {pairs}>"
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@ -507,7 +553,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
key = match.group("flag")
if case_insensitive:
key = key.casefold()

View File

@ -39,10 +39,10 @@ if TYPE_CHECKING:
from .context import Context
__all__ = (
'Paginator',
'HelpCommand',
'DefaultHelpCommand',
'MinimalHelpCommand',
"Paginator",
"HelpCommand",
"DefaultHelpCommand",
"MinimalHelpCommand",
)
# help -> shows info of bot on top/bottom and lists subcommands
@ -89,7 +89,7 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
@ -118,7 +118,7 @@ class Paginator:
def _linesep_len(self):
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line="", *, empty=False):
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@ -136,18 +136,23 @@ class Paginator:
RuntimeError
The line was too big for the current :attr:`max_size`.
"""
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
max_page_size = (
self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
)
if len(line) > max_page_size:
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
if (
self._count + len(line) + self._linesep_len
> self.max_size - self._suffix_len
):
self.close_page()
self._count += len(line) + self._linesep_len
self._current_page.append(line)
if empty:
self._current_page.append('')
self._current_page.append("")
self._count += self._linesep_len
def close_page(self):
@ -176,7 +181,7 @@ class Paginator:
return self._pages
def __repr__(self):
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
return fmt.format(self)
@ -197,7 +202,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overriden__'):
if not hasattr(on_error, "__help_command_not_overriden__"):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
@ -224,7 +229,7 @@ class _HelpCommandImpl(Command):
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError('Missing context parameter') from None
raise ValueError("Missing context parameter") from None
else:
return result
@ -296,13 +301,13 @@ class HelpCommand:
"""
MENTION_TRANSFORMS = {
'@everyone': '@\u200beveryone',
'@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role',
"@everyone": "@\u200beveryone",
"@here": "@\u200bhere",
r"<@!?[0-9]{17,22}>": "@deleted-user",
r"<@&[0-9]{17,22}>": "@deleted-role",
}
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing
@ -321,11 +326,11 @@ class HelpCommand:
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault("name", "help")
attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
@ -398,7 +403,11 @@ class HelpCommand:
"""
command_name = self._command_impl.name
ctx = self.context
if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name:
if (
ctx is None
or ctx.command is None
or ctx.command.qualified_name != command_name
):
return command_name
return ctx.invoked_with
@ -422,20 +431,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
entries.append(parent.name + " " + parent.signature)
parent = parent.parent
parent_sig = ' '.join(reversed(entries))
parent_sig = " ".join(reversed(entries))
if len(command.aliases) > 0:
aliases = '|'.join(command.aliases)
fmt = f'[{command.name}|{aliases}]'
aliases = "|".join(command.aliases)
fmt = f"[{command.name}|{aliases}]"
if parent_sig:
fmt = parent_sig + ' ' + fmt
fmt = parent_sig + " " + fmt
alias = fmt
else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
alias = command.name if not parent_sig else parent_sig + " " + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}'
return f"{self.context.clean_prefix}{alias} {command.signature}"
def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse.
@ -449,7 +458,7 @@ class HelpCommand:
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), '@invalid')
return transforms.get(obj.group(0), "@invalid")
return self.MENTION_PATTERN.sub(replace, string)
@ -527,7 +536,9 @@ class HelpCommand:
The string to use when the command did not have the subcommand requested.
"""
if isinstance(command, Group) and len(command.all_commands) > 0:
return f'Command "{command.qualified_name}" has no subcommand named {string}'
return (
f'Command "{command.qualified_name}" has no subcommand named {string}'
)
return f'Command "{command.qualified_name}" has no subcommands.'
async def filter_commands(self, commands, *, sort=False, key=None):
@ -558,7 +569,9 @@ class HelpCommand:
if sort and key is None:
key = lambda c: c.name
iterator = commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
iterator = (
commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
)
if self.verify_checks is False:
# if we do not need to verify the checks then we can just
@ -846,21 +859,27 @@ class HelpCommand:
# Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through
# the command group chain ourselves.
keys = command.split(' ')
keys = command.split(" ")
cmd = bot.all_commands.get(keys[0])
if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
string = await maybe_coro(
self.command_not_found, self.remove_mentions(keys[0])
)
return await self.send_error_message(string)
for key in keys[1:]:
try:
found = cmd.all_commands.get(key)
except AttributeError:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
string = await maybe_coro(
self.subcommand_not_found, cmd, self.remove_mentions(key)
)
return await self.send_error_message(string)
else:
if found is None:
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
string = await maybe_coro(
self.subcommand_not_found, cmd, self.remove_mentions(key)
)
return await self.send_error_message(string)
cmd = found
@ -907,14 +926,14 @@ class DefaultHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.width = options.pop("width", 80)
self.indent = options.pop("indent", 2)
self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator()
@ -924,7 +943,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...'
return text[: self.width - 3].rstrip() + "..."
return text
def get_ending_note(self):
@ -1021,11 +1040,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion
self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:'
no_category = f"\u200b{self.no_category}:"
def get_category(command, *, no_category=no_category):
cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category
return cog.qualified_name + ":" if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered)
@ -1033,7 +1052,11 @@ class DefaultHelpCommand(HelpCommand):
# Now we can add the commands to the page.
for category, commands in to_iterate:
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
commands = (
sorted(commands, key=lambda c: c.name)
if self.sort_commands
else list(commands)
)
self.add_indented_commands(commands, heading=category, max_size=max_size)
note = self.get_ending_note()
@ -1066,7 +1089,9 @@ class DefaultHelpCommand(HelpCommand):
if cog.description:
self.paginator.add_line(cog.description, empty=True)
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
filtered = await self.filter_commands(
cog.get_commands(), sort=self.sort_commands
)
self.add_indented_commands(filtered, heading=self.commands_heading)
note = self.get_ending_note()
@ -1110,13 +1135,13 @@ class MinimalHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
@ -1149,7 +1174,9 @@ class MinimalHelpCommand(HelpCommand):
)
def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
return (
f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
)
def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
@ -1180,8 +1207,8 @@ class MinimalHelpCommand(HelpCommand):
"""
if commands:
# U+2002 Middle Dot
joined = '\u2002'.join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__')
joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f"__**{heading}**__")
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
@ -1197,8 +1224,12 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command`
The command to show information of.
"""
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
self.paginator.add_line(
fmt.format(
self.context.clean_prefix, command.qualified_name, command.short_doc
)
)
def add_aliases_formatting(self, aliases):
"""Adds the formatting information on a command's aliases.
@ -1215,7 +1246,9 @@ class MinimalHelpCommand(HelpCommand):
aliases: Sequence[:class:`str`]
A list of aliases to format.
"""
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
self.paginator.add_line(
f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True
)
def add_command_formatting(self, command):
"""A utility function to format commands and groups.
@ -1268,7 +1301,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}'
no_category = f"\u200b{self.no_category}"
def get_category(command, *, no_category=no_category):
cog = command.cog
@ -1278,7 +1311,11 @@ class MinimalHelpCommand(HelpCommand):
to_iterate = itertools.groupby(filtered, key=get_category)
for category, commands in to_iterate:
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
commands = (
sorted(commands, key=lambda c: c.name)
if self.sort_commands
else list(commands)
)
self.add_bot_commands_formatting(commands, category)
note = self.get_ending_note()
@ -1300,9 +1337,11 @@ class MinimalHelpCommand(HelpCommand):
if cog.description:
self.paginator.add_line(cog.description, empty=True)
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
filtered = await self.filter_commands(
cog.get_commands(), sort=self.sort_commands
)
if filtered:
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)
@ -1322,7 +1361,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
self.paginator.add_line(f'**{self.commands_heading}**')
self.paginator.add_line(f"**{self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)

View File

@ -22,7 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
from .errors import (
UnexpectedQuoteError,
InvalidEndOfQuotedStringError,
ExpectedClosingQuoteError,
)
# map from opening quotes to closing quotes
_quotes = {
@ -46,6 +50,7 @@ _quotes = {
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
@ -81,20 +86,20 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
result = self.buffer[self.index:]
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
result = self.buffer[self.index:self.index + n]
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
@ -120,7 +125,7 @@ class StringView:
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index:self.index + pos]
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
@ -144,11 +149,11 @@ class StringView:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return ''.join(result)
return "".join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\':
if current == "\\":
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
@ -156,7 +161,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return ''.join(result)
return "".join(result)
if next_char in _escaped_quotes:
# escaped quote
@ -179,14 +184,13 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return ''.join(result)
return "".join(result)
if current.isspace() and not is_quoted:
# end of word found
return ''.join(result)
return "".join(result)
result.append(current)
def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"

View File

@ -48,21 +48,21 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
__all__ = (
'loop',
)
__all__ = ("loop",)
T = TypeVar('T')
T = TypeVar("T")
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
__slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop
) -> None:
self.loop = loop
self.future = future = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
@ -124,7 +124,7 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
@ -132,10 +132,12 @@ class Loop(Generic[LF]):
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
raise TypeError(
f"Expected coroutine function, not {type(self.coro).__name__!r}."
)
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name)
coro = getattr(self, "_" + name)
if coro is None:
return
@ -150,7 +152,7 @@ class Loop(Generic[LF]):
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
await self._call_loop_function("before_loop")
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
@ -193,10 +195,10 @@ class Loop(Generic[LF]):
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function('error', exc)
await self._call_loop_function("error", exc)
raise exc
finally:
await self._call_loop_function('after_loop')
await self._call_loop_function("after_loop")
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
@ -323,7 +325,7 @@ class Loop(Generic[LF]):
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None:
args = (self._injected, *args)
@ -356,7 +358,9 @@ class Loop(Generic[LF]):
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
return bool(not self._is_being_cancelled and self._task and not self._task.done())
return bool(
not self._is_being_cancelled and self._task and not self._task.done()
)
def cancel(self) -> None:
"""Cancels the internal task, if it is running."""
@ -379,7 +383,9 @@ class Loop(Generic[LF]):
The keyword arguments to use.
"""
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
def restart_when_over(
fut: Any, *, args: Any = args, kwargs: Any = kwargs
) -> None:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
@ -410,9 +416,9 @@ class Loop(Generic[LF]):
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f'{exc!r} must be a class.')
raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException):
raise TypeError(f'{exc!r} must inherit from BaseException.')
raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions)
@ -439,7 +445,9 @@ class Loop(Generic[LF]):
Whether all exceptions were successfully removed.
"""
old_length = len(self._valid_exception)
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
self._valid_exception = tuple(
x for x in self._valid_exception if x not in exceptions
)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task[None]]:
@ -466,8 +474,13 @@ class Loop(Generic[LF]):
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
print(
f"Unhandled exception in internal background task {self.coro.__name__!r}.",
file=sys.stderr,
)
traceback.print_exception(
type(exception), exception, exception.__traceback__, file=sys.stderr
)
def before_loop(self, coro: FT) -> FT:
"""A decorator that registers a coroutine to be called before the loop starts running.
@ -489,7 +502,9 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._before_loop = coro
return coro
@ -517,7 +532,9 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._after_loop = coro
return coro
@ -543,7 +560,9 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(
f"Expected coroutine function, received {coro.__class__.__name__!r}."
)
self._error = coro # type: ignore
return coro
@ -557,14 +576,18 @@ class Loop(Generic[LF]):
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(days=1),
self._time[0],
)
next_time = self._time[self._time_index]
if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc), next_time
)
next_date = self._last_iteration
if self._time_index == 0:
@ -580,7 +603,9 @@ class Loop(Generic[LF]):
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
now
if now is not MISSING
else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
@ -601,16 +626,16 @@ class Loop(Generic[LF]):
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
@ -663,7 +688,7 @@ class Loop(Generic[LF]):
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep
self._seconds = float(seconds)
@ -672,7 +697,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time')
raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING

View File

@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union
import os
import io
__all__ = (
'File',
)
__all__ = ("File",)
class File:
@ -64,7 +62,7 @@ class File:
Whether the attachment is a spoiler.
"""
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer')
__slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer")
if TYPE_CHECKING:
fp: io.BufferedIOBase
@ -80,12 +78,12 @@ class File:
):
if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()):
raise ValueError(f'File buffer {fp!r} must be seekable and readable')
raise ValueError(f"File buffer {fp!r} must be seekable and readable")
self.fp = fp
self._original_pos = fp.tell()
self._owner = False
else:
self.fp = open(fp, 'rb')
self.fp = open(fp, "rb")
self._original_pos = 0
self._owner = True
@ -100,14 +98,20 @@ class File:
if isinstance(fp, str):
_, self.filename = os.path.split(fp)
else:
self.filename = getattr(fp, 'name', None)
self.filename = getattr(fp, "name", None)
else:
self.filename = filename
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
self.filename = 'SPOILER_' + self.filename
if (
spoiler
and self.filename is not None
and not self.filename.startswith("SPOILER_")
):
self.filename = "SPOILER_" + self.filename
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_'))
self.spoiler = spoiler or (
self.filename is not None and self.filename.startswith("SPOILER_")
)
def reset(self, *, seek: Union[int, bool] = True) -> None:
# The `seek` parameter is needed because

View File

@ -24,21 +24,34 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
overload,
)
from .enums import UserFlags
__all__ = (
'SystemChannelFlags',
'MessageFlags',
'PublicUserFlags',
'Intents',
'MemberCacheFlags',
'ApplicationFlags',
"SystemChannelFlags",
"MessageFlags",
"PublicUserFlags",
"Intents",
"MemberCacheFlags",
"ApplicationFlags",
)
FV = TypeVar('FV', bound='flag_value')
BF = TypeVar('BF', bound='BaseFlags')
FV = TypeVar("FV", bound="flag_value")
BF = TypeVar("BF", bound="BaseFlags")
class flag_value:
@ -63,7 +76,7 @@ class flag_value:
instance._set_flag(self.flag, value)
def __repr__(self):
return f'<flag_value flag={self.flag!r}>'
return f"<flag_value flag={self.flag!r}>"
class alias_flag_value(flag_value):
@ -98,13 +111,13 @@ class BaseFlags:
value: int
__slots__ = ('value',)
__slots__ = ("value",)
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@ -123,7 +136,7 @@ class BaseFlags:
return hash(self.value)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} value={self.value}>'
return f"<{self.__class__.__name__} value={self.value}>"
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items():
@ -142,7 +155,9 @@ class BaseFlags:
elif toggle is False:
self.value &= ~o
else:
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.')
raise TypeError(
f"Value to set for {self.__class__.__name__} must be a bool."
)
@fill_with_flags(inverted=True)
@ -196,7 +211,7 @@ class SystemChannelFlags(BaseFlags):
elif toggle is False:
self.value |= o
else:
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
raise TypeError("Value to set for SystemChannelFlags must be a bool.")
@flag_value
def join_notifications(self):
@ -412,7 +427,11 @@ class PublicUserFlags(BaseFlags):
def all(self) -> List[UserFlags]:
"""List[:class:`UserFlags`]: Returns all public flags the user has."""
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
return [
public_flag
for public_flag in UserFlags
if self._has_flag(public_flag.value)
]
@fill_with_flags()
@ -461,7 +480,7 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@ -907,7 +926,7 @@ class MemberCacheFlags(BaseFlags):
self.value = (1 << bits) - 1
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@ -977,10 +996,10 @@ class MemberCacheFlags(BaseFlags):
def _verify_intents(self, intents: Intents):
if self.voice and not intents.voice_states:
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states')
raise ValueError("MemberCacheFlags.voice requires Intents.voice_states")
if self.joined and not intents.members:
raise ValueError('MemberCacheFlags.joined requires Intents.members')
raise ValueError("MemberCacheFlags.joined requires Intents.members")
@property
def _voice_only(self):

View File

@ -43,25 +43,31 @@ from .errors import ConnectionClosed, InvalidArgument
_log = logging.getLogger(__name__)
__all__ = (
'DiscordWebSocket',
'KeepAliveHandler',
'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket',
'ReconnectWebSocket',
"DiscordWebSocket",
"KeepAliveHandler",
"VoiceKeepAliveHandler",
"DiscordVoiceWebSocket",
"ReconnectWebSocket",
)
class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
self.op = "RESUME" if resume else "IDENTIFY"
class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future')
EventListener = namedtuple("EventListener", "predicate event result future")
class GatewayRatelimiter:
def __init__(self, count=110, per=60.0):
@ -101,48 +107,57 @@ class GatewayRatelimiter:
async with self.lock:
delta = self.get_delay()
if delta:
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
_log.warning(
"WebSocket in shard ID %s is ratelimited, waiting %.2f seconds",
self.shard_id,
delta,
)
await asyncio.sleep(delta)
class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs):
ws = kwargs.pop('ws', None)
interval = kwargs.pop('interval', None)
shard_id = kwargs.pop('shard_id', None)
ws = kwargs.pop("ws", None)
interval = kwargs.pop("interval", None)
shard_id = kwargs.pop("shard_id", None)
threading.Thread.__init__(self, *args, **kwargs)
self.ws = ws
self._main_thread_id = ws.thread_id
self.interval = interval
self.daemon = True
self.shard_id = shard_id
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self.msg = "Keeping shard ID %s websocket alive with sequence %s."
self.block_msg = "Shard ID %s heartbeat blocked for more than %s seconds."
self.behind_msg = "Can't keep up, shard ID %s websocket is %.1fs behind."
self._stop_ev = threading.Event()
self._last_ack = time.perf_counter()
self._last_send = time.perf_counter()
self._last_recv = time.perf_counter()
self.latency = float('inf')
self.latency = float("inf")
self.heartbeat_timeout = ws._max_heartbeat_timeout
def run(self):
while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
_log.warning(
"Shard ID %s has stopped responding to the gateway. Closing and restarting.",
self.shard_id,
)
coro = self.ws.close(4000)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
f.result()
except Exception:
_log.exception('An error occurred while stopping the gateway. Ignoring.')
_log.exception(
"An error occurred while stopping the gateway. Ignoring."
)
finally:
self.stop()
return
data = self.get_payload()
_log.debug(self.msg, self.shard_id, data['d'])
_log.debug(self.msg, self.shard_id, data["d"])
coro = self.ws.send_heartbeat(data)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
@ -159,8 +174,8 @@ class KeepAliveHandler(threading.Thread):
except KeyError:
msg = self.block_msg
else:
stack = ''.join(traceback.format_stack(frame))
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
stack = "".join(traceback.format_stack(frame))
msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}"
_log.warning(msg, self.shard_id, total)
except Exception:
@ -169,10 +184,7 @@ class KeepAliveHandler(threading.Thread):
self._last_send = time.perf_counter()
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': self.ws.sequence
}
return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence}
def stop(self):
self._stop_ev.set()
@ -187,19 +199,17 @@ class KeepAliveHandler(threading.Thread):
if self.latency > 10:
_log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s."
self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds"
self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind"
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)}
def ack(self):
ack_time = time.perf_counter()
@ -208,10 +218,12 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
async def close(self, *, code: int = 4000, message: bytes = b"") -> bool:
return await super().close(code=code, message=message)
class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6.
@ -252,19 +264,19 @@ class DiscordWebSocket:
The authentication token for discord.
"""
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE = 3
VOICE_STATE = 4
VOICE_PING = 5
RESUME = 6
RECONNECT = 7
REQUEST_MEMBERS = 8
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE = 3
VOICE_STATE = 4
VOICE_PING = 5
RESUME = 6
RECONNECT = 7
REQUEST_MEMBERS = 8
INVALIDATE_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
def __init__(self, socket, *, loop):
self.socket = socket
@ -294,13 +306,23 @@ class DiscordWebSocket:
return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /):
self._dispatch('socket_raw_receive', data)
self._dispatch("socket_raw_receive", data)
def log_receive(self, _, /):
pass
@classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
async def from_client(
cls,
client,
*,
initial=False,
gateway=None,
shard_id=None,
session=None,
sequence=None,
resume=False,
):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -330,7 +352,7 @@ class DiscordWebSocket:
client._connection._update_references(ws)
_log.debug('Created websocket connected to %s', gateway)
_log.debug("Created websocket connected to %s", gateway)
# poll event for OP Hello
await ws.poll_event()
@ -363,83 +385,87 @@ class DiscordWebSocket:
"""
future = self.loop.create_future()
entry = EventListener(event=event, predicate=predicate, result=result, future=future)
entry = EventListener(
event=event, predicate=predicate, result=result, future=future
)
self._dispatch_listeners.append(entry)
return future
async def identify(self):
"""Sends the IDENTIFY packet."""
payload = {
'op': self.IDENTIFY,
'd': {
'token': self.token,
'properties': {
'$os': sys.platform,
'$browser': 'discord.py',
'$device': 'discord.py',
'$referrer': '',
'$referring_domain': ''
"op": self.IDENTIFY,
"d": {
"token": self.token,
"properties": {
"$os": sys.platform,
"$browser": "discord.py",
"$device": "discord.py",
"$referrer": "",
"$referring_domain": "",
},
'compress': True,
'large_threshold': 250,
'v': 3
}
"compress": True,
"large_threshold": 250,
"v": 3,
},
}
if self.shard_id is not None and self.shard_count is not None:
payload['d']['shard'] = [self.shard_id, self.shard_count]
payload["d"]["shard"] = [self.shard_id, self.shard_count]
state = self._connection
if state._activity is not None or state._status is not None:
payload['d']['presence'] = {
'status': state._status,
'game': state._activity,
'since': 0,
'afk': False
payload["d"]["presence"] = {
"status": state._status,
"game": state._activity,
"since": 0,
"afk": False,
}
if state._intents is not None:
payload['d']['intents'] = state._intents.value
payload["d"]["intents"] = state._intents.value
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.call_hooks(
"before_identify", self.shard_id, initial=self._initial_identify
)
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
_log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id)
async def resume(self):
"""Sends the RESUME packet."""
payload = {
'op': self.RESUME,
'd': {
'seq': self.sequence,
'session_id': self.session_id,
'token': self.token
}
"op": self.RESUME,
"d": {
"seq": self.sequence,
"session_id": self.session_id,
"token": self.token,
},
}
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
_log.info("Shard ID %s has sent the RESUME payload.", self.shard_id)
async def received_message(self, msg, /):
if type(msg) is bytes:
self._buffer.extend(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff":
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
msg = msg.decode("utf-8")
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
event = msg.get('t')
_log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg)
event = msg.get("t")
if event:
self._dispatch('socket_event_type', event)
self._dispatch("socket_event_type", event)
op = msg.get('op')
data = msg.get('d')
seq = msg.get('s')
op = msg.get("op")
data = msg.get("d")
seq = msg.get("s")
if seq is not None:
self.sequence = seq
@ -451,7 +477,7 @@ class DiscordWebSocket:
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
_log.debug('Received RECONNECT opcode.')
_log.debug("Received RECONNECT opcode.")
await self.close()
raise ReconnectWebSocket(self.shard_id)
@ -467,8 +493,10 @@ class DiscordWebSocket:
return
if op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id)
interval = data["heartbeat_interval"] / 1000.0
self._keep_alive = KeepAliveHandler(
ws=self, interval=interval, shard_id=self.shard_id
)
# send a heartbeat immediately
await self.send_as_json(self._keep_alive.get_payload())
self._keep_alive.start()
@ -481,33 +509,41 @@ class DiscordWebSocket:
self.sequence = None
self.session_id = None
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
_log.info("Shard ID %s session has been invalidated.", self.shard_id)
await self.close(code=1000)
raise ReconnectWebSocket(self.shard_id, resume=False)
_log.warning('Unknown OP code %s.', op)
_log.warning("Unknown OP code %s.", op)
return
if event == 'READY':
self._trace = trace = data.get('_trace', [])
self.sequence = msg['s']
self.session_id = data['session_id']
if event == "READY":
self._trace = trace = data.get("_trace", [])
self.sequence = msg["s"]
self.session_id = data["session_id"]
# pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id, ', '.join(trace), self.session_id)
data["__shard_id__"] = self.shard_id
_log.info(
"Shard ID %s has connected to Gateway: %s (Session ID: %s).",
self.shard_id,
", ".join(trace),
self.session_id,
)
elif event == 'RESUMED':
self._trace = trace = data.get('_trace', [])
elif event == "RESUMED":
self._trace = trace = data.get("_trace", [])
# pass back the shard ID to the resumed handler
data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
self.shard_id, self.session_id, ', '.join(trace))
data["__shard_id__"] = self.shard_id
_log.info(
"Shard ID %s has successfully RESUMED session %s under trace %s.",
self.shard_id,
self.session_id,
", ".join(trace),
)
try:
func = self._discord_parsers[event]
except KeyError:
_log.debug('Unknown event %s.', event)
_log.debug("Unknown event %s.", event)
else:
func(data)
@ -540,7 +576,7 @@ class DiscordWebSocket:
def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
return float("inf") if heartbeat is None else heartbeat.latency
def _can_handle_close(self):
code = self._close_code or self.socket.close_code
@ -561,10 +597,14 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
_log.debug('Received %s', msg)
elif msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
aiohttp.WSMsgType.CLOSE,
):
_log.debug("Received %s", msg)
raise WebSocketClosure
except (asyncio.TimeoutError, WebSocketClosure) as e:
# Ensure the keep alive handler is closed
@ -573,20 +613,22 @@ class DiscordWebSocket:
self._keep_alive = None
if isinstance(e, asyncio.TimeoutError):
_log.info('Timed out receiving packet. Attempting a reconnect.')
_log.info("Timed out receiving packet. Attempting a reconnect.")
raise ReconnectWebSocket(self.shard_id) from None
code = self._close_code or self.socket.close_code
if self._can_handle_close():
_log.info('Websocket closed with %s, attempting a reconnect.', code)
_log.info("Websocket closed with %s, attempting a reconnect.", code)
raise ReconnectWebSocket(self.shard_id) from None
else:
_log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
_log.info("Websocket closed with %s, cannot reconnect.", code)
raise ConnectionClosed(
self.socket, shard_id=self.shard_id, code=code
) from None
async def debug_send(self, data, /):
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data)
self._dispatch("socket_raw_send", data)
await self.socket.send_str(data)
async def send(self, data, /):
@ -611,62 +653,59 @@ class DiscordWebSocket:
async def change_presence(self, *, activity=None, status=None, since=0.0):
if activity is not None:
if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.')
raise InvalidArgument("activity must derive from BaseActivity.")
activity = [activity.to_dict()]
else:
activity = []
if status == 'idle':
if status == "idle":
since = int(time.time() * 1000)
payload = {
'op': self.PRESENCE,
'd': {
'activities': activity,
'afk': False,
'since': since,
'status': status
}
"op": self.PRESENCE,
"d": {
"activities": activity,
"afk": False,
"since": since,
"status": status,
},
}
sent = utils._to_json(payload)
_log.debug('Sending "%s" to change status', sent)
await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
async def request_chunks(
self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None
):
payload = {
'op': self.REQUEST_MEMBERS,
'd': {
'guild_id': guild_id,
'presences': presences,
'limit': limit
}
"op": self.REQUEST_MEMBERS,
"d": {"guild_id": guild_id, "presences": presences, "limit": limit},
}
if nonce:
payload['d']['nonce'] = nonce
payload["d"]["nonce"] = nonce
if user_ids:
payload['d']['user_ids'] = user_ids
payload["d"]["user_ids"] = user_ids
if query is not None:
payload['d']['query'] = query
payload["d"]["query"] = query
await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
payload = {
'op': self.VOICE_STATE,
'd': {
'guild_id': guild_id,
'channel_id': channel_id,
'self_mute': self_mute,
'self_deaf': self_deaf
}
"op": self.VOICE_STATE,
"d": {
"guild_id": guild_id,
"channel_id": channel_id,
"self_mute": self_mute,
"self_deaf": self_deaf,
},
}
_log.debug('Updating our voice state to %s.', payload)
_log.debug("Updating our voice state to %s.", payload)
await self.send_as_json(payload)
async def close(self, code=4000):
@ -677,6 +716,7 @@ class DiscordWebSocket:
self._close_code = code
await self.socket.close(code=code)
class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections.
@ -708,18 +748,18 @@ class DiscordVoiceWebSocket:
Receive only. Indicates a user has disconnected from voice.
"""
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None):
self.ws = socket
@ -734,7 +774,7 @@ class DiscordVoiceWebSocket:
pass
async def send_as_json(self, data):
_log.debug('Sending voice websocket frame: %s.', data)
_log.debug("Sending voice websocket frame: %s.", data)
await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json
@ -742,32 +782,32 @@ class DiscordVoiceWebSocket:
async def resume(self):
state = self._connection
payload = {
'op': self.RESUME,
'd': {
'token': state.token,
'server_id': str(state.server_id),
'session_id': state.session_id
}
"op": self.RESUME,
"d": {
"token": state.token,
"server_id": str(state.server_id),
"session_id": state.session_id,
},
}
await self.send_as_json(payload)
async def identify(self):
state = self._connection
payload = {
'op': self.IDENTIFY,
'd': {
'server_id': str(state.server_id),
'user_id': str(state.user.id),
'session_id': state.session_id,
'token': state.token
}
"op": self.IDENTIFY,
"d": {
"server_id": str(state.server_id),
"user_id": str(state.user.id),
"session_id": state.session_id,
"token": state.token,
},
}
await self.send_as_json(payload)
@classmethod
async def from_client(cls, client, *, resume=False, hook=None):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = "wss://" + client.endpoint + "/?v=4"
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
@ -785,109 +825,101 @@ class DiscordVoiceWebSocket:
async def select_protocol(self, ip, port, mode):
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
'protocol': 'udp',
'data': {
'address': ip,
'port': port,
'mode': mode
}
}
"op": self.SELECT_PROTOCOL,
"d": {
"protocol": "udp",
"data": {"address": ip, "port": port, "mode": mode},
},
}
await self.send_as_json(payload)
async def client_connect(self):
payload = {
'op': self.CLIENT_CONNECT,
'd': {
'audio_ssrc': self._connection.ssrc
}
"op": self.CLIENT_CONNECT,
"d": {"audio_ssrc": self._connection.ssrc},
}
await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice):
payload = {
'op': self.SPEAKING,
'd': {
'speaking': int(state),
'delay': 0
}
}
payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}}
await self.send_as_json(payload)
async def received_message(self, msg):
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
data = msg.get('d')
_log.debug("Voice websocket frame received: %s", msg)
op = msg["op"]
data = msg.get("d")
if op == self.READY:
await self.initial_connection(data)
elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack()
elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.')
_log.info("Voice RESUME succeeded.")
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
self._connection.mode = data["mode"]
await self.load_secret_key(data)
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
interval = data["heartbeat_interval"] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(
ws=self, interval=min(interval, 5.0)
)
self._keep_alive.start()
await self._hook(self, msg)
async def initial_connection(self, data):
state = self._connection
state.ssrc = data['ssrc']
state.voice_port = data['port']
state.endpoint_ip = data['ip']
state.ssrc = data["ssrc"]
state.voice_port = data["port"]
state.endpoint_ip = data["ip"]
packet = bytearray(70)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
struct.pack_into(">H", packet, 0, 1) # 1 = Send
struct.pack_into(">H", packet, 2, 70) # 70 = Length
struct.pack_into(">I", packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70)
_log.debug('received packet in initial_connection: %s', recv)
_log.debug("received packet in initial_connection: %s", recv)
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
state.ip = recv[ip_start:ip_end].decode("ascii")
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0]
_log.debug("detected ip: %s port: %s", state.ip, state.port)
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))
modes = [
mode for mode in data["modes"] if mode in self._connection.supported_modes
]
_log.debug("received supported encryption modes: %s", ", ".join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.info('selected the voice protocol for use (%s)', mode)
_log.info("selected the voice protocol for use (%s)", mode)
@property
def latency(self):
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
return float("inf") if heartbeat is None else heartbeat.latency
@property
def average_latency(self):
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies:
return float('inf')
return float("inf")
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
_log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key')
_log.info("received secret key for voice connection")
self.secret_key = self._connection.secret_key = data.get("secret_key")
await self.speak()
await self.speak(False)
@ -897,10 +929,14 @@ class DiscordVoiceWebSocket:
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
elif msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000):

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -32,11 +32,11 @@ from .errors import InvalidArgument
from .enums import try_enum, ExpireBehaviour
__all__ = (
'IntegrationAccount',
'IntegrationApplication',
'Integration',
'StreamIntegration',
'BotIntegration',
"IntegrationAccount",
"IntegrationApplication",
"Integration",
"StreamIntegration",
"BotIntegration",
)
if TYPE_CHECKING:
@ -65,14 +65,14 @@ class IntegrationAccount:
The account name.
"""
__slots__ = ('id', 'name')
__slots__ = ("id", "name")
def __init__(self, data: IntegrationAccountPayload) -> None:
self.id: str = data['id']
self.name: str = data['name']
self.id: str = data["id"]
self.name: str = data["name"]
def __repr__(self) -> str:
return f'<IntegrationAccount id={self.id} name={self.name!r}>'
return f"<IntegrationAccount id={self.id} name={self.name!r}>"
class Integration:
@ -99,14 +99,14 @@ class Integration:
"""
__slots__ = (
'guild',
'id',
'_state',
'type',
'name',
'account',
'user',
'enabled',
"guild",
"id",
"_state",
"type",
"name",
"account",
"user",
"enabled",
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
@ -118,14 +118,14 @@ class Integration:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
self.id: int = int(data['id'])
self.type: IntegrationType = data['type']
self.name: str = data['name']
self.account: IntegrationAccount = IntegrationAccount(data['account'])
self.id: int = int(data["id"])
self.type: IntegrationType = data["type"]
self.name: str = data["name"]
self.account: IntegrationAccount = IntegrationAccount(data["account"])
user = data.get('user')
user = data.get("user")
self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
self.enabled: bool = data["enabled"]
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
@ -186,26 +186,28 @@ class StreamIntegration(Integration):
"""
__slots__ = (
'revoked',
'expire_behaviour',
'expire_grace_period',
'synced_at',
'_role_id',
'syncing',
'enable_emoticons',
'subscriber_count',
"revoked",
"expire_behaviour",
"expire_grace_period",
"synced_at",
"_role_id",
"syncing",
"enable_emoticons",
"subscriber_count",
)
def _from_data(self, data: StreamIntegrationPayload) -> None:
super()._from_data(data)
self.revoked: bool = data['revoked']
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior'])
self.expire_grace_period: int = data['expire_grace_period']
self.synced_at: datetime.datetime = parse_time(data['synced_at'])
self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id')
self.syncing: bool = data['syncing']
self.enable_emoticons: bool = data['enable_emoticons']
self.subscriber_count: int = data['subscriber_count']
self.revoked: bool = data["revoked"]
self.expire_behaviour: ExpireBehaviour = try_enum(
ExpireBehaviour, data["expire_behavior"]
)
self.expire_grace_period: int = data["expire_grace_period"]
self.synced_at: datetime.datetime = parse_time(data["synced_at"])
self._role_id: Optional[int] = _get_as_snowflake(data, "role_id")
self.syncing: bool = data["syncing"]
self.enable_emoticons: bool = data["enable_emoticons"]
self.subscriber_count: int = data["subscriber_count"]
@property
def expire_behavior(self) -> ExpireBehaviour:
@ -252,15 +254,17 @@ class StreamIntegration(Integration):
payload: Dict[str, Any] = {}
if expire_behaviour is not MISSING:
if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour')
raise InvalidArgument(
"expire_behaviour field must be of type ExpireBehaviour"
)
payload['expire_behavior'] = expire_behaviour.value
payload["expire_behavior"] = expire_behaviour.value
if expire_grace_period is not MISSING:
payload['expire_grace_period'] = expire_grace_period
payload["expire_grace_period"] = expire_grace_period
if enable_emoticons is not MISSING:
payload['enable_emoticons'] = enable_emoticons
payload["enable_emoticons"] = enable_emoticons
# This endpoint is undocumented.
# Unsure if it returns the data or not as a result
@ -307,21 +311,21 @@ class IntegrationApplication:
"""
__slots__ = (
'id',
'name',
'icon',
'description',
'summary',
'user',
"id",
"name",
"icon",
"description",
"summary",
"user",
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
self.description: str = data['description']
self.summary: str = data['summary']
user = data.get('bot')
self.id: int = int(data["id"])
self.name: str = data["name"]
self.icon: Optional[str] = data["icon"]
self.description: str = data["description"]
self.summary: str = data["summary"]
user = data.get("bot")
self.user: Optional[User] = User(state=state, data=user) if user else None
@ -350,17 +354,19 @@ class BotIntegration(Integration):
The application tied to this integration.
"""
__slots__ = ('application',)
__slots__ = ("application",)
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state)
self.application = IntegrationApplication(
data=data["application"], state=self._state
)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
if value == 'discord':
if value == "discord":
return BotIntegration, value
elif value in ('twitch', 'youtube'):
elif value in ("twitch", "youtube"):
return StreamIntegration, value
else:
return Integration, value

View File

@ -41,9 +41,9 @@ from .permissions import Permissions
from .webhook.async_ import async_context, Webhook, handle_message_parameters
__all__ = (
'Interaction',
'InteractionMessage',
'InteractionResponse',
"Interaction",
"InteractionMessage",
"InteractionResponse",
)
if TYPE_CHECKING:
@ -58,11 +58,24 @@ if TYPE_CHECKING:
from aiohttp import ClientSession
from .embeds import Embed
from .ui.view import View
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .channel import (
VoiceChannel,
StageChannel,
TextChannel,
CategoryChannel,
StoreChannel,
PartialMessageable,
)
from .threads import Thread
InteractionChannel = Union[
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
VoiceChannel,
StageChannel,
TextChannel,
CategoryChannel,
StoreChannel,
Thread,
PartialMessageable,
]
MISSING: Any = utils.MISSING
@ -100,23 +113,23 @@ class Interaction:
"""
__slots__: Tuple[str, ...] = (
'id',
'type',
'guild_id',
'channel_id',
'data',
'application_id',
'message',
'user',
'token',
'version',
'_permissions',
'_state',
'_session',
'_original_message',
'_cs_response',
'_cs_followup',
'_cs_channel',
"id",
"type",
"guild_id",
"channel_id",
"data",
"application_id",
"message",
"user",
"token",
"version",
"_permissions",
"_state",
"_session",
"_original_message",
"_cs_response",
"_cs_followup",
"_cs_channel",
)
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
@ -126,18 +139,18 @@ class Interaction:
self._from_data(data)
def _from_data(self, data: InteractionPayload):
self.id: int = int(data['id'])
self.type: InteractionType = try_enum(InteractionType, data['type'])
self.data: Optional[InteractionData] = data.get('data')
self.token: str = data['token']
self.version: int = data['version']
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.application_id: int = int(data['application_id'])
self.id: int = int(data["id"])
self.type: InteractionType = try_enum(InteractionType, data["type"])
self.data: Optional[InteractionData] = data.get("data")
self.token: str = data["token"]
self.version: int = data["version"]
self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id")
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.application_id: int = int(data["application_id"])
self.message: Optional[Message]
try:
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore
except KeyError:
self.message = None
@ -148,15 +161,15 @@ class Interaction:
if self.guild_id:
guild = self.guild or Object(id=self.guild_id)
try:
member = data['member'] # type: ignore
member = data["member"] # type: ignore
except KeyError:
pass
else:
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
self._permissions = int(member.get('permissions', 0))
self._permissions = int(member.get("permissions", 0))
else:
try:
self.user = User(state=self._state, data=data['user'])
self.user = User(state=self._state, data=data["user"])
except KeyError:
pass
@ -165,7 +178,7 @@ class Interaction:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id)
@utils.cached_slot_property('_cs_channel')
@utils.cached_slot_property("_cs_channel")
def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
@ -176,8 +189,14 @@ class Interaction:
channel = guild and guild._resolve_channel(self.channel_id)
if channel is None:
if self.channel_id is not None:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
type = (
ChannelType.text
if self.guild_id is not None
else ChannelType.private
)
return PartialMessageable(
state=self._state, id=self.channel_id, type=type
)
return None
return channel
@ -189,7 +208,7 @@ class Interaction:
"""
return Permissions(self._permissions)
@utils.cached_slot_property('_cs_response')
@utils.cached_slot_property("_cs_response")
def response(self) -> InteractionResponse:
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
@ -198,13 +217,13 @@ class Interaction:
"""
return InteractionResponse(self)
@utils.cached_slot_property('_cs_followup')
@utils.cached_slot_property("_cs_followup")
def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = {
'id': self.application_id,
'type': 3,
'token': self.token,
"id": self.application_id,
"type": 3,
"token": self.token,
}
return Webhook.from_state(data=payload, state=self._state)
@ -238,7 +257,7 @@ class Interaction:
# TODO: fix later to not raise?
channel = self.channel
if channel is None:
raise ClientException('Channel for message could not be resolved')
raise ClientException("Channel for message could not be resolved")
adapter = async_context.get()
data = await adapter.get_original_interaction_response(
@ -369,8 +388,8 @@ class InteractionResponse:
"""
__slots__: Tuple[str, ...] = (
'_responded',
'_parent',
"_responded",
"_parent",
)
def __init__(self, parent: Interaction):
@ -416,12 +435,16 @@ class InteractionResponse:
elif parent.type is InteractionType.application_command:
defer_type = InteractionResponseType.deferred_channel_message.value
if ephemeral:
data = {'flags': 64}
data = {"flags": 64}
if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=defer_type, data=data
parent.id,
parent.token,
session=parent._session,
type=defer_type,
data=data,
)
self._responded = True
@ -446,7 +469,10 @@ class InteractionResponse:
if parent.type is InteractionType.ping:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.pong.value,
)
self._responded = True
@ -498,28 +524,28 @@ class InteractionResponse:
raise InteractionResponded(self._parent)
payload: Dict[str, Any] = {
'tts': tts,
"tts": tts,
}
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix embed and embeds keyword arguments')
raise TypeError("cannot mix embed and embeds keyword arguments")
if embed is not MISSING:
embeds = [embed]
if embeds:
if len(embeds) > 10:
raise ValueError('embeds cannot exceed maximum of 10 elements')
payload['embeds'] = [e.to_dict() for e in embeds]
raise ValueError("embeds cannot exceed maximum of 10 elements")
payload["embeds"] = [e.to_dict() for e in embeds]
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
if ephemeral:
payload['flags'] = 64
payload["flags"] = 64
if view is not MISSING:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
parent = self._parent
adapter = async_context.get()
@ -591,12 +617,12 @@ class InteractionResponse:
payload = {}
if content is not MISSING:
if content is None:
payload['content'] = None
payload["content"] = None
else:
payload['content'] = str(content)
payload["content"] = str(content)
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix both embed and embeds keyword arguments')
raise TypeError("cannot mix both embed and embeds keyword arguments")
if embed is not MISSING:
if embed is None:
@ -605,17 +631,17 @@ class InteractionResponse:
embeds = [embed]
if embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds]
payload["embeds"] = [e.to_dict() for e in embeds]
if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments]
payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING:
state.prevent_view_updates_for(message_id)
if view is None:
payload['components'] = []
payload["components"] = []
else:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
@ -633,7 +659,7 @@ class InteractionResponse:
class _InteractionMessageState:
__slots__ = ('_parent', '_interaction')
__slots__ = ("_parent", "_interaction")
def __init__(self, interaction: Interaction, parent: ConnectionState):
self._interaction: Interaction = interaction

View File

@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum
from .appinfo import PartialAppInfo
__all__ = (
'PartialInviteChannel',
'PartialInviteGuild',
'Invite',
"PartialInviteChannel",
"PartialInviteGuild",
"Invite",
)
if TYPE_CHECKING:
@ -52,8 +52,8 @@ if TYPE_CHECKING:
from .abc import GuildChannel
from .user import User
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object]
InviteGuildType = Union[Guild, "PartialInviteGuild", Object]
InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object]
import datetime
@ -92,23 +92,25 @@ class PartialInviteChannel:
The partial channel's type.
"""
__slots__ = ('id', 'name', 'type')
__slots__ = ("id", "name", "type")
def __init__(self, data: InviteChannelPayload):
self.id: int = int(data['id'])
self.name: str = data['name']
self.type: ChannelType = try_enum(ChannelType, data['type'])
self.id: int = int(data["id"])
self.name: str = data["name"]
self.type: ChannelType = try_enum(ChannelType, data["type"])
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>'
return (
f"<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>"
)
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime.datetime:
@ -154,26 +156,38 @@ class PartialInviteGuild:
The partial guild's description.
"""
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description')
__slots__ = (
"_state",
"features",
"_icon",
"_banner",
"id",
"name",
"_splash",
"verification_level",
"description",
)
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state
self.id: int = id
self.name: str = data['name']
self.features: List[str] = data.get('features', [])
self._icon: Optional[str] = data.get('icon')
self._banner: Optional[str] = data.get('banner')
self._splash: Optional[str] = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description')
self.name: str = data["name"]
self.features: List[str] = data.get("features", [])
self._icon: Optional[str] = data.get("icon")
self._banner: Optional[str] = data.get("banner")
self._splash: Optional[str] = data.get("splash")
self.verification_level: VerificationLevel = try_enum(
VerificationLevel, data.get("verification_level")
)
self.description: Optional[str] = data.get("description")
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} '
f'description={self.description!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} "
f"description={self.description!r}>"
)
@property
@ -193,17 +207,21 @@ class PartialInviteGuild:
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None:
return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners')
return Asset._from_guild_image(
self._state, self.id, self._banner, path="banners"
)
@property
def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None:
return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
return Asset._from_guild_image(
self._state, self.id, self._splash, path="splashes"
)
I = TypeVar('I', bound='Invite')
I = TypeVar("I", bound="Invite")
class Invite(Hashable):
@ -308,26 +326,26 @@ class Invite(Hashable):
"""
__slots__ = (
'max_age',
'code',
'guild',
'revoked',
'created_at',
'uses',
'temporary',
'max_uses',
'inviter',
'channel',
'target_user',
'target_type',
'_state',
'approximate_member_count',
'approximate_presence_count',
'target_application',
'expires_at',
"max_age",
"code",
"guild",
"revoked",
"created_at",
"uses",
"temporary",
"max_uses",
"inviter",
"channel",
"target_user",
"target_type",
"_state",
"approximate_member_count",
"approximate_presence_count",
"target_application",
"expires_at",
)
BASE = 'https://discord.gg'
BASE = "https://discord.gg"
def __init__(
self,
@ -338,45 +356,67 @@ class Invite(Hashable):
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
):
self._state: ConnectionState = state
self.max_age: Optional[int] = data.get('max_age')
self.code: str = data['code']
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
self.revoked: Optional[bool] = data.get('revoked')
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.temporary: Optional[bool] = data.get('temporary')
self.uses: Optional[int] = data.get('uses')
self.max_uses: Optional[int] = data.get('max_uses')
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
self.max_age: Optional[int] = data.get("max_age")
self.code: str = data["code"]
self.guild: Optional[InviteGuildType] = self._resolve_guild(
data.get("guild"), guild
)
self.revoked: Optional[bool] = data.get("revoked")
self.created_at: Optional[datetime.datetime] = parse_time(
data.get("created_at")
)
self.temporary: Optional[bool] = data.get("temporary")
self.uses: Optional[int] = data.get("uses")
self.max_uses: Optional[int] = data.get("max_uses")
self.approximate_presence_count: Optional[int] = data.get(
"approximate_presence_count"
)
self.approximate_member_count: Optional[int] = data.get(
"approximate_member_count"
)
expires_at = data.get('expires_at', None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
expires_at = data.get("expires_at", None)
self.expires_at: Optional[datetime.datetime] = (
parse_time(expires_at) if expires_at else None
)
inviter_data = data.get('inviter')
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
inviter_data = data.get("inviter")
self.inviter: Optional[User] = (
None if inviter_data is None else self._state.create_user(inviter_data)
)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
self.channel: Optional[InviteChannelType] = self._resolve_channel(
data.get("channel"), channel
)
target_user_data = data.get('target_user')
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
target_user_data = data.get("target_user")
self.target_user: Optional[User] = (
None
if target_user_data is None
else self._state.create_user(target_user_data)
)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
self.target_type: InviteTarget = try_enum(
InviteTarget, data.get("target_type", 0)
)
application = data.get('target_application')
application = data.get("target_application")
self.target_application: Optional[PartialAppInfo] = (
PartialAppInfo(data=application, state=state) if application else None
)
@classmethod
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I:
def from_incomplete(
cls: Type[I], *, state: ConnectionState, data: InvitePayload
) -> I:
guild: Optional[Union[Guild, PartialInviteGuild]]
try:
guild_data = data['guild']
guild_data = data["guild"]
except KeyError:
# If we're here, then this is a group DM
guild = None
else:
guild_id = int(guild_data['id'])
guild_id = int(guild_data["id"])
guild = state._get_guild(guild_id)
if guild is None:
# If it's not cached, then it has to be a partial guild
@ -384,7 +424,9 @@ class Invite(Hashable):
# As far as I know, invites always need a channel
# So this should never raise.
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel'])
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(
data["channel"]
)
if guild is not None and not isinstance(guild, PartialInviteGuild):
# Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel
@ -392,10 +434,12 @@ class Invite(Hashable):
return cls(state=state, data=data, guild=guild, channel=channel)
@classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
def from_gateway(
cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload
) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id'])
channel_id = int(data["channel_id"])
if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
else:
@ -415,7 +459,7 @@ class Invite(Hashable):
if data is None:
return None
guild_id = int(data['id'])
guild_id = int(data["id"])
return PartialInviteGuild(self._state, data, guild_id)
def _resolve_channel(
@ -439,9 +483,9 @@ class Invite(Hashable):
def __repr__(self) -> str:
return (
f'<Invite code={self.code!r} guild={self.guild!r} '
f'online={self.approximate_presence_count} '
f'members={self.approximate_member_count}>'
f"<Invite code={self.code!r} guild={self.guild!r} "
f"online={self.approximate_presence_count} "
f"members={self.approximate_member_count}>"
)
def __hash__(self) -> int:
@ -455,7 +499,7 @@ class Invite(Hashable):
@property
def url(self) -> str:
""":class:`str`: A property that retrieves the invite URL."""
return self.BASE + '/' + self.code
return self.BASE + "/" + self.code
async def delete(self, *, reason: Optional[str] = None):
"""|coro|

View File

@ -26,7 +26,17 @@ from __future__ import annotations
import asyncio
import datetime
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
from typing import (
Awaitable,
TYPE_CHECKING,
TypeVar,
Optional,
Any,
Callable,
Union,
List,
AsyncIterator,
)
from .errors import NoMoreItems
from .utils import snowflake_time, time_snowflake, maybe_coroutine
@ -34,11 +44,11 @@ from .object import Object
from .audit_logs import AuditLogEntry
__all__ = (
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
"ReactionIterator",
"HistoryIterator",
"AuditLogIterator",
"GuildIterator",
"MemberIterator",
)
if TYPE_CHECKING:
@ -67,8 +77,8 @@ if TYPE_CHECKING:
from .threads import Thread
from .abc import Snowflake
T = TypeVar('T')
OT = TypeVar('OT')
T = TypeVar("T")
OT = TypeVar("OT")
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
@ -83,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
nested = attr.split("__")
obj = elem
for attribute in nested:
obj = getattr(obj, attribute)
@ -107,7 +117,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.')
raise ValueError("async iterator chunk sizes must be greater than 0.")
return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@ -182,7 +192,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@ -218,14 +228,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
if data:
self.limit -= retrieve
self.after = Object(id=int(data[-1]['id']))
self.after = Object(id=int(data[-1]["id"]))
if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data):
await self.users.put(User(state=self.state, data=element))
else:
for element in reversed(data):
member_id = int(element['id'])
member_id = int(element["id"])
member = self.guild.get_member(member_id)
if member is not None:
await self.users.put(member)
@ -233,7 +243,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']):
class HistoryIterator(_AsyncIterator["Message"]):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@ -267,7 +277,15 @@ class HistoryIterator(_AsyncIterator['Message']):
``True`` if `after` is specified, otherwise ``False``.
"""
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
def __init__(
self,
messageable,
limit,
before=None,
after=None,
around=None,
oldest_first=None,
):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@ -295,28 +313,30 @@ class HistoryIterator(_AsyncIterator['Message']):
if self.around:
if self.limit is None:
raise ValueError('history does not support around with limit=None')
raise ValueError("history does not support around with limit=None")
if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
raise ValueError(
"history max limit 101 when specifying around parameter"
)
elif self.limit == 101:
self.limit = 100 # Thanks discord
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
elif self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < int(m['id'])
self._filter = lambda m: self.after.id < int(m["id"])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def next(self) -> Message:
if self.messages.empty():
@ -337,7 +357,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return r > 0
async def fill_messages(self):
if not hasattr(self, 'channel'):
if not hasattr(self, "channel"):
# do the required set up
channel = await self.messageable._get_channel()
self.channel = channel
@ -354,7 +374,9 @@ class HistoryIterator(_AsyncIterator['Message']):
channel = self.channel
for element in data:
await self.messages.put(self.state.create_message(channel=channel, data=element))
await self.messages.put(
self.state.create_message(channel=channel, data=element)
)
async def _retrieve_messages(self, retrieve) -> List[Message]:
"""Retrieve messages and update next parameters."""
@ -363,35 +385,50 @@ class HistoryIterator(_AsyncIterator['Message']):
async def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter."""
before = self.before.id if self.before else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
data: List[MessagePayload] = await self.logs_from(
self.channel.id, retrieve, before=before
)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter."""
after = self.after.id if self.after else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
data: List[MessagePayload] = await self.logs_from(
self.channel.id, retrieve, after=after
)
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
async def _retrieve_messages_around_strategy(self, retrieve):
"""Retrieve messages using around parameter."""
if self.around:
around = self.around.id if self.around else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
data: List[MessagePayload] = await self.logs_from(
self.channel.id, retrieve, around=around
)
self.around = None
return data
return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
def __init__(
self,
guild,
limit=None,
before=None,
after=None,
oldest_first=None,
user_id=None,
action_type=None,
):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
@ -420,36 +457,44 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
self.guild.id,
limit=retrieve,
user_id=self.user_id,
action_type=self.action_type,
before=before,
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries
self.before = Object(id=int(entries[-1]["id"]))
return data.get("users", []), entries
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
self.guild.id,
limit=retrieve,
user_id=self.user_id,
action_type=self.action_type,
after=after,
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
self.after = Object(id=int(entries[0]["id"]))
return data.get("users", []), entries
async def next(self) -> AuditLogEntry:
if self.entries.empty():
@ -488,13 +533,15 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
for element in data:
# TODO: remove this if statement later
if element['action_type'] is None:
if element["action_type"] is None:
continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
await self.entries.put(
AuditLogEntry(data=element, users=self._users, guild=self.guild)
)
class GuildIterator(_AsyncIterator['Guild']):
class GuildIterator(_AsyncIterator["Guild"]):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
@ -543,7 +590,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
@ -595,7 +642,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_guilds_after_strategy(self, retrieve):
@ -605,11 +652,11 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
class MemberIterator(_AsyncIterator['Member']):
class MemberIterator(_AsyncIterator["Member"]):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
@ -652,7 +699,7 @@ class MemberIterator(_AsyncIterator['Member']):
if len(data) < 1000:
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id']))
self.after = Object(id=int(data[-1]["user"]["id"]))
for element in reversed(data):
await self.members.put(self.create_member(element))
@ -663,7 +710,7 @@ class MemberIterator(_AsyncIterator['Member']):
return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']):
class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def __init__(
self,
channel_id: int,
@ -681,7 +728,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
self.http = guild._state.http
if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads')
raise ValueError("Cannot iterate over joined public archived threads")
self.before: Optional[str]
if before is None:
@ -721,11 +768,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp']
return data["thread_metadata"]["archive_timestamp"]
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore
return data["id"] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
@ -735,11 +782,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', [])
threads: List[ThreadPayload] = data.get("threads", [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False)
self.has_more = data.get("has_more", False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
@ -750,4 +797,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)

View File

@ -29,7 +29,19 @@ import inspect
import itertools
import sys
from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import discord.abc
@ -44,8 +56,8 @@ from .colour import Colour
from .object import Object
__all__ = (
'VoiceState',
'Member',
"VoiceState",
"Member",
)
if TYPE_CHECKING:
@ -113,52 +125,58 @@ class VoiceState:
"""
__slots__ = (
'session_id',
'deaf',
'mute',
'self_mute',
'self_stream',
'self_video',
'self_deaf',
'afk',
'channel',
'requested_to_speak_at',
'suppress',
"session_id",
"deaf",
"mute",
"self_mute",
"self_stream",
"self_video",
"self_deaf",
"afk",
"channel",
"requested_to_speak_at",
"suppress",
)
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get('session_id')
def __init__(
self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None
):
self.session_id: str = data.get("session_id")
self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
self.self_video: bool = data.get('self_video', False)
self.afk: bool = data.get('suppress', False)
self.mute: bool = data.get('mute', False)
self.deaf: bool = data.get('deaf', False)
self.suppress: bool = data.get('suppress', False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp'))
self.self_mute: bool = data.get("self_mute", False)
self.self_deaf: bool = data.get("self_deaf", False)
self.self_stream: bool = data.get("self_stream", False)
self.self_video: bool = data.get("self_video", False)
self.afk: bool = data.get("suppress", False)
self.mute: bool = data.get("mute", False)
self.deaf: bool = data.get("deaf", False)
self.suppress: bool = data.get("suppress", False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(
data.get("request_to_speak_timestamp")
)
self.channel: Optional[VocalGuildChannel] = channel
def __repr__(self) -> str:
attrs = [
('self_mute', self.self_mute),
('self_deaf', self.self_deaf),
('self_stream', self.self_stream),
('suppress', self.suppress),
('requested_to_speak_at', self.requested_to_speak_at),
('channel', self.channel),
("self_mute", self.self_mute),
("self_deaf", self.self_deaf),
("self_stream", self.self_stream),
("suppress", self.suppress),
("requested_to_speak_at", self.requested_to_speak_at),
("channel", self.channel),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
inner = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
for attr, value in itertools.chain(
BaseUser.__dict__.items(), User.__dict__.items()
):
# ignore private/special methods
if attr.startswith('_'):
if attr.startswith("_"):
continue
# don't override what we already have
@ -167,9 +185,11 @@ def flatten_user(cls):
# if it's a slotted attribute or a property, redirect it
# slotted members are implemented as member_descriptors in Type.__dict__
if not hasattr(value, '__annotations__'):
getter = attrgetter('_user.' + attr)
setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`'))
if not hasattr(value, "__annotations__"):
getter = attrgetter("_user." + attr)
setattr(
cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`")
)
else:
# Technically, this can also use attrgetter
# However I'm not sure how I feel about "functions" returning properties
@ -197,7 +217,7 @@ def flatten_user(cls):
return cls
M = TypeVar('M', bound='Member')
M = TypeVar("M", bound="Member")
@flatten_user
@ -258,17 +278,17 @@ class Member(discord.abc.Messageable, _UserTag):
"""
__slots__ = (
'_roles',
'joined_at',
'premium_since',
'activities',
'guild',
'pending',
'nick',
'_client_status',
'_user',
'_state',
'_avatar',
"_roles",
"joined_at",
"premium_since",
"activities",
"guild",
"pending",
"nick",
"_client_status",
"_user",
"_state",
"_avatar",
)
if TYPE_CHECKING:
@ -288,18 +308,24 @@ class Member(discord.abc.Messageable, _UserTag):
accent_color: Optional[Colour]
accent_colour: Optional[Colour]
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
def __init__(
self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState
):
self._state: ConnectionState = state
self._user: User = state.store_user(data['user'])
self._user: User = state.store_user(data["user"])
self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
self._client_status: Dict[Optional[str], str] = {None: 'offline'}
self.joined_at: Optional[datetime.datetime] = utils.parse_time(
data.get("joined_at")
)
self.premium_since: Optional[datetime.datetime] = utils.parse_time(
data.get("premium_since")
)
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"]))
self._client_status: Dict[Optional[str], str] = {None: "offline"}
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
self._avatar: Optional[str] = data.get('avatar')
self.nick: Optional[str] = data.get("nick", None)
self.pending: bool = data.get("pending", False)
self._avatar: Optional[str] = data.get("avatar")
def __str__(self) -> str:
return str(self._user)
@ -309,8 +335,8 @@ class Member(discord.abc.Messageable, _UserTag):
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
f"<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}"
f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>"
)
def __eq__(self, other: Any) -> bool:
@ -325,25 +351,31 @@ class Member(discord.abc.Messageable, _UserTag):
@classmethod
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
author = message.author
data['user'] = author._to_minimal_user_json() # type: ignore
data["user"] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
self.joined_at = utils.parse_time(data.get("joined_at"))
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self.nick = data.get("nick", None)
self.pending = data.get("pending", False)
@classmethod
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]:
def _try_upgrade(
cls: Type[M],
*,
data: UserWithMemberPayload,
guild: Guild,
state: ConnectionState,
) -> Union[User, M]:
# A User object with a 'member' key
try:
member_data = data.pop('member')
member_data = data.pop("member")
except KeyError:
return state.create_user(data)
else:
member_data['user'] = data # type: ignore
member_data["user"] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore
@classmethod
@ -374,25 +406,27 @@ class Member(discord.abc.Messageable, _UserTag):
# the nickname change is optional,
# if it isn't in the payload then it didn't change
try:
self.nick = data['nick']
self.nick = data["nick"]
except KeyError:
pass
try:
self.pending = data['pending']
self.pending = data["pending"]
except KeyError:
pass
self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self._avatar = data.get('avatar')
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self._avatar = data.get("avatar")
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data['activities']))
def _presence_update(
self, data: PartialPresenceUpdate, user: UserPayload
) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data["activities"]))
self._client_status = {
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore
sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore
}
self._client_status[None] = sys.intern(data['status'])
self._client_status[None] = sys.intern(data["status"])
if len(user) > 1:
return self._update_inner_user(user)
@ -402,7 +436,12 @@ class Member(discord.abc.Messageable, _UserTag):
u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags)
# These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
modified = (
user["username"],
user["avatar"],
user["discriminator"],
user.get("public_flags", 0),
)
if original != modified:
to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified
@ -430,21 +469,21 @@ class Member(discord.abc.Messageable, _UserTag):
@property
def mobile_status(self) -> Status:
""":class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get('mobile', 'offline'))
return try_enum(Status, self._client_status.get("mobile", "offline"))
@property
def desktop_status(self) -> Status:
""":class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get('desktop', 'offline'))
return try_enum(Status, self._client_status.get("desktop", "offline"))
@property
def web_status(self) -> Status:
""":class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get('web', 'offline'))
return try_enum(Status, self._client_status.get("web", "offline"))
def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
return 'mobile' in self._client_status
return "mobile" in self._client_status
@property
def colour(self) -> Colour:
@ -497,8 +536,8 @@ class Member(discord.abc.Messageable, _UserTag):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the member."""
if self.nick:
return f'<@!{self._user.id}>'
return f'<@{self._user.id}>'
return f"<@!{self._user.id}>"
return f"<@{self._user.id}>"
@property
def display_name(self) -> str:
@ -531,7 +570,9 @@ class Member(discord.abc.Messageable, _UserTag):
"""
if self._avatar is None:
return None
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
return Asset._from_guild_avatar(
self._state, self.guild.id, self.id, self._avatar
)
@property
def activity(self) -> Optional[ActivityTypes]:
@ -625,7 +666,9 @@ class Member(discord.abc.Messageable, _UserTag):
Bans this member. Equivalent to :meth:`Guild.ban`.
"""
await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days)
await self.guild.ban(
self, reason=reason, delete_message_days=delete_message_days
)
async def unban(self, *, reason: Optional[str] = None) -> None:
"""|coro|
@ -720,39 +763,41 @@ class Member(discord.abc.Messageable, _UserTag):
payload: Dict[str, Any] = {}
if nick is not MISSING:
nick = nick or ''
nick = nick or ""
if me:
await http.change_my_nickname(guild_id, nick, reason=reason)
else:
payload['nick'] = nick
payload["nick"] = nick
if deafen is not MISSING:
payload['deaf'] = deafen
payload["deaf"] = deafen
if mute is not MISSING:
payload['mute'] = mute
payload["mute"] = mute
if suppress is not MISSING:
voice_state_payload = {
'channel_id': self.voice.channel.id,
'suppress': suppress,
"channel_id": self.voice.channel.id,
"suppress": suppress,
}
if suppress or self.bot:
voice_state_payload['request_to_speak_timestamp'] = None
voice_state_payload["request_to_speak_timestamp"] = None
if me:
await http.edit_my_voice_state(guild_id, voice_state_payload)
else:
if not suppress:
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat()
voice_state_payload[
"request_to_speak_timestamp"
] = datetime.datetime.utcnow().isoformat()
await http.edit_voice_state(guild_id, self.id, voice_state_payload)
if voice_channel is not MISSING:
payload['channel_id'] = voice_channel and voice_channel.id
payload["channel_id"] = voice_channel and voice_channel.id
if roles is not MISSING:
payload['roles'] = tuple(r.id for r in roles)
payload["roles"] = tuple(r.id for r in roles)
if payload:
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
@ -780,17 +825,19 @@ class Member(discord.abc.Messageable, _UserTag):
The operation failed.
"""
payload = {
'channel_id': self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),
"channel_id": self.voice.channel.id,
"request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(),
}
if self._state.self_id != self.id:
payload['suppress'] = False
payload["suppress"] = False
await self._state.http.edit_voice_state(self.guild.id, self.id, payload)
else:
await self._state.http.edit_my_voice_state(self.guild.id, payload)
async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None:
async def move_to(
self, channel: VocalGuildChannel, *, reason: Optional[str] = None
) -> None:
"""|coro|
Moves a member to a new voice channel (they must be connected first).
@ -813,7 +860,9 @@ class Member(discord.abc.Messageable, _UserTag):
"""
await self.edit(voice_channel=channel, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
async def add_roles(
self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True
) -> None:
r"""|coro|
Gives the member a number of :class:`Role`\s.
@ -843,7 +892,9 @@ class Member(discord.abc.Messageable, _UserTag):
"""
if not atomic:
new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s)
new_roles = utils._unique(
Object(id=r.id) for s in (self.roles[1:], roles) for r in s
)
await self.edit(roles=new_roles, reason=reason)
else:
req = self._state.http.add_role
@ -852,7 +903,9 @@ class Member(discord.abc.Messageable, _UserTag):
for role in roles:
await req(guild_id, user_id, role.id, reason=reason)
async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
async def remove_roles(
self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True
) -> None:
r"""|coro|
Removes :class:`Role`\s from this member.

View File

@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
__all__ = (
'AllowedMentions',
)
__all__ = ("AllowedMentions",)
if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload
@ -36,7 +34,7 @@ if TYPE_CHECKING:
class _FakeBool:
def __repr__(self):
return 'True'
return "True"
def __eq__(self, other):
return other is True
@ -47,7 +45,7 @@ class _FakeBool:
default: Any = _FakeBool()
A = TypeVar('A', bound='AllowedMentions')
A = TypeVar("A", bound="AllowedMentions")
class AllowedMentions:
@ -80,7 +78,7 @@ class AllowedMentions:
.. versionadded:: 1.6
"""
__slots__ = ('everyone', 'users', 'roles', 'replied_user')
__slots__ = ("everyone", "users", "roles", "replied_user")
def __init__(
self,
@ -116,22 +114,22 @@ class AllowedMentions:
data = {}
if self.everyone:
parse.append('everyone')
parse.append("everyone")
if self.users == True:
parse.append('users')
parse.append("users")
elif self.users != False:
data['users'] = [x.id for x in self.users]
data["users"] = [x.id for x in self.users]
if self.roles == True:
parse.append('roles')
parse.append("roles")
elif self.roles != False:
data['roles'] = [x.id for x in self.roles]
data["roles"] = [x.id for x in self.roles]
if self.replied_user:
data['replied_user'] = True
data["replied_user"] = True
data['parse'] = parse
data["parse"] = parse
return data # type: ignore
def merge(self, other: AllowedMentions) -> AllowedMentions:
@ -141,11 +139,15 @@ class AllowedMentions:
everyone = self.everyone if other.everyone is default else other.everyone
users = self.users if other.users is default else other.users
roles = self.roles if other.roles is default else other.roles
replied_user = self.replied_user if other.replied_user is default else other.replied_user
return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user)
replied_user = (
self.replied_user if other.replied_user is default else other.replied_user
)
return AllowedMentions(
everyone=everyone, roles=roles, users=users, replied_user=replied_user
)
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(everyone={self.everyone}, '
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})'
f"{self.__class__.__name__}(everyone={self.everyone}, "
f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})"
)

View File

@ -29,7 +29,21 @@ import datetime
import re
import io
from os import PathLike
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
from typing import (
Dict,
TYPE_CHECKING,
Union,
List,
Optional,
Any,
Callable,
Tuple,
ClassVar,
Optional,
overload,
TypeVar,
Type,
)
from . import utils
from .reaction import Reaction
@ -76,15 +90,15 @@ if TYPE_CHECKING:
from .role import Role
from .ui.view import View
MR = TypeVar('MR', bound='MessageReference')
MR = TypeVar("MR", bound="MessageReference")
EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = (
'Attachment',
'Message',
'PartialMessage',
'MessageReference',
'DeletedReferencedMessage',
"Attachment",
"Message",
"PartialMessage",
"MessageReference",
"DeletedReferencedMessage",
)
@ -93,15 +107,17 @@ def convert_emoji_reaction(emoji):
emoji = emoji.emoji
if isinstance(emoji, Emoji):
return f'{emoji.name}:{emoji.id}'
return f"{emoji.name}:{emoji.id}"
if isinstance(emoji, PartialEmoji):
return emoji._as_reaction()
if isinstance(emoji, str):
# Reactions can be in :name:id format, but not <:name:id>.
# No existing emojis have <> in them, so this should be okay.
return emoji.strip('<>')
return emoji.strip("<>")
raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.')
raise InvalidArgument(
f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}."
)
class Attachment(Hashable):
@ -157,28 +173,38 @@ class Attachment(Hashable):
.. versionadded:: 1.7
"""
__slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type')
__slots__ = (
"id",
"size",
"height",
"width",
"filename",
"url",
"proxy_url",
"_http",
"content_type",
)
def __init__(self, *, data: AttachmentPayload, state: ConnectionState):
self.id: int = int(data['id'])
self.size: int = data['size']
self.height: Optional[int] = data.get('height')
self.width: Optional[int] = data.get('width')
self.filename: str = data['filename']
self.url: str = data.get('url')
self.proxy_url: str = data.get('proxy_url')
self.id: int = int(data["id"])
self.size: int = data["size"]
self.height: Optional[int] = data.get("height")
self.width: Optional[int] = data.get("width")
self.filename: str = data["filename"]
self.url: str = data.get("url")
self.proxy_url: str = data.get("proxy_url")
self._http = state.http
self.content_type: Optional[str] = data.get('content_type')
self.content_type: Optional[str] = data.get("content_type")
def is_spoiler(self) -> bool:
""":class:`bool`: Whether this attachment contains a spoiler."""
return self.filename.startswith('SPOILER_')
return self.filename.startswith("SPOILER_")
def __repr__(self) -> str:
return f'<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>'
return f"<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>"
def __str__(self) -> str:
return self.url or ''
return self.url or ""
async def save(
self,
@ -227,7 +253,7 @@ class Attachment(Hashable):
fp.seek(0)
return written
else:
with open(fp, 'wb') as f:
with open(fp, "wb") as f:
return f.write(data)
async def read(self, *, use_cached: bool = False) -> bytes:
@ -309,19 +335,19 @@ class Attachment(Hashable):
def to_dict(self) -> AttachmentPayload:
result: AttachmentPayload = {
'filename': self.filename,
'id': self.id,
'proxy_url': self.proxy_url,
'size': self.size,
'url': self.url,
'spoiler': self.is_spoiler(),
"filename": self.filename,
"id": self.id,
"proxy_url": self.proxy_url,
"size": self.size,
"url": self.url,
"spoiler": self.is_spoiler(),
}
if self.height:
result['height'] = self.height
result["height"] = self.height
if self.width:
result['width'] = self.width
result["width"] = self.width
if self.content_type:
result['content_type'] = self.content_type
result["content_type"] = self.content_type
return result
@ -335,7 +361,7 @@ class DeletedReferencedMessage:
.. versionadded:: 1.6
"""
__slots__ = ('_parent',)
__slots__ = ("_parent",)
def __init__(self, parent: MessageReference):
self._parent: MessageReference = parent
@ -347,7 +373,7 @@ class DeletedReferencedMessage:
def id(self) -> int:
""":class:`int`: The message ID of the deleted referenced message."""
# the parent's message id won't be None here
return self._parent.message_id # type: ignore
return self._parent.message_id # type: ignore
@property
def channel_id(self) -> int:
@ -394,9 +420,23 @@ class MessageReference:
.. versionadded:: 1.6
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state')
__slots__ = (
"message_id",
"channel_id",
"guild_id",
"fail_if_not_exists",
"resolved",
"_state",
)
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
def __init__(
self,
*,
message_id: int,
channel_id: int,
guild_id: Optional[int] = None,
fail_if_not_exists: bool = True,
):
self._state: Optional[ConnectionState] = None
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
self.message_id: Optional[int] = message_id
@ -405,18 +445,22 @@ class MessageReference:
self.fail_if_not_exists: bool = fail_if_not_exists
@classmethod
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
def with_state(
cls: Type[MR], state: ConnectionState, data: MessageReferencePayload
) -> MR:
self = cls.__new__(cls)
self.message_id = utils._get_as_snowflake(data, 'message_id')
self.channel_id = int(data.pop('channel_id'))
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.fail_if_not_exists = data.get('fail_if_not_exists', True)
self.message_id = utils._get_as_snowflake(data, "message_id")
self.channel_id = int(data.pop("channel_id"))
self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.fail_if_not_exists = data.get("fail_if_not_exists", True)
self._state = state
self.resolved = None
return self
@classmethod
def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR:
def from_message(
cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True
) -> MR:
"""Creates a :class:`MessageReference` from an existing :class:`~discord.Message`.
.. versionadded:: 1.6
@ -439,7 +483,7 @@ class MessageReference:
self = cls(
message_id=message.id,
channel_id=message.channel.id,
guild_id=getattr(message.guild, 'id', None),
guild_id=getattr(message.guild, "id", None),
fail_if_not_exists=fail_if_not_exists,
)
self._state = message._state
@ -456,36 +500,38 @@ class MessageReference:
.. versionadded:: 1.7
"""
guild_id = self.guild_id if self.guild_id is not None else '@me'
return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}'
guild_id = self.guild_id if self.guild_id is not None else "@me"
return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}"
def __repr__(self) -> str:
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
return f"<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>"
def to_dict(self) -> MessageReferencePayload:
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
result['channel_id'] = self.channel_id
result: MessageReferencePayload = (
{"message_id": self.message_id} if self.message_id is not None else {}
)
result["channel_id"] = self.channel_id
if self.guild_id is not None:
result['guild_id'] = self.guild_id
result["guild_id"] = self.guild_id
if self.fail_if_not_exists is not None:
result['fail_if_not_exists'] = self.fail_if_not_exists
result["fail_if_not_exists"] = self.fail_if_not_exists
return result
to_message_reference_dict = to_dict
def flatten_handlers(cls):
prefix = len('_handle_')
prefix = len("_handle_")
handlers = [
(key[prefix:], value)
for key, value in cls.__dict__.items()
if key.startswith('_handle_') and key != '_handle_member'
if key.startswith("_handle_") and key != "_handle_member"
]
# store _handle_member last
handlers.append(('member', cls._handle_member))
handlers.append(("member", cls._handle_member))
cls._HANDLERS = handlers
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')]
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")]
return cls
@ -615,36 +661,36 @@ class Message(Hashable):
"""
__slots__ = (
'_state',
'_edited_timestamp',
'_cs_channel_mentions',
'_cs_raw_mentions',
'_cs_clean_content',
'_cs_raw_channel_mentions',
'_cs_raw_role_mentions',
'_cs_system_content',
'tts',
'content',
'channel',
'webhook_id',
'mention_everyone',
'embeds',
'id',
'mentions',
'author',
'attachments',
'nonce',
'pinned',
'role_mentions',
'type',
'flags',
'reactions',
'reference',
'application',
'activity',
'stickers',
'components',
'guild',
"_state",
"_edited_timestamp",
"_cs_channel_mentions",
"_cs_raw_mentions",
"_cs_clean_content",
"_cs_raw_channel_mentions",
"_cs_raw_role_mentions",
"_cs_system_content",
"tts",
"content",
"channel",
"webhook_id",
"mention_everyone",
"embeds",
"id",
"mentions",
"author",
"attachments",
"nonce",
"pinned",
"role_mentions",
"type",
"flags",
"reactions",
"reference",
"application",
"activity",
"stickers",
"components",
"guild",
)
if TYPE_CHECKING:
@ -664,39 +710,49 @@ class Message(Hashable):
data: MessagePayload,
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id')
self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])]
self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']]
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']]
self.application: Optional[MessageApplicationPayload] = data.get('application')
self.activity: Optional[MessageActivityPayload] = data.get('activity')
self.id: int = int(data["id"])
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id")
self.reactions: List[Reaction] = [
Reaction(message=self, data=d) for d in data.get("reactions", [])
]
self.attachments: List[Attachment] = [
Attachment(data=a, state=self._state) for a in data["attachments"]
]
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]]
self.application: Optional[MessageApplicationPayload] = data.get("application")
self.activity: Optional[MessageActivityPayload] = data.get("activity")
self.channel: MessageableChannel = channel
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp'])
self.type: MessageType = try_enum(MessageType, data['type'])
self.pinned: bool = data['pinned']
self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0))
self.mention_everyone: bool = data['mention_everyone']
self.tts: bool = data['tts']
self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(
data["edited_timestamp"]
)
self.type: MessageType = try_enum(MessageType, data["type"])
self.pinned: bool = data["pinned"]
self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0))
self.mention_everyone: bool = data["mention_everyone"]
self.tts: bool = data["tts"]
self.content: str = data["content"]
self.nonce: Optional[Union[int, str]] = data.get("nonce")
self.stickers: List[StickerItem] = [
StickerItem(data=d, state=state) for d in data.get("sticker_items", [])
]
self.components: List[Component] = [
_component_factory(d) for d in data.get("components", [])
]
try:
# if the channel doesn't have a guild attribute, we handle that
self.guild = channel.guild # type: ignore
except AttributeError:
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id"))
try:
ref = data['message_reference']
ref = data["message_reference"]
except KeyError:
self.reference = None
else:
self.reference = ref = MessageReference.with_state(state, ref)
try:
resolved = data['referenced_message']
resolved = data["referenced_message"]
except KeyError:
pass
else:
@ -712,18 +768,15 @@ class Message(Hashable):
# the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'):
for handler in ("author", "member", "mentions", "mention_roles"):
try:
getattr(self, f'_handle_{handler}')(data[handler])
getattr(self, f"_handle_{handler}")(data[handler])
except KeyError:
continue
def __repr__(self) -> str:
name = self.__class__.__name__
return (
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
)
return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>"
def __str__(self) -> Optional[str]:
return self.content
@ -741,7 +794,7 @@ class Message(Hashable):
def _add_reaction(self, data, emoji, user_id) -> Reaction:
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
is_me = data['me'] = user_id == self._state.self_id
is_me = data["me"] = user_id == self._state.self_id
if reaction is None:
reaction = Reaction(message=self, data=data, emoji=emoji)
@ -753,12 +806,14 @@ class Message(Hashable):
return reaction
def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction:
def _remove_reaction(
self, data: ReactionPayload, emoji: EmojiInputType, user_id: int
) -> Reaction:
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
if reaction is None:
# already removed?
raise ValueError('Emoji already removed?')
raise ValueError("Emoji already removed?")
# if reaction isn't in the list, we crash. This means discord
# sent bad data, or we stored improperly
@ -872,7 +927,7 @@ class Message(Hashable):
return
for mention in filter(None, mentions):
id_search = int(mention['id'])
id_search = int(mention["id"])
member = guild.get_member(id_search)
if member is not None:
r.append(member)
@ -890,11 +945,13 @@ class Message(Hashable):
def _handle_components(self, components: List[ComponentPayload]):
self.components = [_component_factory(d) for d in components]
def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None:
def _rebind_cached_references(
self, new_guild: Guild, new_channel: Union[TextChannel, Thread]
) -> None:
self.guild = new_guild
self.channel = new_channel
@utils.cached_slot_property('_cs_raw_mentions')
@utils.cached_slot_property("_cs_raw_mentions")
def raw_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of user IDs matched with
the syntax of ``<@user_id>`` in the message content.
@ -902,30 +959,30 @@ class Message(Hashable):
This allows you to receive the user IDs of mentioned users
even in a private message context.
"""
return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_channel_mentions')
@utils.cached_slot_property("_cs_raw_channel_mentions")
def raw_channel_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of channel IDs matched with
the syntax of ``<#channel_id>`` in the message content.
"""
return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_role_mentions')
@utils.cached_slot_property("_cs_raw_role_mentions")
def raw_role_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of role IDs matched with
the syntax of ``<@&role_id>`` in the message content.
"""
return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_channel_mentions')
@utils.cached_slot_property("_cs_channel_mentions")
def channel_mentions(self) -> List[GuildChannel]:
if self.guild is None:
return []
it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions))
return utils._unique(it)
@utils.cached_slot_property('_cs_clean_content')
@utils.cached_slot_property("_cs_clean_content")
def clean_content(self) -> str:
""":class:`str`: A property that returns the content in a "cleaned up"
manner. This basically means that mentions are transformed
@ -972,9 +1029,9 @@ class Message(Hashable):
# fmt: on
def repl(obj):
return transformations.get(re.escape(obj.group(0)), '')
return transformations.get(re.escape(obj.group(0)), "")
pattern = re.compile('|'.join(transformations.keys()))
pattern = re.compile("|".join(transformations.keys()))
result = pattern.sub(repl, self.content)
return escape_mentions(result)
@ -991,8 +1048,8 @@ class Message(Hashable):
@property
def jump_url(self) -> str:
""":class:`str`: Returns a URL that allows the client to jump to this message."""
guild_id = getattr(self.guild, 'id', '@me')
return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}'
guild_id = getattr(self.guild, "id", "@me")
return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}"
def is_system(self) -> bool:
""":class:`bool`: Whether the message is a system message.
@ -1009,7 +1066,7 @@ class Message(Hashable):
MessageType.thread_starter_message,
)
@utils.cached_slot_property('_cs_system_content')
@utils.cached_slot_property("_cs_system_content")
def system_content(self):
r""":class:`str`: A property that returns the content that is rendered
regardless of the :attr:`Message.type`.
@ -1024,24 +1081,26 @@ class Message(Hashable):
if self.type is MessageType.recipient_add:
if self.channel.type is ChannelType.group:
return f'{self.author.name} added {self.mentions[0].name} to the group.'
return f"{self.author.name} added {self.mentions[0].name} to the group."
else:
return f'{self.author.name} added {self.mentions[0].name} to the thread.'
return (
f"{self.author.name} added {self.mentions[0].name} to the thread."
)
if self.type is MessageType.recipient_remove:
if self.channel.type is ChannelType.group:
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
return f"{self.author.name} removed {self.mentions[0].name} from the group."
else:
return f'{self.author.name} removed {self.mentions[0].name} from the thread.'
return f"{self.author.name} removed {self.mentions[0].name} from the thread."
if self.type is MessageType.channel_name_change:
return f'{self.author.name} changed the channel name: **{self.content}**'
return f"{self.author.name} changed the channel name: **{self.content}**"
if self.type is MessageType.channel_icon_change:
return f'{self.author.name} changed the channel icon.'
return f"{self.author.name} changed the channel icon."
if self.type is MessageType.pins_add:
return f'{self.author.name} pinned a message to this channel.'
return f"{self.author.name} pinned a message to this channel."
if self.type is MessageType.new_member:
formats = [
@ -1065,64 +1124,66 @@ class Message(Hashable):
if self.type is MessageType.premium_guild_subscription:
if not self.content:
return f'{self.author.name} just boosted the server!'
return f"{self.author.name} just boosted the server!"
else:
return f'{self.author.name} just boosted the server **{self.content}** times!'
return f"{self.author.name} just boosted the server **{self.content}** times!"
if self.type is MessageType.premium_guild_tier_1:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**"
if self.type is MessageType.premium_guild_tier_2:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**"
if self.type is MessageType.premium_guild_tier_3:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**"
if self.type is MessageType.channel_follow_add:
return f'{self.author.name} has added {self.content} to this channel'
return f"{self.author.name} has added {self.content} to this channel"
if self.type is MessageType.guild_stream:
# the author will be a Member
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore
if self.type is MessageType.guild_discovery_disqualified:
return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.'
return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details."
if self.type is MessageType.guild_discovery_requalified:
return 'This server is eligible for Server Discovery again and has been automatically relisted!'
return "This server is eligible for Server Discovery again and has been automatically relisted!"
if self.type is MessageType.guild_discovery_grace_period_initial_warning:
return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.'
return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery."
if self.type is MessageType.guild_discovery_grace_period_final_warning:
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.'
return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery."
if self.type is MessageType.thread_created:
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.'
return f"{self.author.name} started a thread: **{self.content}**. See all **threads**."
if self.type is MessageType.reply:
return self.content
if self.type is MessageType.thread_starter_message:
if self.reference is None or self.reference.resolved is None:
return 'Sorry, we couldn\'t load the first message in this thread'
return "Sorry, we couldn't load the first message in this thread"
# the resolved message for the reference will be a Message
return self.reference.resolved.content # type: ignore
if self.type is MessageType.guild_invite_reminder:
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!"
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
async def delete(
self, *, delay: Optional[float] = None, silent: bool = False
) -> None:
"""|coro|
Deletes the message.
@ -1271,45 +1332,52 @@ class Message(Hashable):
payload: Dict[str, Any] = {}
if content is not MISSING:
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
else:
payload['content'] = None
payload["content"] = None
if embed is not MISSING and embeds is not MISSING:
raise InvalidArgument('cannot pass both embed and embeds parameter to edit()')
raise InvalidArgument(
"cannot pass both embed and embeds parameter to edit()"
)
if embed is not MISSING:
if embed is None:
payload['embeds'] = []
payload["embeds"] = []
else:
payload['embeds'] = [embed.to_dict()]
payload["embeds"] = [embed.to_dict()]
elif embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds]
payload["embeds"] = [e.to_dict() for e in embeds]
if suppress is not MISSING:
flags = MessageFlags._from_value(self.flags.value)
flags.suppress_embeds = suppress
payload['flags'] = flags.value
payload["flags"] = flags.value
if allowed_mentions is MISSING:
if self._state.allowed_mentions is not None and self.author.id == self._state.self_id:
payload['allowed_mentions'] = self._state.allowed_mentions.to_dict()
if (
self._state.allowed_mentions is not None
and self.author.id == self._state.self_id
):
payload["allowed_mentions"] = self._state.allowed_mentions.to_dict()
else:
if allowed_mentions is not None:
if self._state.allowed_mentions is not None:
payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
payload["allowed_mentions"] = self._state.allowed_mentions.merge(
allowed_mentions
).to_dict()
else:
payload['allowed_mentions'] = allowed_mentions.to_dict()
payload["allowed_mentions"] = allowed_mentions.to_dict()
if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments]
payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING:
self._state.prevent_view_updates_for(self.id)
if view:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
else:
payload['components'] = []
payload["components"] = []
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data)
@ -1430,7 +1498,9 @@ class Message(Hashable):
emoji = convert_emoji_reaction(emoji)
await self._state.http.add_reaction(self.channel.id, self.id, emoji)
async def remove_reaction(self, emoji: Union[EmojiInputType, Reaction], member: Snowflake) -> None:
async def remove_reaction(
self, emoji: Union[EmojiInputType, Reaction], member: Snowflake
) -> None:
"""|coro|
Remove a reaction by the member from the message.
@ -1467,7 +1537,9 @@ class Message(Hashable):
if member.id == self._state.self_id:
await self._state.http.remove_own_reaction(self.channel.id, self.id, emoji)
else:
await self._state.http.remove_reaction(self.channel.id, self.id, emoji, member.id)
await self._state.http.remove_reaction(
self.channel.id, self.id, emoji, member.id
)
async def clear_reaction(self, emoji: Union[EmojiInputType, Reaction]) -> None:
"""|coro|
@ -1516,7 +1588,9 @@ class Message(Hashable):
"""
await self._state.http.clear_reactions(self.channel.id, self.id)
async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread:
async def create_thread(
self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING
) -> Thread:
"""|coro|
Creates a public thread from this message.
@ -1551,14 +1625,17 @@ class Message(Hashable):
The created thread.
"""
if self.guild is None:
raise InvalidArgument('This message does not have guild info attached.')
raise InvalidArgument("This message does not have guild info attached.")
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
default_auto_archive_duration: ThreadArchiveDuration = getattr(
self.channel, "default_auto_archive_duration", 1440
)
data = await self._state.http.start_thread_with_message(
self.channel.id,
self.id,
name=name,
auto_archive_duration=auto_archive_duration or default_auto_archive_duration,
auto_archive_duration=auto_archive_duration
or default_auto_archive_duration,
)
return Thread(guild=self.guild, state=self._state, data=data)
@ -1607,16 +1684,18 @@ class Message(Hashable):
The reference to this message.
"""
return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists)
return MessageReference.from_message(
self, fail_if_not_exists=fail_if_not_exists
)
def to_message_reference_dict(self) -> MessageReferencePayload:
data: MessageReferencePayload = {
'message_id': self.id,
'channel_id': self.channel.id,
"message_id": self.id,
"channel_id": self.channel.id,
}
if self.guild is not None:
data['guild_id'] = self.guild.id
data["guild_id"] = self.guild.id
return data
@ -1662,7 +1741,7 @@ class PartialMessage(Hashable):
The message ID.
"""
__slots__ = ('channel', 'id', '_cs_guild', '_state')
__slots__ = ("channel", "id", "_cs_guild", "_state")
jump_url: str = Message.jump_url # type: ignore
delete = Message.delete
@ -1686,7 +1765,9 @@ class PartialMessage(Hashable):
ChannelType.public_thread,
ChannelType.private_thread,
):
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}')
raise TypeError(
f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}"
)
self.channel: PartialMessageableChannel = channel
self._state: ConnectionState = channel._state
@ -1702,17 +1783,17 @@ class PartialMessage(Hashable):
pinned = property(None, lambda x, y: None)
def __repr__(self) -> str:
return f'<PartialMessage id={self.id} channel={self.channel!r}>'
return f"<PartialMessage id={self.id} channel={self.channel!r}>"
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: The partial message's creation time in UTC."""
return utils.snowflake_time(self.id)
@utils.cached_slot_property('_cs_guild')
@utils.cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable."""
return getattr(self.channel, 'guild', None)
return getattr(self.channel, "guild", None)
async def fetch(self) -> Message:
"""|coro|
@ -1794,58 +1875,62 @@ class PartialMessage(Hashable):
"""
try:
content = fields['content']
content = fields["content"]
except KeyError:
pass
else:
if content is not None:
fields['content'] = str(content)
fields["content"] = str(content)
try:
embed = fields['embed']
embed = fields["embed"]
except KeyError:
pass
else:
if embed is not None:
fields['embed'] = embed.to_dict()
fields["embed"] = embed.to_dict()
try:
suppress: bool = fields.pop('suppress')
suppress: bool = fields.pop("suppress")
except KeyError:
pass
else:
flags = MessageFlags._from_value(0)
flags.suppress_embeds = suppress
fields['flags'] = flags.value
fields["flags"] = flags.value
delete_after = fields.pop('delete_after', None)
delete_after = fields.pop("delete_after", None)
try:
allowed_mentions = fields.pop('allowed_mentions')
allowed_mentions = fields.pop("allowed_mentions")
except KeyError:
pass
else:
if allowed_mentions is not None:
if self._state.allowed_mentions is not None:
allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
allowed_mentions = self._state.allowed_mentions.merge(
allowed_mentions
).to_dict()
else:
allowed_mentions = allowed_mentions.to_dict()
fields['allowed_mentions'] = allowed_mentions
fields["allowed_mentions"] = allowed_mentions
try:
view = fields.pop('view')
view = fields.pop("view")
except KeyError:
# To check for the view afterwards
view = None
else:
self._state.prevent_view_updates_for(self.id)
if view:
fields['components'] = view.to_components()
fields["components"] = view.to_components()
else:
fields['components'] = []
fields["components"] = []
if fields:
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
data = await self._state.http.edit_message(
self.channel.id, self.id, **fields
)
if delete_after is not None:
await self.delete(delay=delete_after)

View File

@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE.
"""
__all__ = (
'EqualityComparable',
'Hashable',
"EqualityComparable",
"Hashable",
)
class EqualityComparable:
__slots__ = ()
@ -40,6 +41,7 @@ class EqualityComparable:
return other.id != self.id
return True
class Hashable(EqualityComparable):
__slots__ = ()

View File

@ -35,11 +35,11 @@ from typing import (
if TYPE_CHECKING:
import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
__all__ = (
'Object',
)
__all__ = ("Object",)
class Object(Hashable):
"""Represents a generic Discord object.
@ -83,12 +83,14 @@ class Object(Hashable):
try:
id = int(id)
except ValueError:
raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None
raise TypeError(
f"id parameter must be convertable to int not {id.__class__!r}"
) from None
else:
self.id = id
def __repr__(self) -> str:
return f'<Object id={self.id!r}>'
return f"<Object id={self.id!r}>"
@property
def created_at(self) -> datetime.datetime:

View File

@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException
__all__ = (
'OggError',
'OggPage',
'OggStream',
"OggError",
"OggPage",
"OggStream",
)
class OggError(DiscordException):
"""An exception that is thrown for Ogg stream parsing errors."""
pass
# https://tools.ietf.org/html/rfc3533
# https://tools.ietf.org/html/rfc7845
class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
_header: ClassVar[struct.Struct] = struct.Struct("<xBQIIIB")
if TYPE_CHECKING:
flag: int
gran_pos: int
@ -57,14 +61,20 @@ class OggPage:
try:
header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
(
self.flag,
self.gran_pos,
self.serial,
self.pagenum,
self.crc,
self.segnum,
) = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
bodylen = sum(struct.unpack("B" * self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen)
except Exception:
raise OggError('bad data stream') from None
raise OggError("bad data stream") from None
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
packetlen = offset = 0
@ -76,7 +86,7 @@ class OggPage:
partial = True
else:
packetlen += seg
yield self.data[offset:offset+packetlen], True
yield self.data[offset : offset + packetlen], True
offset += packetlen
packetlen = 0
partial = False
@ -84,18 +94,19 @@ class OggPage:
if partial:
yield self.data[offset:], False
class OggStream:
def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream
def _next_page(self) -> Optional[OggPage]:
head = self.stream.read(4)
if head == b'OggS':
if head == b"OggS":
return OggPage(self.stream)
elif not head:
return None
else:
raise OggError('invalid header magic')
raise OggError("invalid header magic")
def _iter_pages(self) -> Generator[OggPage, None, None]:
page = self._next_page()
@ -104,10 +115,10 @@ class OggStream:
page = self._next_page()
def iter_packets(self) -> Generator[bytes, None, None]:
partial = b''
partial = b""
for page in self._iter_pages():
for data, complete in page.iter_packets():
partial += data
if complete:
yield partial
partial = b''
partial = b""

View File

@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
from typing import (
List,
Tuple,
TypedDict,
Any,
TYPE_CHECKING,
Callable,
TypeVar,
Literal,
Optional,
overload,
)
import array
import ctypes
@ -38,9 +49,10 @@ import sys
from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING:
T = TypeVar('T')
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
T = TypeVar("T")
BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"]
SIGNAL_CTL = Literal["auto", "voice", "music"]
class BandCtl(TypedDict):
narrow: int
@ -49,81 +61,89 @@ class BandCtl(TypedDict):
superwide: int
full: int
class SignalCtl(TypedDict):
auto: int
voice: int
music: int
__all__ = (
'Encoder',
'OpusError',
'OpusNotLoaded',
"Encoder",
"OpusError",
"OpusNotLoaded",
)
_log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float)
_lib = None
class EncoderStruct(ctypes.Structure):
pass
class DecoderStruct(ctypes.Structure):
pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct)
## Some constants from opus_defines.h
# Error codes
OK = 0
OK = 0
BAD_ARG = -1
# Encoder CTLs
APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048
APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048
APPLICATION_LOWDELAY = 2051
CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008
CTL_SET_FEC = 4012
CTL_SET_PLP = 4014
CTL_SET_SIGNAL = 4024
CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008
CTL_SET_FEC = 4012
CTL_SET_PLP = 4014
CTL_SET_SIGNAL = 4024
# Decoder CTLs
CTL_SET_GAIN = 4034
CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039
band_ctl: BandCtl = {
'narrow': 1101,
'medium': 1102,
'wide': 1103,
'superwide': 1104,
'full': 1105,
"narrow": 1101,
"medium": 1102,
"wide": 1103,
"superwide": 1104,
"full": 1105,
}
signal_ctl: SignalCtl = {
'auto': -1000,
'voice': 3001,
'music': 3002,
"auto": -1000,
"voice": 3001,
"music": 3002,
}
def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK:
_log.info('error has happened in %s', func.__name__)
_log.info("error has happened in %s", func.__name__)
raise OpusError(result)
return result
def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj
if ret.value != OK:
_log.info('error has happened in %s', func.__name__)
_log.info("error has happened in %s", func.__name__)
raise OpusError(ret.value)
return result
# A list of exported functions.
# The first argument is obviously the name.
# The second one are the types of arguments it takes.
@ -131,54 +151,90 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
# The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [
# Generic
('opus_get_version_string',
None, ctypes.c_char_p, None),
('opus_strerror',
[ctypes.c_int], ctypes.c_char_p, None),
("opus_get_version_string", None, ctypes.c_char_p, None),
("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None),
# Encoder functions
('opus_encoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_encoder_create',
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
('opus_encode',
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encode_float',
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_encoder_destroy',
[EncoderStructPtr], None, None),
("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None),
(
"opus_encoder_create",
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr],
EncoderStructPtr,
_err_ne,
),
(
"opus_encode",
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
(
"opus_encode_float",
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
("opus_encoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_encoder_destroy", [EncoderStructPtr], None, None),
# Decoder functions
('opus_decoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_decoder_create',
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
('opus_decode',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt),
('opus_decode_float',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt),
('opus_decoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_decoder_destroy',
[DecoderStructPtr], None, None),
('opus_decoder_get_nb_samples',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None),
(
"opus_decoder_create",
[ctypes.c_int, ctypes.c_int, c_int_ptr],
DecoderStructPtr,
_err_ne,
),
(
"opus_decode",
[
DecoderStructPtr,
ctypes.c_char_p,
ctypes.c_int32,
c_int16_ptr,
ctypes.c_int,
ctypes.c_int,
],
ctypes.c_int,
_err_lt,
),
(
"opus_decode_float",
[
DecoderStructPtr,
ctypes.c_char_p,
ctypes.c_int32,
c_float_ptr,
ctypes.c_int,
ctypes.c_int,
],
ctypes.c_int,
_err_lt,
),
("opus_decoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_decoder_destroy", [DecoderStructPtr], None, None),
(
"opus_decoder_get_nb_samples",
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int,
_err_lt,
),
# Packet functions
('opus_packet_get_bandwidth',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_channels',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_frames',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_samples_per_frame',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt),
("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt),
(
"opus_packet_get_nb_frames",
[ctypes.c_char_p, ctypes.c_int],
ctypes.c_int,
_err_lt,
),
(
"opus_packet_get_samples_per_frame",
[ctypes.c_char_p, ctypes.c_int],
ctypes.c_int,
_err_lt,
),
]
def libopus_loader(name: str) -> Any:
# create the library...
lib = ctypes.cdll.LoadLibrary(name)
@ -203,22 +259,24 @@ def libopus_loader(name: str) -> Any:
return lib
def _load_default() -> bool:
global _lib
try:
if sys.platform == 'win32':
if sys.platform == "win32":
_basedir = os.path.dirname(os.path.abspath(__file__))
_bitness = struct.calcsize('P') * 8
_target = 'x64' if _bitness > 32 else 'x86'
_filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll')
_bitness = struct.calcsize("P") * 8
_target = "x64" if _bitness > 32 else "x86"
_filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll")
_lib = libopus_loader(_filename)
else:
_lib = libopus_loader(ctypes.util.find_library('opus'))
_lib = libopus_loader(ctypes.util.find_library("opus"))
except Exception:
_lib = None
return _lib is not None
def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice.
@ -257,6 +315,7 @@ def load_opus(name: str) -> None:
global _lib
_lib = libopus_loader(name)
def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -271,6 +330,7 @@ def is_loaded() -> bool:
global _lib
return _lib is not None
class OpusError(DiscordException):
"""An exception that is thrown for libopus related errors.
@ -282,19 +342,22 @@ class OpusError(DiscordException):
def __init__(self, code: int):
self.code: int = code
msg = _lib.opus_strerror(self.code).decode('utf-8')
msg = _lib.opus_strerror(self.code).decode("utf-8")
_log.info('"%s" has happened', msg)
super().__init__(msg)
class OpusNotLoaded(DiscordException):
"""An exception that is thrown for when libopus is not loaded."""
pass
class _OpusStruct:
SAMPLING_RATE = 48000
CHANNELS = 2
FRAME_LENGTH = 20 # in milliseconds
SAMPLE_SIZE = struct.calcsize('h') * CHANNELS
SAMPLE_SIZE = struct.calcsize("h") * CHANNELS
SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH)
FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE
@ -304,7 +367,8 @@ class _OpusStruct:
if not is_loaded() and not _load_default():
raise OpusNotLoaded()
return _lib.opus_get_version_string().decode('utf-8')
return _lib.opus_get_version_string().decode("utf-8")
class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO):
@ -315,18 +379,20 @@ class Encoder(_OpusStruct):
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full')
self.set_signal_type('auto')
self.set_bandwidth("full")
self.set_signal_type("auto")
def __del__(self) -> None:
if hasattr(self, '_state'):
if hasattr(self, "_state"):
_lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
self._state = None # type: ignore
def _create_state(self) -> EncoderStruct:
ret = ctypes.c_int()
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
return _lib.opus_encoder_create(
self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)
)
def set_bitrate(self, kbps: int) -> int:
kbps = min(512, max(16, int(kbps)))
@ -336,14 +402,18 @@ class Encoder(_OpusStruct):
def set_bandwidth(self, req: BAND_CTL) -> None:
if req not in band_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
raise KeyError(
f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}'
)
k = band_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
def set_signal_type(self, req: SIGNAL_CTL) -> None:
if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
raise KeyError(
f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}'
)
k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
@ -352,18 +422,19 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm)
# bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
return array.array("b", data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct):
def __init__(self):
@ -372,14 +443,16 @@ class Decoder(_OpusStruct):
self._state: DecoderStruct = self._create_state()
def __del__(self) -> None:
if hasattr(self, '_state'):
if hasattr(self, "_state"):
_lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
self._state = None # type: ignore
def _create_state(self) -> DecoderStruct:
ret = ctypes.c_int()
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
return _lib.opus_decoder_create(
self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)
)
@staticmethod
def packet_get_nb_frames(data: bytes) -> int:
@ -411,12 +484,12 @@ class Decoder(_OpusStruct):
def set_gain(self, dB: float) -> int:
"""Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8)
def set_volume(self, mult: float) -> int:
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self) -> int:
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
@ -428,14 +501,16 @@ class Decoder(_OpusStruct):
@overload
def decode(self, data: bytes, *, fec: bool) -> bytes:
...
@overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
...
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
if data is None and fec:
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
raise InvalidArgument(
"Invalid arguments: FEC cannot be used with null data"
)
if data is None:
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME
@ -449,6 +524,8 @@ class Decoder(_OpusStruct):
pcm = (ctypes.c_int16 * (frame_size * channel_count))()
pcm_ptr = ctypes.cast(pcm, c_int16_ptr)
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
ret = _lib.opus_decode(
self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec
)
return array.array('h', pcm[:ret * channel_count]).tobytes()
return array.array("h", pcm[: ret * channel_count]).tobytes()

View File

@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin
from .errors import InvalidArgument
from . import utils
__all__ = (
'PartialEmoji',
)
__all__ = ("PartialEmoji",)
if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag:
__slots__ = ()
@ -49,7 +48,7 @@ class _EmojiTag:
raise NotImplementedError
PE = TypeVar('PE', bound='PartialEmoji')
PE = TypeVar("PE", bound="PartialEmoji")
class PartialEmoji(_EmojiTag, AssetMixin):
@ -90,9 +89,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
The ID of the custom emoji, if applicable.
"""
__slots__ = ('animated', 'name', 'id', '_state')
__slots__ = ("animated", "name", "id", "_state")
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
_CUSTOM_EMOJI_RE = re.compile(
r"<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?"
)
if TYPE_CHECKING:
id: Optional[int]
@ -104,11 +105,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
self._state: Optional[ConnectionState] = None
@classmethod
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE:
def from_dict(
cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]
) -> PE:
return cls(
animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'),
name=data.get('name') or '',
animated=data.get("animated", False),
id=utils._get_as_snowflake(data, "id"),
name=data.get("name") or "",
)
@classmethod
@ -139,19 +142,19 @@ class PartialEmoji(_EmojiTag, AssetMixin):
match = cls._CUSTOM_EMOJI_RE.match(value)
if match is not None:
groups = match.groupdict()
animated = bool(groups['animated'])
emoji_id = int(groups['id'])
name = groups['name']
animated = bool(groups["animated"])
emoji_id = int(groups["id"])
name = groups["name"]
return cls(name=name, animated=animated, id=emoji_id)
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name}
o: Dict[str, Any] = {"name": self.name}
if self.id:
o['id'] = self.id
o["id"] = self.id
if self.animated:
o['animated'] = self.animated
o["animated"] = self.animated
return o
def _to_partial(self) -> PartialEmoji:
@ -159,7 +162,12 @@ class PartialEmoji(_EmojiTag, AssetMixin):
@classmethod
def with_state(
cls: Type[PE], state: ConnectionState, *, name: str, animated: bool = False, id: Optional[int] = None
cls: Type[PE],
state: ConnectionState,
*,
name: str,
animated: bool = False,
id: Optional[int] = None,
) -> PE:
self = cls(name=name, animated=animated, id=id)
self._state = state
@ -169,11 +177,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
if self.id is None:
return self.name
if self.animated:
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
def __repr__(self):
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>"
def __eq__(self, other: Any) -> bool:
if self.is_unicode_emoji():
@ -200,7 +208,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
def _as_reaction(self) -> str:
if self.id is None:
return self.name
return f'{self.name}:{self.id}'
return f"{self.name}:{self.id}"
@property
def created_at(self) -> Optional[datetime]:
@ -220,13 +228,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
If this isn't a custom emoji then an empty string is returned
"""
if self.is_unicode_emoji():
return ''
return ""
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
async def read(self) -> bytes:
if self.is_unicode_emoji():
raise InvalidArgument('PartialEmoji is not a custom emoji')
raise InvalidArgument("PartialEmoji is not a custom emoji")
return await super().read()

View File

@ -24,12 +24,24 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, Tuple, Type, TypeVar, Optional
from typing import (
Callable,
Any,
ClassVar,
Dict,
Iterator,
Set,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Optional,
)
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
__all__ = (
'Permissions',
'PermissionOverwrite',
"Permissions",
"PermissionOverwrite",
)
# A permission alias works like a regular flag but is marked
@ -38,7 +50,9 @@ class permission_alias(alias_flag_value):
alias: str
def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permission_alias]:
def make_permission_alias(
alias: str,
) -> Callable[[Callable[[Any], int]], permission_alias]:
def decorator(func: Callable[[Any], int]) -> permission_alias:
ret = permission_alias(func)
ret.alias = alias
@ -46,7 +60,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis
return decorator
P = TypeVar('P', bound='Permissions')
P = TypeVar("P", bound="Permissions")
@fill_with_flags()
class Permissions(BaseFlags):
@ -101,12 +117,14 @@ class Permissions(BaseFlags):
def __init__(self, permissions: int = 0, **kwargs: bool):
if not isinstance(permissions, int):
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.')
raise TypeError(
f"Expected int parameter, received {permissions.__class__.__name__} instead."
)
self.value = permissions
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid permission name.')
raise TypeError(f"{key!r} is not a valid permission name.")
setattr(self, key, value)
def is_subset(self, other: Permissions) -> bool:
@ -114,14 +132,18 @@ class Permissions(BaseFlags):
if isinstance(other, Permissions):
return (self.value & other.value) == self.value
else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
raise TypeError(
f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}"
)
def is_superset(self, other: Permissions) -> bool:
"""Returns ``True`` if self has the same or more permissions as other."""
if isinstance(other, Permissions):
return (self.value | other.value) == self.value
else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
raise TypeError(
f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}"
)
def is_strict_subset(self, other: Permissions) -> bool:
"""Returns ``True`` if the permissions on other are a strict subset of those on self."""
@ -336,7 +358,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
return 1 << 10
@make_permission_alias('read_messages')
@make_permission_alias("read_messages")
def view_channel(self) -> int:
""":class:`bool`: An alias for :attr:`read_messages`.
@ -389,7 +411,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
return 1 << 18
@make_permission_alias('external_emojis')
@make_permission_alias("external_emojis")
def use_external_emojis(self) -> int:
""":class:`bool`: An alias for :attr:`external_emojis`.
@ -453,7 +475,7 @@ class Permissions(BaseFlags):
"""
return 1 << 28
@make_permission_alias('manage_roles')
@make_permission_alias("manage_roles")
def manage_permissions(self) -> int:
""":class:`bool`: An alias for :attr:`manage_roles`.
@ -471,7 +493,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30
@make_permission_alias('manage_emojis')
@make_permission_alias("manage_emojis")
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
@ -535,7 +557,7 @@ class Permissions(BaseFlags):
"""
return 1 << 37
@make_permission_alias('external_stickers')
@make_permission_alias("external_stickers")
def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`.
@ -551,7 +573,9 @@ class Permissions(BaseFlags):
"""
return 1 << 38
PO = TypeVar('PO', bound='PermissionOverwrite')
PO = TypeVar("PO", bound="PermissionOverwrite")
def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
@ -614,7 +638,7 @@ class PermissionOverwrite:
Set the value of permissions by their name.
"""
__slots__ = ('_values',)
__slots__ = ("_values",)
if TYPE_CHECKING:
VALID_NAMES: ClassVar[Set[str]]
@ -670,7 +694,7 @@ class PermissionOverwrite:
for key, value in kwargs.items():
if key not in self.VALID_NAMES:
raise ValueError(f'no permission called {key}.')
raise ValueError(f"no permission called {key}.")
setattr(self, key, value)
@ -679,7 +703,9 @@ class PermissionOverwrite:
def _set(self, key: str, value: Optional[bool]) -> None:
if value not in (True, None, False):
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}')
raise TypeError(
f"Expected bool or NoneType, received {value.__class__.__name__}"
)
if value is None:
self._values.pop(key, None)

View File

@ -36,7 +36,18 @@ import sys
import re
import io
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Generic,
IO,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
)
from .errors import ClientException
from .opus import Encoder as OpusEncoder
@ -47,27 +58,28 @@ if TYPE_CHECKING:
from .voice_client import VoiceClient
AT = TypeVar('AT', bound='AudioSource')
FT = TypeVar('FT', bound='FFmpegOpusAudio')
AT = TypeVar("AT", bound="AudioSource")
FT = TypeVar("FT", bound="FFmpegOpusAudio")
_log = logging.getLogger(__name__)
__all__ = (
'AudioSource',
'PCMAudio',
'FFmpegAudio',
'FFmpegPCMAudio',
'FFmpegOpusAudio',
'PCMVolumeTransformer',
"AudioSource",
"PCMAudio",
"FFmpegAudio",
"FFmpegPCMAudio",
"FFmpegOpusAudio",
"PCMVolumeTransformer",
)
CREATE_NO_WINDOW: int
if sys.platform != 'win32':
if sys.platform != "win32":
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource:
"""Represents an audio stream.
@ -114,6 +126,7 @@ class AudioSource:
def __del__(self) -> None:
self.cleanup()
class PCMAudio(AudioSource):
"""Represents raw 16-bit 48KHz stereo PCM audio source.
@ -122,15 +135,17 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM.
"""
def __init__(self, stream: io.BufferedIOBase) -> None:
self.stream: io.BufferedIOBase = stream
def read(self) -> bytes:
ret = self.stream.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b''
return b""
return ret
class FFmpegAudio(AudioSource):
"""Represents an FFmpeg (or AVConv) based AudioSource.
@ -140,13 +155,22 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3
"""
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = "ffmpeg",
args: Any,
**subprocess_kwargs: Any,
):
piping = subprocess_kwargs.get("stdin") == subprocess.PIPE
if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
raise TypeError(
"parameter conflict: 'source' parameter cannot be a string when piping to stdin"
)
args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE}
kwargs = {"stdout": subprocess.PIPE}
kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
@ -155,20 +179,26 @@ class FFmpegAudio(AudioSource):
self._pipe_thread: Optional[threading.Thread] = None
if piping:
n = f'popen-stdin-writer:{id(self):#x}'
n = f"popen-stdin-writer:{id(self):#x}"
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread = threading.Thread(
target=self._pipe_writer, args=(source,), daemon=True, name=n
)
self._pipe_thread.start()
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
process = None
try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
process = subprocess.Popen(
args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs
)
except FileNotFoundError:
executable = args.partition(' ')[0] if isinstance(args, str) else args[0]
raise ClientException(executable + ' was not found.') from None
executable = args.partition(" ")[0] if isinstance(args, str) else args[0]
raise ClientException(executable + " was not found.") from None
except subprocess.SubprocessError as exc:
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc
raise ClientException(
f"Popen failed: {exc.__class__.__name__}: {exc}"
) from exc
else:
return process
@ -177,20 +207,32 @@ class FFmpegAudio(AudioSource):
if proc is MISSING:
return
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
_log.info("Preparing to terminate ffmpeg process %s.", proc.pid)
try:
proc.kill()
except Exception:
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
_log.exception(
"Ignoring error attempting to kill ffmpeg process %s", proc.pid
)
if proc.poll() is None:
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
_log.info(
"ffmpeg process %s has not terminated. Waiting to terminate...",
proc.pid,
)
proc.communicate()
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
_log.info(
"ffmpeg process %s should have terminated with a return code of %s.",
proc.pid,
proc.returncode,
)
else:
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
_log.info(
"ffmpeg process %s successfully terminated with return code of %s.",
proc.pid,
proc.returncode,
)
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
@ -202,7 +244,11 @@ class FFmpegAudio(AudioSource):
try:
self._stdin.write(data)
except Exception:
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
_log.debug(
"Write error for %s, this is probably not a problem",
self,
exc_info=True,
)
# at this point the source data is either exhausted or the process is fubar
self._process.terminate()
return
@ -211,6 +257,7 @@ class FFmpegAudio(AudioSource):
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@ -250,38 +297,42 @@ class FFmpegPCMAudio(FFmpegAudio):
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = 'ffmpeg',
executable: str = "ffmpeg",
pipe: bool = False,
stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None
options: Optional[str] = None,
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {
"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL,
"stderr": stderr,
}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append('-i')
args.append('-' if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
args.append("-i")
args.append("-" if pipe else source)
args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"))
if isinstance(options, str):
args.extend(shlex.split(options))
args.append('pipe:1')
args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
def read(self) -> bytes:
ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b''
return b""
return ret
def is_opus(self) -> bool:
return False
class FFmpegOpusAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@ -349,7 +400,7 @@ class FFmpegOpusAudio(FFmpegAudio):
*,
bitrate: int = 128,
codec: Optional[str] = None,
executable: str = 'ffmpeg',
executable: str = "ffmpeg",
pipe=False,
stderr=None,
before_options=None,
@ -357,28 +408,42 @@ class FFmpegOpusAudio(FFmpegAudio):
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {
"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL,
"stderr": stderr,
}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append('-i')
args.append('-' if pipe else source)
args.append("-i")
args.append("-" if pipe else source)
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
codec = "copy" if codec in ("opus", "libopus") else "libopus"
args.extend(('-map_metadata', '-1',
'-f', 'opus',
'-c:a', codec,
'-ar', '48000',
'-ac', '2',
'-b:a', f'{bitrate}k',
'-loglevel', 'warning'))
args.extend(
(
"-map_metadata",
"-1",
"-f",
"opus",
"-c:a",
codec,
"-ar",
"48000",
"-ac",
"2",
"-b:a",
f"{bitrate}k",
"-loglevel",
"warning",
)
)
if isinstance(options, str):
args.extend(shlex.split(options))
args.append('pipe:1')
args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
self._packet_iter = OggStream(self._stdout).iter_packets()
@ -388,7 +453,9 @@ class FFmpegOpusAudio(FFmpegAudio):
cls: Type[FT],
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
method: Optional[
Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]
] = None,
**kwargs: Any,
) -> FT:
"""|coro|
@ -446,7 +513,7 @@ class FFmpegOpusAudio(FFmpegAudio):
An instance of this class.
"""
executable = kwargs.get('executable')
executable = kwargs.get("executable")
codec, bitrate = await cls.probe(source, method=method, executable=executable)
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore
@ -455,7 +522,9 @@ class FFmpegOpusAudio(FFmpegAudio):
cls,
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
method: Optional[
Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]
] = None,
executable: Optional[str] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""|coro|
@ -484,12 +553,12 @@ class FFmpegOpusAudio(FFmpegAudio):
A 2-tuple with the codec and bitrate of the input source.
"""
method = method or 'native'
executable = executable or 'ffmpeg'
method = method or "native"
executable = executable or "ffmpeg"
probefunc = fallback = None
if isinstance(method, str):
probefunc = getattr(cls, '_probe_codec_' + method, None)
probefunc = getattr(cls, "_probe_codec_" + method, None)
if probefunc is None:
raise AttributeError(f"Invalid probe method {method!r}")
@ -500,8 +569,10 @@ class FFmpegOpusAudio(FFmpegAudio):
probefunc = method
fallback = cls._probe_codec_fallback
else:
raise TypeError("Expected str or callable for parameter 'probe', " \
f"not '{method.__class__.__name__}'")
raise TypeError(
"Expected str or callable for parameter 'probe', "
f"not '{method.__class__.__name__}'"
)
codec = bitrate = None
loop = asyncio.get_event_loop()
@ -512,7 +583,9 @@ class FFmpegOpusAudio(FFmpegAudio):
_log.exception("Probe '%s' using '%s' failed", method, executable)
return # type: ignore
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
_log.exception(
"Probe '%s' using '%s' failed, trying fallback", method, executable
)
try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
except Exception:
@ -525,28 +598,51 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate
@staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
def _probe_codec_native(
source, executable: str = "ffmpeg"
) -> Tuple[Optional[str], Optional[int]]:
exe = (
executable[:2] + "probe"
if executable in ("ffmpeg", "avconv")
else executable
)
args = [
exe,
"-v",
"quiet",
"-print_format",
"json",
"-show_streams",
"-select_streams",
"a:0",
source,
]
output = subprocess.check_output(args, timeout=20)
codec = bitrate = None
if output:
data = json.loads(output)
streamdata = data['streams'][0]
streamdata = data["streams"][0]
codec = streamdata.get('codec_name')
bitrate = int(streamdata.get('bit_rate', 0))
bitrate = max(round(bitrate/1000), 512)
codec = streamdata.get("codec_name")
bitrate = int(streamdata.get("bit_rate", 0))
bitrate = max(round(bitrate / 1000), 512)
return codec, bitrate
@staticmethod
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
args = [executable, '-hide_banner', '-i', source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
def _probe_codec_fallback(
source, executable: str = "ffmpeg"
) -> Tuple[Optional[str], Optional[int]]:
args = [executable, "-hide_banner", "-i", source]
proc = subprocess.Popen(
args,
creationflags=CREATE_NO_WINDOW,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
out, _ = proc.communicate(timeout=20)
output = out.decode('utf8')
output = out.decode("utf8")
codec = bitrate = None
codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output)
@ -560,11 +656,12 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate
def read(self) -> bytes:
return next(self._packet_iter, b'')
return next(self._packet_iter, b"")
def is_opus(self) -> bool:
return True
class PCMVolumeTransformer(AudioSource, Generic[AT]):
"""Transforms a previous :class:`AudioSource` to have volume controls.
@ -589,10 +686,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
def __init__(self, original: AT, volume: float = 1.0):
if not isinstance(original, AudioSource):
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.')
raise TypeError(f"expected AudioSource not {original.__class__.__name__}.")
if original.is_opus():
raise ClientException('AudioSource must not be Opus encoded.')
raise ClientException("AudioSource must not be Opus encoded.")
self.original: AT = original
self.volume = volume
@ -613,6 +710,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
@ -625,7 +723,7 @@ class AudioPlayer(threading.Thread):
self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock()
@ -685,11 +783,11 @@ class AudioPlayer(threading.Thread):
try:
self.after(error)
except Exception as exc:
_log.exception('Calling the after function failed.')
_log.exception("Calling the after function failed.")
exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error:
msg = f'Exception in voice thread {self.name}'
msg = f"Exception in voice thread {self.name}"
_log.exception(msg, exc_info=error)
print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__)
@ -725,6 +823,8 @@ class AudioPlayer(threading.Thread):
def _speak(self, speaking: bool) -> None:
try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
asyncio.run_coroutine_threadsafe(
self.client.ws.speak(speaking), self.client.loop
)
except Exception as e:
_log.info("Speaking call in player failed: %s", e)

View File

@ -34,7 +34,7 @@ if TYPE_CHECKING:
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent
IntegrationDeleteEvent,
)
from .message import Message
from .partial_emoji import PartialEmoji
@ -42,20 +42,20 @@ if TYPE_CHECKING:
__all__ = (
'RawMessageDeleteEvent',
'RawBulkMessageDeleteEvent',
'RawMessageUpdateEvent',
'RawReactionActionEvent',
'RawReactionClearEvent',
'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent',
"RawMessageDeleteEvent",
"RawBulkMessageDeleteEvent",
"RawMessageUpdateEvent",
"RawReactionActionEvent",
"RawReactionClearEvent",
"RawReactionClearEmojiEvent",
"RawIntegrationDeleteEvent",
)
class _RawReprMixin:
def __repr__(self) -> str:
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>'
value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__)
return f"<{self.__class__.__name__} {value}>"
class RawMessageDeleteEvent(_RawReprMixin):
@ -73,14 +73,14 @@ class RawMessageDeleteEvent(_RawReprMixin):
The cached message, if found in the internal message cache.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
__slots__ = ("message_id", "channel_id", "guild_id", "cached_message")
def __init__(self, data: MessageDeleteEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -100,15 +100,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
The cached messages, if found in the internal message cache.
"""
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
__slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages")
def __init__(self, data: BulkMessageDeleteEvent) -> None:
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])}
self.channel_id: int = int(data['channel_id'])
self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])}
self.channel_id: int = int(data["channel_id"])
self.cached_messages: List[Message] = []
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -136,16 +136,16 @@ class RawMessageUpdateEvent(_RawReprMixin):
it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
__slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message")
def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.data: MessageUpdateEvent = data
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -179,19 +179,28 @@ class RawReactionActionEvent(_RawReprMixin):
.. versionadded:: 1.3
"""
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
'event_type', 'member')
__slots__ = (
"message_id",
"user_id",
"channel_id",
"guild_id",
"emoji",
"event_type",
"member",
)
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
def __init__(
self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str
) -> None:
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.emoji: PartialEmoji = emoji
self.event_type: str = event_type
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -209,14 +218,14 @@ class RawReactionClearEvent(_RawReprMixin):
The guild ID where the reactions got cleared.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id')
__slots__ = ("message_id", "channel_id", "guild_id")
def __init__(self, data: ReactionClearEvent) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -238,15 +247,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
The custom or unicode emoji being removed.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
__slots__ = ("message_id", "channel_id", "guild_id", "emoji")
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
self.emoji: PartialEmoji = emoji
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@ -266,13 +275,13 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
The guild ID where the integration got deleted.
"""
__slots__ = ('integration_id', 'application_id', 'guild_id')
__slots__ = ("integration_id", "application_id", "guild_id")
def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data['id'])
self.guild_id: int = int(data['guild_id'])
self.integration_id: int = int(data["id"])
self.guild_id: int = int(data["guild_id"])
try:
self.application_id: Optional[int] = int(data['application_id'])
self.application_id: Optional[int] = int(data["application_id"])
except KeyError:
self.application_id: Optional[int] = None

View File

@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional
from .iterators import ReactionIterator
__all__ = (
'Reaction',
)
__all__ = ("Reaction",)
if TYPE_CHECKING:
from .types.message import Reaction as ReactionPayload
@ -38,6 +36,7 @@ if TYPE_CHECKING:
from .emoji import Emoji
from .abc import Snowflake
class Reaction:
"""Represents a reaction to a message.
@ -75,13 +74,22 @@ class Reaction:
message: :class:`Message`
Message this reaction is for.
"""
__slots__ = ('message', 'count', 'emoji', 'me')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None):
__slots__ = ("message", "count", "emoji", "me")
def __init__(
self,
*,
message: Message,
data: ReactionPayload,
emoji: Optional[Union[PartialEmoji, Emoji, str]] = None,
):
self.message: Message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji'])
self.count: int = data.get('count', 1)
self.me: bool = data.get('me')
self.emoji: Union[
PartialEmoji, Emoji, str
] = emoji or message._state.get_reaction_emoji(data["emoji"])
self.count: int = data.get("count", 1)
self.me: bool = data.get("me")
# TODO: typeguard
def is_custom_emoji(self) -> bool:
@ -103,7 +111,7 @@ class Reaction:
return str(self.emoji)
def __repr__(self) -> str:
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>'
return f"<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>"
async def remove(self, user: Snowflake) -> None:
"""|coro|
@ -155,7 +163,9 @@ class Reaction:
"""
await self.message.clear_reaction(self.emoji)
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator:
def users(
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
) -> ReactionIterator:
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
The ``after`` parameter must represent a member
@ -201,7 +211,7 @@ class Reaction:
"""
if not isinstance(self.emoji, str):
emoji = f'{self.emoji.name}:{self.emoji.id}'
emoji = f"{self.emoji.name}:{self.emoji.id}"
else:
emoji = self.emoji

View File

@ -32,8 +32,8 @@ from .mixins import Hashable
from .utils import snowflake_time, _get_as_snowflake, MISSING
__all__ = (
'RoleTags',
'Role',
"RoleTags",
"Role",
)
if TYPE_CHECKING:
@ -68,19 +68,21 @@ class RoleTags:
"""
__slots__ = (
'bot_id',
'integration_id',
'_premium_subscriber',
"bot_id",
"integration_id",
"_premium_subscriber",
)
def __init__(self, data: RoleTagPayload):
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id')
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id')
self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id")
self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id")
# NOTE: The API returns "null" for this if it's valid, which corresponds to None.
# This is different from other fields where "null" means "not there".
# So in this case, a value of None is the same as True.
# Which means we would need a different sentinel.
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING)
self._premium_subscriber: Optional[Any] = data.get(
"premium_subscriber", MISSING
)
def is_bot_managed(self) -> bool:
""":class:`bool`: Whether the role is associated with a bot."""
@ -96,12 +98,12 @@ class RoleTags:
def __repr__(self) -> str:
return (
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} '
f'premium_subscriber={self.is_premium_subscriber()}>'
f"<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} "
f"premium_subscriber={self.is_premium_subscriber()}>"
)
R = TypeVar('R', bound='Role')
R = TypeVar("R", bound="Role")
class Role(Hashable):
@ -181,23 +183,23 @@ class Role(Hashable):
"""
__slots__ = (
'id',
'name',
'_permissions',
'_colour',
'position',
'managed',
'mentionable',
'hoist',
'guild',
'tags',
'_state',
"id",
"name",
"_permissions",
"_colour",
"position",
"managed",
"mentionable",
"hoist",
"guild",
"tags",
"_state",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload):
self.guild: Guild = guild
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(data)
def __str__(self) -> str:
@ -207,14 +209,14 @@ class Role(Hashable):
return self.id
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'
return f"<Role id={self.id} name={self.name!r}>"
def __lt__(self: R, other: R) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented
if self.guild != other.guild:
raise RuntimeError('cannot compare roles from two different guilds.')
raise RuntimeError("cannot compare roles from two different guilds.")
# the @everyone role is always the lowest role in hierarchy
guild_id = self.guild.id
@ -246,17 +248,17 @@ class Role(Hashable):
return not r
def _update(self, data: RolePayload):
self.name: str = data['name']
self._permissions: int = int(data.get('permissions', 0))
self.position: int = data.get('position', 0)
self._colour: int = data.get('color', 0)
self.hoist: bool = data.get('hoist', False)
self.managed: bool = data.get('managed', False)
self.mentionable: bool = data.get('mentionable', False)
self.name: str = data["name"]
self._permissions: int = int(data.get("permissions", 0))
self.position: int = data.get("position", 0)
self._colour: int = data.get("color", 0)
self.hoist: bool = data.get("hoist", False)
self.managed: bool = data.get("managed", False)
self.mentionable: bool = data.get("mentionable", False)
self.tags: Optional[RoleTags]
try:
self.tags = RoleTags(data['tags'])
self.tags = RoleTags(data["tags"])
except KeyError:
self.tags = None
@ -291,7 +293,11 @@ class Role(Hashable):
.. versionadded:: 2.0
"""
me = self.guild.me
return not self.is_default() and not self.managed and (me.top_role > self or me.id == self.guild.owner_id)
return (
not self.is_default()
and not self.managed
and (me.top_role > self or me.id == self.guild.owner_id)
)
@property
def permissions(self) -> Permissions:
@ -316,7 +322,7 @@ class Role(Hashable):
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention a role."""
return f'<@&{self.id}>'
return f"<@&{self.id}>"
@property
def members(self) -> List[Member]:
@ -340,15 +346,23 @@ class Role(Hashable):
http = self._state.http
change_range = range(min(self.position, position), max(self.position, position) + 1)
roles = [r.id for r in self.guild.roles[1:] if r.position in change_range and r.id != self.id]
change_range = range(
min(self.position, position), max(self.position, position) + 1
)
roles = [
r.id
for r in self.guild.roles[1:]
if r.position in change_range and r.id != self.id
]
if self.position > position:
roles.insert(0, self.id)
else:
roles.append(self.id)
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
payload: List[RolePositionUpdate] = [
{"id": z[0], "position": z[1]} for z in zip(roles, change_range)
]
await http.move_role_position(self.guild.id, payload, reason=reason)
async def edit(
@ -420,23 +434,25 @@ class Role(Hashable):
if colour is not MISSING:
if isinstance(colour, int):
payload['color'] = colour
payload["color"] = colour
else:
payload['color'] = colour.value
payload["color"] = colour.value
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if permissions is not MISSING:
payload['permissions'] = permissions.value
payload["permissions"] = permissions.value
if hoist is not MISSING:
payload['hoist'] = hoist
payload["hoist"] = hoist
if mentionable is not MISSING:
payload['mentionable'] = mentionable
payload["mentionable"] = mentionable
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
data = await self._state.http.edit_role(
self.guild.id, self.id, reason=reason, **payload
)
return Role(guild=self.guild, data=data, state=self._state)
async def delete(self, *, reason: Optional[str] = None) -> None:

View File

@ -43,18 +43,28 @@ from .errors import (
from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Tuple,
Type,
Optional,
List,
Dict,
TypeVar,
)
if TYPE_CHECKING:
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
EI = TypeVar('EI', bound='EventItem')
EI = TypeVar("EI", bound="EventItem")
__all__ = (
'AutoShardedClient',
'ShardInfo',
"AutoShardedClient",
"ShardInfo",
)
_log = logging.getLogger(__name__)
@ -70,11 +80,13 @@ class EventType:
class EventItem:
__slots__ = ('type', 'shard', 'error')
__slots__ = ("type", "shard", "error")
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
def __init__(
self, etype: int, shard: Optional["Shard"], error: Optional[Exception]
) -> None:
self.type: int = etype
self.shard: Optional['Shard'] = shard
self.shard: Optional["Shard"] = shard
self.error: Optional[Exception] = error
def __lt__(self: EI, other: EI) -> bool:
@ -92,7 +104,12 @@ class EventItem:
class Shard:
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
def __init__(
self,
ws: DiscordWebSocket,
client: AutoShardedClient,
queue_put: Callable[[EventItem], None],
) -> None:
self.ws: DiscordWebSocket = ws
self._client: Client = client
self._dispatch: Callable[..., None] = client.dispatch
@ -129,11 +146,11 @@ class Shard:
async def disconnect(self) -> None:
await self.close()
self._dispatch('shard_disconnect', self.id)
self._dispatch("shard_disconnect", self.id)
async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
if not self._reconnect:
self._queue_put(EventItem(EventType.close, self, e))
return
@ -149,14 +166,23 @@ class Shard:
if isinstance(e, ConnectionClosed):
if e.code == 4014:
self._queue_put(EventItem(EventType.terminate, self, PrivilegedIntentsRequired(self.id)))
self._queue_put(
EventItem(
EventType.terminate, self, PrivilegedIntentsRequired(self.id)
)
)
return
if e.code != 1000:
self._queue_put(EventItem(EventType.close, self, e))
return
retry = self._backoff.delay()
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
_log.error(
"Attempting a reconnect for shard ID %s in %.2fs",
self.id,
retry,
exc_info=e,
)
await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e))
@ -179,9 +205,9 @@ class Shard:
async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task()
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
_log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
try:
coro = DiscordWebSocket.from_client(
self._client,
@ -231,7 +257,7 @@ class ShardInfo:
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
"""
__slots__ = ('_parent', 'id', 'shard_count')
__slots__ = ("_parent", "id", "shard_count")
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent: Shard = parent
@ -320,16 +346,23 @@ class AutoShardedClient(Client):
if TYPE_CHECKING:
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
def __init__(
self,
*args: Any,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any,
) -> None:
kwargs.pop("shard_id", None)
self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
if self.shard_count is None:
raise ClientException('When passing manual shard_ids, you must provide a shard_count.')
raise ClientException(
"When passing manual shard_ids, you must provide a shard_count."
)
elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.')
raise ClientException("shard_ids parameter must be a list or a tuple.")
# instead of a single websocket, we have multiple
# the key is the shard_id
@ -338,7 +371,9 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
def _get_websocket(
self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None
) -> DiscordWebSocket:
if shard_id is None:
# guild_id won't be None if shard_id is None and shard_count won't be None here
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
@ -363,7 +398,7 @@ class AutoShardedClient(Client):
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
"""
if not self.__shards:
return float('nan')
return float("nan")
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
@ -372,7 +407,9 @@ class AutoShardedClient(Client):
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
return [
(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()
]
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
@ -386,14 +423,21 @@ class AutoShardedClient(Client):
@property
def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
return {
shard_id: ShardInfo(parent, self.shard_count)
for shard_id, parent in self.__shards.items()
}
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
async def launch_shard(
self, gateway: str, shard_id: int, *, initial: bool = False
) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
coro = DiscordWebSocket.from_client(
self, initial=initial, gateway=gateway, shard_id=shard_id
)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
_log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
@ -458,7 +502,10 @@ class AutoShardedClient(Client):
except Exception:
pass
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
to_close = [
asyncio.ensure_future(shard.close(), loop=self.loop)
for shard in self.__shards.values()
]
if to_close:
await asyncio.wait(to_close)
@ -503,10 +550,10 @@ class AutoShardedClient(Client):
"""
if status is None:
status_value = 'online'
status_value = "online"
status_enum = Status.online
elif status is Status.offline:
status_value = 'invisible'
status_value = "invisible"
status_enum = Status.offline
else:
status_enum = status

View File

@ -31,9 +31,7 @@ from .mixins import Hashable
from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum
__all__ = (
'StageInstance',
)
__all__ = ("StageInstance",)
if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload
@ -82,41 +80,51 @@ class StageInstance(Hashable):
"""
__slots__ = (
'_state',
'id',
'guild',
'channel_id',
'topic',
'privacy_level',
'discoverable_disabled',
'_cs_channel',
"_state",
"id",
"guild",
"channel_id",
"topic",
"privacy_level",
"discoverable_disabled",
"_cs_channel",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
def __init__(
self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload
) -> None:
self._state = state
self.guild = guild
self._update(data)
def _update(self, data: StageInstancePayload):
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
self.id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.topic: str = data["topic"]
self.privacy_level: StagePrivacyLevel = try_enum(
StagePrivacyLevel, data["privacy_level"]
)
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
return f"<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>"
@cached_slot_property('_cs_channel')
@cached_slot_property("_cs_channel")
def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore
return self._state.get_channel(self.channel_id) # type: ignore
def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
async def edit(
self,
*,
topic: str = MISSING,
privacy_level: StagePrivacyLevel = MISSING,
reason: Optional[str] = None,
) -> None:
"""|coro|
Edits the stage instance.
@ -146,16 +154,20 @@ class StageInstance(Hashable):
payload = {}
if topic is not MISSING:
payload['topic'] = topic
payload["topic"] = topic
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
raise InvalidArgument(
"privacy_level field must be of type PrivacyLevel"
)
payload['privacy_level'] = privacy_level.value
payload["privacy_level"] = privacy_level.value
if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
await self._state.http.edit_stage_instance(
self.channel_id, **payload, reason=reason
)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|

File diff suppressed because it is too large Load Diff

View File

@ -33,11 +33,11 @@ from .errors import InvalidData
from .enums import StickerType, StickerFormatType, try_enum
__all__ = (
'StickerPack',
'StickerItem',
'Sticker',
'StandardSticker',
'GuildSticker',
"StickerPack",
"StickerItem",
"Sticker",
"StandardSticker",
"GuildSticker",
)
if TYPE_CHECKING:
@ -102,15 +102,15 @@ class StickerPack(Hashable):
"""
__slots__ = (
'_state',
'id',
'stickers',
'name',
'sku_id',
'cover_sticker_id',
'cover_sticker',
'description',
'_banner',
"_state",
"id",
"stickers",
"name",
"sku_id",
"cover_sticker_id",
"cover_sticker",
"description",
"_banner",
)
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
@ -118,15 +118,17 @@ class StickerPack(Hashable):
self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data['id'])
stickers = data['stickers']
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
self.name: str = data['name']
self.sku_id: int = int(data['sku_id'])
self.cover_sticker_id: int = int(data['cover_sticker_id'])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data['description']
self._banner: int = int(data['banner_asset_id'])
self.id: int = int(data["id"])
stickers = data["stickers"]
self.stickers: List[StandardSticker] = [
StandardSticker(state=self._state, data=sticker) for sticker in stickers
]
self.name: str = data["name"]
self.sku_id: int = int(data["sku_id"])
self.cover_sticker_id: int = int(data["cover_sticker_id"])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data["description"]
self._banner: int = int(data["banner_asset_id"])
@property
def banner(self) -> Asset:
@ -134,7 +136,7 @@ class StickerPack(Hashable):
return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str:
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
return f"<StickerPack id={self.id} name={self.name!r} description={self.description!r}>"
def __str__(self) -> str:
return self.name
@ -205,17 +207,19 @@ class StickerItem(_StickerTag):
The URL for the sticker's image.
"""
__slots__ = ('_state', 'name', 'id', 'format', 'url')
__slots__ = ("_state", "name", "id", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state
self.name: str = data['name']
self.id: int = int(data['id'])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
self.name: str = data["name"]
self.id: int = int(data["id"])
self.format: StickerFormatType = try_enum(
StickerFormatType, data["format_type"]
)
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
return f"<StickerItem id={self.id} name={self.name!r} format={self.format}>"
def __str__(self) -> str:
return self.name
@ -236,7 +240,7 @@ class StickerItem(_StickerTag):
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore
cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._state, data=data)
@ -275,21 +279,23 @@ class Sticker(_StickerTag):
The URL for the sticker's image.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
__slots__ = ("_state", "id", "name", "description", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self.format: StickerFormatType = try_enum(
StickerFormatType, data["format_type"]
)
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str:
return f'<Sticker id={self.id} name={self.name!r}>'
return f"<Sticker id={self.id} name={self.name!r}>"
def __str__(self) -> str:
return self.name
@ -337,21 +343,23 @@ class StandardSticker(Sticker):
The sticker's sort order within its pack.
"""
__slots__ = ('sort_value', 'pack_id', 'type', 'tags')
__slots__ = ("sort_value", "pack_id", "type", "tags")
def _from_data(self, data: StandardStickerPayload) -> None:
super()._from_data(data)
self.sort_value: int = data['sort_value']
self.pack_id: int = int(data['pack_id'])
self.sort_value: int = data["sort_value"]
self.pack_id: int = int(data["pack_id"])
self.type: StickerType = StickerType.standard
try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")]
except KeyError:
self.tags = []
def __repr__(self) -> str:
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
return (
f"<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>"
)
async def pack(self) -> StickerPack:
"""|coro|
@ -370,13 +378,15 @@ class StandardSticker(Sticker):
:class:`StickerPack`
The retrieved sticker pack.
"""
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data['sticker_packs']
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
data: ListPremiumStickerPacksPayload = (
await self._state.http.list_premium_sticker_packs()
)
packs = data["sticker_packs"]
pack = find(lambda d: int(d["id"]) == self.pack_id, packs)
if pack:
return StickerPack(state=self._state, data=pack)
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
raise InvalidData(f"Could not find corresponding sticker pack for {self!r}")
class GuildSticker(Sticker):
@ -419,21 +429,21 @@ class GuildSticker(Sticker):
The name of a unicode emoji that represents this sticker.
"""
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
__slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild")
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data['available']
self.guild_id: int = int(data['guild_id'])
user = data.get('user')
self.available: bool = data["available"]
self.guild_id: int = int(data["guild_id"])
user = data.get("user")
self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data['tags']
self.emoji: str = data["tags"]
self.type: StickerType = StickerType.guild
def __repr__(self) -> str:
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
return f"<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>"
@cached_slot_property('_cs_guild')
@cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild.
@ -480,10 +490,10 @@ class GuildSticker(Sticker):
payload: EditGuildSticker = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if description is not MISSING:
payload['description'] = description
payload["description"] = description
if emoji is not MISSING:
try:
@ -491,11 +501,13 @@ class GuildSticker(Sticker):
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
emoji = emoji.replace(" ", "_")
payload['tags'] = emoji
payload["tags"] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(
self.guild_id, self.id, payload, reason
)
return GuildSticker(state=self._state, data=data)
async def delete(self, *, reason: Optional[str] = None) -> None:
@ -521,7 +533,9 @@ class GuildSticker(Sticker):
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
def _sticker_factory(
sticker_type: Literal[1, 2]
) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type)
if value == StickerType.standard:
return StandardSticker, value

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
)
__all__ = (
'Team',
'TeamMember',
"Team",
"TeamMember",
)
@ -62,26 +62,28 @@ class Team:
.. versionadded:: 1.3
"""
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members')
__slots__ = ("_state", "id", "name", "_icon", "owner_id", "members")
def __init__(self, state: ConnectionState, data: TeamPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data['icon']
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id')
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']]
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data["icon"]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id")
self.members: List[TeamMember] = [
TeamMember(self, self._state, member) for member in data["members"]
]
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='team')
return Asset._from_icon(self._state, self.id, self._icon, path="team")
@property
def owner(self) -> Optional[TeamMember]:
@ -130,16 +132,18 @@ class TeamMember(BaseUser):
The membership state of the member (e.g. invited or accepted)
"""
__slots__ = ('team', 'membership_state', 'permissions')
__slots__ = ("team", "membership_state", "permissions")
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload):
self.team: Team = team
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
self.permissions: List[str] = data['permissions']
super().__init__(state=state, data=data['user'])
self.membership_state: TeamMembershipState = try_enum(
TeamMembershipState, data["membership_state"]
)
self.permissions: List[str] = data["permissions"]
super().__init__(state=state, data=data["user"])
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>"
)

View File

@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion
from .guild import Guild
__all__ = (
'Template',
)
__all__ = ("Template",)
if TYPE_CHECKING:
import datetime
@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
raise AttributeError('PartialTemplateState does not support http methods.')
raise AttributeError("PartialTemplateState does not support http methods.")
class _PartialTemplateState:
@ -84,7 +82,7 @@ class _PartialTemplateState:
return []
def __getattr__(self, attr):
raise AttributeError(f'PartialTemplateState does not support {attr!r}.')
raise AttributeError(f"PartialTemplateState does not support {attr!r}.")
class Template:
@ -118,16 +116,16 @@ class Template:
"""
__slots__ = (
'code',
'uses',
'name',
'description',
'creator',
'created_at',
'updated_at',
'source_guild',
'is_dirty',
'_state',
"code",
"uses",
"name",
"description",
"creator",
"created_at",
"updated_at",
"source_guild",
"is_dirty",
"_state",
)
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
@ -135,38 +133,46 @@ class Template:
self._store(data)
def _store(self, data: TemplatePayload) -> None:
self.code: str = data['code']
self.uses: int = data['usage_count']
self.name: str = data['name']
self.description: Optional[str] = data['description']
creator_data = data.get('creator')
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
self.code: str = data["code"]
self.uses: int = data["usage_count"]
self.name: str = data["name"]
self.description: Optional[str] = data["description"]
creator_data = data.get("creator")
self.creator: Optional[User] = (
None if creator_data is None else self._state.create_user(creator_data)
)
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
self.created_at: Optional[datetime.datetime] = parse_time(
data.get("created_at")
)
self.updated_at: Optional[datetime.datetime] = parse_time(
data.get("updated_at")
)
guild_id = int(data['source_guild_id'])
guild_id = int(data["source_guild_id"])
guild: Optional[Guild] = self._state._get_guild(guild_id)
self.source_guild: Guild
if guild is None:
source_serialised = data['serialized_source_guild']
source_serialised['id'] = guild_id
source_serialised = data["serialized_source_guild"]
source_serialised["id"] = guild_id
state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
self.is_dirty: Optional[bool] = data.get("is_dirty", None)
def __repr__(self) -> str:
return (
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}'
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
f"<Template code={self.code!r} uses={self.uses} name={self.name!r}"
f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>"
)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
async def create_guild(
self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None
) -> Guild:
"""|coro|
Creates a :class:`.Guild` using the template.
@ -203,7 +209,9 @@ class Template:
region = region or VoiceRegion.us_west
region_value = region.value
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
data = await self._state.http.create_from_template(
self.code, name, region_value, icon
)
return Guild(data=data, state=self._state)
async def sync(self) -> Template:
@ -279,11 +287,13 @@ class Template:
payload = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if description is not MISSING:
payload['description'] = description
payload["description"] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
data = await self._state.http.edit_template(
self.source_guild.id, self.code, payload
)
return Template(state=self._state, data=data)
async def delete(self) -> None:
@ -310,7 +320,7 @@ class Template:
@property
def url(self) -> str:
""":class:`str`: The template url.
.. versionadded:: 2.0
"""
return f'https://discord.new/{self.code}'
return f"https://discord.new/{self.code}"

View File

@ -35,8 +35,8 @@ from .errors import ClientException
from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = (
'Thread',
'ThreadMember',
"Thread",
"ThreadMember",
)
if TYPE_CHECKING:
@ -128,25 +128,25 @@ class Thread(Messageable, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'_type',
'_state',
'_members',
'owner_id',
'parent_id',
'last_message_id',
'message_count',
'member_count',
'slowmode_delay',
'me',
'locked',
'archived',
'invitable',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
"name",
"id",
"guild",
"_type",
"_state",
"_members",
"owner_id",
"parent_id",
"last_message_id",
"message_count",
"member_count",
"slowmode_delay",
"me",
"locked",
"archived",
"invitable",
"archiver_id",
"auto_archive_duration",
"archive_timestamp",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
@ -160,50 +160,50 @@ class Thread(Messageable, Hashable):
def __repr__(self) -> str:
return (
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}'
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>'
f"<Thread id={self.id!r} name={self.name!r} parent={self.parent}"
f" owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>"
)
def __str__(self) -> str:
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self._unroll_metadata(data['thread_metadata'])
self.id = int(data["id"])
self.parent_id = int(data["parent_id"])
self.owner_id = int(data["owner_id"])
self.name = data["name"]
self._type = try_enum(ChannelType, data["type"])
self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self._unroll_metadata(data["thread_metadata"])
try:
member = data['member']
member = data["member"]
except KeyError:
self.me = None
else:
self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
self.archived = data["archived"]
self.archiver_id = _get_as_snowflake(data, "archiver_id")
self.auto_archive_duration = data["auto_archive_duration"]
self.archive_timestamp = parse_time(data["archive_timestamp"])
self.locked = data.get("locked", False)
self.invitable = data.get("invitable", True)
def _update(self, data):
try:
self.name = data['name']
self.name = data["name"]
except KeyError:
pass
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.slowmode_delay = data.get("rate_limit_per_user", 0)
try:
self._unroll_metadata(data['thread_metadata'])
self._unroll_metadata(data["thread_metadata"])
except KeyError:
pass
@ -225,7 +225,7 @@ class Thread(Messageable, Hashable):
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def members(self) -> List[ThreadMember]:
@ -256,7 +256,11 @@ class Thread(Messageable, Hashable):
Optional[:class:`Message`]
The last message in this channel or ``None`` if not found.
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None
return (
self._state._get_message(self.last_message_id)
if self.last_message_id
else None
)
@property
def category(self) -> Optional[CategoryChannel]:
@ -275,9 +279,9 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.category
@property
def category_id(self) -> Optional[int]:
"""The category channel ID the parent channel belongs to, if applicable.
@ -295,7 +299,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.category_id
def is_private(self) -> bool:
@ -352,7 +356,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.permissions_for(obj)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@ -402,7 +406,7 @@ class Thread(Messageable, Hashable):
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
@ -477,11 +481,19 @@ class Thread(Messageable, Hashable):
if check is MISSING:
check = lambda m: True
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
iterator = self.history(
limit=limit,
before=before,
after=after,
oldest_first=oldest_first,
around=around,
)
ret: List[Message] = []
count = 0
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
minimum_time = (
int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
)
async def _single_delete_strategy(messages: Iterable[Message]):
for m in messages:
@ -577,17 +589,17 @@ class Thread(Messageable, Hashable):
"""
payload = {}
if name is not MISSING:
payload['name'] = str(name)
payload["name"] = str(name)
if archived is not MISSING:
payload['archived'] = archived
payload["archived"] = archived
if auto_archive_duration is not MISSING:
payload['auto_archive_duration'] = auto_archive_duration
payload["auto_archive_duration"] = auto_archive_duration
if locked is not MISSING:
payload['locked'] = locked
payload["locked"] = locked
if invitable is not MISSING:
payload['invitable'] = invitable
payload["invitable"] = invitable
if slowmode_delay is not MISSING:
payload['rate_limit_per_user'] = slowmode_delay
payload["rate_limit_per_user"] = slowmode_delay
data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload
@ -773,12 +785,12 @@ class ThreadMember(Hashable):
"""
__slots__ = (
'id',
'thread_id',
'joined_at',
'flags',
'_state',
'parent',
"id",
"thread_id",
"joined_at",
"flags",
"_state",
"parent",
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
@ -787,22 +799,22 @@ class ThreadMember(Hashable):
self._from_data(data)
def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
return f"<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>"
def _from_data(self, data: ThreadMemberPayload):
try:
self.id = int(data['user_id'])
self.id = int(data["user_id"])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
try:
self.thread_id = int(data['id'])
self.thread_id = int(data["id"])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
self.joined_at = parse_time(data["join_timestamp"])
self.flags = data["flags"]
@property
def thread(self) -> Thread:

View File

@ -29,7 +29,7 @@ from .user import PartialUser
from .snowflake import Snowflake
StatusType = Literal['idle', 'dnd', 'online', 'offline']
StatusType = Literal["idle", "dnd", "online", "offline"]
class PartialPresenceUpdate(TypedDict):

View File

@ -30,6 +30,7 @@ from .user import User
from .team import Team
from .snowflake import Snowflake
class BaseAppInfo(TypedDict):
id: Snowflake
name: str
@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict):
summary: str
description: str
class _AppInfoOptional(TypedDict, total=False):
team: Team
guild_id: Snowflake
@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False):
hook: bool
max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str]
owner: User
bot_public: bool
bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str]
cover_image: str
@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
max_participants: int
flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass

View File

@ -26,7 +26,12 @@ from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
from .webhook import Webhook
from .guild import MFALevel, VerificationLevel, ExplicitContentFilterLevel, DefaultMessageNotificationLevel
from .guild import (
MFALevel,
VerificationLevel,
ExplicitContentFilterLevel,
DefaultMessageNotificationLevel,
)
from .integration import IntegrationExpireBehavior, PartialIntegration
from .user import User
from .snowflake import Snowflake
@ -84,31 +89,47 @@ AuditLogEvent = Literal[
class _AuditLogChange_Str(TypedDict):
key: Literal[
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
"name",
"description",
"preferred_locale",
"vanity_url_code",
"topic",
"code",
"allow",
"deny",
"permissions",
"tags",
]
new_value: str
old_value: str
class _AuditLogChange_AssetHash(TypedDict):
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
key: Literal[
"icon_hash",
"splash_hash",
"discovery_splash_hash",
"banner_hash",
"avatar_hash",
"asset",
]
new_value: str
old_value: str
class _AuditLogChange_Snowflake(TypedDict):
key: Literal[
'id',
'owner_id',
'afk_channel_id',
'rules_channel_id',
'public_updates_channel_id',
'widget_channel_id',
'system_channel_id',
'application_id',
'channel_id',
'inviter_id',
'guild_id',
"id",
"owner_id",
"afk_channel_id",
"rules_channel_id",
"public_updates_channel_id",
"widget_channel_id",
"system_channel_id",
"application_id",
"channel_id",
"inviter_id",
"guild_id",
]
new_value: Snowflake
old_value: Snowflake
@ -116,20 +137,20 @@ class _AuditLogChange_Snowflake(TypedDict):
class _AuditLogChange_Bool(TypedDict):
key: Literal[
'widget_enabled',
'nsfw',
'hoist',
'mentionable',
'temporary',
'deaf',
'mute',
'nick',
'enabled_emoticons',
'region',
'rtc_region',
'available',
'archived',
'locked',
"widget_enabled",
"nsfw",
"hoist",
"mentionable",
"temporary",
"deaf",
"mute",
"nick",
"enabled_emoticons",
"region",
"rtc_region",
"available",
"archived",
"locked",
]
new_value: bool
old_value: bool
@ -137,72 +158,72 @@ class _AuditLogChange_Bool(TypedDict):
class _AuditLogChange_Int(TypedDict):
key: Literal[
'afk_timeout',
'prune_delete_days',
'position',
'bitrate',
'rate_limit_per_user',
'color',
'max_uses',
'max_age',
'user_limit',
'auto_archive_duration',
'default_auto_archive_duration',
"afk_timeout",
"prune_delete_days",
"position",
"bitrate",
"rate_limit_per_user",
"color",
"max_uses",
"max_age",
"user_limit",
"auto_archive_duration",
"default_auto_archive_duration",
]
new_value: int
old_value: int
class _AuditLogChange_ListRole(TypedDict):
key: Literal['$add', '$remove']
key: Literal["$add", "$remove"]
new_value: List[Role]
old_value: List[Role]
class _AuditLogChange_MFALevel(TypedDict):
key: Literal['mfa_level']
key: Literal["mfa_level"]
new_value: MFALevel
old_value: MFALevel
class _AuditLogChange_VerificationLevel(TypedDict):
key: Literal['verification_level']
key: Literal["verification_level"]
new_value: VerificationLevel
old_value: VerificationLevel
class _AuditLogChange_ExplicitContentFilter(TypedDict):
key: Literal['explicit_content_filter']
key: Literal["explicit_content_filter"]
new_value: ExplicitContentFilterLevel
old_value: ExplicitContentFilterLevel
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
key: Literal['default_message_notifications']
key: Literal["default_message_notifications"]
new_value: DefaultMessageNotificationLevel
old_value: DefaultMessageNotificationLevel
class _AuditLogChange_ChannelType(TypedDict):
key: Literal['type']
key: Literal["type"]
new_value: ChannelType
old_value: ChannelType
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
key: Literal['expire_behavior']
key: Literal["expire_behavior"]
new_value: IntegrationExpireBehavior
old_value: IntegrationExpireBehavior
class _AuditLogChange_VideoQualityMode(TypedDict):
key: Literal['video_quality_mode']
key: Literal["video_quality_mode"]
new_value: VideoQualityMode
old_value: VideoQualityMode
class _AuditLogChange_Overwrites(TypedDict):
key: Literal['permission_overwrites']
key: Literal["permission_overwrites"]
new_value: List[PermissionOverwrite]
old_value: List[PermissionOverwrite]
@ -232,7 +253,7 @@ class AuditEntryInfo(TypedDict):
message_id: Snowflake
count: str
id: Snowflake
type: Literal['0', '1']
type: Literal["0", "1"]
role_name: str

View File

@ -128,7 +128,15 @@ class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
thread_metadata: ThreadMetadata
GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StoreChannel, StageChannel, ThreadChannel]
GuildChannel = Union[
TextChannel,
NewsChannel,
VoiceChannel,
CategoryChannel,
StoreChannel,
StageChannel,
ThreadChannel,
]
class DMChannel(_BaseChannel):

View File

@ -24,49 +24,60 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, TypedDict
class _EmbedFooterOptional(TypedDict, total=False):
icon_url: str
proxy_icon_url: str
class EmbedFooter(_EmbedFooterOptional):
text: str
class _EmbedFieldOptional(TypedDict, total=False):
inline: bool
class EmbedField(_EmbedFieldOptional):
name: str
value: str
class EmbedThumbnail(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedVideo(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedImage(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedProvider(TypedDict, total=False):
name: str
url: str
class EmbedAuthor(TypedDict, total=False):
name: str
url: str
icon_url: str
proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal["rich", "image", "video", "gifv", "article", "link"]
class Embed(TypedDict, total=False):
title: str

View File

@ -75,28 +75,28 @@ VerificationLevel = Literal[0, 1, 2, 3, 4]
NSFWLevel = Literal[0, 1, 2, 3]
PremiumTier = Literal[0, 1, 2, 3]
GuildFeature = Literal[
'ANIMATED_ICON',
'BANNER',
'COMMERCE',
'COMMUNITY',
'DISCOVERABLE',
'FEATURABLE',
'INVITE_SPLASH',
'MEMBER_VERIFICATION_GATE_ENABLED',
'MONETIZATION_ENABLED',
'MORE_EMOJI',
'MORE_STICKERS',
'NEWS',
'PARTNERED',
'PREVIEW_ENABLED',
'PRIVATE_THREADS',
'SEVEN_DAY_THREAD_ARCHIVE',
'THREE_DAY_THREAD_ARCHIVE',
'TICKETED_EVENTS_ENABLED',
'VANITY_URL',
'VERIFIED',
'VIP_REGIONS',
'WELCOME_SCREEN_ENABLED',
"ANIMATED_ICON",
"BANNER",
"COMMERCE",
"COMMUNITY",
"DISCOVERABLE",
"FEATURABLE",
"INVITE_SPLASH",
"MEMBER_VERIFICATION_GATE_ENABLED",
"MONETIZATION_ENABLED",
"MORE_EMOJI",
"MORE_STICKERS",
"NEWS",
"PARTNERED",
"PREVIEW_ENABLED",
"PRIVATE_THREADS",
"SEVEN_DAY_THREAD_ARCHIVE",
"THREE_DAY_THREAD_ARCHIVE",
"TICKETED_EVENTS_ENABLED",
"VANITY_URL",
"VERIFIED",
"VIP_REGIONS",
"WELCOME_SCREEN_ENABLED",
]

View File

@ -56,7 +56,7 @@ class PartialIntegration(TypedDict):
account: IntegrationAccount
IntegrationType = Literal['twitch', 'youtube', 'discord']
IntegrationType = Literal["twitch", "youtube", "discord"]
class BaseIntegration(PartialIntegration):

View File

@ -39,6 +39,7 @@ if TYPE_CHECKING:
ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
type: ApplicationCommandType
@ -100,32 +101,44 @@ class _ApplicationCommandInteractionDataOption(TypedDict):
name: str
class _ApplicationCommandInteractionDataOptionSubcommand(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionSubcommand(
_ApplicationCommandInteractionDataOption
):
type: Literal[1, 2]
options: List[ApplicationCommandInteractionDataOption]
class _ApplicationCommandInteractionDataOptionString(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionString(
_ApplicationCommandInteractionDataOption
):
type: Literal[3]
value: str
class _ApplicationCommandInteractionDataOptionInteger(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionInteger(
_ApplicationCommandInteractionDataOption
):
type: Literal[4]
value: int
class _ApplicationCommandInteractionDataOptionBoolean(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionBoolean(
_ApplicationCommandInteractionDataOption
):
type: Literal[5]
value: bool
class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionSnowflake(
_ApplicationCommandInteractionDataOption
):
type: Literal[6, 7, 8, 9]
value: Snowflake
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
class _ApplicationCommandInteractionDataOptionNumber(
_ApplicationCommandInteractionDataOption
):
type: Literal[10]
value: float
@ -222,9 +235,6 @@ class MessageInteraction(TypedDict):
user: User
class _EditApplicationCommandOptional(TypedDict, total=False):
description: str
options: Optional[List[ApplicationCommandOption]]

View File

@ -128,7 +128,7 @@ class Message(_MessageOptional):
type: MessageType
AllowedMentionType = Literal['roles', 'users', 'everyone']
AllowedMentionType = Literal["roles", "users", "everyone"]
class AllowedMentions(TypedDict):

View File

@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional
from .user import PartialUser
from .snowflake import Snowflake
class TeamMember(TypedDict):
user: PartialUser
membership_state: int
permissions: List[str]
team_id: Snowflake
class Team(TypedDict):
id: Snowflake
name: str

View File

@ -27,7 +27,9 @@ from .snowflake import Snowflake
from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
SupportedModes = Literal[
"xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"
]
class _PartialVoiceStateOptional(TypedDict, total=False):

View File

@ -35,16 +35,16 @@ from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent
__all__ = (
'Button',
'button',
"Button",
"button",
)
if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
B = TypeVar('B', bound='Button')
V = TypeVar('V', bound='View', covariant=True)
B = TypeVar("B", bound="Button")
V = TypeVar("V", bound="View", covariant=True)
class Button(Item[V]):
@ -76,12 +76,12 @@ class Button(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
'style',
'url',
'disabled',
'label',
'emoji',
'row',
"style",
"url",
"disabled",
"label",
"emoji",
"row",
)
def __init__(
@ -97,7 +97,7 @@ class Button(Item[V]):
):
super().__init__()
if custom_id is not None and url is not None:
raise TypeError('cannot mix both url and custom_id with Button')
raise TypeError("cannot mix both url and custom_id with Button")
self._provided_custom_id = custom_id is not None
if url is None and custom_id is None:
@ -112,7 +112,9 @@ class Button(Item[V]):
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
raise TypeError(
f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}"
)
self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
@ -145,7 +147,7 @@ class Button(Item[V]):
@custom_id.setter
def custom_id(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str')
raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value
@ -157,7 +159,7 @@ class Button(Item[V]):
@url.setter
def url(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str')
raise TypeError("url must be None or str")
self._underlying.url = value
@property
@ -191,7 +193,9 @@ class Button(Item[V]):
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
raise TypeError(
f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead"
)
else:
self._underlying.emoji = None
@ -273,17 +277,17 @@ def button(
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
raise TypeError("button function must be a coroutine function")
func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = {
'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'row': row,
"style": style,
"custom_id": custom_id,
"url": None,
"disabled": disabled,
"label": label,
"emoji": emoji,
"row": row,
}
return func

View File

@ -24,21 +24,30 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generic,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
)
from ..interactions import Interaction
__all__ = (
'Item',
)
__all__ = ("Item",)
if TYPE_CHECKING:
from ..enums import ComponentType
from .view import View
from ..components import Component
I = TypeVar('I', bound='Item')
V = TypeVar('V', bound='View', covariant=True)
I = TypeVar("I", bound="Item")
V = TypeVar("V", bound="View", covariant=True)
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
@ -53,7 +62,7 @@ class Item(Generic[V]):
.. versionadded:: 2.0
"""
__item_repr_attributes__: Tuple[str, ...] = ('row',)
__item_repr_attributes__: Tuple[str, ...] = ("row",)
def __init__(self):
self._view: Optional[V] = None
@ -91,8 +100,10 @@ class Item(Generic[V]):
return self._provided_custom_id
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
return f'<{self.__class__.__name__} {attrs}>'
attrs = " ".join(
f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__
)
return f"<{self.__class__.__name__} {attrs}>"
@property
def row(self) -> Optional[int]:
@ -105,7 +116,7 @@ class Item(Generic[V]):
elif 5 > value >= 0:
self._row = value
else:
raise ValueError('row cannot be negative or greater than or equal to 5')
raise ValueError("row cannot be negative or greater than or equal to 5")
@property
def width(self) -> int:

View File

@ -39,8 +39,8 @@ from ..components import (
)
__all__ = (
'Select',
'select',
"Select",
"select",
)
if TYPE_CHECKING:
@ -50,8 +50,8 @@ if TYPE_CHECKING:
ComponentInteractionData,
)
S = TypeVar('S', bound='Select')
V = TypeVar('V', bound='View', covariant=True)
S = TypeVar("S", bound="Select")
V = TypeVar("V", bound="View", covariant=True)
class Select(Item[V]):
@ -89,11 +89,11 @@ class Select(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
)
def __init__(
@ -131,7 +131,7 @@ class Select(Item[V]):
@custom_id.setter
def custom_id(self, value: str):
if not isinstance(value, str):
raise TypeError('custom_id must be None or str')
raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value
@ -143,7 +143,7 @@ class Select(Item[V]):
@placeholder.setter
def placeholder(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str')
raise TypeError("placeholder must be None or str")
self._underlying.placeholder = value
@ -173,9 +173,9 @@ class Select(Item[V]):
@options.setter
def options(self, value: List[SelectOption]):
if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption')
raise TypeError("options must be a list of SelectOption")
if not all(isinstance(obj, SelectOption) for obj in value):
raise TypeError('all list items must subclass SelectOption')
raise TypeError("all list items must subclass SelectOption")
self._underlying.options = value
@ -224,7 +224,6 @@ class Select(Item[V]):
default=default,
)
self.append_option(option)
def append_option(self, option: SelectOption):
@ -242,7 +241,7 @@ class Select(Item[V]):
"""
if len(self._underlying.options) > 25:
raise ValueError('maximum number of options already provided')
raise ValueError("maximum number of options already provided")
self._underlying.options.append(option)
@ -272,7 +271,7 @@ class Select(Item[V]):
def refresh_state(self, interaction: Interaction) -> None:
data: ComponentInteractionData = interaction.data # type: ignore
self._selected_values = data.get('values', [])
self._selected_values = data.get("values", [])
@classmethod
def from_component(cls: Type[S], component: SelectMenu) -> S:
@ -340,17 +339,17 @@ def select(
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
raise TypeError("select function must be a coroutine function")
func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder,
'custom_id': custom_id,
'row': row,
'min_values': min_values,
'max_values': max_values,
'options': options,
'disabled': disabled,
"placeholder": placeholder,
"custom_id": custom_id,
"row": row,
"min_values": min_values,
"max_values": max_values,
"options": options,
"disabled": disabled,
}
return func

View File

@ -23,7 +23,18 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Tuple
from typing import (
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
)
from functools import partial
from itertools import groupby
@ -41,9 +52,7 @@ from ..components import (
SelectMenu as SelectComponent,
)
__all__ = (
'View',
)
__all__ = ("View",)
if TYPE_CHECKING:
@ -74,9 +83,7 @@ def _component_to_item(component: Component) -> Item:
class _ViewWeights:
__slots__ = (
'weights',
)
__slots__ = ("weights",)
def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0]
@ -92,13 +99,15 @@ class _ViewWeights:
if weight + item.width <= 5:
return index
raise ValueError('could not find open space for item')
raise ValueError("could not find open space for item")
def add_item(self, item: Item) -> None:
if item.row is not None:
total = self.weights[item.row] + item.width
if total > 5:
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
raise ValueError(
f"item would not fit at row {item.row} ({total} > 5 width)"
)
self.weights[item.row] = total
item._rendered_row = item.row
else:
@ -144,11 +153,11 @@ class View:
children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
if hasattr(member, '__discord_ui_model_type__'):
if hasattr(member, "__discord_ui_model_type__"):
children.append(member)
if len(children) > 25:
raise TypeError('View cannot have more than 25 children')
raise TypeError("View cannot have more than 25 children")
cls.__view_children_items__ = children
@ -156,7 +165,9 @@ class View:
self.timeout = timeout
self.children: List[Item] = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
item: Item = func.__discord_ui_model_type__(
**func.__discord_ui_model_kwargs__
)
item.callback = partial(func, self, item)
item._view = self
setattr(self, func.__name__, item)
@ -171,7 +182,7 @@ class View:
self.__stopped: asyncio.Future[bool] = loop.create_future()
def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>"
async def __timeout_task_impl(self) -> None:
while True:
@ -203,15 +214,17 @@ class View:
components.append(
{
'type': 1,
'components': children,
"type": 1,
"components": children,
}
)
return components
@classmethod
def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View:
def from_message(
cls, message: Message, /, *, timeout: Optional[float] = 180.0
) -> View:
"""Converts a message's components into a :class:`View`.
The :attr:`.Message.components` of a message are read-only
@ -261,10 +274,10 @@ class View:
"""
if len(self.children) > 25:
raise ValueError('maximum number of children exceeded')
raise ValueError("maximum number of children exceeded")
if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}')
raise TypeError(f"expected Item not {item.__class__!r}")
self.__weights.add_item(item)
@ -327,7 +340,9 @@ class View:
"""
pass
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
async def on_error(
self, error: Exception, item: Item, interaction: Interaction
) -> None:
"""|coro|
A callback that is called when an item's callback or :meth:`interaction_check`
@ -344,8 +359,10 @@ class View:
interaction: :class:`~discord.Interaction`
The interaction that led to the failure.
"""
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
traceback.print_exception(
error.__class__, error, error.__traceback__, file=sys.stderr
)
async def _scheduled_task(self, item: Item, interaction: Interaction):
try:
@ -377,13 +394,18 @@ class View:
return
self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
asyncio.create_task(
self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}"
)
def _dispatch_item(self, item: Item, interaction: Interaction):
if self.__stopped.done():
return
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
asyncio.create_task(
self._scheduled_task(item, interaction),
name=f"discord-ui-view-dispatch-{self.id}",
)
def refresh(self, components: List[Component]):
# This is pretty hacky at the moment
@ -437,7 +459,9 @@ class View:
A persistent view has all their components with a set ``custom_id`` and
a :attr:`timeout` set to ``None``.
"""
return self.timeout is None and all(item.is_persistent() for item in self.children)
return self.timeout is None and all(
item.is_persistent() for item in self.children
)
async def wait(self) -> bool:
"""Waits until the view has finished interacting.
@ -509,7 +533,9 @@ class ViewStore:
key = (component_type, message_id, custom_id)
# Fallback to None message_id searches in case a persistent view
# was added without an associated message_id
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
value = self._views.get(key) or self._views.get(
(component_type, None, custom_id)
)
if value is None:
return

View File

@ -45,11 +45,11 @@ if TYPE_CHECKING:
__all__ = (
'User',
'ClientUser',
"User",
"ClientUser",
)
BU = TypeVar('BU', bound='BaseUser')
BU = TypeVar("BU", bound="BaseUser")
class _UserTag:
@ -59,16 +59,16 @@ class _UserTag:
class BaseUser(_UserTag):
__slots__ = (
'name',
'id',
'discriminator',
'_avatar',
'_banner',
'_accent_colour',
'bot',
'system',
'_public_flags',
'_state',
"name",
"id",
"discriminator",
"_avatar",
"_banner",
"_accent_colour",
"bot",
"system",
"_public_flags",
"_state",
)
if TYPE_CHECKING:
@ -94,7 +94,7 @@ class BaseUser(_UserTag):
)
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
return f"{self.name}#{self.discriminator}"
def __int__(self) -> int:
return self.id
@ -109,15 +109,15 @@ class BaseUser(_UserTag):
return self.id >> 22
def _update(self, data: UserPayload) -> None:
self.name = data['username']
self.id = int(data['id'])
self.discriminator = data['discriminator']
self._avatar = data['avatar']
self._banner = data.get('banner', None)
self._accent_colour = data.get('accent_color', None)
self._public_flags = data.get('public_flags', 0)
self.bot = data.get('bot', False)
self.system = data.get('system', False)
self.name = data["username"]
self.id = int(data["id"])
self.discriminator = data["discriminator"]
self._avatar = data["avatar"]
self._banner = data.get("banner", None)
self._accent_colour = data.get("accent_color", None)
self._public_flags = data.get("public_flags", 0)
self.bot = data.get("bot", False)
self.system = data.get("system", False)
@classmethod
def _copy(cls: Type[BU], user: BU) -> BU:
@ -137,11 +137,11 @@ class BaseUser(_UserTag):
def _to_minimal_user_json(self) -> Dict[str, Any]:
return {
'username': self.name,
'id': self.id,
'avatar': self._avatar,
'discriminator': self.discriminator,
'bot': self.bot,
"username": self.name,
"id": self.id,
"avatar": self._avatar,
"discriminator": self.discriminator,
"bot": self.bot,
}
@property
@ -163,7 +163,9 @@ class BaseUser(_UserTag):
@property
def default_avatar(self) -> Asset:
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
return Asset._from_default_avatar(
self._state, int(self.discriminator) % len(DefaultAvatar)
)
@property
def display_avatar(self) -> Asset:
@ -240,7 +242,7 @@ class BaseUser(_UserTag):
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>'
return f"<@{self.id}>"
@property
def created_at(self) -> datetime:
@ -324,7 +326,7 @@ class ClientUser(BaseUser):
Specifies if the user has MFA turned on and working.
"""
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
__slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__")
if TYPE_CHECKING:
verified: bool
@ -337,19 +339,21 @@ class ClientUser(BaseUser):
def __repr__(self) -> str:
return (
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
f"<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>"
)
def _update(self, data: UserPayload) -> None:
super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False)
self.locale = data.get('locale')
self._flags = data.get('flags', 0)
self.mfa_enabled = data.get('mfa_enabled', False)
self.verified = data.get("verified", False)
self.locale = data.get("locale")
self._flags = data.get("flags", 0)
self.mfa_enabled = data.get("mfa_enabled", False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
async def edit(
self, *, username: str = MISSING, avatar: bytes = MISSING
) -> ClientUser:
"""|coro|
Edits the current profile of the client.
@ -388,10 +392,10 @@ class ClientUser(BaseUser):
"""
payload: Dict[str, Any] = {}
if username is not MISSING:
payload['username'] = username
payload["username"] = username
if avatar is not MISSING:
payload['avatar'] = _bytes_to_base64_data(avatar)
payload["avatar"] = _bytes_to_base64_data(avatar)
data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data)
@ -436,14 +440,14 @@ class User(BaseUser, discord.abc.Messageable):
Specifies if the user is a system user (i.e. represents Discord officially).
"""
__slots__ = ('_stored',)
__slots__ = ("_stored",)
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
self._stored: bool = False
def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
return f"<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>"
def __del__(self) -> None:
try:
@ -481,7 +485,9 @@ class User(BaseUser, discord.abc.Messageable):
.. versionadded:: 1.7
"""
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
return [
guild for guild in self._state._guilds.values() if guild.get_member(self.id)
]
async def create_dm(self) -> DMChannel:
"""|coro|

View File

@ -72,18 +72,18 @@ else:
__all__ = (
'oauth_url',
'snowflake_time',
'time_snowflake',
'find',
'get',
'sleep_until',
'utcnow',
'remove_markdown',
'escape_markdown',
'escape_mentions',
'as_chunks',
'format_dt',
"oauth_url",
"snowflake_time",
"time_snowflake",
"find",
"get",
"sleep_until",
"utcnow",
"remove_markdown",
"escape_markdown",
"escape_mentions",
"as_chunks",
"format_dt",
)
DISCORD_EPOCH = 1420070400000
@ -97,7 +97,7 @@ class _MissingSentinel:
return False
def __repr__(self):
return '...'
return "..."
MISSING: Any = _MissingSentinel()
@ -106,7 +106,7 @@ MISSING: Any = _MissingSentinel()
class _cached_property:
def __init__(self, function):
self.function = function
self.__doc__ = getattr(function, '__doc__')
self.__doc__ = getattr(function, "__doc__")
def __get__(self, instance, owner):
if instance is None:
@ -131,15 +131,14 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Mapping[str, Any]
P = ParamSpec('P')
P = ParamSpec("P")
else:
cached_property = _cached_property
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
_Iter = Union[Iterator[T], AsyncIterator[T]]
@ -147,7 +146,7 @@ class CachedSlotProperty(Generic[T, T_co]):
def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.name = name
self.function = function
self.__doc__ = getattr(function, '__doc__')
self.__doc__ = getattr(function, "__doc__")
@overload
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
@ -177,10 +176,12 @@ class classproperty(Generic[T_co]):
return self.fget(owner)
def __set__(self, instance, value) -> None:
raise AttributeError('cannot set attribute')
raise AttributeError("cannot set attribute")
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
def cached_slot_property(
name: str,
) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
return CachedSlotProperty(name, func)
@ -245,18 +246,22 @@ def copy_doc(original: Callable) -> Callable[[T], T]:
return decorator
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
def deprecated(
instead: Optional[str] = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.simplefilter("always", DeprecationWarning) # turn off filter
if instead:
fmt = "{0.__name__} is deprecated, use {1} instead."
else:
fmt = '{0.__name__} is deprecated.'
fmt = "{0.__name__} is deprecated."
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
warnings.simplefilter('default', DeprecationWarning) # reset filter
warnings.warn(
fmt.format(func, instead), stacklevel=3, category=DeprecationWarning
)
warnings.simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs)
return decorated
@ -301,18 +306,18 @@ def oauth_url(
:class:`str`
The OAuth2 URL for inviting the bot into guilds.
"""
url = f'https://discord.com/oauth2/authorize?client_id={client_id}'
url += '&scope=' + '+'.join(scopes or ('bot',))
url = f"https://discord.com/oauth2/authorize?client_id={client_id}"
url += "&scope=" + "+".join(scopes or ("bot",))
if permissions is not MISSING:
url += f'&permissions={permissions.value}'
url += f"&permissions={permissions.value}"
if guild is not MISSING:
url += f'&guild_id={guild.id}'
url += f"&guild_id={guild.id}"
if redirect_uri is not MISSING:
from urllib.parse import urlencode
url += '&response_type=code&' + urlencode({'redirect_uri': redirect_uri})
url += "&response_type=code&" + urlencode({"redirect_uri": redirect_uri})
if disable_guild_select:
url += '&disable_guild_select=true'
url += "&disable_guild_select=true"
return url
@ -435,13 +440,15 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace('__', '.'))
pred = attrget(k.replace("__", "."))
for elem in iterable:
if pred(elem) == v:
return elem
return None
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
converted = [
(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()
]
for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
@ -463,46 +470,48 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
def _get_mime_type_for_image(data: bytes):
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
return 'image/png'
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
return 'image/jpeg'
elif data.startswith((b'\x47\x49\x46\x38\x37\x61', b'\x47\x49\x46\x38\x39\x61')):
return 'image/gif'
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP':
return 'image/webp'
if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"):
return "image/png"
elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
return "image/jpeg"
elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")):
return "image/gif"
elif data.startswith(b"RIFF") and data[8:12] == b"WEBP":
return "image/webp"
else:
raise InvalidArgument('Unsupported image type given')
raise InvalidArgument("Unsupported image type given")
def _bytes_to_base64_data(data: bytes) -> str:
fmt = 'data:{mime};base64,{data}'
fmt = "data:{mime};base64,{data}"
mime = _get_mime_type_for_image(data)
b64 = b64encode(data).decode('ascii')
b64 = b64encode(data).decode("ascii")
return fmt.format(mime=mime, data=b64)
if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode('utf-8')
return orjson.dumps(obj).decode("utf-8")
_from_json = orjson.loads # type: ignore
else:
def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
_from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
reset_after: Optional[str] = request.headers.get("X-Ratelimit-Reset-After")
if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
reset = datetime.datetime.fromtimestamp(
float(request.headers["X-Ratelimit-Reset"]), utc
)
return (reset - now).total_seconds()
else:
return float(reset_after)
@ -527,7 +536,9 @@ async def async_all(gen, *, check=_isawaitable):
async def sane_wait_for(futures, *, timeout):
ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
done, pending = await asyncio.wait(
ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED
)
if len(pending) != 0:
raise asyncio.TimeoutError()
@ -550,7 +561,9 @@ def compute_timedelta(dt: datetime.datetime):
return max((dt - now).total_seconds(), 0)
async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]:
async def sleep_until(
when: datetime.datetime, result: Optional[T] = None
) -> Optional[T]:
"""|coro|
Sleep until a specified time.
@ -612,7 +625,7 @@ class SnowflakeList(array.array):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
i = bisect_left(self, element)
@ -627,7 +640,7 @@ class SnowflakeList(array.array):
return i != len(self) and self[i] == element
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
_IS_ASCII = re.compile(r"^[\x00-\x7f]+$")
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
@ -636,7 +649,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
if match:
return match.endpos
UNICODE_WIDE_CHAR_TYPE = 'WFA'
UNICODE_WIDE_CHAR_TYPE = "WFA"
func = unicodedata.east_asian_width
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
@ -660,7 +673,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite):
return invite.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)"
m = re.match(rx, invite)
if m:
return m.group(1)
@ -688,22 +701,27 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template):
return code.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)"
m = re.match(rx, code)
if m:
return m.group(1)
return code
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|'))
_MARKDOWN_ESCAPE_SUBREGEX = "|".join(
r"\{0}(?=([\s\S]*((?<!\{0})\{0})))".format(c) for c in ("*", "`", "_", "~", "|")
)
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
_MARKDOWN_ESCAPE_COMMON = r"^>(?:>>)?\s|\[.+\]\(.+\)"
_MARKDOWN_ESCAPE_REGEX = re.compile(fr'(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})', re.MULTILINE)
_MARKDOWN_ESCAPE_REGEX = re.compile(
fr"(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})",
re.MULTILINE,
)
_URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])'
_URL_REGEX = r"(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])"
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
_MARKDOWN_STOCK_REGEX = fr"(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})"
def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
@ -732,15 +750,17 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
def replacement(match):
groupdict = match.groupdict()
return groupdict.get('url', '')
return groupdict.get("url", "")
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})'
regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE)
def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str:
def escape_markdown(
text: str, *, as_needed: bool = False, ignore_links: bool = True
) -> str:
r"""A helper function that escapes Discord's markdown.
Parameters
@ -769,18 +789,18 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
def replacement(match):
groupdict = match.groupdict()
is_url = groupdict.get('url')
is_url = groupdict.get("url")
if is_url:
return is_url
return '\\' + groupdict['markdown']
return "\\" + groupdict["markdown"]
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})'
regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE)
else:
text = re.sub(r'\\', r'\\\\', text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
text = re.sub(r"\\", r"\\\\", text)
return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text)
def escape_mentions(text: str) -> str:
@ -806,7 +826,7 @@ def escape_mentions(text: str) -> str:
:class:`str`
The text with the mentions removed.
"""
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text)
return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text)
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
@ -870,7 +890,7 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
A new iterator which yields chunks of a given size.
"""
if max_size <= 0:
raise ValueError('Chunk sizes must be greater than 0.')
raise ValueError("Chunk sizes must be greater than 0.")
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
@ -916,11 +936,11 @@ def evaluate_annotation(
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
if hasattr(tp, "__args__"):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
@ -938,10 +958,17 @@ def evaluate_annotation(
implicit_str = False
is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
evaluated_args = tuple(
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str)
for arg in args
)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.')
if is_literal and not all(
isinstance(x, (str, int, bool, type(None))) for x in evaluated_args
):
raise TypeError(
"Literal arguments must be of type str, int, bool, or NoneType."
)
if evaluated_args == args:
return tp
@ -971,7 +998,7 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)
TimestampStyle = Literal['f', 'F', 'd', 'D', 't', 'T', 'R']
TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
@ -1015,5 +1042,5 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
The formatted string.
"""
if style is None:
return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>'
return f"<t:{int(dt.timestamp())}>"
return f"<t:{int(dt.timestamp())}:{style}>"

View File

@ -66,26 +66,26 @@ if TYPE_CHECKING:
VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes,
)
has_nacl: bool
try:
import nacl.secret # type: ignore
has_nacl = True
except ImportError:
has_nacl = False
__all__ = (
'VoiceProtocol',
'VoiceClient',
"VoiceProtocol",
"VoiceClient",
)
_log = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@ -195,6 +195,7 @@ class VoiceProtocol:
key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
endpoint_ip: str
voice_port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
@ -258,15 +259,15 @@ class VoiceClient(VoiceProtocol):
warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
"xsalsa20_poly1305_lite",
"xsalsa20_poly1305_suffix",
"xsalsa20_poly1305",
)
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, 'guild', None)
return getattr(self.channel, "guild", None)
@property
def user(self) -> ClientUser:
@ -283,8 +284,8 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
channel_id = data['channel_id']
self.session_id = data["session_id"]
channel_id = data["channel_id"]
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
@ -301,20 +302,22 @@ class VoiceClient(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
_log.info('Ignoring extraneous voice server update.')
_log.info("Ignoring extraneous voice server update.")
return
self.token = data.get('token')
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
self.token = data.get("token")
self.server_id = int(data["guild_id"])
endpoint = data.get("endpoint")
if endpoint is None or self.token is None:
_log.warning('Awaiting endpoint... This requires waiting. ' \
'If timeout occurred considering raising the timeout and reconnecting.')
_log.warning(
"Awaiting endpoint... This requires waiting. "
"If timeout occurred considering raising the timeout and reconnecting."
)
return
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
self.endpoint, _, _ = endpoint.rpartition(":")
if self.endpoint.startswith("wss://"):
# Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:]
@ -335,18 +338,24 @@ class VoiceClient(VoiceProtocol):
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self) -> None:
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
_log.info(
"The voice handshake is being terminated for Channel ID %s (Guild ID %s)",
self.channel.id,
self.guild.id,
)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
_log.info(
"Starting voice handshake... (connection attempt %d)", self._connections + 1
)
self._connections += 1
def finish_handshake(self) -> None:
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
_log.info("Voice handshake complete. Endpoint found %s", self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
@ -359,8 +368,8 @@ class VoiceClient(VoiceProtocol):
self._connected.set()
return ws
async def connect(self, *, reconnect: bool, timeout: float) ->None:
_log.info('Connecting to voice...')
async def connect(self, *, reconnect: bool, timeout: float) -> None:
_log.info("Connecting to voice...")
self.timeout = timeout
for i in range(5):
@ -388,7 +397,7 @@ class VoiceClient(VoiceProtocol):
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
_log.exception('Failed to connect to voice... Retrying...')
_log.exception("Failed to connect to voice... Retrying...")
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
@ -405,7 +414,9 @@ class VoiceClient(VoiceProtocol):
self._potentially_reconnecting = True
try:
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
await asyncio.wait_for(
self._voice_server_complete.wait(), timeout=self.timeout
)
except asyncio.TimeoutError:
self._potentially_reconnecting = False
await self.disconnect(force=True)
@ -453,14 +464,21 @@ class VoiceClient(VoiceProtocol):
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
_log.info(
"Disconnecting from voice normally, close code %d.",
exc.code,
)
await self.disconnect()
break
if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.')
_log.info(
"Disconnected from voice by force... potentially reconnecting."
)
successful = await self.potential_reconnect()
if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
_log.info(
"Reconnect was unsuccessful, disconnecting from voice normally..."
)
await self.disconnect()
break
else:
@ -471,7 +489,9 @@ class VoiceClient(VoiceProtocol):
raise
retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
_log.exception(
"Disconnected from voice... Reconnecting in %.2fs.", retry
)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
@ -479,7 +499,7 @@ class VoiceClient(VoiceProtocol):
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...')
_log.warning("Could not connect to voice... Retrying...")
continue
async def disconnect(self, *, force: bool = False) -> None:
@ -527,11 +547,11 @@ class VoiceClient(VoiceProtocol):
# Formulate rtp header
header[0] = 0x80
header[1] = 0x78
struct.pack_into('>H', header, 2, self.sequence)
struct.pack_into('>I', header, 4, self.timestamp)
struct.pack_into('>I', header, 8, self.ssrc)
struct.pack_into(">H", header, 2, self.sequence)
struct.pack_into(">I", header, 4, self.timestamp)
struct.pack_into(">I", header, 8, self.ssrc)
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
encrypt_packet = getattr(self, "_encrypt_" + self.mode)
return encrypt_packet(header, data)
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
@ -551,12 +571,14 @@ class VoiceClient(VoiceProtocol):
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:4] = struct.pack('>I', self._lite_nonce)
self.checked_add('_lite_nonce', 1, 4294967295)
nonce[:4] = struct.pack(">I", self._lite_nonce)
self.checked_add("_lite_nonce", 1, 4294967295)
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None:
def play(
self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None
) -> None:
"""Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted
@ -586,13 +608,15 @@ class VoiceClient(VoiceProtocol):
"""
if not self.is_connected():
raise ClientException('Not connected to voice.')
raise ClientException("Not connected to voice.")
if self.is_playing():
raise ClientException('Already playing audio.')
raise ClientException("Already playing audio.")
if not isinstance(source, AudioSource):
raise TypeError(f'source must be an AudioSource not {source.__class__.__name__}')
raise TypeError(
f"source must be an AudioSource not {source.__class__.__name__}"
)
if not self.encoder and not source.is_opus():
self.encoder = opus.Encoder()
@ -635,10 +659,10 @@ class VoiceClient(VoiceProtocol):
@source.setter
def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource):
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
raise TypeError(f"expected AudioSource not {value.__class__.__name__}.")
if self._player is None:
raise ValueError('Not playing anything.')
raise ValueError("Not playing anything.")
self._player._set_source(value)
@ -662,7 +686,7 @@ class VoiceClient(VoiceProtocol):
Encoding the data failed.
"""
self.checked_add('sequence', 1, 65535)
self.checked_add("sequence", 1, 65535)
if encode:
encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
else:
@ -671,6 +695,10 @@ class VoiceClient(VoiceProtocol):
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
_log.warning(
"A packet has been dropped (seq: %s, timestamp: %s)",
self.sequence,
self.timestamp,
)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

View File

@ -30,13 +30,30 @@ import json
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, NamedTuple, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
TYPE_CHECKING,
Tuple,
Union,
overload,
)
from contextvars import ContextVar
import aiohttp
from .. import utils
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
from ..errors import (
InvalidArgument,
HTTPException,
Forbidden,
NotFound,
DiscordServerError,
)
from ..message import Message
from ..enums import try_enum, WebhookType
from ..user import BaseUser, User
@ -46,10 +63,10 @@ from ..mixins import Hashable
from ..channel import PartialMessageable
__all__ = (
'Webhook',
'WebhookMessage',
'PartialWebhookChannel',
'PartialWebhookGuild',
"Webhook",
"WebhookMessage",
"PartialWebhookChannel",
"PartialWebhookGuild",
)
_log = logging.getLogger(__name__)
@ -120,14 +137,14 @@ class AsyncWebhookAdapter:
self._locks[bucket] = lock = asyncio.Lock()
if payload is not None:
headers['Content-Type'] = 'application/json'
headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
headers["Authorization"] = f"Bot {auth_token}"
if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[aiohttp.ClientResponse] = None
data: Optional[Union[Dict[str, Any], str]] = None
@ -147,23 +164,30 @@ class AsyncWebhookAdapter:
to_send = form_data
try:
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
async with session.request(
method, url, data=to_send, headers=headers, params=params
) as response:
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
"Webhook ID %s with %s %s has returned status code %s",
webhook_id,
method,
url,
response.status,
)
data = (await response.text(encoding='utf-8')) or None
if data and response.headers['Content-Type'] == 'application/json':
data = (await response.text(encoding="utf-8")) or None
if (
data
and response.headers["Content-Type"] == "application/json"
):
data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status != 429:
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status != 429:
delta = utils._parse_ratelimit_header(response)
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
)
lock.delay_by(delta)
@ -171,11 +195,15 @@ class AsyncWebhookAdapter:
return data
if response.status == 429:
if not response.headers.get('Via'):
if not response.headers.get("Via"):
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
retry_after: float = data["retry_after"] # type: ignore
_log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds",
webhook_id,
retry_after,
)
await asyncio.sleep(retry_after)
continue
@ -201,7 +229,7 @@ class AsyncWebhookAdapter:
raise DiscordServerError(response, data)
raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.')
raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook(
self,
@ -211,7 +239,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token(
@ -222,7 +250,12 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session, reason=reason)
def edit_webhook(
@ -234,8 +267,10 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(
route, session, reason=reason, payload=payload, auth_token=token
)
def edit_webhook_with_token(
self,
@ -246,7 +281,12 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session, reason=reason, payload=payload)
def execute_webhook(
@ -261,11 +301,23 @@ class AsyncWebhookAdapter:
thread_id: Optional[int] = None,
wait: bool = False,
) -> Response[Optional[MessagePayload]]:
params = {'wait': int(wait)}
params = {"wait": int(wait)}
if thread_id:
params['thread_id'] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
params["thread_id"] = thread_id
route = Route(
"POST",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(
route,
session,
payload=payload,
multipart=multipart,
files=files,
params=params,
)
def get_webhook_message(
self,
@ -276,8 +328,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[MessagePayload]:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -296,13 +348,15 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None,
) -> Response[Message]:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
return self.request(
route, session, payload=payload, multipart=multipart, files=files
)
def delete_webhook_message(
self,
@ -313,8 +367,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[None]:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -328,7 +382,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token(
@ -338,7 +392,12 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"GET",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session=session)
def create_interaction_response(
@ -351,15 +410,15 @@ class AsyncWebhookAdapter:
data: Optional[Dict[str, Any]] = None,
) -> Response[None]:
payload: Dict[str, Any] = {
'type': type,
"type": type,
}
if data is not None:
payload['data'] = data
payload["data"] = data
route = Route(
'POST',
'/interactions/{webhook_id}/{webhook_token}/callback',
"POST",
"/interactions/{webhook_id}/{webhook_token}/callback",
webhook_id=interaction_id,
webhook_token=token,
)
@ -374,8 +433,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[MessagePayload]:
r = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id,
webhook_token=token,
)
@ -392,12 +451,14 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None,
) -> Response[MessagePayload]:
r = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id,
webhook_token=token,
)
return self.request(r, session, payload=payload, multipart=multipart, files=files)
return self.request(
r, session, payload=payload, multipart=multipart, files=files
)
def delete_original_interaction_response(
self,
@ -407,8 +468,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[None]:
r = Route(
'DELETE',
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
"DELETE",
"/webhooks/{webhook_id}/{wehook_token}/messages/@original",
webhook_id=application_id,
wehook_token=token,
)
@ -437,82 +498,86 @@ def handle_message_parameters(
previous_allowed_mentions: Optional[AllowedMentions] = None,
) -> ExecuteWebhookParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
raise TypeError("Cannot mix file and files keyword arguments.")
if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.')
raise TypeError("Cannot mix embed and embeds keyword arguments.")
payload = {}
if embeds is not MISSING:
if len(embeds) > 10:
raise InvalidArgument('embeds has a maximum of 10 elements.')
payload['embeds'] = [e.to_dict() for e in embeds]
raise InvalidArgument("embeds has a maximum of 10 elements.")
payload["embeds"] = [e.to_dict() for e in embeds]
if embed is not MISSING:
if embed is None:
payload['embeds'] = []
payload["embeds"] = []
else:
payload['embeds'] = [embed.to_dict()]
payload["embeds"] = [embed.to_dict()]
if content is not MISSING:
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
else:
payload['content'] = None
payload["content"] = None
if view is not MISSING:
if view is not None:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
else:
payload['components'] = []
payload["components"] = []
payload['tts'] = tts
payload["tts"] = tts
if avatar_url:
payload['avatar_url'] = str(avatar_url)
payload["avatar_url"] = str(avatar_url)
if username:
payload['username'] = username
payload["username"] = username
if ephemeral:
payload['flags'] = 64
payload["flags"] = 64
if allowed_mentions:
if previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
payload["allowed_mentions"] = previous_allowed_mentions.merge(
allowed_mentions
).to_dict()
else:
payload['allowed_mentions'] = allowed_mentions.to_dict()
payload["allowed_mentions"] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
payload["allowed_mentions"] = previous_allowed_mentions.to_dict()
multipart = []
if file is not MISSING:
files = [file]
if files:
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
multipart.append({"name": "payload_json", "value": utils._to_json(payload)})
payload = None
if len(files) == 1:
file = files[0]
multipart.append(
{
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
"name": "file",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
else:
for index, file in enumerate(files):
multipart.append(
{
'name': f'file{index}',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
"name": f"file{index}",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar(
"async_webhook_context", default=AsyncWebhookAdapter()
)
class PartialWebhookChannel(Hashable):
@ -530,14 +595,14 @@ class PartialWebhookChannel(Hashable):
The partial channel's name.
"""
__slots__ = ('id', 'name')
__slots__ = ("id", "name")
def __init__(self, *, data):
self.id = int(data['id'])
self.name = data['name']
self.id = int(data["id"])
self.name = data["name"]
def __repr__(self):
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
return f"<PartialWebhookChannel name={self.name!r} id={self.id}>"
class PartialWebhookGuild(Hashable):
@ -555,16 +620,16 @@ class PartialWebhookGuild(Hashable):
The partial guild's name.
"""
__slots__ = ('id', 'name', '_icon', '_state')
__slots__ = ("id", "name", "_icon", "_state")
def __init__(self, *, data, state):
self._state = state
self.id = int(data['id'])
self.name = data['name']
self._icon = data['icon']
self.id = int(data["id"])
self.name = data["name"]
self._icon = data["icon"]
def __repr__(self):
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
return f"<PartialWebhookGuild name={self.name!r} id={self.id}>"
@property
def icon(self) -> Optional[Asset]:
@ -578,13 +643,15 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
raise AttributeError('PartialWebhookState does not support http methods.')
raise AttributeError("PartialWebhookState does not support http methods.")
class _WebhookState:
__slots__ = ('_parent', '_webhook')
__slots__ = ("_parent", "_webhook")
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
def __init__(
self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]
):
self._webhook: Any = webhook
self._parent: Optional[ConnectionState]
@ -621,7 +688,7 @@ class _WebhookState:
if self._parent is not None:
return getattr(self._parent, attr)
raise AttributeError(f'PartialWebhookState does not support {attr!r}.')
raise AttributeError(f"PartialWebhookState does not support {attr!r}.")
class WebhookMessage(Message):
@ -750,47 +817,54 @@ class WebhookMessage(Message):
class BaseWebhook(Hashable):
__slots__: Tuple[str, ...] = (
'id',
'type',
'guild_id',
'channel_id',
'token',
'auth_token',
'user',
'name',
'_avatar',
'source_channel',
'source_guild',
'_state',
"id",
"type",
"guild_id",
"channel_id",
"token",
"auth_token",
"user",
"name",
"_avatar",
"source_channel",
"source_guild",
"_state",
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[ConnectionState] = None,
):
self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(
self, parent=state
)
self._update(data)
def _update(self, data: WebhookPayload):
self.id = int(data['id'])
self.type = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name')
self._avatar = data.get('avatar')
self.token = data.get('token')
self.id = int(data["id"])
self.type = try_enum(WebhookType, int(data["type"]))
self.channel_id = utils._get_as_snowflake(data, "channel_id")
self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.name = data.get("name")
self._avatar = data.get("avatar")
self.token = data.get("token")
user = data.get('user')
user = data.get("user")
self.user: Optional[Union[BaseUser, User]] = None
if user is not None:
# state parameter may be _WebhookState
self.user = User(state=self._state, data=user) # type: ignore
source_channel = data.get('source_channel')
source_channel = data.get("source_channel")
if source_channel:
source_channel = PartialWebhookChannel(data=source_channel)
self.source_channel: Optional[PartialWebhookChannel] = source_channel
source_guild = data.get('source_guild')
source_guild = data.get("source_guild")
if source_guild:
source_guild = PartialWebhookGuild(data=source_guild, state=self._state)
@ -927,22 +1001,35 @@ class Webhook(BaseWebhook):
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = ('session',)
__slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state=None,
):
super().__init__(data, token, state)
self.session = session
def __repr__(self):
return f'<Webhook id={self.id!r}>'
return f"<Webhook id={self.id!r}>"
@property
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def partial(
cls,
id: int,
token: str,
*,
session: aiohttp.ClientSession,
bot_token: Optional[str] = None,
) -> Webhook:
"""Creates a partial :class:`Webhook`.
Parameters
@ -970,15 +1057,21 @@ class Webhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token.
"""
data: WebhookPayload = {
'id': id,
'type': 1,
'token': token,
"id": id,
"type": 1,
"token": token,
}
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def from_url(
cls,
url: str,
*,
session: aiohttp.ClientSession,
bot_token: Optional[str] = None,
) -> Webhook:
"""Creates a partial :class:`Webhook` from a webhook URL.
Parameters
@ -1008,24 +1101,32 @@ class Webhook(BaseWebhook):
A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token.
"""
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
m = re.search(
r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})",
url,
)
if m is None:
raise InvalidArgument('Invalid webhook URL given.')
raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict()
data['type'] = 1
data["type"] = 1
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
'type': 2,
'name': name,
'channel_id': channel.id,
'guild_id': channel.guild.id,
'user': {'username': user.name, 'discriminator': user.discriminator, 'id': user.id, 'avatar': user._avatar},
"id": data["webhook_id"],
"type": 2,
"name": name,
"channel_id": channel.id,
"guild_id": channel.guild.id,
"user": {
"username": user.name,
"discriminator": user.discriminator,
"id": user.id,
"avatar": user._avatar,
},
}
state = channel._state
@ -1075,11 +1176,17 @@ class Webhook(BaseWebhook):
adapter = async_context.get()
if prefer_auth and self.auth_token:
data = await adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
data = await adapter.fetch_webhook(
self.id, self.auth_token, session=self.session
)
elif self.token:
data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
data = await adapter.fetch_webhook_with_token(
self.id, self.token, session=self.session
)
else:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
return Webhook(data, self.session, token=self.auth_token, state=self._state)
@ -1112,14 +1219,20 @@ class Webhook(BaseWebhook):
This webhook does not have a token associated with it.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter = async_context.get()
if prefer_auth and self.auth_token:
await adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
await adapter.delete_webhook(
self.id, token=self.auth_token, session=self.session, reason=reason
)
elif self.token:
await adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason)
await adapter.delete_webhook_with_token(
self.id, self.token, session=self.session, reason=reason
)
async def edit(
self,
@ -1165,14 +1278,18 @@ class Webhook(BaseWebhook):
or it tried editing a channel without authentication.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
payload = {}
if name is not MISSING:
payload['name'] = str(name) if name is not None else None
payload["name"] = str(name) if name is not None else None
if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
payload["avatar"] = (
utils._bytes_to_base64_data(avatar) if avatar is not None else None
)
adapter = async_context.get()
@ -1180,27 +1297,45 @@ class Webhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook')
raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
payload["channel_id"] = channel.id
data = await adapter.edit_webhook(
self.id,
self.auth_token,
payload=payload,
session=self.session,
reason=reason,
)
if prefer_auth and self.auth_token:
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
data = await adapter.edit_webhook(
self.id,
self.auth_token,
payload=payload,
session=self.session,
reason=reason,
)
elif self.token:
data = await adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
self.id,
self.token,
payload=payload,
session=self.session,
reason=reason,
)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
raise RuntimeError("Unreachable code hit: data was not assigned")
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
return Webhook(
data=data, session=self.session, token=self.auth_token, state=self._state
)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
@ -1350,22 +1485,30 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
)
if content is None:
content = MISSING
application_webhook = self.type is WebhookType.application
if ephemeral and not application_webhook:
raise InvalidArgument('ephemeral messages can only be sent from application webhooks')
raise InvalidArgument(
"ephemeral messages can only be sent from application webhooks"
)
if application_webhook:
wait = True
if view is not MISSING:
if isinstance(self._state, _WebhookState):
raise InvalidArgument('Webhook views require an associated state with the webhook')
raise InvalidArgument(
"Webhook views require an associated state with the webhook"
)
if ephemeral is True and view.timeout is None:
view.timeout = 15 * 60.0
@ -1439,7 +1582,9 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter = async_context.get()
data = await adapter.get_webhook_message(
@ -1525,15 +1670,21 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
if view is not MISSING:
if isinstance(self._state, _WebhookState):
raise InvalidArgument('This webhook does not have state associated with it')
raise InvalidArgument(
"This webhook does not have state associated with it"
)
self._state.prevent_view_updates_for(message_id)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
)
params = handle_message_parameters(
content=content,
file=file,
@ -1583,7 +1734,9 @@ class Webhook(BaseWebhook):
Deleted a message that is not yours.
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter = async_context.get()
await adapter.delete_webhook_message(

View File

@ -37,10 +37,28 @@ import time
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from .. import utils
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
from ..errors import (
InvalidArgument,
HTTPException,
Forbidden,
NotFound,
DiscordServerError,
)
from ..message import Message
from ..http import Route
from ..channel import PartialMessageable
@ -48,8 +66,8 @@ from ..channel import PartialMessageable
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
__all__ = (
'SyncWebhook',
'SyncWebhookMessage',
"SyncWebhook",
"SyncWebhookMessage",
)
_log = logging.getLogger(__name__)
@ -116,14 +134,14 @@ class WebhookAdapter:
self._locks[bucket] = lock = threading.Lock()
if payload is not None:
headers['Content-Type'] = 'application/json'
headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
headers["Authorization"] = f"Bot {auth_token}"
if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[Response] = None
data: Optional[Union[Dict[str, Any], str]] = None
@ -140,36 +158,50 @@ class WebhookAdapter:
if multipart:
file_data = {}
for p in multipart:
name = p['name']
if name == 'payload_json':
to_send = {'payload_json': p['value']}
name = p["name"]
if name == "payload_json":
to_send = {"payload_json": p["value"]}
else:
file_data[name] = (p['filename'], p['value'], p['content_type'])
file_data[name] = (
p["filename"],
p["value"],
p["content_type"],
)
try:
with session.request(
method, url, data=to_send, files=file_data, headers=headers, params=params
method,
url,
data=to_send,
files=file_data,
headers=headers,
params=params,
) as response:
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
"Webhook ID %s with %s %s has returned status code %s",
webhook_id,
method,
url,
response.status_code,
)
response.encoding = 'utf-8'
response.encoding = "utf-8"
# Compatibility with aiohttp
response.status = response.status_code # type: ignore
data = response.text or None
if data and response.headers['Content-Type'] == 'application/json':
if (
data
and response.headers["Content-Type"] == "application/json"
):
data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status_code != 429:
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status_code != 429:
delta = utils._parse_ratelimit_header(response)
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
)
lock.delay_by(delta)
@ -177,11 +209,15 @@ class WebhookAdapter:
return data
if response.status_code == 429:
if not response.headers.get('Via'):
if not response.headers.get("Via"):
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
retry_after: float = data["retry_after"] # type: ignore
_log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds",
webhook_id,
retry_after,
)
time.sleep(retry_after)
continue
@ -207,7 +243,7 @@ class WebhookAdapter:
raise DiscordServerError(response, data)
raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.')
raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook(
self,
@ -217,7 +253,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token(
@ -228,7 +264,12 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session, reason=reason)
def edit_webhook(
@ -240,8 +281,10 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(
route, session, reason=reason, payload=payload, auth_token=token
)
def edit_webhook_with_token(
self,
@ -252,7 +295,12 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session, reason=reason, payload=payload)
def execute_webhook(
@ -267,11 +315,23 @@ class WebhookAdapter:
thread_id: Optional[int] = None,
wait: bool = False,
):
params = {'wait': int(wait)}
params = {"wait": int(wait)}
if thread_id:
params['thread_id'] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
params["thread_id"] = thread_id
route = Route(
"POST",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(
route,
session,
payload=payload,
multipart=multipart,
files=files,
params=params,
)
def get_webhook_message(
self,
@ -282,8 +342,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -302,13 +362,15 @@ class WebhookAdapter:
files: Optional[List[File]] = None,
):
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
return self.request(
route, session, payload=payload, multipart=multipart, files=files
)
def delete_webhook_message(
self,
@ -319,8 +381,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -334,7 +396,7 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token(
@ -344,7 +406,12 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route(
"GET",
"/webhooks/{webhook_id}/{webhook_token}",
webhook_id=webhook_id,
webhook_token=token,
)
return self.request(route, session=session)
@ -516,22 +583,35 @@ class SyncWebhook(BaseWebhook):
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = ('session',)
__slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state=None,
):
super().__init__(data, token, state)
self.session = session
def __repr__(self):
return f'<Webhook id={self.id!r}>'
return f"<Webhook id={self.id!r}>"
@property
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
def partial(
cls,
id: int,
token: str,
*,
session: Session = MISSING,
bot_token: Optional[str] = None,
) -> SyncWebhook:
"""Creates a partial :class:`Webhook`.
Parameters
@ -556,21 +636,23 @@ class SyncWebhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token.
"""
data: WebhookPayload = {
'id': id,
'type': 1,
'token': token,
"id": id,
"type": 1,
"token": token,
}
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}')
raise TypeError(f"expected requests.Session not {session.__class__!r}")
else:
session = requests # type: ignore
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
def from_url(
cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None
) -> SyncWebhook:
"""Creates a partial :class:`Webhook` from a webhook URL.
Parameters
@ -597,17 +679,20 @@ class SyncWebhook(BaseWebhook):
A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token.
"""
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
m = re.search(
r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})",
url,
)
if m is None:
raise InvalidArgument('Invalid webhook URL given.')
raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict()
data['type'] = 1
data["type"] = 1
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}')
raise TypeError(f"expected requests.Session not {session.__class__!r}")
else:
session = requests # type: ignore
return cls(data, session, token=bot_token) # type: ignore
@ -648,9 +733,13 @@ class SyncWebhook(BaseWebhook):
if prefer_auth and self.auth_token:
data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
elif self.token:
data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
data = adapter.fetch_webhook_with_token(
self.id, self.token, session=self.session
)
else:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
@ -679,14 +768,20 @@ class SyncWebhook(BaseWebhook):
This webhook does not have a token associated with it.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter: WebhookAdapter = _get_webhook_adapter()
if prefer_auth and self.auth_token:
adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
adapter.delete_webhook(
self.id, token=self.auth_token, session=self.session, reason=reason
)
elif self.token:
adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason)
adapter.delete_webhook_with_token(
self.id, self.token, session=self.session, reason=reason
)
def edit(
self,
@ -731,14 +826,18 @@ class SyncWebhook(BaseWebhook):
The newly edited webhook.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
payload = {}
if name is not MISSING:
payload['name'] = str(name) if name is not None else None
payload["name"] = str(name) if name is not None else None
if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
payload["avatar"] = (
utils._bytes_to_base64_data(avatar) if avatar is not None else None
)
adapter: WebhookAdapter = _get_webhook_adapter()
@ -746,25 +845,45 @@ class SyncWebhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook')
raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
payload["channel_id"] = channel.id
data = adapter.edit_webhook(
self.id,
self.auth_token,
payload=payload,
session=self.session,
reason=reason,
)
if prefer_auth and self.auth_token:
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
data = adapter.edit_webhook(
self.id,
self.auth_token,
payload=payload,
session=self.session,
reason=reason,
)
elif self.token:
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
data = adapter.edit_webhook_with_token(
self.id,
self.token,
payload=payload,
session=self.session,
reason=reason,
)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
raise RuntimeError("Unreachable code hit: data was not assigned")
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
return SyncWebhook(
data=data, session=self.session, token=self.auth_token, state=self._state
)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
@ -887,9 +1006,13 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
)
if content is None:
content = MISSING
@ -951,7 +1074,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message(
@ -1015,9 +1140,13 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(
self._state, "allowed_mentions", None
)
params = handle_message_parameters(
content=content,
file=file,
@ -1060,7 +1189,9 @@ class SyncWebhook(BaseWebhook):
Deleted a message that is not yours.
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument(
"This webhook does not have a token associated with it"
)
adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message(

View File

@ -41,11 +41,12 @@ if TYPE_CHECKING:
)
__all__ = (
'WidgetChannel',
'WidgetMember',
'Widget',
"WidgetChannel",
"WidgetMember",
"Widget",
)
class WidgetChannel:
"""Represents a "partial" widget channel.
@ -76,7 +77,8 @@ class WidgetChannel:
position: :class:`int`
The channel's position
"""
__slots__ = ('id', 'name', 'position')
__slots__ = ("id", "name", "position")
def __init__(self, id: int, name: str, position: int) -> None:
self.id: int = id
@ -87,18 +89,19 @@ class WidgetChannel:
return self.name
def __repr__(self) -> str:
return f'<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>'
return f"<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>"
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id)
class WidgetMember(BaseUser):
"""Represents a "partial" member of the widget's guild.
@ -147,9 +150,21 @@ class WidgetMember(BaseUser):
connected_channel: Optional[:class:`WidgetChannel`]
Which channel the member is connected to.
"""
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted',
'connected_channel')
__slots__ = (
"name",
"status",
"nick",
"avatar",
"discriminator",
"id",
"bot",
"activity",
"deafened",
"suppress",
"muted",
"connected_channel",
)
if TYPE_CHECKING:
activity: Optional[Union[BaseActivity, Spotify]]
@ -159,17 +174,21 @@ class WidgetMember(BaseUser):
*,
state: ConnectionState,
data: WidgetMemberPayload,
connected_channel: Optional[WidgetChannel] = None
connected_channel: Optional[WidgetChannel] = None,
) -> None:
super().__init__(state=state, data=data)
self.nick: Optional[str] = data.get('nick')
self.status: Status = try_enum(Status, data.get('status'))
self.deafened: Optional[bool] = data.get('deaf', False) or data.get('self_deaf', False)
self.muted: Optional[bool] = data.get('mute', False) or data.get('self_mute', False)
self.suppress: Optional[bool] = data.get('suppress', False)
self.nick: Optional[str] = data.get("nick")
self.status: Status = try_enum(Status, data.get("status"))
self.deafened: Optional[bool] = data.get("deaf", False) or data.get(
"self_deaf", False
)
self.muted: Optional[bool] = data.get("mute", False) or data.get(
"self_mute", False
)
self.suppress: Optional[bool] = data.get("suppress", False)
try:
game = data['game']
game = data["game"]
except KeyError:
activity = None
else:
@ -190,6 +209,7 @@ class WidgetMember(BaseUser):
""":class:`str`: Returns the member's display name."""
return self.nick or self.name
class Widget:
"""Represents a :class:`Guild` widget.
@ -227,27 +247,34 @@ class Widget:
retrieved is capped.
"""
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
__slots__ = ("_state", "channels", "_invite", "id", "members", "name")
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state
self._invite = data['instant_invite']
self.name: str = data['name']
self.id: int = int(data['id'])
self._invite = data["instant_invite"]
self.name: str = data["name"]
self.id: int = int(data["id"])
self.channels: List[WidgetChannel] = []
for channel in data.get('channels', []):
_id = int(channel['id'])
self.channels.append(WidgetChannel(id=_id, name=channel['name'], position=channel['position']))
for channel in data.get("channels", []):
_id = int(channel["id"])
self.channels.append(
WidgetChannel(
id=_id, name=channel["name"], position=channel["position"]
)
)
self.members: List[WidgetMember] = []
channels = {channel.id: channel for channel in self.channels}
for member in data.get('members', []):
connected_channel = _get_as_snowflake(member, 'channel_id')
for member in data.get("members", []):
connected_channel = _get_as_snowflake(member, "channel_id")
if connected_channel in channels:
connected_channel = channels[connected_channel] # type: ignore
elif connected_channel:
connected_channel = WidgetChannel(id=connected_channel, name='', position=0)
connected_channel = WidgetChannel(
id=connected_channel, name="", position=0
)
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore
@ -260,7 +287,9 @@ class Widget:
return False
def __repr__(self) -> str:
return f'<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>'
return (
f"<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>"
)
@property
def created_at(self) -> datetime.datetime:

View File

@ -18,45 +18,45 @@ import re
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))
sys.path.append(os.path.abspath('extensions'))
sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath("extensions"))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'builder',
'sphinx.ext.autodoc',
'sphinx.ext.extlinks',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinxcontrib_trio',
'details',
'exception_hierarchy',
'attributetable',
'resourcelinks',
'nitpick_file_ignorer',
"builder",
"sphinx.ext.autodoc",
"sphinx.ext.extlinks",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinxcontrib_trio",
"details",
"exception_hierarchy",
"attributetable",
"resourcelinks",
"nitpick_file_ignorer",
]
autodoc_member_order = 'bysource'
autodoc_typehints = 'none'
autodoc_member_order = "bysource"
autodoc_typehints = "none"
# maybe consider this?
# napoleon_attr_annotations = False
extlinks = {
'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'GH-'),
"issue": ("https://github.com/Rapptz/discord.py/issues/%s", "GH-"),
}
# Links used for cross-referencing stuff in other documentation
intersphinx_mapping = {
'py': ('https://docs.python.org/3', None),
'aio': ('https://docs.aiohttp.org/en/stable/', None),
'req': ('https://docs.python-requests.org/en/latest/', None)
"py": ("https://docs.python.org/3", None),
"aio": ("https://docs.aiohttp.org/en/stable/", None),
"req": ("https://docs.python-requests.org/en/latest/", None),
}
rst_prolog = """
@ -67,20 +67,20 @@ rst_prolog = """
"""
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix of source filenames.
source_suffix = '.rst'
source_suffix = ".rst"
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'discord.py'
copyright = '2015-present, Rapptz'
project = "discord.py"
copyright = "2015-present, Rapptz"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@ -88,15 +88,17 @@ copyright = '2015-present, Rapptz'
#
# The short X.Y version.
version = ''
with open('../discord/__init__.py') as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
version = ""
with open("../discord/__init__.py") as f:
version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE
).group(1)
# The full version, including alpha/beta/rc tags.
release = version
# This assumes a tag is available for final releases
branch = 'master' if version.endswith('a') else 'v' + version
branch = "master" if version.endswith("a") else "v" + version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@ -105,49 +107,49 @@ branch = 'master' if version.endswith('a') else 'v' + version
# Usually you set "language" from the command line for these cases.
language = None
locale_dirs = ['locale/']
locale_dirs = ["locale/"]
gettext_compact = False
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all
# documents.
#default_role = None
# default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'friendly'
pygments_style = "friendly"
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False
# keep_warnings = False
# Nitpicky mode options
nitpick_ignore_files = [
"migrating_to_async",
"migrating",
"whats_new",
"migrating_to_async",
"migrating",
"whats_new",
]
# -- Options for HTML output ----------------------------------------------
@ -156,21 +158,21 @@ html_experimental_html5_writer = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'basic'
html_theme = "basic"
html_context = {
'discord_invite': 'https://discord.gg/r3sSKJJ',
'discord_extensions': [
('discord.ext.commands', 'ext/commands'),
('discord.ext.tasks', 'ext/tasks'),
],
"discord_invite": "https://discord.gg/r3sSKJJ",
"discord_extensions": [
("discord.ext.commands", "ext/commands"),
("discord.ext.tasks", "ext/tasks"),
],
}
resource_links = {
'discord': 'https://discord.gg/r3sSKJJ',
'issues': 'https://github.com/Rapptz/discord.py/issues',
'discussions': 'https://github.com/Rapptz/discord.py/discussions',
'examples': f'https://github.com/Rapptz/discord.py/tree/{branch}/examples',
"discord": "https://discord.gg/r3sSKJJ",
"issues": "https://github.com/Rapptz/discord.py/issues",
"discussions": "https://github.com/Rapptz/discord.py/discussions",
"examples": f"https://github.com/Rapptz/discord.py/tree/{branch}/examples",
}
# Theme options are theme-specific and customize the look and feel of a theme
@ -180,155 +182,143 @@ resource_links = {
# }
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
#html_title = None
# html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None
# html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
#html_logo = None
# html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
html_favicon = './images/discord_py_logo.ico'
html_favicon = "./images/discord_py_logo.ico"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
#html_extra_path = []
# html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
#html_additional_pages = {}
# html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# html_split_index = False
# If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True
# html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True
# html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True
# html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
#html_use_opensearch = ''
# html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# html_file_suffix = None
# Language to be used for generating the HTML full-text search index.
# Sphinx supports the following languages:
# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
#html_search_language = 'en'
# html_search_language = 'en'
# A dictionary with options for the search language support, empty by default.
# Now only 'ja' uses this config value
#html_search_options = {'type': 'default'}
# html_search_options = {'type': 'default'}
# The name of a javascript file (relative to the configuration directory) that
# implements a search results scorer. If empty, the default will be used.
html_search_scorer = '_static/scorer.js'
html_search_scorer = "_static/scorer.js"
html_js_files = [
'custom.js',
'settings.js',
'copy.js',
'sidebar.js'
]
html_js_files = ["custom.js", "settings.js", "copy.js", "sidebar.js"]
# Output file base name for HTML help builder.
htmlhelp_basename = 'discord.pydoc'
htmlhelp_basename = "discord.pydoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
('index', 'discord.py.tex', 'discord.py Documentation',
'Rapptz', 'manual'),
("index", "discord.py.tex", "discord.py Documentation", "Rapptz", "manual"),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
#latex_logo = None
# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
#latex_use_parts = False
# latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'discord.py', 'discord.py Documentation',
['Rapptz'], 1)
]
man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
# If true, show URL addresses after external links.
#man_show_urls = False
# man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
@ -337,25 +327,32 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'discord.py', 'discord.py Documentation',
'Rapptz', 'discord.py', 'One line description of project.',
'Miscellaneous'),
(
"index",
"discord.py",
"discord.py Documentation",
"Rapptz",
"discord.py",
"One line description of project.",
"Miscellaneous",
),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'
# texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
# texinfo_no_detailmenu = False
def setup(app):
if app.config.language == 'ja':
app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None)
app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg'
app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg'
if app.config.language == "ja":
app.config.intersphinx_mapping["py"] = ("https://docs.python.org/ja/3", None)
app.config.html_context["discord_invite"] = "https://discord.gg/nXzj3dg"
app.config.resource_links["discord"] = "https://discord.gg/nXzj3dg"

View File

@ -9,60 +9,78 @@ import inspect
import os
import re
class attributetable(nodes.General, nodes.Element):
pass
class attributetablecolumn(nodes.General, nodes.Element):
pass
class attributetabletitle(nodes.TextElement):
pass
class attributetableplaceholder(nodes.General, nodes.Element):
pass
class attributetablebadge(nodes.TextElement):
pass
class attributetable_item(nodes.Part, nodes.Element):
pass
def visit_attributetable_node(self, node):
class_ = node["python-class"]
self.body.append(f'<div class="py-attribute-table" data-move-to-id="{class_}">')
def visit_attributetablecolumn_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='py-attribute-table-column'))
self.body.append(self.starttag(node, "div", CLASS="py-attribute-table-column"))
def visit_attributetabletitle_node(self, node):
self.body.append(self.starttag(node, 'span'))
self.body.append(self.starttag(node, "span"))
def visit_attributetablebadge_node(self, node):
attributes = {
'class': 'py-attribute-table-badge',
'title': node['badge-type'],
"class": "py-attribute-table-badge",
"title": node["badge-type"],
}
self.body.append(self.starttag(node, 'span', **attributes))
self.body.append(self.starttag(node, "span", **attributes))
def visit_attributetable_item_node(self, node):
self.body.append(self.starttag(node, 'li', CLASS='py-attribute-table-entry'))
self.body.append(self.starttag(node, "li", CLASS="py-attribute-table-entry"))
def depart_attributetable_node(self, node):
self.body.append('</div>')
self.body.append("</div>")
def depart_attributetablecolumn_node(self, node):
self.body.append('</div>')
self.body.append("</div>")
def depart_attributetabletitle_node(self, node):
self.body.append('</span>')
self.body.append("</span>")
def depart_attributetablebadge_node(self, node):
self.body.append('</span>')
self.body.append("</span>")
def depart_attributetable_item_node(self, node):
self.body.append('</li>')
self.body.append("</li>")
_name_parser_regex = re.compile(r"(?P<module>[\w.]+\.)?(?P<name>\w+)")
_name_parser_regex = re.compile(r'(?P<module>[\w.]+\.)?(?P<name>\w+)')
class PyAttributeTable(SphinxDirective):
has_content = False
@ -74,13 +92,15 @@ class PyAttributeTable(SphinxDirective):
def parse_name(self, content):
path, name = _name_parser_regex.match(content).groups()
if path:
modulename = path.rstrip('.')
modulename = path.rstrip(".")
else:
modulename = self.env.temp_data.get('autodoc:module')
modulename = self.env.temp_data.get("autodoc:module")
if not modulename:
modulename = self.env.ref_context.get('py:module')
modulename = self.env.ref_context.get("py:module")
if modulename is None:
raise RuntimeError('modulename somehow None for %s in %s.' % (content, self.env.docname))
raise RuntimeError(
"modulename somehow None for %s in %s." % (content, self.env.docname)
)
return modulename, name
@ -112,29 +132,33 @@ class PyAttributeTable(SphinxDirective):
replaced.
"""
content = self.arguments[0].strip()
node = attributetableplaceholder('')
node = attributetableplaceholder("")
modulename, name = self.parse_name(content)
node['python-doc'] = self.env.docname
node['python-module'] = modulename
node['python-class'] = name
node['python-full-name'] = f'{modulename}.{name}'
node["python-doc"] = self.env.docname
node["python-module"] = modulename
node["python-class"] = name
node["python-full-name"] = f"{modulename}.{name}"
return [node]
def build_lookup_table(env):
# Given an environment, load up a lookup table of
# full-class-name: objects
result = {}
domain = env.domains['py']
domain = env.domains["py"]
ignored = {
'data', 'exception', 'module', 'class',
"data",
"exception",
"module",
"class",
}
for (fullname, _, objtype, docname, _, _) in domain.get_objects():
if objtype in ignored:
continue
classname, _, child = fullname.rpartition('.')
classname, _, child = fullname.rpartition(".")
try:
result[classname].append(child)
except KeyError:
@ -143,36 +167,46 @@ def build_lookup_table(env):
return result
TableElement = namedtuple('TableElement', 'fullname label badge')
TableElement = namedtuple("TableElement", "fullname label badge")
def process_attributetable(app, doctree, fromdocname):
env = app.builder.env
lookup = build_lookup_table(env)
for node in doctree.traverse(attributetableplaceholder):
modulename, classname, fullname = node['python-module'], node['python-class'], node['python-full-name']
modulename, classname, fullname = (
node["python-module"],
node["python-class"],
node["python-full-name"],
)
groups = get_class_results(lookup, modulename, classname, fullname)
table = attributetable('')
table = attributetable("")
for label, subitems in groups.items():
if not subitems:
continue
table.append(class_results_to_node(label, sorted(subitems, key=lambda c: c.label)))
table.append(
class_results_to_node(label, sorted(subitems, key=lambda c: c.label))
)
table['python-class'] = fullname
table["python-class"] = fullname
if not table:
node.replace_self([])
else:
node.replace_self([table])
def get_class_results(lookup, modulename, name, fullname):
module = importlib.import_module(modulename)
cls = getattr(module, name)
groups = OrderedDict([
(_('Attributes'), []),
(_('Methods'), []),
])
groups = OrderedDict(
[
(_("Attributes"), []),
(_("Methods"), []),
]
)
try:
members = lookup[fullname]
@ -180,8 +214,8 @@ def get_class_results(lookup, modulename, name, fullname):
return groups
for attr in members:
attrlookup = f'{fullname}.{attr}'
key = _('Attributes')
attrlookup = f"{fullname}.{attr}"
key = _("Attributes")
badge = None
label = attr
value = None
@ -192,53 +226,73 @@ def get_class_results(lookup, modulename, name, fullname):
break
if value is not None:
doc = value.__doc__ or ''
if inspect.iscoroutinefunction(value) or doc.startswith('|coro|'):
key = _('Methods')
badge = attributetablebadge('async', 'async')
badge['badge-type'] = _('coroutine')
doc = value.__doc__ or ""
if inspect.iscoroutinefunction(value) or doc.startswith("|coro|"):
key = _("Methods")
badge = attributetablebadge("async", "async")
badge["badge-type"] = _("coroutine")
elif isinstance(value, classmethod):
key = _('Methods')
label = f'{name}.{attr}'
badge = attributetablebadge('cls', 'cls')
badge['badge-type'] = _('classmethod')
key = _("Methods")
label = f"{name}.{attr}"
badge = attributetablebadge("cls", "cls")
badge["badge-type"] = _("classmethod")
elif inspect.isfunction(value):
if doc.startswith(('A decorator', 'A shortcut decorator')):
if doc.startswith(("A decorator", "A shortcut decorator")):
# finicky but surprisingly consistent
badge = attributetablebadge('@', '@')
badge['badge-type'] = _('decorator')
key = _('Methods')
badge = attributetablebadge("@", "@")
badge["badge-type"] = _("decorator")
key = _("Methods")
else:
key = _('Methods')
badge = attributetablebadge('def', 'def')
badge['badge-type'] = _('method')
key = _("Methods")
badge = attributetablebadge("def", "def")
badge["badge-type"] = _("method")
groups[key].append(TableElement(fullname=attrlookup, label=label, badge=badge))
return groups
def class_results_to_node(key, elements):
title = attributetabletitle(key, key)
ul = nodes.bullet_list('')
ul = nodes.bullet_list("")
for element in elements:
ref = nodes.reference('', '', internal=True,
refuri='#' + element.fullname,
anchorname='',
*[nodes.Text(element.label)])
para = addnodes.compact_paragraph('', '', ref)
ref = nodes.reference(
"",
"",
internal=True,
refuri="#" + element.fullname,
anchorname="",
*[nodes.Text(element.label)],
)
para = addnodes.compact_paragraph("", "", ref)
if element.badge is not None:
ul.append(attributetable_item('', element.badge, para))
ul.append(attributetable_item("", element.badge, para))
else:
ul.append(attributetable_item('', para))
ul.append(attributetable_item("", para))
return attributetablecolumn("", title, ul)
return attributetablecolumn('', title, ul)
def setup(app):
app.add_directive('attributetable', PyAttributeTable)
app.add_node(attributetable, html=(visit_attributetable_node, depart_attributetable_node))
app.add_node(attributetablecolumn, html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node))
app.add_node(attributetabletitle, html=(visit_attributetabletitle_node, depart_attributetabletitle_node))
app.add_node(attributetablebadge, html=(visit_attributetablebadge_node, depart_attributetablebadge_node))
app.add_node(attributetable_item, html=(visit_attributetable_item_node, depart_attributetable_item_node))
app.add_directive("attributetable", PyAttributeTable)
app.add_node(
attributetable, html=(visit_attributetable_node, depart_attributetable_node)
)
app.add_node(
attributetablecolumn,
html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node),
)
app.add_node(
attributetabletitle,
html=(visit_attributetabletitle_node, depart_attributetabletitle_node),
)
app.add_node(
attributetablebadge,
html=(visit_attributetablebadge_node, depart_attributetablebadge_node),
)
app.add_node(
attributetable_item,
html=(visit_attributetable_item_node, depart_attributetable_item_node),
)
app.add_node(attributetableplaceholder)
app.connect('doctree-resolved', process_attributetable)
app.connect("doctree-resolved", process_attributetable)

View File

@ -2,15 +2,15 @@ from sphinx.builders.html import StandaloneHTMLBuilder
from sphinx.environment.adapters.indexentries import IndexEntries
from sphinx.writers.html5 import HTML5Translator
class DPYHTML5Translator(HTML5Translator):
def visit_section(self, node):
self.section_level += 1
self.body.append(
self.starttag(node, 'section'))
self.body.append(self.starttag(node, "section"))
def depart_section(self, node):
self.section_level -= 1
self.body.append('</section>\n')
self.body.append("</section>\n")
def visit_table(self, node):
self.body.append('<div class="table-wrapper">')
@ -18,7 +18,8 @@ class DPYHTML5Translator(HTML5Translator):
def depart_table(self, node):
super().depart_table(node)
self.body.append('</div>')
self.body.append("</div>")
class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
# This is mostly copy pasted from Sphinx.
@ -28,50 +29,56 @@ class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
genindex = IndexEntries(self.env).create_index(self, group_entries=False)
indexcounts = []
for _k, entries in genindex:
indexcounts.append(sum(1 + len(subitems)
for _, (_, subitems, _) in entries))
indexcounts.append(
sum(1 + len(subitems) for _, (_, subitems, _) in entries)
)
genindexcontext = {
'genindexentries': genindex,
'genindexcounts': indexcounts,
'split_index': self.config.html_split_index,
"genindexentries": genindex,
"genindexcounts": indexcounts,
"split_index": self.config.html_split_index,
}
if self.config.html_split_index:
self.handle_page('genindex', genindexcontext,
'genindex-split.html')
self.handle_page('genindex-all', genindexcontext,
'genindex.html')
self.handle_page("genindex", genindexcontext, "genindex-split.html")
self.handle_page("genindex-all", genindexcontext, "genindex.html")
for (key, entries), count in zip(genindex, indexcounts):
ctx = {'key': key, 'entries': entries, 'count': count,
'genindexentries': genindex}
self.handle_page('genindex-' + key, ctx,
'genindex-single.html')
ctx = {
"key": key,
"entries": entries,
"count": count,
"genindexentries": genindex,
}
self.handle_page("genindex-" + key, ctx, "genindex-single.html")
else:
self.handle_page('genindex', genindexcontext, 'genindex.html')
self.handle_page("genindex", genindexcontext, "genindex.html")
def add_custom_jinja2(app):
env = app.builder.templates.environment
env.tests['prefixedwith'] = str.startswith
env.tests['suffixedwith'] = str.endswith
env.tests["prefixedwith"] = str.startswith
env.tests["suffixedwith"] = str.endswith
def add_builders(app):
"""This is necessary because RTD injects their own for some reason."""
app.set_translator('html', DPYHTML5Translator, override=True)
app.set_translator("html", DPYHTML5Translator, override=True)
app.add_builder(DPYStandaloneHTMLBuilder, override=True)
try:
original = app.registry.builders['readthedocs']
original = app.registry.builders["readthedocs"]
except KeyError:
pass
else:
injected_mro = tuple(base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder
for base in original.mro()[1:])
new_builder = type(original.__name__, injected_mro, {'name': 'readthedocs'})
app.set_translator('readthedocs', DPYHTML5Translator, override=True)
injected_mro = tuple(
base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder
for base in original.mro()[1:]
)
new_builder = type(original.__name__, injected_mro, {"name": "readthedocs"})
app.set_translator("readthedocs", DPYHTML5Translator, override=True)
app.add_builder(new_builder, override=True)
def setup(app):
add_builders(app)
app.connect('builder-inited', add_custom_jinja2)
app.connect("builder-inited", add_custom_jinja2)

View File

@ -3,32 +3,43 @@ from docutils.parsers.rst import states, directives
from docutils.parsers.rst.roles import set_classes
from docutils import nodes
class details(nodes.General, nodes.Element):
pass
class summary(nodes.General, nodes.Element):
pass
def visit_details_node(self, node):
self.body.append(self.starttag(node, 'details', CLASS=node.attributes.get('class', '')))
self.body.append(
self.starttag(node, "details", CLASS=node.attributes.get("class", ""))
)
def visit_summary_node(self, node):
self.body.append(self.starttag(node, 'summary', CLASS=node.attributes.get('summary-class', '')))
self.body.append(
self.starttag(node, "summary", CLASS=node.attributes.get("summary-class", ""))
)
self.body.append(node.rawsource)
def depart_details_node(self, node):
self.body.append('</details>\n')
self.body.append("</details>\n")
def depart_summary_node(self, node):
self.body.append('</summary>')
self.body.append("</summary>")
class DetailsDirective(Directive):
final_argument_whitespace = True
optional_arguments = 1
option_spec = {
'class': directives.class_option,
'summary-class': directives.class_option,
"class": directives.class_option,
"summary-class": directives.class_option,
}
has_content = True
@ -37,19 +48,22 @@ class DetailsDirective(Directive):
set_classes(self.options)
self.assert_has_content()
text = '\n'.join(self.content)
text = "\n".join(self.content)
node = details(text, **self.options)
if self.arguments:
summary_node = summary(self.arguments[0], **self.options)
summary_node.source, summary_node.line = self.state_machine.get_source_and_line(self.lineno)
(
summary_node.source,
summary_node.line,
) = self.state_machine.get_source_and_line(self.lineno)
node += summary_node
self.state.nested_parse(self.content, self.content_offset, node)
return [node]
def setup(app):
app.add_node(details, html=(visit_details_node, depart_details_node))
app.add_node(summary, html=(visit_summary_node, depart_summary_node))
app.add_directive('details', DetailsDirective)
app.add_directive("details", DetailsDirective)

View File

@ -4,24 +4,32 @@ from docutils.parsers.rst.roles import set_classes
from docutils import nodes
from sphinx.locale import _
class exception_hierarchy(nodes.General, nodes.Element):
pass
def visit_exception_hierarchy_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='exception-hierarchy-content'))
self.body.append(self.starttag(node, "div", CLASS="exception-hierarchy-content"))
def depart_exception_hierarchy_node(self, node):
self.body.append('</div>\n')
self.body.append("</div>\n")
class ExceptionHierarchyDirective(Directive):
has_content = True
def run(self):
self.assert_has_content()
node = exception_hierarchy('\n'.join(self.content))
node = exception_hierarchy("\n".join(self.content))
self.state.nested_parse(self.content, self.content_offset, node)
return [node]
def setup(app):
app.add_node(exception_hierarchy, html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node))
app.add_directive('exception_hierarchy', ExceptionHierarchyDirective)
app.add_node(
exception_hierarchy,
html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node),
)
app.add_directive("exception_hierarchy", ExceptionHierarchyDirective)

View File

@ -5,18 +5,20 @@ from sphinx.util import logging as sphinx_logging
class NitpickFileIgnorer(logging.Filter):
def __init__(self, app: Sphinx) -> None:
self.app = app
super().__init__()
def filter(self, record: sphinx_logging.SphinxLogRecord) -> bool:
if getattr(record, 'type', None) == 'ref':
return record.location.get('refdoc') not in self.app.config.nitpick_ignore_files
if getattr(record, "type", None) == "ref":
return (
record.location.get("refdoc")
not in self.app.config.nitpick_ignore_files
)
return True
def setup(app: Sphinx):
app.add_config_value('nitpick_ignore_files', [], '')
app.add_config_value("nitpick_ignore_files", [], "")
f = NitpickFileIgnorer(app)
sphinx_logging.getLogger('sphinx.transforms.post_transforms').logger.addFilter(f)
sphinx_logging.getLogger("sphinx.transforms.post_transforms").logger.addFilter(f)

View File

@ -22,7 +22,7 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
lineno: int,
inliner: Inliner,
options: Dict = {},
content: List[str] = []
content: List[str] = [],
) -> Tuple[List[Node], List[system_message]]:
text = utils.unescape(text)
@ -32,13 +32,15 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
title = full_url
pnode = nodes.reference(title, title, internal=False, refuri=full_url)
return [pnode], []
return role
def add_link_role(app: Sphinx) -> None:
app.add_role('resource', make_link_role(app.config.resource_links))
app.add_role("resource", make_link_role(app.config.resource_links))
def setup(app: Sphinx) -> Dict[str, Any]:
app.add_config_value('resource_links', {}, 'env')
app.connect('builder-inited', add_link_role)
return {'version': sphinx.__display_version__, 'parallel_read_safe': True}
app.add_config_value("resource_links", {}, "env")
app.connect("builder-inited", add_link_role)
return {"version": sphinx.__display_version__, "parallel_read_safe": True}

View File

@ -2,6 +2,7 @@ from discord.ext import tasks
import discord
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -13,18 +14,19 @@ class MyClient(discord.Client):
self.my_background_task.start()
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
@tasks.loop(seconds=60) # task runs every 60 seconds
@tasks.loop(seconds=60) # task runs every 60 seconds
async def my_background_task(self):
channel = self.get_channel(1234567) # channel ID goes here
channel = self.get_channel(1234567) # channel ID goes here
self.counter += 1
await channel.send(self.counter)
@my_background_task.before_loop
async def before_my_task(self):
await self.wait_until_ready() # wait until the bot logs in
await self.wait_until_ready() # wait until the bot logs in
client = MyClient(intents=discord.Intents(guilds=True))
client.run('token')
client.run("token")

View File

@ -1,6 +1,7 @@
import discord
import asyncio
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -9,18 +10,18 @@ class MyClient(discord.Client):
self.bg_task = self.loop.create_task(self.my_background_task())
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def my_background_task(self):
await self.wait_until_ready()
counter = 0
channel = self.get_channel(1234567) # channel ID goes here
channel = self.get_channel(1234567) # channel ID goes here
while not self.is_closed():
counter += 1
await channel.send(counter)
await asyncio.sleep(60) # task runs every 60 seconds
await asyncio.sleep(60) # task runs every 60 seconds
client = MyClient(intents=discord.Intents(guilds=True))
client.run('token')
client.run("token")

View File

@ -4,51 +4,58 @@ import discord
from discord.ext import commands
import random
description = '''An example bot to showcase the discord.ext.commands extension
description = """An example bot to showcase the discord.ext.commands extension
module.
There are a number of utility commands being showcased here.'''
There are a number of utility commands being showcased here."""
intents = discord.Intents(guilds=True, messages=True, members=True)
bot = commands.Bot(command_prefix='t-', description=description, intents=intents)
bot = commands.Bot(command_prefix="t-", description=description, intents=intents)
@bot.event
async def on_ready():
print(f'Logged in as {bot.user} (ID: {bot.user.id})')
print('------')
print(f"Logged in as {bot.user} (ID: {bot.user.id})")
print("------")
@bot.command()
async def add(ctx, left: int, right: int):
"""Adds two numbers together."""
await ctx.send(left + right)
@bot.command()
async def roll(ctx, dice: str):
"""Rolls a dice in NdN format."""
try:
rolls, limit = map(int, dice.split('d'))
rolls, limit = map(int, dice.split("d"))
except Exception:
await ctx.send('Format has to be in NdN!')
await ctx.send("Format has to be in NdN!")
return
result = ', '.join(str(random.randint(1, limit)) for r in range(rolls))
result = ", ".join(str(random.randint(1, limit)) for r in range(rolls))
await ctx.send(result)
@bot.command(description='For when you wanna settle the score some other way')
@bot.command(description="For when you wanna settle the score some other way")
async def choose(ctx, *choices: str):
"""Chooses between multiple choices."""
await ctx.send(random.choice(choices))
@bot.command()
async def repeat(ctx, times: int, content='repeating...'):
async def repeat(ctx, times: int, content="repeating..."):
"""Repeats a message multiple times."""
for i in range(times):
await ctx.send(content)
@bot.command()
async def joined(ctx, member: discord.Member):
"""Says when a member joined."""
await ctx.send(f'{member.name} joined in {member.joined_at}')
await ctx.send(f"{member.name} joined in {member.joined_at}")
@bot.group()
async def cool(ctx):
@ -57,11 +64,13 @@ async def cool(ctx):
In reality this just checks if a subcommand is being invoked.
"""
if ctx.invoked_subcommand is None:
await ctx.send(f'No, {ctx.subcommand_passed} is not cool')
await ctx.send(f"No, {ctx.subcommand_passed} is not cool")
@cool.command(name='bot')
@cool.command(name="bot")
async def _bot(ctx):
"""Is the bot cool?"""
await ctx.send('Yes, the bot is cool.')
await ctx.send("Yes, the bot is cool.")
bot.run('token')
bot.run("token")

View File

@ -6,26 +6,24 @@ import youtube_dl
from discord.ext import commands
# Suppress noise about console usage from errors
youtube_dl.utils.bug_reports_message = lambda: ''
youtube_dl.utils.bug_reports_message = lambda: ""
ytdl_format_options = {
'format': 'bestaudio/best',
'outtmpl': '%(extractor)s-%(id)s-%(title)s.%(ext)s',
'restrictfilenames': True,
'noplaylist': True,
'nocheckcertificate': True,
'ignoreerrors': False,
'logtostderr': False,
'quiet': True,
'no_warnings': True,
'default_search': 'auto',
'source_address': '0.0.0.0' # bind to ipv4 since ipv6 addresses cause issues sometimes
"format": "bestaudio/best",
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
"restrictfilenames": True,
"noplaylist": True,
"nocheckcertificate": True,
"ignoreerrors": False,
"logtostderr": False,
"quiet": True,
"no_warnings": True,
"default_search": "auto",
"source_address": "0.0.0.0", # bind to ipv4 since ipv6 addresses cause issues sometimes
}
ffmpeg_options = {
'options': '-vn'
}
ffmpeg_options = {"options": "-vn"}
ytdl = youtube_dl.YoutubeDL(ytdl_format_options)
@ -36,19 +34,21 @@ class YTDLSource(discord.PCMVolumeTransformer):
self.data = data
self.title = data.get('title')
self.url = data.get('url')
self.title = data.get("title")
self.url = data.get("url")
@classmethod
async def from_url(cls, url, *, loop=None, stream=False):
loop = loop or asyncio.get_event_loop()
data = await loop.run_in_executor(None, lambda: ytdl.extract_info(url, download=not stream))
data = await loop.run_in_executor(
None, lambda: ytdl.extract_info(url, download=not stream)
)
if 'entries' in data:
if "entries" in data:
# take first item from a playlist
data = data['entries'][0]
data = data["entries"][0]
filename = data['url'] if stream else ytdl.prepare_filename(data)
filename = data["url"] if stream else ytdl.prepare_filename(data)
return cls(discord.FFmpegPCMAudio(filename, **ffmpeg_options), data=data)
@ -70,9 +70,11 @@ class Music(commands.Cog):
"""Plays a file from the local filesystem"""
source = discord.PCMVolumeTransformer(discord.FFmpegPCMAudio(query))
ctx.voice_client.play(source, after=lambda e: print(f'Player error: {e}') if e else None)
ctx.voice_client.play(
source, after=lambda e: print(f"Player error: {e}") if e else None
)
await ctx.send(f'Now playing: {query}')
await ctx.send(f"Now playing: {query}")
@commands.command()
async def yt(self, ctx, *, url):
@ -80,9 +82,11 @@ class Music(commands.Cog):
async with ctx.typing():
player = await YTDLSource.from_url(url, loop=self.bot.loop)
ctx.voice_client.play(player, after=lambda e: print(f'Player error: {e}') if e else None)
ctx.voice_client.play(
player, after=lambda e: print(f"Player error: {e}") if e else None
)
await ctx.send(f'Now playing: {player.title}')
await ctx.send(f"Now playing: {player.title}")
@commands.command()
async def stream(self, ctx, *, url):
@ -90,9 +94,11 @@ class Music(commands.Cog):
async with ctx.typing():
player = await YTDLSource.from_url(url, loop=self.bot.loop, stream=True)
ctx.voice_client.play(player, after=lambda e: print(f'Player error: {e}') if e else None)
ctx.voice_client.play(
player, after=lambda e: print(f"Player error: {e}") if e else None
)
await ctx.send(f'Now playing: {player.title}')
await ctx.send(f"Now playing: {player.title}")
@commands.command()
async def volume(self, ctx, volume: int):
@ -123,16 +129,19 @@ class Music(commands.Cog):
elif ctx.voice_client.is_playing():
ctx.voice_client.stop()
bot = commands.Bot(
command_prefix=commands.when_mentioned_or("!"),
description='Relatively simple music bot example',
intents=discord.Intents(guilds=True, guild_messages=True, voice_states=True)
description="Relatively simple music bot example",
intents=discord.Intents(guilds=True, guild_messages=True, voice_states=True),
)
@bot.event
async def on_ready():
print(f'Logged in as {bot.user} (ID: {bot.user.id})')
print('------')
print(f"Logged in as {bot.user} (ID: {bot.user.id})")
print("------")
bot.add_cog(Music(bot))
bot.run('token')
bot.run("token")

View File

@ -7,7 +7,7 @@ from discord.ext import commands
intents = discord.Intents(guilds=True, messages=True, members=True)
bot = commands.Bot('!', intents=intents)
bot = commands.Bot("!", intents=intents)
@bot.command()
@ -28,14 +28,16 @@ async def userinfo(ctx: commands.Context, user: discord.User):
user_id = user.id
username = user.name
avatar = user.avatar.url
await ctx.send(f'User found: {user_id} -- {username}\n{avatar}')
await ctx.send(f"User found: {user_id} -- {username}\n{avatar}")
@userinfo.error
async def userinfo_error(ctx: commands.Context, error: commands.CommandError):
# if the conversion above fails for any reason, it will raise `commands.BadArgument`
# so we handle this in this error handler:
if isinstance(error, commands.BadArgument):
return await ctx.send('Couldn\'t find that user.')
return await ctx.send("Couldn't find that user.")
# Custom Converter here
class ChannelOrMemberConverter(commands.Converter):
@ -69,21 +71,25 @@ class ChannelOrMemberConverter(commands.Converter):
# If the value could not be converted we can raise an error
# so our error handlers can deal with it in one place.
# The error has to be CommandError derived, so BadArgument works fine here.
raise commands.BadArgument(f'No Member or TextChannel could be converted from "{argument}"')
raise commands.BadArgument(
f'No Member or TextChannel could be converted from "{argument}"'
)
@bot.command()
async def notify(ctx: commands.Context, target: ChannelOrMemberConverter):
# This command signature utilises the custom converter written above
# What will happen during command invocation is that the `target` above will be passed to
# the `argument` parameter of the `ChannelOrMemberConverter.convert` method and
# the `argument` parameter of the `ChannelOrMemberConverter.convert` method and
# the conversion will go through the process defined there.
await target.send(f'Hello, {target.name}!')
await target.send(f"Hello, {target.name}!")
@bot.command()
async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]):
async def ignore(
ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]
):
# This command signature utilises the `typing.Union` typehint.
# The `commands` framework attempts a conversion of each type in this Union *in order*.
# So, it will attempt to convert whatever is passed to `target` to a `discord.Member` instance.
@ -94,9 +100,16 @@ async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, dis
# To check the resulting type, `isinstance` is used
if isinstance(target, discord.Member):
await ctx.send(f'Member found: {target.mention}, adding them to the ignore list.')
elif isinstance(target, discord.TextChannel): # this could be an `else` but for completeness' sake.
await ctx.send(f'Channel found: {target.mention}, adding it to the ignore list.')
await ctx.send(
f"Member found: {target.mention}, adding them to the ignore list."
)
elif isinstance(
target, discord.TextChannel
): # this could be an `else` but for completeness' sake.
await ctx.send(
f"Channel found: {target.mention}, adding it to the ignore list."
)
# Built-in type converters.
@bot.command()
@ -109,4 +122,5 @@ async def multiply(ctx: commands.Context, number: int, maybe: bool):
return await ctx.send(number * 2)
await ctx.send(number * 5)
bot.run('token')
bot.run("token")

View File

@ -10,7 +10,7 @@ class MyContext(commands.Context):
# depending on whether value is True or False
# if its True, it'll add a green check mark
# otherwise, it'll add a red cross mark
emoji = '\N{WHITE HEAVY CHECK MARK}' if value else '\N{CROSS MARK}'
emoji = "\N{WHITE HEAVY CHECK MARK}" if value else "\N{CROSS MARK}"
try:
# this will react to the command author's message
await self.message.add_reaction(emoji)
@ -27,9 +27,10 @@ class MyBot(commands.Bot):
# subclass to the super() method, which tells the bot to
# use the new MyContext class
return await super().get_context(message, cls=cls)
bot = MyBot(command_prefix='!', intents=discord.Intents(guilds=True, messages=True))
bot = MyBot(command_prefix="!", intents=discord.Intents(guilds=True, messages=True))
@bot.command()
async def guess(ctx, number: int):
@ -42,8 +43,9 @@ async def guess(ctx, number: int):
# or a red cross mark if it wasn't
await ctx.tick(number == value)
# IMPORTANT: You shouldn't hard code your token
# these are very important, and leaking them can
# these are very important, and leaking them can
# let people do very malicious things with your
# bot. Try to use a file or something to keep
# them private, and don't commit it to GitHub

View File

@ -1,21 +1,23 @@
import discord
class MyClient(discord.Client):
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def on_message(self, message):
if message.content.startswith('!deleteme'):
msg = await message.channel.send('I will delete myself now...')
if message.content.startswith("!deleteme"):
msg = await message.channel.send("I will delete myself now...")
await msg.delete()
# this also works
await message.channel.send('Goodbye in 3 seconds...', delete_after=3.0)
await message.channel.send("Goodbye in 3 seconds...", delete_after=3.0)
async def on_message_delete(self, message):
msg = f'{message.author} has deleted the message: {message.content}'
msg = f"{message.author} has deleted the message: {message.content}"
await message.channel.send(msg)
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
client.run('token')
client.run("token")

View File

@ -1,20 +1,22 @@
import discord
import asyncio
class MyClient(discord.Client):
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def on_message(self, message):
if message.content.startswith('!editme'):
msg = await message.channel.send('10')
if message.content.startswith("!editme"):
msg = await message.channel.send("10")
await asyncio.sleep(3.0)
await msg.edit(content='40')
await msg.edit(content="40")
async def on_message_edit(self, before, after):
msg = f'**{before.author}** edited their message:\n{before.content} -> {after.content}'
msg = f"**{before.author}** edited their message:\n{before.content} -> {after.content}"
await before.channel.send(msg)
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
client.run('token')
client.run("token")

View File

@ -2,18 +2,19 @@ import discord
import random
import asyncio
class MyClient(discord.Client):
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def on_message(self, message):
# we do not want the bot to reply to itself
if message.author.id == self.user.id:
return
if message.content.startswith('$guess'):
await message.channel.send('Guess a number between 1 and 10.')
if message.content.startswith("$guess"):
await message.channel.send("Guess a number between 1 and 10.")
def is_correct(m):
return m.author == message.author and m.content.isdigit()
@ -21,14 +22,17 @@ class MyClient(discord.Client):
answer = random.randint(1, 10)
try:
guess = await self.wait_for('message', check=is_correct, timeout=5.0)
guess = await self.wait_for("message", check=is_correct, timeout=5.0)
except asyncio.TimeoutError:
return await message.channel.send(f'Sorry, you took too long it was {answer}.')
return await message.channel.send(
f"Sorry, you took too long it was {answer}."
)
if int(guess.content) == answer:
await message.channel.send('You are right!')
await message.channel.send("You are right!")
else:
await message.channel.send(f'Oops. It is actually {answer}.')
await message.channel.send(f"Oops. It is actually {answer}.")
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
client.run('token')
client.run("token")

View File

@ -2,17 +2,18 @@
import discord
class MyClient(discord.Client):
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def on_member_join(self, member):
guild = member.guild
if guild.system_channel is not None:
to_send = f'Welcome {member.mention} to {guild.name}!'
to_send = f"Welcome {member.mention} to {guild.name}!"
await guild.system_channel.send(to_send)
client = MyClient(intents=discord.Intents(guilds=True, members=True))
client.run('token')
client.run("token")

View File

@ -2,15 +2,24 @@
import discord
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.role_message_id = 0 # ID of the message that can be reacted to to add/remove a role.
self.role_message_id = (
0 # ID of the message that can be reacted to to add/remove a role.
)
self.emoji_to_role = {
discord.PartialEmoji(name='🔴'): 0, # ID of the role associated with unicode emoji '🔴'.
discord.PartialEmoji(name='🟡'): 0, # ID of the role associated with unicode emoji '🟡'.
discord.PartialEmoji(name='green', id=0): 0, # ID of the role associated with a partial emoji's ID.
discord.PartialEmoji(
name="🔴"
): 0, # ID of the role associated with unicode emoji '🔴'.
discord.PartialEmoji(
name="🟡"
): 0, # ID of the role associated with unicode emoji '🟡'.
discord.PartialEmoji(
name="green", id=0
): 0, # ID of the role associated with a partial emoji's ID.
}
async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent):
@ -78,6 +87,7 @@ class MyClient(discord.Client):
# If we want to do something in case of errors we'd do it here.
pass
intents = discord.Intents(guilds=True, members=True, guild_reactions=True)
client = MyClient(intents=intents)
client.run('token')
client.run("token")

View File

@ -1,17 +1,19 @@
import discord
class MyClient(discord.Client):
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
async def on_message(self, message):
# we do not want the bot to reply to itself
if message.author.id == self.user.id:
return
if message.content.startswith('!hello'):
await message.reply('Hello!', mention_author=True)
if message.content.startswith("!hello"):
await message.reply("Hello!", mention_author=True)
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
client.run('token')
client.run("token")

View File

@ -6,18 +6,19 @@ from discord.ext import commands
bot = commands.Bot(
command_prefix=commands.when_mentioned,
description="Nothing to see here!",
intents=discord.Intents(guilds=True, messages=True)
intents=discord.Intents(guilds=True, messages=True),
)
# the `hidden` keyword argument hides it from the help command.
# the `hidden` keyword argument hides it from the help command.
@bot.group(hidden=True)
async def secret(ctx: commands.Context):
"""What is this "secret" you speak of?"""
if ctx.invoked_subcommand is None:
await ctx.send('Shh!', delete_after=5)
await ctx.send("Shh!", delete_after=5)
def create_overwrites(ctx, *objects):
"""This is just a helper function that creates the overwrites for the
"""This is just a helper function that creates the overwrites for the
voice/text channels.
A `discord.PermissionOverwrite` allows you to determine the permissions
@ -31,40 +32,51 @@ def create_overwrites(ctx, *objects):
# a dict comprehension is being utilised here to set the same permission overwrites
# for each `discord.Role` or `discord.Member`.
overwrites = {
obj: discord.PermissionOverwrite(view_channel=True)
for obj in objects
obj: discord.PermissionOverwrite(view_channel=True) for obj in objects
}
# prevents the default role (@everyone) from viewing the channel
# if it isn't already allowed to view the channel.
overwrites.setdefault(ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False))
overwrites.setdefault(
ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False)
)
# makes sure the client is always allowed to view the channel.
overwrites[ctx.guild.me] = discord.PermissionOverwrite(view_channel=True)
return overwrites
# since these commands rely on guild related features,
# it is best to lock it to be guild-only.
@secret.command()
@commands.guild_only()
async def text(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]):
"""This makes a text channel with a specified name
async def text(
ctx: commands.Context,
name: str,
*objects: typing.Union[discord.Role, discord.Member]
):
"""This makes a text channel with a specified name
that is only visible to roles or members that are specified.
"""
overwrites = create_overwrites(ctx, *objects)
await ctx.guild.create_text_channel(
name,
overwrites=overwrites,
topic='Top secret text channel. Any leakage of this channel may result in serious trouble.',
reason='Very secret business.',
topic="Top secret text channel. Any leakage of this channel may result in serious trouble.",
reason="Very secret business.",
)
@secret.command()
@commands.guild_only()
async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]):
async def voice(
ctx: commands.Context,
name: str,
*objects: typing.Union[discord.Role, discord.Member]
):
"""This does the same thing as the `text` subcommand
but instead creates a voice channel.
"""
@ -72,14 +84,15 @@ async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord
overwrites = create_overwrites(ctx, *objects)
await ctx.guild.create_voice_channel(
name,
overwrites=overwrites,
reason='Very secret business.'
name, overwrites=overwrites, reason="Very secret business."
)
@secret.command()
@commands.guild_only()
async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role):
async def emoji(
ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role
):
"""This clones a specified emoji that only specified roles
are allowed to use.
"""
@ -90,11 +103,8 @@ async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: disc
# the key parameter here is `roles`, which controls
# what roles are able to use the emoji.
await ctx.guild.create_custom_emoji(
name=emoji.name,
image=emoji_bytes,
roles=roles,
reason='Very secret business.'
name=emoji.name, image=emoji_bytes, roles=roles, reason="Very secret business."
)
bot.run('token')
bot.run("token")

View File

@ -6,13 +6,13 @@ import discord
class Bot(commands.Bot):
def __init__(self):
super().__init__(
command_prefix=commands.when_mentioned_or('$'),
intents=discord.Intents(guilds=True, messages=True)
command_prefix=commands.when_mentioned_or("$"),
intents=discord.Intents(guilds=True, messages=True),
)
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
# Define a simple View that gives us a confirmation menu
@ -24,16 +24,18 @@ class Confirm(discord.ui.View):
# When the confirm button is pressed, set the inner value to `True` and
# stop the View from listening to more input.
# We also send the user an ephemeral message that we're confirming their choice.
@discord.ui.button(label='Confirm', style=discord.ButtonStyle.green)
async def confirm(self, button: discord.ui.Button, interaction: discord.Interaction):
await interaction.response.send_message('Confirming', ephemeral=True)
@discord.ui.button(label="Confirm", style=discord.ButtonStyle.green)
async def confirm(
self, button: discord.ui.Button, interaction: discord.Interaction
):
await interaction.response.send_message("Confirming", ephemeral=True)
self.value = True
self.stop()
# This one is similar to the confirmation button except sets the inner value to `False`
@discord.ui.button(label='Cancel', style=discord.ButtonStyle.grey)
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.grey)
async def cancel(self, button: discord.ui.Button, interaction: discord.Interaction):
await interaction.response.send_message('Cancelling', ephemeral=True)
await interaction.response.send_message("Cancelling", ephemeral=True)
self.value = False
self.stop()
@ -46,15 +48,15 @@ async def ask(ctx: commands.Context):
"""Asks the user a question to confirm something."""
# We create the view and assign it to a variable so we can wait for it later.
view = Confirm()
await ctx.send('Do you want to continue?', view=view)
await ctx.send("Do you want to continue?", view=view)
# Wait for the View to stop listening for input...
await view.wait()
if view.value is None:
print('Timed out...')
print("Timed out...")
elif view.value:
print('Confirmed...')
print("Confirmed...")
else:
print('Cancelled...')
print("Cancelled...")
bot.run('token')
bot.run("token")

Some files were not shown because too many files have changed in this diff Show More