Run black on the repository, with the default configuration. #43
@ -9,13 +9,13 @@ A basic wrapper for the Discord API.
|
||||
|
||||
"""
|
||||
|
||||
__title__ = 'discord'
|
||||
__author__ = 'Rapptz'
|
||||
__license__ = 'MIT'
|
||||
__copyright__ = 'Copyright 2015-present Rapptz'
|
||||
__version__ = '2.0.0a'
|
||||
__title__ = "discord"
|
||||
__author__ = "Rapptz"
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright 2015-present Rapptz"
|
||||
__version__ = "2.0.0a"
|
||||
|
||||
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
|
||||
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
|
||||
|
||||
import logging
|
||||
from typing import NamedTuple, Literal
|
||||
@ -69,6 +69,8 @@ class VersionInfo(NamedTuple):
|
||||
serial: int
|
||||
|
||||
|
||||
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
|
||||
version_info: VersionInfo = VersionInfo(
|
||||
major=2, minor=0, micro=0, releaselevel="alpha", serial=0
|
||||
)
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
@ -31,26 +31,37 @@ import pkg_resources
|
||||
import aiohttp
|
||||
import platform
|
||||
|
||||
|
||||
def show_version():
|
||||
entries = []
|
||||
|
||||
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
|
||||
entries.append(
|
||||
"- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(
|
||||
sys.version_info
|
||||
)
|
||||
)
|
||||
version_info = discord.version_info
|
||||
entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info))
|
||||
if version_info.releaselevel != 'final':
|
||||
pkg = pkg_resources.get_distribution('discord.py')
|
||||
entries.append(
|
||||
"- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(
|
||||
version_info
|
||||
)
|
||||
)
|
||||
if version_info.releaselevel != "final":
|
||||
pkg = pkg_resources.get_distribution("discord.py")
|
||||
if pkg:
|
||||
entries.append(f' - discord.py pkg_resources: v{pkg.version}')
|
||||
entries.append(f" - discord.py pkg_resources: v{pkg.version}")
|
||||
|
||||
entries.append(f'- aiohttp v{aiohttp.__version__}')
|
||||
entries.append(f"- aiohttp v{aiohttp.__version__}")
|
||||
uname = platform.uname()
|
||||
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
|
||||
print('\n'.join(entries))
|
||||
entries.append("- system info: {0.system} {0.release} {0.version}".format(uname))
|
||||
print("\n".join(entries))
|
||||
|
||||
|
||||
def core(parser, args):
|
||||
if args.version:
|
||||
show_version()
|
||||
|
||||
|
||||
_bot_template = """#!/usr/bin/env python3
|
||||
|
||||
from discord.ext import commands
|
||||
@ -120,7 +131,7 @@ def setup(bot):
|
||||
bot.add_cog({name}(bot))
|
||||
'''
|
||||
|
||||
_cog_extras = '''
|
||||
_cog_extras = """
|
||||
def cog_unload(self):
|
||||
# clean up logic goes here
|
||||
pass
|
||||
@ -149,22 +160,22 @@ _cog_extras = '''
|
||||
# called after a command is called here
|
||||
pass
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
# certain file names and directory names are forbidden
|
||||
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
|
||||
# although some of this doesn't apply to Linux, we might as well be consistent
|
||||
_base_table = {
|
||||
'<': '-',
|
||||
'>': '-',
|
||||
':': '-',
|
||||
'"': '-',
|
||||
"<": "-",
|
||||
">": "-",
|
||||
":": "-",
|
||||
'"': "-",
|
||||
# '/': '-', these are fine
|
||||
# '\\': '-',
|
||||
'|': '-',
|
||||
'?': '-',
|
||||
'*': '-',
|
||||
"|": "-",
|
||||
"?": "-",
|
||||
"*": "-",
|
||||
}
|
||||
|
||||
# NUL (0) and 1-31 are disallowed
|
||||
@ -172,21 +183,45 @@ _base_table.update((chr(i), None) for i in range(32))
|
||||
|
||||
_translation_table = str.maketrans(_base_table)
|
||||
|
||||
|
||||
def to_path(parser, name, *, replace_spaces=False):
|
||||
if isinstance(name, Path):
|
||||
return name
|
||||
|
||||
if sys.platform == 'win32':
|
||||
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \
|
||||
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9')
|
||||
if sys.platform == "win32":
|
||||
forbidden = (
|
||||
"CON",
|
||||
"PRN",
|
||||
"AUX",
|
||||
"NUL",
|
||||
"COM1",
|
||||
"COM2",
|
||||
"COM3",
|
||||
"COM4",
|
||||
"COM5",
|
||||
"COM6",
|
||||
"COM7",
|
||||
"COM8",
|
||||
"COM9",
|
||||
"LPT1",
|
||||
"LPT2",
|
||||
"LPT3",
|
||||
"LPT4",
|
||||
"LPT5",
|
||||
"LPT6",
|
||||
"LPT7",
|
||||
"LPT8",
|
||||
"LPT9",
|
||||
)
|
||||
if len(name) <= 4 and name.upper() in forbidden:
|
||||
parser.error('invalid directory name given, use a different one')
|
||||
parser.error("invalid directory name given, use a different one")
|
||||
|
||||
name = name.translate(_translation_table)
|
||||
if replace_spaces:
|
||||
name = name.replace(' ', '-')
|
||||
name = name.replace(" ", "-")
|
||||
return Path(name)
|
||||
|
||||
|
||||
def newbot(parser, args):
|
||||
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
|
||||
|
||||
@ -195,106 +230,147 @@ def newbot(parser, args):
|
||||
try:
|
||||
new_directory.mkdir(exist_ok=True, parents=True)
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create our bot directory ({exc})')
|
||||
parser.error(f"could not create our bot directory ({exc})")
|
||||
|
||||
cogs = new_directory / 'cogs'
|
||||
cogs = new_directory / "cogs"
|
||||
|
||||
try:
|
||||
cogs.mkdir(exist_ok=True)
|
||||
init = cogs / '__init__.py'
|
||||
init = cogs / "__init__.py"
|
||||
init.touch()
|
||||
except OSError as exc:
|
||||
print(f'warning: could not create cogs directory ({exc})')
|
||||
print(f"warning: could not create cogs directory ({exc})")
|
||||
|
||||
try:
|
||||
with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp:
|
||||
with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp:
|
||||
fp.write('token = "place your token here"\ncogs = []\n')
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create config file ({exc})')
|
||||
parser.error(f"could not create config file ({exc})")
|
||||
|
||||
try:
|
||||
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
|
||||
base = 'Bot' if not args.sharded else 'AutoShardedBot'
|
||||
with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp:
|
||||
base = "Bot" if not args.sharded else "AutoShardedBot"
|
||||
fp.write(_bot_template.format(base=base, prefix=args.prefix))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create bot file ({exc})')
|
||||
parser.error(f"could not create bot file ({exc})")
|
||||
|
||||
if not args.no_git:
|
||||
try:
|
||||
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
|
||||
with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp:
|
||||
fp.write(_gitignore_template)
|
||||
except OSError as exc:
|
||||
print(f'warning: could not create .gitignore file ({exc})')
|
||||
print(f"warning: could not create .gitignore file ({exc})")
|
||||
|
||||
print("successfully made bot at", new_directory)
|
||||
|
||||
print('successfully made bot at', new_directory)
|
||||
|
||||
def newcog(parser, args):
|
||||
cog_dir = to_path(parser, args.directory)
|
||||
try:
|
||||
cog_dir.mkdir(exist_ok=True)
|
||||
except OSError as exc:
|
||||
print(f'warning: could not create cogs directory ({exc})')
|
||||
print(f"warning: could not create cogs directory ({exc})")
|
||||
|
||||
directory = cog_dir / to_path(parser, args.name)
|
||||
directory = directory.with_suffix('.py')
|
||||
directory = directory.with_suffix(".py")
|
||||
try:
|
||||
with open(str(directory), 'w', encoding='utf-8') as fp:
|
||||
attrs = ''
|
||||
extra = _cog_extras if args.full else ''
|
||||
with open(str(directory), "w", encoding="utf-8") as fp:
|
||||
attrs = ""
|
||||
extra = _cog_extras if args.full else ""
|
||||
if args.class_name:
|
||||
name = args.class_name
|
||||
else:
|
||||
name = str(directory.stem)
|
||||
if '-' in name or '_' in name:
|
||||
translation = str.maketrans('-_', ' ')
|
||||
name = name.translate(translation).title().replace(' ', '')
|
||||
if "-" in name or "_" in name:
|
||||
translation = str.maketrans("-_", " ")
|
||||
name = name.translate(translation).title().replace(" ", "")
|
||||
else:
|
||||
name = name.title()
|
||||
|
||||
if args.display_name:
|
||||
attrs += f', name="{args.display_name}"'
|
||||
if args.hide_commands:
|
||||
attrs += ', command_attrs=dict(hidden=True)'
|
||||
attrs += ", command_attrs=dict(hidden=True)"
|
||||
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
|
||||
except OSError as exc:
|
||||
parser.error(f'could not create cog file ({exc})')
|
||||
parser.error(f"could not create cog file ({exc})")
|
||||
else:
|
||||
print('successfully made cog at', directory)
|
||||
print("successfully made cog at", directory)
|
||||
|
||||
|
||||
def add_newbot_args(subparser):
|
||||
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
|
||||
parser = subparser.add_parser(
|
||||
"newbot", help="creates a command bot project quickly"
|
||||
)
|
||||
parser.set_defaults(func=newbot)
|
||||
|
||||
parser.add_argument('name', help='the bot project name')
|
||||
parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd())
|
||||
parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='<prefix>')
|
||||
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true')
|
||||
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
|
||||
parser.add_argument("name", help="the bot project name")
|
||||
parser.add_argument(
|
||||
"directory",
|
||||
help="the directory to place it in (default: .)",
|
||||
nargs="?",
|
||||
default=Path.cwd(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharded", help="whether to use AutoShardedBot", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-git",
|
||||
help="do not create a .gitignore file",
|
||||
action="store_true",
|
||||
dest="no_git",
|
||||
)
|
||||
|
||||
|
||||
def add_newcog_args(subparser):
|
||||
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
|
||||
parser = subparser.add_parser("newcog", help="creates a new cog template quickly")
|
||||
parser.set_defaults(func=newcog)
|
||||
|
||||
parser.add_argument('name', help='the cog name')
|
||||
parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs'))
|
||||
parser.add_argument('--class-name', help='the class name of the cog (default: <name>)', dest='class_name')
|
||||
parser.add_argument('--display-name', help='the cog name (default: <name>)')
|
||||
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true')
|
||||
parser.add_argument('--full', help='add all special methods as well', action='store_true')
|
||||
parser.add_argument("name", help="the cog name")
|
||||
parser.add_argument(
|
||||
"directory",
|
||||
help="the directory to place it in (default: cogs)",
|
||||
nargs="?",
|
||||
default=Path("cogs"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class-name",
|
||||
help="the class name of the cog (default: <name>)",
|
||||
dest="class_name",
|
||||
)
|
||||
parser.add_argument("--display-name", help="the cog name (default: <name>)")
|
||||
parser.add_argument(
|
||||
"--hide-commands",
|
||||
help="whether to hide all commands in the cog",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full", help="add all special methods as well", action="store_true"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
|
||||
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="discord", description="Tools for helping with discord.py"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", action="store_true", help="shows the library version"
|
||||
)
|
||||
parser.set_defaults(func=core)
|
||||
|
||||
subparser = parser.add_subparsers(dest='subcommand', title='subcommands')
|
||||
subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
|
||||
add_newbot_args(subparser)
|
||||
add_newcog_args(subparser)
|
||||
return parser, parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
parser, args = parse_args()
|
||||
args.func(parser, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
250
discord/abc.py
250
discord/abc.py
@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem
|
||||
from . import utils
|
||||
|
||||
__all__ = (
|
||||
'Snowflake',
|
||||
'User',
|
||||
'PrivateChannel',
|
||||
'GuildChannel',
|
||||
'Messageable',
|
||||
'Connectable',
|
||||
"Snowflake",
|
||||
"User",
|
||||
"PrivateChannel",
|
||||
"GuildChannel",
|
||||
"Messageable",
|
||||
"Connectable",
|
||||
)
|
||||
|
||||
T = TypeVar('T', bound=VoiceProtocol)
|
||||
T = TypeVar("T", bound=VoiceProtocol)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
@ -89,7 +89,9 @@ if TYPE_CHECKING:
|
||||
OverwriteType,
|
||||
)
|
||||
|
||||
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
|
||||
PartialMessageableChannel = Union[
|
||||
TextChannel, Thread, DMChannel, PartialMessageable
|
||||
]
|
||||
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
|
||||
SnowflakeTime = Union["Snowflake", datetime]
|
||||
|
||||
@ -98,7 +100,7 @@ MISSING = utils.MISSING
|
||||
|
||||
class _Undefined:
|
||||
def __repr__(self) -> str:
|
||||
return 'see-below'
|
||||
return "see-below"
|
||||
|
||||
|
||||
_undefined: Any = _Undefined()
|
||||
@ -189,23 +191,23 @@ class PrivateChannel(Snowflake, Protocol):
|
||||
|
||||
|
||||
class _Overwrites:
|
||||
__slots__ = ('id', 'allow', 'deny', 'type')
|
||||
__slots__ = ("id", "allow", "deny", "type")
|
||||
|
||||
ROLE = 0
|
||||
MEMBER = 1
|
||||
|
||||
def __init__(self, data: PermissionOverwritePayload):
|
||||
self.id: int = int(data['id'])
|
||||
self.allow: int = int(data.get('allow', 0))
|
||||
self.deny: int = int(data.get('deny', 0))
|
||||
self.type: OverwriteType = data['type']
|
||||
self.id: int = int(data["id"])
|
||||
self.allow: int = int(data.get("allow", 0))
|
||||
self.deny: int = int(data.get("deny", 0))
|
||||
self.type: OverwriteType = data["type"]
|
||||
|
||||
def _asdict(self) -> PermissionOverwritePayload:
|
||||
return {
|
||||
'id': self.id,
|
||||
'allow': str(self.allow),
|
||||
'deny': str(self.deny),
|
||||
'type': self.type,
|
||||
"id": self.id,
|
||||
"allow": str(self.allow),
|
||||
"deny": str(self.deny),
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
def is_role(self) -> bool:
|
||||
@ -215,7 +217,7 @@ class _Overwrites:
|
||||
return self.type == 1
|
||||
|
||||
|
||||
GCH = TypeVar('GCH', bound='GuildChannel')
|
||||
GCH = TypeVar("GCH", bound="GuildChannel")
|
||||
|
||||
|
||||
class GuildChannel:
|
||||
@ -254,7 +256,9 @@ class GuildChannel:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]):
|
||||
def __init__(
|
||||
self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]
|
||||
):
|
||||
...
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -276,11 +280,13 @@ class GuildChannel:
|
||||
reason: Optional[str],
|
||||
) -> None:
|
||||
if position < 0:
|
||||
raise InvalidArgument('Channel position cannot be less than 0.')
|
||||
raise InvalidArgument("Channel position cannot be less than 0.")
|
||||
|
||||
http = self._state.http
|
||||
bucket = self._sorting_bucket
|
||||
channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]
|
||||
channels: List[GuildChannel] = [
|
||||
c for c in self.guild.channels if c._sorting_bucket == bucket
|
||||
]
|
||||
|
||||
channels.sort(key=lambda c: c.position)
|
||||
|
||||
@ -291,106 +297,124 @@ class GuildChannel:
|
||||
# not there somehow lol
|
||||
return
|
||||
else:
|
||||
index = next((i for i, c in enumerate(channels) if c.position >= position), len(channels))
|
||||
index = next(
|
||||
(i for i, c in enumerate(channels) if c.position >= position),
|
||||
len(channels),
|
||||
)
|
||||
# add ourselves at our designated position
|
||||
channels.insert(index, self)
|
||||
|
||||
payload = []
|
||||
for index, c in enumerate(channels):
|
||||
d: Dict[str, Any] = {'id': c.id, 'position': index}
|
||||
d: Dict[str, Any] = {"id": c.id, "position": index}
|
||||
if parent_id is not _undefined and c.id == self.id:
|
||||
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
|
||||
payload.append(d)
|
||||
|
||||
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
|
||||
|
||||
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
|
||||
async def _edit(
|
||||
self, options: Dict[str, Any], reason: Optional[str]
|
||||
) -> Optional[ChannelPayload]:
|
||||
try:
|
||||
parent = options.pop('category')
|
||||
parent = options.pop("category")
|
||||
except KeyError:
|
||||
parent_id = _undefined
|
||||
else:
|
||||
parent_id = parent and parent.id
|
||||
|
||||
try:
|
||||
options['rate_limit_per_user'] = options.pop('slowmode_delay')
|
||||
options["rate_limit_per_user"] = options.pop("slowmode_delay")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
rtc_region = options.pop('rtc_region')
|
||||
rtc_region = options.pop("rtc_region")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
options['rtc_region'] = None if rtc_region is None else str(rtc_region)
|
||||
options["rtc_region"] = None if rtc_region is None else str(rtc_region)
|
||||
|
||||
try:
|
||||
video_quality_mode = options.pop('video_quality_mode')
|
||||
video_quality_mode = options.pop("video_quality_mode")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
options['video_quality_mode'] = int(video_quality_mode)
|
||||
options["video_quality_mode"] = int(video_quality_mode)
|
||||
|
||||
lock_permissions = options.pop('sync_permissions', False)
|
||||
lock_permissions = options.pop("sync_permissions", False)
|
||||
|
||||
try:
|
||||
position = options.pop('position')
|
||||
position = options.pop("position")
|
||||
except KeyError:
|
||||
if parent_id is not _undefined:
|
||||
if lock_permissions:
|
||||
category = self.guild.get_channel(parent_id)
|
||||
if category:
|
||||
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
|
||||
options['parent_id'] = parent_id
|
||||
options["permission_overwrites"] = [
|
||||
c._asdict() for c in category._overwrites
|
||||
]
|
||||
options["parent_id"] = parent_id
|
||||
elif lock_permissions and self.category_id is not None:
|
||||
# if we're syncing permissions on a pre-existing channel category without changing it
|
||||
# we need to update the permissions to point to the pre-existing category
|
||||
category = self.guild.get_channel(self.category_id)
|
||||
if category:
|
||||
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
|
||||
options["permission_overwrites"] = [
|
||||
c._asdict() for c in category._overwrites
|
||||
]
|
||||
else:
|
||||
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
|
||||
await self._move(
|
||||
position,
|
||||
parent_id=parent_id,
|
||||
lock_permissions=lock_permissions,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
overwrites = options.get('overwrites', None)
|
||||
overwrites = options.get("overwrites", None)
|
||||
if overwrites is not None:
|
||||
perms = []
|
||||
for target, perm in overwrites.items():
|
||||
if not isinstance(perm, PermissionOverwrite):
|
||||
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}')
|
||||
raise InvalidArgument(
|
||||
f"Expected PermissionOverwrite received {perm.__class__.__name__}"
|
||||
)
|
||||
|
||||
allow, deny = perm.pair()
|
||||
payload = {
|
||||
'allow': allow.value,
|
||||
'deny': deny.value,
|
||||
'id': target.id,
|
||||
"allow": allow.value,
|
||||
"deny": deny.value,
|
||||
"id": target.id,
|
||||
}
|
||||
|
||||
if isinstance(target, Role):
|
||||
payload['type'] = _Overwrites.ROLE
|
||||
payload["type"] = _Overwrites.ROLE
|
||||
else:
|
||||
payload['type'] = _Overwrites.MEMBER
|
||||
payload["type"] = _Overwrites.MEMBER
|
||||
|
||||
perms.append(payload)
|
||||
options['permission_overwrites'] = perms
|
||||
options["permission_overwrites"] = perms
|
||||
|
||||
try:
|
||||
ch_type = options['type']
|
||||
ch_type = options["type"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if not isinstance(ch_type, ChannelType):
|
||||
raise InvalidArgument('type field must be of type ChannelType')
|
||||
options['type'] = ch_type.value
|
||||
raise InvalidArgument("type field must be of type ChannelType")
|
||||
options["type"] = ch_type.value
|
||||
|
||||
if options:
|
||||
return await self._state.http.edit_channel(self.id, reason=reason, **options)
|
||||
return await self._state.http.edit_channel(
|
||||
self.id, reason=reason, **options
|
||||
)
|
||||
|
||||
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
|
||||
self._overwrites = []
|
||||
everyone_index = 0
|
||||
everyone_id = self.guild.id
|
||||
|
||||
for index, overridden in enumerate(data.get('permission_overwrites', [])):
|
||||
for index, overridden in enumerate(data.get("permission_overwrites", [])):
|
||||
overwrite = _Overwrites(overridden)
|
||||
self._overwrites.append(overwrite)
|
||||
|
||||
@ -429,7 +453,7 @@ class GuildChannel:
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: The string that allows you to mention the channel."""
|
||||
return f'<#{self.id}>'
|
||||
return f"<#{self.id}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime:
|
||||
@ -589,7 +613,9 @@ class GuildChannel:
|
||||
try:
|
||||
maybe_everyone = self._overwrites[0]
|
||||
if maybe_everyone.id == self.guild.id:
|
||||
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
|
||||
base.handle_overwrite(
|
||||
allow=maybe_everyone.allow, deny=maybe_everyone.deny
|
||||
)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@ -620,7 +646,9 @@ class GuildChannel:
|
||||
try:
|
||||
maybe_everyone = self._overwrites[0]
|
||||
if maybe_everyone.id == self.guild.id:
|
||||
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
|
||||
base.handle_overwrite(
|
||||
allow=maybe_everyone.allow, deny=maybe_everyone.deny
|
||||
)
|
||||
remaining_overwrites = self._overwrites[1:]
|
||||
else:
|
||||
remaining_overwrites = self._overwrites
|
||||
@ -703,7 +731,9 @@ class GuildChannel:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
|
||||
async def set_permissions(
|
||||
self, target, *, overwrite=_undefined, reason=None, **permissions
|
||||
):
|
||||
r"""|coro|
|
||||
|
||||
Sets the channel specific permission overwrites for a target in the
|
||||
@ -779,18 +809,18 @@ class GuildChannel:
|
||||
elif isinstance(target, Role):
|
||||
perm_type = _Overwrites.ROLE
|
||||
else:
|
||||
raise InvalidArgument('target parameter must be either Member or Role')
|
||||
raise InvalidArgument("target parameter must be either Member or Role")
|
||||
|
||||
if overwrite is _undefined:
|
||||
if len(permissions) == 0:
|
||||
raise InvalidArgument('No overwrite provided.')
|
||||
raise InvalidArgument("No overwrite provided.")
|
||||
try:
|
||||
overwrite = PermissionOverwrite(**permissions)
|
||||
except (ValueError, TypeError):
|
||||
raise InvalidArgument('Invalid permissions given to keyword arguments.')
|
||||
raise InvalidArgument("Invalid permissions given to keyword arguments.")
|
||||
else:
|
||||
if len(permissions) > 0:
|
||||
raise InvalidArgument('Cannot mix overwrite and keyword arguments.')
|
||||
raise InvalidArgument("Cannot mix overwrite and keyword arguments.")
|
||||
|
||||
# TODO: wait for event
|
||||
|
||||
@ -798,9 +828,11 @@ class GuildChannel:
|
||||
await http.delete_channel_permissions(self.id, target.id, reason=reason)
|
||||
elif isinstance(overwrite, PermissionOverwrite):
|
||||
(allow, deny) = overwrite.pair()
|
||||
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
|
||||
await http.edit_channel_permissions(
|
||||
self.id, target.id, allow.value, deny.value, perm_type, reason=reason
|
||||
)
|
||||
else:
|
||||
raise InvalidArgument('Invalid overwrite type provided.')
|
||||
raise InvalidArgument("Invalid overwrite type provided.")
|
||||
|
||||
async def _clone_impl(
|
||||
self: GCH,
|
||||
@ -809,19 +841,23 @@ class GuildChannel:
|
||||
name: Optional[str] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> GCH:
|
||||
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
|
||||
base_attrs['parent_id'] = self.category_id
|
||||
base_attrs['name'] = name or self.name
|
||||
base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites]
|
||||
base_attrs["parent_id"] = self.category_id
|
||||
base_attrs["name"] = name or self.name
|
||||
guild_id = self.guild.id
|
||||
cls = self.__class__
|
||||
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
|
||||
data = await self._state.http.create_channel(
|
||||
guild_id, self.type.value, reason=reason, **base_attrs
|
||||
)
|
||||
obj = cls(state=self._state, guild=self.guild, data=data)
|
||||
|
||||
# temporarily add it to the cache
|
||||
self.guild._channels[obj.id] = obj # type: ignore
|
||||
return obj
|
||||
|
||||
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
|
||||
async def clone(
|
||||
self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> GCH:
|
||||
"""|coro|
|
||||
|
||||
Clones this channel. This creates a channel with the same properties
|
||||
@ -964,14 +1000,16 @@ class GuildChannel:
|
||||
if not kwargs:
|
||||
return
|
||||
|
||||
beginning, end = kwargs.get('beginning'), kwargs.get('end')
|
||||
before, after = kwargs.get('before'), kwargs.get('after')
|
||||
offset = kwargs.get('offset', 0)
|
||||
beginning, end = kwargs.get("beginning"), kwargs.get("end")
|
||||
before, after = kwargs.get("before"), kwargs.get("after")
|
||||
offset = kwargs.get("offset", 0)
|
||||
if sum(bool(a) for a in (beginning, end, before, after)) > 1:
|
||||
raise InvalidArgument('Only one of [before, after, end, beginning] can be used.')
|
||||
raise InvalidArgument(
|
||||
"Only one of [before, after, end, beginning] can be used."
|
||||
)
|
||||
|
||||
bucket = self._sorting_bucket
|
||||
parent_id = kwargs.get('category', MISSING)
|
||||
parent_id = kwargs.get("category", MISSING)
|
||||
# fmt: off
|
||||
channels: List[GuildChannel]
|
||||
if parent_id not in (MISSING, None):
|
||||
@ -1008,22 +1046,26 @@ class GuildChannel:
|
||||
elif before:
|
||||
index = next((i for i, c in enumerate(channels) if c.id == before.id), None)
|
||||
elif after:
|
||||
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)
|
||||
index = next(
|
||||
(i + 1 for i, c in enumerate(channels) if c.id == after.id), None
|
||||
)
|
||||
|
||||
if index is None:
|
||||
raise InvalidArgument('Could not resolve appropriate move position')
|
||||
raise InvalidArgument("Could not resolve appropriate move position")
|
||||
|
||||
channels.insert(max((index + offset), 0), self)
|
||||
payload = []
|
||||
lock_permissions = kwargs.get('sync_permissions', False)
|
||||
reason = kwargs.get('reason')
|
||||
lock_permissions = kwargs.get("sync_permissions", False)
|
||||
reason = kwargs.get("reason")
|
||||
for index, channel in enumerate(channels):
|
||||
d = {'id': channel.id, 'position': index}
|
||||
d = {"id": channel.id, "position": index}
|
||||
if parent_id is not MISSING and channel.id == self.id:
|
||||
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
|
||||
payload.append(d)
|
||||
|
||||
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
|
||||
await self._state.http.bulk_channel_update(
|
||||
self.guild.id, payload, reason=reason
|
||||
)
|
||||
|
||||
async def create_invite(
|
||||
self,
|
||||
@ -1126,7 +1168,10 @@ class GuildChannel:
|
||||
state = self._state
|
||||
data = await state.http.invites_from_channel(self.id)
|
||||
guild = self.guild
|
||||
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]
|
||||
return [
|
||||
Invite(state=state, data=invite, channel=self, guild=guild)
|
||||
for invite in data
|
||||
]
|
||||
|
||||
|
||||
class Messageable:
|
||||
@ -1332,14 +1377,18 @@ class Messageable:
|
||||
content = str(content) if content is not None else None
|
||||
|
||||
if embed is not None and embeds is not None:
|
||||
raise InvalidArgument('cannot pass both embed and embeds parameter to send()')
|
||||
raise InvalidArgument(
|
||||
"cannot pass both embed and embeds parameter to send()"
|
||||
)
|
||||
|
||||
if embed is not None:
|
||||
embed = embed.to_dict()
|
||||
|
||||
elif embeds is not None:
|
||||
if len(embeds) > 10:
|
||||
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
|
||||
raise InvalidArgument(
|
||||
"embeds parameter must be a list of up to 10 elements"
|
||||
)
|
||||
embeds = [embed.to_dict() for embed in embeds]
|
||||
|
||||
if stickers is not None:
|
||||
@ -1347,36 +1396,44 @@ class Messageable:
|
||||
|
||||
if allowed_mentions is not None:
|
||||
if state.allowed_mentions is not None:
|
||||
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
|
||||
allowed_mentions = state.allowed_mentions.merge(
|
||||
allowed_mentions
|
||||
).to_dict()
|
||||
else:
|
||||
allowed_mentions = allowed_mentions.to_dict()
|
||||
else:
|
||||
allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict()
|
||||
allowed_mentions = (
|
||||
state.allowed_mentions and state.allowed_mentions.to_dict()
|
||||
)
|
||||
|
||||
if mention_author is not None:
|
||||
allowed_mentions = allowed_mentions or AllowedMentions().to_dict()
|
||||
allowed_mentions['replied_user'] = bool(mention_author)
|
||||
allowed_mentions["replied_user"] = bool(mention_author)
|
||||
|
||||
if reference is not None:
|
||||
try:
|
||||
reference = reference.to_message_reference_dict()
|
||||
except AttributeError:
|
||||
raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None
|
||||
raise InvalidArgument(
|
||||
"reference parameter must be Message, MessageReference, or PartialMessage"
|
||||
) from None
|
||||
|
||||
if view:
|
||||
if not hasattr(view, '__discord_ui_view__'):
|
||||
raise InvalidArgument(f'view parameter must be View not {view.__class__!r}')
|
||||
if not hasattr(view, "__discord_ui_view__"):
|
||||
raise InvalidArgument(
|
||||
f"view parameter must be View not {view.__class__!r}"
|
||||
)
|
||||
|
||||
components = view.to_components()
|
||||
else:
|
||||
components = None
|
||||
|
||||
if file is not None and files is not None:
|
||||
raise InvalidArgument('cannot pass both file and files parameter to send()')
|
||||
raise InvalidArgument("cannot pass both file and files parameter to send()")
|
||||
|
||||
if file is not None:
|
||||
if not isinstance(file, File):
|
||||
raise InvalidArgument('file parameter must be File')
|
||||
raise InvalidArgument("file parameter must be File")
|
||||
|
||||
try:
|
||||
data = await state.http.send_files(
|
||||
@ -1397,9 +1454,11 @@ class Messageable:
|
||||
|
||||
elif files is not None:
|
||||
if len(files) > 10:
|
||||
raise InvalidArgument('files parameter must be a list of up to 10 elements')
|
||||
raise InvalidArgument(
|
||||
"files parameter must be a list of up to 10 elements"
|
||||
)
|
||||
elif not all(isinstance(file, File) for file in files):
|
||||
raise InvalidArgument('files parameter must be a list of File')
|
||||
raise InvalidArgument("files parameter must be a list of File")
|
||||
|
||||
try:
|
||||
data = await state.http.send_files(
|
||||
@ -1594,7 +1653,14 @@ class Messageable:
|
||||
:class:`~discord.Message`
|
||||
The message with the message data parsed.
|
||||
"""
|
||||
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
|
||||
return HistoryIterator(
|
||||
self,
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
around=around,
|
||||
oldest_first=oldest_first,
|
||||
)
|
||||
|
||||
|
||||
class Connectable(Protocol):
|
||||
@ -1666,13 +1732,13 @@ class Connectable(Protocol):
|
||||
state = self._state
|
||||
|
||||
if state._get_voice_client(key_id):
|
||||
raise ClientException('Already connected to a voice channel.')
|
||||
raise ClientException("Already connected to a voice channel.")
|
||||
|
||||
client = state._get_client()
|
||||
voice = cls(client, self)
|
||||
|
||||
if not isinstance(voice, VoiceProtocol):
|
||||
raise TypeError('Type must meet VoiceProtocol abstract base class.')
|
||||
raise TypeError("Type must meet VoiceProtocol abstract base class.")
|
||||
|
||||
state._add_voice_client(key_id, voice)
|
||||
|
||||
|
@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji
|
||||
from .utils import _get_as_snowflake
|
||||
|
||||
__all__ = (
|
||||
'BaseActivity',
|
||||
'Activity',
|
||||
'Streaming',
|
||||
'Game',
|
||||
'Spotify',
|
||||
'CustomActivity',
|
||||
"BaseActivity",
|
||||
"Activity",
|
||||
"Streaming",
|
||||
"Game",
|
||||
"Spotify",
|
||||
"CustomActivity",
|
||||
)
|
||||
|
||||
"""If curious, this is the current schema for an activity.
|
||||
@ -119,10 +119,10 @@ class BaseActivity:
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
__slots__ = ('_created_at',)
|
||||
__slots__ = ("_created_at",)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._created_at: Optional[float] = kwargs.pop('created_at', None)
|
||||
self._created_at: Optional[float] = kwargs.pop("created_at", None)
|
||||
|
||||
@property
|
||||
def created_at(self) -> Optional[datetime.datetime]:
|
||||
@ -131,7 +131,9 @@ class BaseActivity:
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
if self._created_at is not None:
|
||||
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._created_at / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
|
||||
def to_dict(self) -> ActivityPayload:
|
||||
raise NotImplementedError
|
||||
@ -199,58 +201,62 @@ class Activity(BaseActivity):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'state',
|
||||
'details',
|
||||
'_created_at',
|
||||
'timestamps',
|
||||
'assets',
|
||||
'party',
|
||||
'flags',
|
||||
'sync_id',
|
||||
'session_id',
|
||||
'type',
|
||||
'name',
|
||||
'url',
|
||||
'application_id',
|
||||
'emoji',
|
||||
'buttons',
|
||||
"state",
|
||||
"details",
|
||||
"_created_at",
|
||||
"timestamps",
|
||||
"assets",
|
||||
"party",
|
||||
"flags",
|
||||
"sync_id",
|
||||
"session_id",
|
||||
"type",
|
||||
"name",
|
||||
"url",
|
||||
"application_id",
|
||||
"emoji",
|
||||
"buttons",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.state: Optional[str] = kwargs.pop('state', None)
|
||||
self.details: Optional[str] = kwargs.pop('details', None)
|
||||
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {})
|
||||
self.assets: ActivityAssets = kwargs.pop('assets', {})
|
||||
self.party: ActivityParty = kwargs.pop('party', {})
|
||||
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id')
|
||||
self.name: Optional[str] = kwargs.pop('name', None)
|
||||
self.url: Optional[str] = kwargs.pop('url', None)
|
||||
self.flags: int = kwargs.pop('flags', 0)
|
||||
self.sync_id: Optional[str] = kwargs.pop('sync_id', None)
|
||||
self.session_id: Optional[str] = kwargs.pop('session_id', None)
|
||||
self.buttons: List[ActivityButton] = kwargs.pop('buttons', [])
|
||||
self.state: Optional[str] = kwargs.pop("state", None)
|
||||
self.details: Optional[str] = kwargs.pop("details", None)
|
||||
self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {})
|
||||
self.assets: ActivityAssets = kwargs.pop("assets", {})
|
||||
self.party: ActivityParty = kwargs.pop("party", {})
|
||||
self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id")
|
||||
self.name: Optional[str] = kwargs.pop("name", None)
|
||||
self.url: Optional[str] = kwargs.pop("url", None)
|
||||
self.flags: int = kwargs.pop("flags", 0)
|
||||
self.sync_id: Optional[str] = kwargs.pop("sync_id", None)
|
||||
self.session_id: Optional[str] = kwargs.pop("session_id", None)
|
||||
self.buttons: List[ActivityButton] = kwargs.pop("buttons", [])
|
||||
|
||||
activity_type = kwargs.pop('type', -1)
|
||||
activity_type = kwargs.pop("type", -1)
|
||||
self.type: ActivityType = (
|
||||
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
|
||||
activity_type
|
||||
if isinstance(activity_type, ActivityType)
|
||||
else try_enum(ActivityType, activity_type)
|
||||
)
|
||||
|
||||
emoji = kwargs.pop('emoji', None)
|
||||
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
|
||||
emoji = kwargs.pop("emoji", None)
|
||||
self.emoji: Optional[PartialEmoji] = (
|
||||
PartialEmoji.from_dict(emoji) if emoji is not None else None
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = (
|
||||
('type', self.type),
|
||||
('name', self.name),
|
||||
('url', self.url),
|
||||
('details', self.details),
|
||||
('application_id', self.application_id),
|
||||
('session_id', self.session_id),
|
||||
('emoji', self.emoji),
|
||||
("type", self.type),
|
||||
("name", self.name),
|
||||
("url", self.url),
|
||||
("details", self.details),
|
||||
("application_id", self.application_id),
|
||||
("session_id", self.session_id),
|
||||
("emoji", self.emoji),
|
||||
)
|
||||
inner = ' '.join('%s=%r' % t for t in attrs)
|
||||
return f'<Activity {inner}>'
|
||||
inner = " ".join("%s=%r" % t for t in attrs)
|
||||
return f"<Activity {inner}>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
ret: Dict[str, Any] = {}
|
||||
@ -263,16 +269,16 @@ class Activity(BaseActivity):
|
||||
continue
|
||||
|
||||
ret[attr] = value
|
||||
ret['type'] = int(self.type)
|
||||
ret["type"] = int(self.type)
|
||||
if self.emoji:
|
||||
ret['emoji'] = self.emoji.to_dict()
|
||||
ret["emoji"] = self.emoji.to_dict()
|
||||
return ret
|
||||
|
||||
@property
|
||||
def start(self) -> Optional[datetime.datetime]:
|
||||
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
|
||||
try:
|
||||
timestamp = self.timestamps['start'] / 1000
|
||||
timestamp = self.timestamps["start"] / 1000
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
@ -282,7 +288,7 @@ class Activity(BaseActivity):
|
||||
def end(self) -> Optional[datetime.datetime]:
|
||||
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
|
||||
try:
|
||||
timestamp = self.timestamps['end'] / 1000
|
||||
timestamp = self.timestamps["end"] / 1000
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
@ -295,11 +301,11 @@ class Activity(BaseActivity):
|
||||
return None
|
||||
|
||||
try:
|
||||
large_image = self.assets['large_image']
|
||||
large_image = self.assets["large_image"]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png'
|
||||
return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png"
|
||||
|
||||
@property
|
||||
def small_image_url(self) -> Optional[str]:
|
||||
@ -308,21 +314,21 @@ class Activity(BaseActivity):
|
||||
return None
|
||||
|
||||
try:
|
||||
small_image = self.assets['small_image']
|
||||
small_image = self.assets["small_image"]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png'
|
||||
return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png"
|
||||
|
||||
@property
|
||||
def large_image_text(self) -> Optional[str]:
|
||||
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
|
||||
return self.assets.get('large_text', None)
|
||||
return self.assets.get("large_text", None)
|
||||
|
||||
@property
|
||||
def small_image_text(self) -> Optional[str]:
|
||||
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
|
||||
return self.assets.get('small_text', None)
|
||||
return self.assets.get("small_text", None)
|
||||
|
||||
|
||||
class Game(BaseActivity):
|
||||
@ -359,20 +365,20 @@ class Game(BaseActivity):
|
||||
The game's name.
|
||||
"""
|
||||
|
||||
__slots__ = ('name', '_end', '_start')
|
||||
__slots__ = ("name", "_end", "_start")
|
||||
|
||||
def __init__(self, name: str, **extra):
|
||||
super().__init__(**extra)
|
||||
self.name: str = name
|
||||
|
||||
try:
|
||||
timestamps: ActivityTimestamps = extra['timestamps']
|
||||
timestamps: ActivityTimestamps = extra["timestamps"]
|
||||
except KeyError:
|
||||
self._start = 0
|
||||
self._end = 0
|
||||
else:
|
||||
self._start = timestamps.get('start', 0)
|
||||
self._end = timestamps.get('end', 0)
|
||||
self._start = timestamps.get("start", 0)
|
||||
self._end = timestamps.get("end", 0)
|
||||
|
||||
@property
|
||||
def type(self) -> ActivityType:
|
||||
@ -386,29 +392,33 @@ class Game(BaseActivity):
|
||||
def start(self) -> Optional[datetime.datetime]:
|
||||
"""Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable."""
|
||||
if self._start:
|
||||
return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._start / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def end(self) -> Optional[datetime.datetime]:
|
||||
"""Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable."""
|
||||
if self._end:
|
||||
return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._end / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Game name={self.name!r}>'
|
||||
return f"<Game name={self.name!r}>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
timestamps: Dict[str, Any] = {}
|
||||
if self._start:
|
||||
timestamps['start'] = self._start
|
||||
timestamps["start"] = self._start
|
||||
|
||||
if self._end:
|
||||
timestamps['end'] = self._end
|
||||
timestamps["end"] = self._end
|
||||
|
||||
# fmt: off
|
||||
return {
|
||||
@ -473,16 +483,16 @@ class Streaming(BaseActivity):
|
||||
A dictionary comprising of similar keys than those in :attr:`Activity.assets`.
|
||||
"""
|
||||
|
||||
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
|
||||
__slots__ = ("platform", "name", "game", "url", "details", "assets")
|
||||
|
||||
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
|
||||
super().__init__(**extra)
|
||||
self.platform: Optional[str] = name
|
||||
self.name: Optional[str] = extra.pop('details', name)
|
||||
self.game: Optional[str] = extra.pop('state', None)
|
||||
self.name: Optional[str] = extra.pop("details", name)
|
||||
self.game: Optional[str] = extra.pop("state", None)
|
||||
self.url: str = url
|
||||
self.details: Optional[str] = extra.pop('details', self.name) # compatibility
|
||||
self.assets: ActivityAssets = extra.pop('assets', {})
|
||||
self.details: Optional[str] = extra.pop("details", self.name) # compatibility
|
||||
self.assets: ActivityAssets = extra.pop("assets", {})
|
||||
|
||||
@property
|
||||
def type(self) -> ActivityType:
|
||||
@ -496,7 +506,7 @@ class Streaming(BaseActivity):
|
||||
return str(self.name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Streaming name={self.name!r}>'
|
||||
return f"<Streaming name={self.name!r}>"
|
||||
|
||||
@property
|
||||
def twitch_name(self):
|
||||
@ -507,11 +517,11 @@ class Streaming(BaseActivity):
|
||||
"""
|
||||
|
||||
try:
|
||||
name = self.assets['large_image']
|
||||
name = self.assets["large_image"]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
return name[7:] if name[:7] == 'twitch:' else None
|
||||
return name[7:] if name[:7] == "twitch:" else None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# fmt: off
|
||||
@ -523,11 +533,15 @@ class Streaming(BaseActivity):
|
||||
}
|
||||
# fmt: on
|
||||
if self.details:
|
||||
ret['details'] = self.details
|
||||
ret["details"] = self.details
|
||||
return ret
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
|
||||
return (
|
||||
isinstance(other, Streaming)
|
||||
and other.name == self.name
|
||||
and other.url == self.url
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
@ -559,17 +573,26 @@ class Spotify:
|
||||
Returns the string 'Spotify'.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
|
||||
__slots__ = (
|
||||
"_state",
|
||||
"_details",
|
||||
"_timestamps",
|
||||
"_assets",
|
||||
"_party",
|
||||
"_sync_id",
|
||||
"_session_id",
|
||||
"_created_at",
|
||||
)
|
||||
|
||||
def __init__(self, **data):
|
||||
self._state: str = data.pop('state', '')
|
||||
self._details: str = data.pop('details', '')
|
||||
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
|
||||
self._assets: ActivityAssets = data.pop('assets', {})
|
||||
self._party: ActivityParty = data.pop('party', {})
|
||||
self._sync_id: str = data.pop('sync_id')
|
||||
self._session_id: str = data.pop('session_id')
|
||||
self._created_at: Optional[float] = data.pop('created_at', None)
|
||||
self._state: str = data.pop("state", "")
|
||||
self._details: str = data.pop("details", "")
|
||||
self._timestamps: Dict[str, int] = data.pop("timestamps", {})
|
||||
self._assets: ActivityAssets = data.pop("assets", {})
|
||||
self._party: ActivityParty = data.pop("party", {})
|
||||
self._sync_id: str = data.pop("sync_id")
|
||||
self._session_id: str = data.pop("session_id")
|
||||
self._created_at: Optional[float] = data.pop("created_at", None)
|
||||
|
||||
@property
|
||||
def type(self) -> ActivityType:
|
||||
@ -586,7 +609,9 @@ class Spotify:
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
if self._created_at is not None:
|
||||
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._created_at / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
|
||||
@property
|
||||
def colour(self) -> Colour:
|
||||
@ -604,21 +629,21 @@ class Spotify:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'flags': 48, # SYNC | PLAY
|
||||
'name': 'Spotify',
|
||||
'assets': self._assets,
|
||||
'party': self._party,
|
||||
'sync_id': self._sync_id,
|
||||
'session_id': self._session_id,
|
||||
'timestamps': self._timestamps,
|
||||
'details': self._details,
|
||||
'state': self._state,
|
||||
"flags": 48, # SYNC | PLAY
|
||||
"name": "Spotify",
|
||||
"assets": self._assets,
|
||||
"party": self._party,
|
||||
"sync_id": self._sync_id,
|
||||
"session_id": self._session_id,
|
||||
"timestamps": self._timestamps,
|
||||
"details": self._details,
|
||||
"state": self._state,
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
""":class:`str`: The activity's name. This will always return "Spotify"."""
|
||||
return 'Spotify'
|
||||
return "Spotify"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
@ -635,10 +660,10 @@ class Spotify:
|
||||
return hash(self._session_id)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return 'Spotify'
|
||||
return "Spotify"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>'
|
||||
return f"<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>"
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
@ -648,7 +673,7 @@ class Spotify:
|
||||
@property
|
||||
def artists(self) -> List[str]:
|
||||
"""List[:class:`str`]: The artists of the song being played."""
|
||||
return self._state.split('; ')
|
||||
return self._state.split("; ")
|
||||
|
||||
@property
|
||||
def artist(self) -> str:
|
||||
@ -662,16 +687,16 @@ class Spotify:
|
||||
@property
|
||||
def album(self) -> str:
|
||||
""":class:`str`: The album that the song being played belongs to."""
|
||||
return self._assets.get('large_text', '')
|
||||
return self._assets.get("large_text", "")
|
||||
|
||||
@property
|
||||
def album_cover_url(self) -> str:
|
||||
""":class:`str`: The album cover image URL from Spotify's CDN."""
|
||||
large_image = self._assets.get('large_image', '')
|
||||
if large_image[:8] != 'spotify:':
|
||||
return ''
|
||||
large_image = self._assets.get("large_image", "")
|
||||
if large_image[:8] != "spotify:":
|
||||
return ""
|
||||
album_image_id = large_image[8:]
|
||||
return 'https://i.scdn.co/image/' + album_image_id
|
||||
return "https://i.scdn.co/image/" + album_image_id
|
||||
|
||||
@property
|
||||
def track_id(self) -> str:
|
||||
@ -684,17 +709,21 @@ class Spotify:
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return f'https://open.spotify.com/track/{self.track_id}'
|
||||
return f"https://open.spotify.com/track/{self.track_id}"
|
||||
|
||||
@property
|
||||
def start(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._timestamps["start"] / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
|
||||
@property
|
||||
def end(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
|
||||
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
|
||||
return datetime.datetime.fromtimestamp(
|
||||
self._timestamps["end"] / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
|
||||
@property
|
||||
def duration(self) -> datetime.timedelta:
|
||||
@ -704,7 +733,7 @@ class Spotify:
|
||||
@property
|
||||
def party_id(self) -> str:
|
||||
""":class:`str`: The party ID of the listening party."""
|
||||
return self._party.get('id', '')
|
||||
return self._party.get("id", "")
|
||||
|
||||
|
||||
class CustomActivity(BaseActivity):
|
||||
@ -738,13 +767,15 @@ class CustomActivity(BaseActivity):
|
||||
The emoji to pass to the activity, if any.
|
||||
"""
|
||||
|
||||
__slots__ = ('name', 'emoji', 'state')
|
||||
__slots__ = ("name", "emoji", "state")
|
||||
|
||||
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
|
||||
def __init__(
|
||||
self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any
|
||||
):
|
||||
super().__init__(**extra)
|
||||
self.name: Optional[str] = name
|
||||
self.state: Optional[str] = extra.pop('state', None)
|
||||
if self.name == 'Custom Status':
|
||||
self.state: Optional[str] = extra.pop("state", None)
|
||||
if self.name == "Custom Status":
|
||||
self.name = self.state
|
||||
|
||||
self.emoji: Optional[PartialEmoji]
|
||||
@ -757,7 +788,9 @@ class CustomActivity(BaseActivity):
|
||||
elif isinstance(emoji, PartialEmoji):
|
||||
self.emoji = emoji
|
||||
else:
|
||||
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.')
|
||||
raise TypeError(
|
||||
f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> ActivityType:
|
||||
@ -770,22 +803,26 @@ class CustomActivity(BaseActivity):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
if self.name == self.state:
|
||||
o = {
|
||||
'type': ActivityType.custom.value,
|
||||
'state': self.name,
|
||||
'name': 'Custom Status',
|
||||
"type": ActivityType.custom.value,
|
||||
"state": self.name,
|
||||
"name": "Custom Status",
|
||||
}
|
||||
else:
|
||||
o = {
|
||||
'type': ActivityType.custom.value,
|
||||
'name': self.name,
|
||||
"type": ActivityType.custom.value,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
if self.emoji:
|
||||
o['emoji'] = self.emoji.to_dict()
|
||||
o["emoji"] = self.emoji.to_dict()
|
||||
return o
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
|
||||
return (
|
||||
isinstance(other, CustomActivity)
|
||||
and other.name == self.name
|
||||
and other.emoji == self.emoji
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
@ -796,47 +833,54 @@ class CustomActivity(BaseActivity):
|
||||
def __str__(self) -> str:
|
||||
if self.emoji:
|
||||
if self.name:
|
||||
return f'{self.emoji} {self.name}'
|
||||
return f"{self.emoji} {self.name}"
|
||||
return str(self.emoji)
|
||||
else:
|
||||
return str(self.name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
|
||||
return f"<CustomActivity name={self.name!r} emoji={self.emoji!r}>"
|
||||
|
||||
|
||||
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
|
||||
|
||||
|
||||
@overload
|
||||
def create_activity(data: ActivityPayload) -> ActivityTypes:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def create_activity(data: None) -> None:
|
||||
...
|
||||
|
||||
|
||||
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
game_type = try_enum(ActivityType, data.get('type', -1))
|
||||
game_type = try_enum(ActivityType, data.get("type", -1))
|
||||
if game_type is ActivityType.playing:
|
||||
if 'application_id' in data or 'session_id' in data:
|
||||
if "application_id" in data or "session_id" in data:
|
||||
return Activity(**data)
|
||||
return Game(**data)
|
||||
elif game_type is ActivityType.custom:
|
||||
try:
|
||||
name = data.pop('name')
|
||||
name = data.pop("name")
|
||||
except KeyError:
|
||||
return Activity(**data)
|
||||
else:
|
||||
# we removed the name key from data already
|
||||
return CustomActivity(name=name, **data) # type: ignore
|
||||
return CustomActivity(name=name, **data) # type: ignore
|
||||
elif game_type is ActivityType.streaming:
|
||||
if 'url' in data:
|
||||
if "url" in data:
|
||||
# the url won't be None here
|
||||
return Streaming(**data) # type: ignore
|
||||
return Streaming(**data) # type: ignore
|
||||
return Activity(**data)
|
||||
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
|
||||
elif (
|
||||
game_type is ActivityType.listening
|
||||
and "sync_id" in data
|
||||
and "session_id" in data
|
||||
):
|
||||
return Spotify(**data)
|
||||
return Activity(**data)
|
||||
|
@ -40,8 +40,8 @@ if TYPE_CHECKING:
|
||||
from .state import ConnectionState
|
||||
|
||||
__all__ = (
|
||||
'AppInfo',
|
||||
'PartialAppInfo',
|
||||
"AppInfo",
|
||||
"PartialAppInfo",
|
||||
)
|
||||
|
||||
|
||||
@ -115,58 +115,60 @@ class AppInfo:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'description',
|
||||
'id',
|
||||
'name',
|
||||
'rpc_origins',
|
||||
'bot_public',
|
||||
'bot_require_code_grant',
|
||||
'owner',
|
||||
'_icon',
|
||||
'summary',
|
||||
'verify_key',
|
||||
'team',
|
||||
'guild_id',
|
||||
'primary_sku_id',
|
||||
'slug',
|
||||
'_cover_image',
|
||||
'terms_of_service_url',
|
||||
'privacy_policy_url',
|
||||
"_state",
|
||||
"description",
|
||||
"id",
|
||||
"name",
|
||||
"rpc_origins",
|
||||
"bot_public",
|
||||
"bot_require_code_grant",
|
||||
"owner",
|
||||
"_icon",
|
||||
"summary",
|
||||
"verify_key",
|
||||
"team",
|
||||
"guild_id",
|
||||
"primary_sku_id",
|
||||
"slug",
|
||||
"_cover_image",
|
||||
"terms_of_service_url",
|
||||
"privacy_policy_url",
|
||||
)
|
||||
|
||||
def __init__(self, state: ConnectionState, data: AppInfoPayload):
|
||||
from .team import Team
|
||||
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self._icon: Optional[str] = data['icon']
|
||||
self.rpc_origins: List[str] = data['rpc_origins']
|
||||
self.bot_public: bool = data['bot_public']
|
||||
self.bot_require_code_grant: bool = data['bot_require_code_grant']
|
||||
self.owner: User = state.create_user(data['owner'])
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self.description: str = data["description"]
|
||||
self._icon: Optional[str] = data["icon"]
|
||||
self.rpc_origins: List[str] = data["rpc_origins"]
|
||||
self.bot_public: bool = data["bot_public"]
|
||||
self.bot_require_code_grant: bool = data["bot_require_code_grant"]
|
||||
self.owner: User = state.create_user(data["owner"])
|
||||
|
||||
team: Optional[TeamPayload] = data.get('team')
|
||||
team: Optional[TeamPayload] = data.get("team")
|
||||
self.team: Optional[Team] = Team(state, team) if team else None
|
||||
|
||||
self.summary: str = data['summary']
|
||||
self.verify_key: str = data['verify_key']
|
||||
self.summary: str = data["summary"]
|
||||
self.verify_key: str = data["verify_key"]
|
||||
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
|
||||
|
||||
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id')
|
||||
self.slug: Optional[str] = data.get('slug')
|
||||
self._cover_image: Optional[str] = data.get('cover_image')
|
||||
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
|
||||
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
|
||||
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(
|
||||
data, "primary_sku_id"
|
||||
)
|
||||
self.slug: Optional[str] = data.get("slug")
|
||||
self._cover_image: Optional[str] = data.get("cover_image")
|
||||
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
|
||||
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
|
||||
f'description={self.description!r} public={self.bot_public} '
|
||||
f'owner={self.owner!r}>'
|
||||
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
|
||||
f"description={self.description!r} public={self.bot_public} "
|
||||
f"owner={self.owner!r}>"
|
||||
)
|
||||
|
||||
@property
|
||||
@ -174,7 +176,7 @@ class AppInfo:
|
||||
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
|
||||
if self._icon is None:
|
||||
return None
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path='app')
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path="app")
|
||||
|
||||
@property
|
||||
def cover_image(self) -> Optional[Asset]:
|
||||
@ -195,6 +197,7 @@ class AppInfo:
|
||||
"""
|
||||
return self._state._get_guild(self.guild_id)
|
||||
|
||||
|
||||
class PartialAppInfo:
|
||||
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
|
||||
|
||||
@ -222,26 +225,37 @@ class PartialAppInfo:
|
||||
The application's privacy policy URL, if set.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon')
|
||||
__slots__ = (
|
||||
"_state",
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"rpc_origins",
|
||||
"summary",
|
||||
"verify_key",
|
||||
"terms_of_service_url",
|
||||
"privacy_policy_url",
|
||||
"_icon",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self._icon: Optional[str] = data.get('icon')
|
||||
self.description: str = data['description']
|
||||
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
|
||||
self.summary: str = data['summary']
|
||||
self.verify_key: str = data['verify_key']
|
||||
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
|
||||
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self._icon: Optional[str] = data.get("icon")
|
||||
self.description: str = data["description"]
|
||||
self.rpc_origins: Optional[List[str]] = data.get("rpc_origins")
|
||||
self.summary: str = data["summary"]
|
||||
self.verify_key: str = data["verify_key"]
|
||||
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
|
||||
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
|
||||
return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
|
||||
if self._icon is None:
|
||||
return None
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path='app')
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path="app")
|
||||
|
100
discord/asset.py
100
discord/asset.py
@ -33,13 +33,11 @@ from . import utils
|
||||
|
||||
import yarl
|
||||
|
||||
__all__ = (
|
||||
'Asset',
|
||||
)
|
||||
__all__ = ("Asset",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
|
||||
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
|
||||
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
|
||||
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
|
||||
|
||||
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
|
||||
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
|
||||
@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
|
||||
|
||||
MISSING = utils.MISSING
|
||||
|
||||
|
||||
class AssetMixin:
|
||||
url: str
|
||||
_state: Optional[Any]
|
||||
@ -71,11 +70,16 @@ class AssetMixin:
|
||||
The content of the asset.
|
||||
"""
|
||||
if self._state is None:
|
||||
raise DiscordException('Invalid state (no ConnectionState provided)')
|
||||
raise DiscordException("Invalid state (no ConnectionState provided)")
|
||||
|
||||
return await self._state.http.get_from_cdn(self.url)
|
||||
|
||||
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
|
||||
async def save(
|
||||
self,
|
||||
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
|
||||
*,
|
||||
seek_begin: bool = True,
|
||||
) -> int:
|
||||
"""|coro|
|
||||
|
||||
Saves this asset into a file-like object.
|
||||
@ -112,7 +116,7 @@ class AssetMixin:
|
||||
fp.seek(0)
|
||||
return written
|
||||
else:
|
||||
with open(fp, 'wb') as f:
|
||||
with open(fp, "wb") as f:
|
||||
return f.write(data)
|
||||
|
||||
|
||||
@ -143,13 +147,13 @@ class Asset(AssetMixin):
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'_state',
|
||||
'_url',
|
||||
'_animated',
|
||||
'_key',
|
||||
"_state",
|
||||
"_url",
|
||||
"_animated",
|
||||
"_key",
|
||||
)
|
||||
|
||||
BASE = 'https://cdn.discordapp.com'
|
||||
BASE = "https://cdn.discordapp.com"
|
||||
|
||||
def __init__(self, state, *, url: str, key: str, animated: bool = False):
|
||||
self._state = state
|
||||
@ -161,26 +165,28 @@ class Asset(AssetMixin):
|
||||
def _from_default_avatar(cls, state, index: int) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/embed/avatars/{index}.png',
|
||||
url=f"{cls.BASE}/embed/avatars/{index}.png",
|
||||
key=str(index),
|
||||
animated=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
animated = avatar.startswith("a_")
|
||||
format = "gif" if animated else "png"
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024',
|
||||
url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024",
|
||||
key=avatar,
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
|
||||
animated = avatar.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
def _from_guild_avatar(
|
||||
cls, state, guild_id: int, member_id: int, avatar: str
|
||||
) -> Asset:
|
||||
animated = avatar.startswith("a_")
|
||||
format = "gif" if animated else "png"
|
||||
return cls(
|
||||
state,
|
||||
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
|
||||
@ -192,7 +198,7 @@ class Asset(AssetMixin):
|
||||
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
|
||||
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
|
||||
key=icon_hash,
|
||||
animated=False,
|
||||
)
|
||||
@ -201,7 +207,7 @@ class Asset(AssetMixin):
|
||||
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
|
||||
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
|
||||
key=cover_image_hash,
|
||||
animated=False,
|
||||
)
|
||||
@ -210,18 +216,18 @@ class Asset(AssetMixin):
|
||||
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024',
|
||||
url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
|
||||
key=image,
|
||||
animated=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
|
||||
animated = icon_hash.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
animated = icon_hash.startswith("a_")
|
||||
format = "gif" if animated else "png"
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024',
|
||||
url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024",
|
||||
key=icon_hash,
|
||||
animated=animated,
|
||||
)
|
||||
@ -230,20 +236,20 @@ class Asset(AssetMixin):
|
||||
def _from_sticker_banner(cls, state, banner: int) -> Asset:
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
|
||||
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
|
||||
key=str(banner),
|
||||
animated=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
|
||||
animated = banner_hash.startswith('a_')
|
||||
format = 'gif' if animated else 'png'
|
||||
animated = banner_hash.startswith("a_")
|
||||
format = "gif" if animated else "png"
|
||||
return cls(
|
||||
state,
|
||||
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
|
||||
url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512",
|
||||
key=banner_hash,
|
||||
animated=animated
|
||||
animated=animated,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -253,8 +259,8 @@ class Asset(AssetMixin):
|
||||
return len(self._url)
|
||||
|
||||
def __repr__(self):
|
||||
shorten = self._url.replace(self.BASE, '')
|
||||
return f'<Asset url={shorten!r}>'
|
||||
shorten = self._url.replace(self.BASE, "")
|
||||
return f"<Asset url={shorten!r}>"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Asset) and self._url == other._url
|
||||
@ -312,21 +318,27 @@ class Asset(AssetMixin):
|
||||
if format is not MISSING:
|
||||
if self._animated:
|
||||
if format not in VALID_ASSET_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
raise InvalidArgument(
|
||||
f"format must be one of {VALID_ASSET_FORMATS}"
|
||||
)
|
||||
url = url.with_path(f"{path}.{format}")
|
||||
elif static_format is MISSING:
|
||||
if format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
|
||||
url = url.with_path(f'{path}.{format}')
|
||||
raise InvalidArgument(
|
||||
f"format must be one of {VALID_STATIC_FORMATS}"
|
||||
)
|
||||
url = url.with_path(f"{path}.{format}")
|
||||
|
||||
if static_format is not MISSING and not self._animated:
|
||||
if static_format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}')
|
||||
url = url.with_path(f'{path}.{static_format}')
|
||||
raise InvalidArgument(
|
||||
f"static_format must be one of {VALID_STATIC_FORMATS}"
|
||||
)
|
||||
url = url.with_path(f"{path}.{static_format}")
|
||||
|
||||
if size is not MISSING:
|
||||
if not utils.valid_icon_size(size):
|
||||
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
|
||||
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
|
||||
url = url.with_query(size=size)
|
||||
else:
|
||||
url = url.with_query(url.raw_query_string)
|
||||
@ -353,7 +365,7 @@ class Asset(AssetMixin):
|
||||
The new updated asset.
|
||||
"""
|
||||
if not utils.valid_icon_size(size):
|
||||
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
|
||||
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
|
||||
|
||||
url = str(yarl.URL(self._url).with_query(size=size))
|
||||
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
|
||||
@ -379,14 +391,14 @@ class Asset(AssetMixin):
|
||||
|
||||
if self._animated:
|
||||
if format not in VALID_ASSET_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
|
||||
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
|
||||
else:
|
||||
if format not in VALID_STATIC_FORMATS:
|
||||
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
|
||||
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
|
||||
|
||||
url = yarl.URL(self._url)
|
||||
path, _ = os.path.splitext(url.path)
|
||||
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
|
||||
url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string))
|
||||
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
|
||||
|
||||
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
|
||||
|
@ -24,7 +24,20 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from . import enums, utils
|
||||
from .asset import Asset
|
||||
@ -35,9 +48,9 @@ from .object import Object
|
||||
from .permissions import PermissionOverwrite, Permissions
|
||||
|
||||
__all__ = (
|
||||
'AuditLogDiff',
|
||||
'AuditLogChanges',
|
||||
'AuditLogEntry',
|
||||
"AuditLogDiff",
|
||||
"AuditLogChanges",
|
||||
"AuditLogEntry",
|
||||
)
|
||||
|
||||
|
||||
@ -74,18 +87,25 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
|
||||
return int(data)
|
||||
|
||||
|
||||
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
|
||||
def _transform_channel(
|
||||
entry: AuditLogEntry, data: Optional[Snowflake]
|
||||
) -> Optional[Union[abc.GuildChannel, Object]]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry.guild.get_channel(int(data)) or Object(id=data)
|
||||
|
||||
|
||||
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
|
||||
def _transform_member_id(
|
||||
entry: AuditLogEntry, data: Optional[Snowflake]
|
||||
) -> Union[Member, User, None]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry._get_member(int(data))
|
||||
|
||||
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
|
||||
|
||||
def _transform_guild_id(
|
||||
entry: AuditLogEntry, data: Optional[Snowflake]
|
||||
) -> Optional[Guild]:
|
||||
if data is None:
|
||||
return None
|
||||
return entry._state._get_guild(data)
|
||||
@ -96,16 +116,16 @@ def _transform_overwrites(
|
||||
) -> List[Tuple[Object, PermissionOverwrite]]:
|
||||
overwrites = []
|
||||
for elem in data:
|
||||
allow = Permissions(int(elem['allow']))
|
||||
deny = Permissions(int(elem['deny']))
|
||||
allow = Permissions(int(elem["allow"]))
|
||||
deny = Permissions(int(elem["deny"]))
|
||||
ow = PermissionOverwrite.from_pair(allow, deny)
|
||||
|
||||
ow_type = elem['type']
|
||||
ow_id = int(elem['id'])
|
||||
ow_type = elem["type"]
|
||||
ow_id = int(elem["id"])
|
||||
target = None
|
||||
if ow_type == '0':
|
||||
if ow_type == "0":
|
||||
target = entry.guild.get_role(ow_id)
|
||||
elif ow_type == '1':
|
||||
elif ow_type == "1":
|
||||
target = entry._get_member(ow_id)
|
||||
|
||||
if target is None:
|
||||
@ -128,7 +148,9 @@ def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Ass
|
||||
return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore
|
||||
|
||||
|
||||
def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]:
|
||||
def _guild_hash_transformer(
|
||||
path: str,
|
||||
) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]:
|
||||
def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]:
|
||||
if data is None:
|
||||
return None
|
||||
@ -137,7 +159,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
|
||||
return _transform
|
||||
|
||||
|
||||
T = TypeVar('T', bound=enums.Enum)
|
||||
T = TypeVar("T", bound=enums.Enum)
|
||||
|
||||
|
||||
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
|
||||
@ -146,12 +168,16 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
|
||||
|
||||
return _transform
|
||||
|
||||
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
|
||||
if entry.action.name.startswith('sticker_'):
|
||||
|
||||
def _transform_type(
|
||||
entry: AuditLogEntry, data: Union[int]
|
||||
) -> Union[enums.ChannelType, enums.StickerType]:
|
||||
if entry.action.name.startswith("sticker_"):
|
||||
return enums.try_enum(enums.StickerType, data)
|
||||
else:
|
||||
return enums.try_enum(enums.ChannelType, data)
|
||||
|
||||
|
||||
class AuditLogDiff:
|
||||
def __len__(self) -> int:
|
||||
return len(self.__dict__)
|
||||
@ -160,8 +186,8 @@ class AuditLogDiff:
|
||||
yield from self.__dict__.items()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
values = ' '.join('%s=%r' % item for item in self.__dict__.items())
|
||||
return f'<AuditLogDiff {values}>'
|
||||
values = " ".join("%s=%r" % item for item in self.__dict__.items())
|
||||
return f"<AuditLogDiff {values}>"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -217,14 +243,14 @@ class AuditLogChanges:
|
||||
self.after = AuditLogDiff()
|
||||
|
||||
for elem in data:
|
||||
attr = elem['key']
|
||||
attr = elem["key"]
|
||||
|
||||
# special cases for role add/remove
|
||||
if attr == '$add':
|
||||
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore
|
||||
if attr == "$add":
|
||||
self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore
|
||||
continue
|
||||
elif attr == '$remove':
|
||||
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore
|
||||
elif attr == "$remove":
|
||||
self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore
|
||||
continue
|
||||
|
||||
try:
|
||||
@ -238,7 +264,7 @@ class AuditLogChanges:
|
||||
transformer: Optional[Transformer]
|
||||
|
||||
try:
|
||||
before = elem['old_value']
|
||||
before = elem["old_value"]
|
||||
except KeyError:
|
||||
before = None
|
||||
else:
|
||||
@ -248,7 +274,7 @@ class AuditLogChanges:
|
||||
setattr(self.before, attr, before)
|
||||
|
||||
try:
|
||||
after = elem['new_value']
|
||||
after = elem["new_value"]
|
||||
except KeyError:
|
||||
after = None
|
||||
else:
|
||||
@ -258,34 +284,40 @@ class AuditLogChanges:
|
||||
setattr(self.after, attr, after)
|
||||
|
||||
# add an alias
|
||||
if hasattr(self.after, 'colour'):
|
||||
if hasattr(self.after, "colour"):
|
||||
self.after.color = self.after.colour
|
||||
self.before.color = self.before.colour
|
||||
if hasattr(self.after, 'expire_behavior'):
|
||||
if hasattr(self.after, "expire_behavior"):
|
||||
self.after.expire_behaviour = self.after.expire_behavior
|
||||
self.before.expire_behaviour = self.before.expire_behavior
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>'
|
||||
return f"<AuditLogChanges before={self.before!r} after={self.after!r}>"
|
||||
|
||||
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None:
|
||||
if not hasattr(first, 'roles'):
|
||||
setattr(first, 'roles', [])
|
||||
def _handle_role(
|
||||
self,
|
||||
first: AuditLogDiff,
|
||||
second: AuditLogDiff,
|
||||
entry: AuditLogEntry,
|
||||
elem: List[RolePayload],
|
||||
) -> None:
|
||||
if not hasattr(first, "roles"):
|
||||
setattr(first, "roles", [])
|
||||
|
||||
data = []
|
||||
g: Guild = entry.guild # type: ignore
|
||||
|
||||
for e in elem:
|
||||
role_id = int(e['id'])
|
||||
role_id = int(e["id"])
|
||||
role = g.get_role(role_id)
|
||||
|
||||
if role is None:
|
||||
role = Object(id=role_id)
|
||||
role.name = e['name'] # type: ignore
|
||||
role.name = e["name"] # type: ignore
|
||||
|
||||
data.append(role)
|
||||
|
||||
setattr(second, 'roles', data)
|
||||
setattr(second, "roles", data)
|
||||
|
||||
|
||||
class _AuditLogProxyMemberPrune:
|
||||
@ -358,63 +390,81 @@ class AuditLogEntry(Hashable):
|
||||
which actions have this field filled out.
|
||||
"""
|
||||
|
||||
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
|
||||
def __init__(
|
||||
self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild
|
||||
):
|
||||
self._state = guild._state
|
||||
self.guild = guild
|
||||
self._users = users
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: AuditLogEntryPayload) -> None:
|
||||
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
|
||||
self.id = int(data['id'])
|
||||
self.action = enums.try_enum(enums.AuditLogAction, data["action_type"])
|
||||
self.id = int(data["id"])
|
||||
|
||||
# this key is technically not usually present
|
||||
self.reason = data.get('reason')
|
||||
self.extra = data.get('options')
|
||||
self.reason = data.get("reason")
|
||||
self.extra = data.get("options")
|
||||
|
||||
if isinstance(self.action, enums.AuditLogAction) and self.extra:
|
||||
if self.action is enums.AuditLogAction.member_prune:
|
||||
# member prune has two keys with useful information
|
||||
self.extra: _AuditLogProxyMemberPrune = type(
|
||||
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()}
|
||||
"_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()}
|
||||
)()
|
||||
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
|
||||
channel_id = int(self.extra['channel_id'])
|
||||
elif (
|
||||
self.action is enums.AuditLogAction.member_move
|
||||
or self.action is enums.AuditLogAction.message_delete
|
||||
):
|
||||
channel_id = int(self.extra["channel_id"])
|
||||
elems = {
|
||||
'count': int(self.extra['count']),
|
||||
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
|
||||
"count": int(self.extra["count"]),
|
||||
"channel": self.guild.get_channel(channel_id)
|
||||
or Object(id=channel_id),
|
||||
}
|
||||
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)()
|
||||
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type(
|
||||
"_AuditLogProxy", (), elems
|
||||
)()
|
||||
elif self.action is enums.AuditLogAction.member_disconnect:
|
||||
# The member disconnect action has a dict with some information
|
||||
elems = {
|
||||
'count': int(self.extra['count']),
|
||||
"count": int(self.extra["count"]),
|
||||
}
|
||||
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)()
|
||||
elif self.action.name.endswith('pin'):
|
||||
self.extra: _AuditLogProxyMemberDisconnect = type(
|
||||
"_AuditLogProxy", (), elems
|
||||
)()
|
||||
elif self.action.name.endswith("pin"):
|
||||
# the pin actions have a dict with some information
|
||||
channel_id = int(self.extra['channel_id'])
|
||||
channel_id = int(self.extra["channel_id"])
|
||||
elems = {
|
||||
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
|
||||
'message_id': int(self.extra['message_id']),
|
||||
"channel": self.guild.get_channel(channel_id)
|
||||
or Object(id=channel_id),
|
||||
"message_id": int(self.extra["message_id"]),
|
||||
}
|
||||
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)()
|
||||
elif self.action.name.startswith('overwrite_'):
|
||||
self.extra: _AuditLogProxyPinAction = type(
|
||||
"_AuditLogProxy", (), elems
|
||||
)()
|
||||
elif self.action.name.startswith("overwrite_"):
|
||||
# the overwrite_ actions have a dict with some information
|
||||
instance_id = int(self.extra['id'])
|
||||
the_type = self.extra.get('type')
|
||||
if the_type == '1':
|
||||
instance_id = int(self.extra["id"])
|
||||
the_type = self.extra.get("type")
|
||||
if the_type == "1":
|
||||
self.extra = self._get_member(instance_id)
|
||||
elif the_type == '0':
|
||||
elif the_type == "0":
|
||||
role = self.guild.get_role(instance_id)
|
||||
if role is None:
|
||||
role = Object(id=instance_id)
|
||||
role.name = self.extra.get('role_name') # type: ignore
|
||||
role.name = self.extra.get("role_name") # type: ignore
|
||||
self.extra: Role = role
|
||||
elif self.action.name.startswith('stage_instance'):
|
||||
channel_id = int(self.extra['channel_id'])
|
||||
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)}
|
||||
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)()
|
||||
elif self.action.name.startswith("stage_instance"):
|
||||
channel_id = int(self.extra["channel_id"])
|
||||
elems = {
|
||||
"channel": self.guild.get_channel(channel_id)
|
||||
or Object(id=channel_id)
|
||||
}
|
||||
self.extra: _AuditLogProxyStageInstanceAction = type(
|
||||
"_AuditLogProxy", (), elems
|
||||
)()
|
||||
|
||||
# fmt: off
|
||||
self.extra: Union[
|
||||
@ -433,16 +483,16 @@ class AuditLogEntry(Hashable):
|
||||
# where new_value and old_value are not guaranteed to be there depending
|
||||
# on the action type, so let's just fetch it for now and only turn it
|
||||
# into meaningful data when requested
|
||||
self._changes = data.get('changes', [])
|
||||
self._changes = data.get("changes", [])
|
||||
|
||||
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore
|
||||
self._target_id = utils._get_as_snowflake(data, 'target_id')
|
||||
self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
|
||||
self._target_id = utils._get_as_snowflake(data, "target_id")
|
||||
|
||||
def _get_member(self, user_id: int) -> Union[Member, User, None]:
|
||||
return self.guild.get_member(user_id) or self._users.get(user_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>'
|
||||
return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
|
||||
|
||||
@utils.cached_property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
@ -450,9 +500,24 @@ class AuditLogEntry(Hashable):
|
||||
return utils.snowflake_time(self.id)
|
||||
|
||||
@utils.cached_property
|
||||
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]:
|
||||
def target(
|
||||
self,
|
||||
) -> Union[
|
||||
Guild,
|
||||
abc.GuildChannel,
|
||||
Member,
|
||||
User,
|
||||
Role,
|
||||
Invite,
|
||||
Emoji,
|
||||
StageInstance,
|
||||
GuildSticker,
|
||||
Thread,
|
||||
Object,
|
||||
None,
|
||||
]:
|
||||
try:
|
||||
converter = getattr(self, '_convert_target_' + self.action.target_type)
|
||||
converter = getattr(self, "_convert_target_" + self.action.target_type)
|
||||
except AttributeError:
|
||||
return Object(id=self._target_id)
|
||||
else:
|
||||
@ -483,7 +548,9 @@ class AuditLogEntry(Hashable):
|
||||
def _convert_target_guild(self, target_id: int) -> Guild:
|
||||
return self.guild
|
||||
|
||||
def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]:
|
||||
def _convert_target_channel(
|
||||
self, target_id: int
|
||||
) -> Union[abc.GuildChannel, Object]:
|
||||
return self.guild.get_channel(target_id) or Object(id=target_id)
|
||||
|
||||
def _convert_target_user(self, target_id: int) -> Union[Member, User, None]:
|
||||
@ -495,14 +562,18 @@ class AuditLogEntry(Hashable):
|
||||
def _convert_target_invite(self, target_id: int) -> Invite:
|
||||
# invites have target_id set to null
|
||||
# so figure out which change has the full invite data
|
||||
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
|
||||
changeset = (
|
||||
self.before
|
||||
if self.action is enums.AuditLogAction.invite_delete
|
||||
else self.after
|
||||
)
|
||||
|
||||
fake_payload = {
|
||||
'max_age': changeset.max_age,
|
||||
'max_uses': changeset.max_uses,
|
||||
'code': changeset.code,
|
||||
'temporary': changeset.temporary,
|
||||
'uses': changeset.uses,
|
||||
"max_age": changeset.max_age,
|
||||
"max_uses": changeset.max_uses,
|
||||
"code": changeset.code,
|
||||
"temporary": changeset.temporary,
|
||||
"uses": changeset.uses,
|
||||
}
|
||||
|
||||
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore
|
||||
@ -518,7 +589,9 @@ class AuditLogEntry(Hashable):
|
||||
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
|
||||
return self._get_member(target_id)
|
||||
|
||||
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
|
||||
def _convert_target_stage_instance(
|
||||
self, target_id: int
|
||||
) -> Union[StageInstance, Object]:
|
||||
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
|
||||
|
||||
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
|
||||
|
@ -29,11 +29,10 @@ import time
|
||||
import random
|
||||
from typing import Callable, Generic, Literal, TypeVar, overload, Union
|
||||
|
||||
T = TypeVar('T', bool, Literal[True], Literal[False])
|
||||
T = TypeVar("T", bool, Literal[True], Literal[False])
|
||||
|
||||
__all__ = ("ExponentialBackoff",)
|
||||
|
||||
__all__ = (
|
||||
'ExponentialBackoff',
|
||||
)
|
||||
|
||||
class ExponentialBackoff(Generic[T]):
|
||||
"""An implementation of the exponential backoff algorithm
|
||||
@ -69,7 +68,7 @@ class ExponentialBackoff(Generic[T]):
|
||||
rand = random.Random()
|
||||
rand.seed()
|
||||
|
||||
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
|
||||
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
|
||||
|
||||
@overload
|
||||
def delay(self: ExponentialBackoff[Literal[False]]) -> float:
|
||||
|
@ -45,7 +45,13 @@ import datetime
|
||||
|
||||
import discord.abc
|
||||
from .permissions import PermissionOverwrite, Permissions
|
||||
from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode
|
||||
from .enums import (
|
||||
ChannelType,
|
||||
StagePrivacyLevel,
|
||||
try_enum,
|
||||
VoiceRegion,
|
||||
VideoQualityMode,
|
||||
)
|
||||
from .mixins import Hashable
|
||||
from .object import Object
|
||||
from . import utils
|
||||
@ -57,14 +63,14 @@ from .threads import Thread
|
||||
from .iterators import ArchivedThreadIterator
|
||||
|
||||
__all__ = (
|
||||
'TextChannel',
|
||||
'VoiceChannel',
|
||||
'StageChannel',
|
||||
'DMChannel',
|
||||
'CategoryChannel',
|
||||
'StoreChannel',
|
||||
'GroupChannel',
|
||||
'PartialMessageable',
|
||||
"TextChannel",
|
||||
"VoiceChannel",
|
||||
"StageChannel",
|
||||
"DMChannel",
|
||||
"CategoryChannel",
|
||||
"StoreChannel",
|
||||
"GroupChannel",
|
||||
"PartialMessageable",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -155,51 +161,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'guild',
|
||||
'topic',
|
||||
'_state',
|
||||
'nsfw',
|
||||
'category_id',
|
||||
'position',
|
||||
'slowmode_delay',
|
||||
'_overwrites',
|
||||
'_type',
|
||||
'last_message_id',
|
||||
'default_auto_archive_duration',
|
||||
"name",
|
||||
"id",
|
||||
"guild",
|
||||
"topic",
|
||||
"_state",
|
||||
"nsfw",
|
||||
"category_id",
|
||||
"position",
|
||||
"slowmode_delay",
|
||||
"_overwrites",
|
||||
"_type",
|
||||
"last_message_id",
|
||||
"default_auto_archive_duration",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
|
||||
def __init__(
|
||||
self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self._type: int = data['type']
|
||||
self.id: int = int(data["id"])
|
||||
self._type: int = data["type"]
|
||||
self._update(guild, data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = [
|
||||
('id', self.id),
|
||||
('name', self.name),
|
||||
('position', self.position),
|
||||
('nsfw', self.nsfw),
|
||||
('news', self.is_news()),
|
||||
('category_id', self.category_id),
|
||||
("id", self.id),
|
||||
("name", self.name),
|
||||
("position", self.position),
|
||||
("nsfw", self.nsfw),
|
||||
("news", self.is_news()),
|
||||
("category_id", self.category_id),
|
||||
]
|
||||
joined = ' '.join('%s=%r' % t for t in attrs)
|
||||
return f'<{self.__class__.__name__} {joined}>'
|
||||
joined = " ".join("%s=%r" % t for t in attrs)
|
||||
return f"<{self.__class__.__name__} {joined}>"
|
||||
|
||||
def _update(self, guild: Guild, data: TextChannelPayload) -> None:
|
||||
self.guild: Guild = guild
|
||||
self.name: str = data['name']
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
|
||||
self.topic: Optional[str] = data.get('topic')
|
||||
self.position: int = data['position']
|
||||
self.nsfw: bool = data.get('nsfw', False)
|
||||
self.name: str = data["name"]
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
|
||||
self.topic: Optional[str] = data.get("topic")
|
||||
self.position: int = data["position"]
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
# Does this need coercion into `int`? No idea yet.
|
||||
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
|
||||
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
|
||||
self._type: int = data.get('type', self._type)
|
||||
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
|
||||
self.slowmode_delay: int = data.get("rate_limit_per_user", 0)
|
||||
self.default_auto_archive_duration: ThreadArchiveDuration = data.get(
|
||||
"default_auto_archive_duration", 1440
|
||||
)
|
||||
self._type: int = data.get("type", self._type)
|
||||
self.last_message_id: Optional[int] = utils._get_as_snowflake(
|
||||
data, "last_message_id"
|
||||
)
|
||||
self._fill_overwrites(data)
|
||||
|
||||
async def _get_channel(self):
|
||||
@ -234,7 +246,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id]
|
||||
return [
|
||||
thread
|
||||
for thread in self.guild._threads.values()
|
||||
if thread.parent_id == self.id
|
||||
]
|
||||
|
||||
def is_nsfw(self) -> bool:
|
||||
""":class:`bool`: Checks if the channel is NSFW."""
|
||||
@ -263,7 +279,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
Optional[:class:`Message`]
|
||||
The last message in this channel or ``None`` if not found.
|
||||
"""
|
||||
return self._state._get_message(self.last_message_id) if self.last_message_id else None
|
||||
return (
|
||||
self._state._get_message(self.last_message_id)
|
||||
if self.last_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@ -359,9 +379,17 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
|
||||
async def clone(
|
||||
self, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> TextChannel:
|
||||
return await self._clone_impl(
|
||||
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
|
||||
{
|
||||
"topic": self.topic,
|
||||
"nsfw": self.nsfw,
|
||||
"rate_limit_per_user": self.slowmode_delay,
|
||||
},
|
||||
name=name,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
|
||||
@ -408,7 +436,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
return
|
||||
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
raise ClientException("Can only bulk delete messages up to 100 messages")
|
||||
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
@ -483,11 +511,19 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
if check is MISSING:
|
||||
check = lambda m: True
|
||||
|
||||
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
|
||||
iterator = self.history(
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
oldest_first=oldest_first,
|
||||
around=around,
|
||||
)
|
||||
ret: List[Message] = []
|
||||
count = 0
|
||||
|
||||
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
|
||||
minimum_time = (
|
||||
int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
|
||||
)
|
||||
strategy = self.delete_messages if bulk else _single_delete_strategy
|
||||
|
||||
async for message in iterator:
|
||||
@ -548,7 +584,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
data = await self._state.http.channel_webhooks(self.id)
|
||||
return [Webhook.from_state(d, state=self._state) for d in data]
|
||||
|
||||
async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook:
|
||||
async def create_webhook(
|
||||
self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None
|
||||
) -> Webhook:
|
||||
"""|coro|
|
||||
|
||||
Creates a webhook for this channel.
|
||||
@ -586,10 +624,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
if avatar is not None:
|
||||
avatar = utils._bytes_to_base64_data(avatar) # type: ignore
|
||||
|
||||
data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason)
|
||||
data = await self._state.http.create_webhook(
|
||||
self.id, name=str(name), avatar=avatar, reason=reason
|
||||
)
|
||||
return Webhook.from_state(data, state=self._state)
|
||||
|
||||
async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook:
|
||||
async def follow(
|
||||
self, *, destination: TextChannel, reason: Optional[str] = None
|
||||
) -> Webhook:
|
||||
"""
|
||||
Follows a channel using a webhook.
|
||||
|
||||
@ -625,14 +667,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
|
||||
if not self.is_news():
|
||||
raise ClientException('The channel must be a news channel.')
|
||||
raise ClientException("The channel must be a news channel.")
|
||||
|
||||
if not isinstance(destination, TextChannel):
|
||||
raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}')
|
||||
raise InvalidArgument(
|
||||
f"Expected TextChannel received {destination.__class__.__name__}"
|
||||
)
|
||||
|
||||
from .webhook import Webhook
|
||||
|
||||
data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason)
|
||||
data = await self._state.http.follow_webhook(
|
||||
self.id, webhook_channel_id=destination.id, reason=reason
|
||||
)
|
||||
return Webhook._as_follower(data, channel=destination, user=self._state.user)
|
||||
|
||||
def get_partial_message(self, message_id: int, /) -> PartialMessage:
|
||||
@ -731,7 +777,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
data = await self._state.http.start_thread_without_message(
|
||||
self.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
|
||||
auto_archive_duration=auto_archive_duration
|
||||
or self.default_auto_archive_duration,
|
||||
type=type.value,
|
||||
reason=reason,
|
||||
)
|
||||
@ -740,7 +787,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
self.id,
|
||||
message.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration,
|
||||
auto_archive_duration=auto_archive_duration
|
||||
or self.default_auto_archive_duration,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
@ -787,45 +835,64 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
|
||||
:class:`Thread`
|
||||
The archived threads.
|
||||
"""
|
||||
return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before)
|
||||
return ArchivedThreadIterator(
|
||||
self.id,
|
||||
self.guild,
|
||||
limit=limit,
|
||||
joined=joined,
|
||||
private=private,
|
||||
before=before,
|
||||
)
|
||||
|
||||
|
||||
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'guild',
|
||||
'bitrate',
|
||||
'user_limit',
|
||||
'_state',
|
||||
'position',
|
||||
'_overwrites',
|
||||
'category_id',
|
||||
'rtc_region',
|
||||
'video_quality_mode',
|
||||
"name",
|
||||
"id",
|
||||
"guild",
|
||||
"bitrate",
|
||||
"user_limit",
|
||||
"_state",
|
||||
"position",
|
||||
"_overwrites",
|
||||
"category_id",
|
||||
"rtc_region",
|
||||
"video_quality_mode",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: ConnectionState,
|
||||
guild: Guild,
|
||||
data: Union[VoiceChannelPayload, StageChannelPayload],
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
self._update(guild, data)
|
||||
|
||||
def _get_voice_client_key(self) -> Tuple[int, str]:
|
||||
return self.guild.id, 'guild_id'
|
||||
return self.guild.id, "guild_id"
|
||||
|
||||
def _get_voice_state_pair(self) -> Tuple[int, int]:
|
||||
return self.guild.id, self.id
|
||||
|
||||
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
|
||||
def _update(
|
||||
self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]
|
||||
) -> None:
|
||||
self.guild = guild
|
||||
self.name: str = data['name']
|
||||
rtc = data.get('rtc_region')
|
||||
self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None
|
||||
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
|
||||
self.position: int = data['position']
|
||||
self.bitrate: int = data.get('bitrate')
|
||||
self.user_limit: int = data.get('user_limit')
|
||||
self.name: str = data["name"]
|
||||
rtc = data.get("rtc_region")
|
||||
self.rtc_region: Optional[VoiceRegion] = (
|
||||
try_enum(VoiceRegion, rtc) if rtc is not None else None
|
||||
)
|
||||
self.video_quality_mode: VideoQualityMode = try_enum(
|
||||
VideoQualityMode, data.get("video_quality_mode", 1)
|
||||
)
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
|
||||
self.position: int = data["position"]
|
||||
self.bitrate: int = data.get("bitrate")
|
||||
self.user_limit: int = data.get("user_limit")
|
||||
self._fill_overwrites(data)
|
||||
|
||||
@property
|
||||
@ -933,17 +1000,17 @@ class VoiceChannel(VocalGuildChannel):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = [
|
||||
('id', self.id),
|
||||
('name', self.name),
|
||||
('rtc_region', self.rtc_region),
|
||||
('position', self.position),
|
||||
('bitrate', self.bitrate),
|
||||
('video_quality_mode', self.video_quality_mode),
|
||||
('user_limit', self.user_limit),
|
||||
('category_id', self.category_id),
|
||||
("id", self.id),
|
||||
("name", self.name),
|
||||
("rtc_region", self.rtc_region),
|
||||
("position", self.position),
|
||||
("bitrate", self.bitrate),
|
||||
("video_quality_mode", self.video_quality_mode),
|
||||
("user_limit", self.user_limit),
|
||||
("category_id", self.category_id),
|
||||
]
|
||||
joined = ' '.join('%s=%r' % t for t in attrs)
|
||||
return f'<{self.__class__.__name__} {joined}>'
|
||||
joined = " ".join("%s=%r" % t for t in attrs)
|
||||
return f"<{self.__class__.__name__} {joined}>"
|
||||
|
||||
@property
|
||||
def type(self) -> ChannelType:
|
||||
@ -951,8 +1018,14 @@ class VoiceChannel(VocalGuildChannel):
|
||||
return ChannelType.voice
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
|
||||
return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason)
|
||||
async def clone(
|
||||
self, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> VoiceChannel:
|
||||
return await self._clone_impl(
|
||||
{"bitrate": self.bitrate, "user_limit": self.user_limit},
|
||||
name=name,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@ -1093,31 +1166,35 @@ class StageChannel(VocalGuildChannel):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
__slots__ = ('topic',)
|
||||
__slots__ = ("topic",)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = [
|
||||
('id', self.id),
|
||||
('name', self.name),
|
||||
('topic', self.topic),
|
||||
('rtc_region', self.rtc_region),
|
||||
('position', self.position),
|
||||
('bitrate', self.bitrate),
|
||||
('video_quality_mode', self.video_quality_mode),
|
||||
('user_limit', self.user_limit),
|
||||
('category_id', self.category_id),
|
||||
("id", self.id),
|
||||
("name", self.name),
|
||||
("topic", self.topic),
|
||||
("rtc_region", self.rtc_region),
|
||||
("position", self.position),
|
||||
("bitrate", self.bitrate),
|
||||
("video_quality_mode", self.video_quality_mode),
|
||||
("user_limit", self.user_limit),
|
||||
("category_id", self.category_id),
|
||||
]
|
||||
joined = ' '.join('%s=%r' % t for t in attrs)
|
||||
return f'<{self.__class__.__name__} {joined}>'
|
||||
joined = " ".join("%s=%r" % t for t in attrs)
|
||||
return f"<{self.__class__.__name__} {joined}>"
|
||||
|
||||
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
|
||||
super()._update(guild, data)
|
||||
self.topic = data.get('topic')
|
||||
self.topic = data.get("topic")
|
||||
|
||||
@property
|
||||
def requesting_to_speak(self) -> List[Member]:
|
||||
"""List[:class:`Member`]: A list of members who are requesting to speak in the stage channel."""
|
||||
return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None]
|
||||
return [
|
||||
member
|
||||
for member in self.members
|
||||
if member.voice and member.voice.requested_to_speak_at is not None
|
||||
]
|
||||
|
||||
@property
|
||||
def speakers(self) -> List[Member]:
|
||||
@ -1128,7 +1205,9 @@ class StageChannel(VocalGuildChannel):
|
||||
return [
|
||||
member
|
||||
for member in self.members
|
||||
if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None
|
||||
if member.voice
|
||||
and not member.voice.suppress
|
||||
and member.voice.requested_to_speak_at is None
|
||||
]
|
||||
|
||||
@property
|
||||
@ -1137,7 +1216,9 @@ class StageChannel(VocalGuildChannel):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return [member for member in self.members if member.voice and member.voice.suppress]
|
||||
return [
|
||||
member for member in self.members if member.voice and member.voice.suppress
|
||||
]
|
||||
|
||||
@property
|
||||
def moderators(self) -> List[Member]:
|
||||
@ -1146,7 +1227,11 @@ class StageChannel(VocalGuildChannel):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
required_permissions = Permissions.stage_moderator()
|
||||
return [member for member in self.members if self.permissions_for(member) >= required_permissions]
|
||||
return [
|
||||
member
|
||||
for member in self.members
|
||||
if self.permissions_for(member) >= required_permissions
|
||||
]
|
||||
|
||||
@property
|
||||
def type(self) -> ChannelType:
|
||||
@ -1154,7 +1239,9 @@ class StageChannel(VocalGuildChannel):
|
||||
return ChannelType.stage_voice
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel:
|
||||
async def clone(
|
||||
self, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> StageChannel:
|
||||
return await self._clone_impl({}, name=name, reason=reason)
|
||||
|
||||
@property
|
||||
@ -1166,7 +1253,11 @@ class StageChannel(VocalGuildChannel):
|
||||
return utils.get(self.guild.stage_instances, channel_id=self.id)
|
||||
|
||||
async def create_instance(
|
||||
self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
|
||||
self,
|
||||
*,
|
||||
topic: str,
|
||||
privacy_level: StagePrivacyLevel = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> StageInstance:
|
||||
"""|coro|
|
||||
|
||||
@ -1201,13 +1292,15 @@ class StageChannel(VocalGuildChannel):
|
||||
The newly created stage instance.
|
||||
"""
|
||||
|
||||
payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic}
|
||||
payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic}
|
||||
|
||||
if privacy_level is not MISSING:
|
||||
if not isinstance(privacy_level, StagePrivacyLevel):
|
||||
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
|
||||
raise InvalidArgument(
|
||||
"privacy_level field must be of type PrivacyLevel"
|
||||
)
|
||||
|
||||
payload['privacy_level'] = privacy_level.value
|
||||
payload["privacy_level"] = privacy_level.value
|
||||
|
||||
data = await self._state.http.create_stage_instance(**payload, reason=reason)
|
||||
return StageInstance(guild=self.guild, state=self._state, data=data)
|
||||
@ -1361,22 +1454,33 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
|
||||
"""
|
||||
|
||||
__slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id')
|
||||
__slots__ = (
|
||||
"name",
|
||||
"id",
|
||||
"guild",
|
||||
"nsfw",
|
||||
"_state",
|
||||
"position",
|
||||
"_overwrites",
|
||||
"category_id",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
|
||||
def __init__(
|
||||
self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
self._update(guild, data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
|
||||
return f"<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
|
||||
|
||||
def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
|
||||
self.guild: Guild = guild
|
||||
self.name: str = data['name']
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
|
||||
self.nsfw: bool = data.get('nsfw', False)
|
||||
self.position: int = data['position']
|
||||
self.name: str = data["name"]
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
self.position: int = data["position"]
|
||||
self._fill_overwrites(data)
|
||||
|
||||
@property
|
||||
@ -1393,8 +1497,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
return self.nsfw
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
|
||||
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
|
||||
async def clone(
|
||||
self, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> CategoryChannel:
|
||||
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@ -1463,7 +1569,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.move)
|
||||
async def move(self, **kwargs):
|
||||
kwargs.pop('category', None)
|
||||
kwargs.pop("category", None)
|
||||
await super().move(**kwargs)
|
||||
|
||||
@property
|
||||
@ -1483,14 +1589,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
@property
|
||||
def text_channels(self) -> List[TextChannel]:
|
||||
"""List[:class:`TextChannel`]: Returns the text channels that are under this category."""
|
||||
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)]
|
||||
ret = [
|
||||
c
|
||||
for c in self.guild.channels
|
||||
if c.category_id == self.id and isinstance(c, TextChannel)
|
||||
]
|
||||
ret.sort(key=lambda c: (c.position, c.id))
|
||||
return ret
|
||||
|
||||
@property
|
||||
def voice_channels(self) -> List[VoiceChannel]:
|
||||
"""List[:class:`VoiceChannel`]: Returns the voice channels that are under this category."""
|
||||
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)]
|
||||
ret = [
|
||||
c
|
||||
for c in self.guild.channels
|
||||
if c.category_id == self.id and isinstance(c, VoiceChannel)
|
||||
]
|
||||
ret.sort(key=lambda c: (c.position, c.id))
|
||||
return ret
|
||||
|
||||
@ -1500,7 +1614,11 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)]
|
||||
ret = [
|
||||
c
|
||||
for c in self.guild.channels
|
||||
if c.category_id == self.id and isinstance(c, StageChannel)
|
||||
]
|
||||
ret.sort(key=lambda c: (c.position, c.id))
|
||||
return ret
|
||||
|
||||
@ -1590,30 +1708,32 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'guild',
|
||||
'_state',
|
||||
'nsfw',
|
||||
'category_id',
|
||||
'position',
|
||||
'_overwrites',
|
||||
"name",
|
||||
"id",
|
||||
"guild",
|
||||
"_state",
|
||||
"nsfw",
|
||||
"category_id",
|
||||
"position",
|
||||
"_overwrites",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload):
|
||||
def __init__(
|
||||
self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
self._update(guild, data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
|
||||
return f"<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
|
||||
|
||||
def _update(self, guild: Guild, data: StoreChannelPayload) -> None:
|
||||
self.guild: Guild = guild
|
||||
self.name: str = data['name']
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
|
||||
self.position: int = data['position']
|
||||
self.nsfw: bool = data.get('nsfw', False)
|
||||
self.name: str = data["name"]
|
||||
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
|
||||
self.position: int = data["position"]
|
||||
self.nsfw: bool = data.get("nsfw", False)
|
||||
self._fill_overwrites(data)
|
||||
|
||||
@property
|
||||
@ -1639,8 +1759,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
return self.nsfw
|
||||
|
||||
@utils.copy_doc(discord.abc.GuildChannel.clone)
|
||||
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel:
|
||||
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
|
||||
async def clone(
|
||||
self, *, name: Optional[str] = None, reason: Optional[str] = None
|
||||
) -> StoreChannel:
|
||||
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
|
||||
|
||||
@overload
|
||||
async def edit(
|
||||
@ -1716,7 +1838,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
|
||||
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
|
||||
|
||||
|
||||
DMC = TypeVar('DMC', bound='DMChannel')
|
||||
DMC = TypeVar("DMC", bound="DMChannel")
|
||||
|
||||
|
||||
class DMChannel(discord.abc.Messageable, Hashable):
|
||||
@ -1756,24 +1878,26 @@ class DMChannel(discord.abc.Messageable, Hashable):
|
||||
The direct message channel ID.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'recipient', 'me', '_state')
|
||||
__slots__ = ("id", "recipient", "me", "_state")
|
||||
|
||||
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
|
||||
def __init__(
|
||||
self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.recipient: Optional[User] = state.store_user(data['recipients'][0])
|
||||
self.recipient: Optional[User] = state.store_user(data["recipients"][0])
|
||||
self.me: ClientUser = me
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
|
||||
async def _get_channel(self):
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.recipient:
|
||||
return f'Direct Message with {self.recipient}'
|
||||
return 'Direct Message with Unknown User'
|
||||
return f"Direct Message with {self.recipient}"
|
||||
return "Direct Message with Unknown User"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<DMChannel id={self.id} recipient={self.recipient!r}>'
|
||||
return f"<DMChannel id={self.id} recipient={self.recipient!r}>"
|
||||
|
||||
@classmethod
|
||||
def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC:
|
||||
@ -1892,19 +2016,32 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
The group channel's name if provided.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state')
|
||||
__slots__ = (
|
||||
"id",
|
||||
"recipients",
|
||||
"owner_id",
|
||||
"owner",
|
||||
"_icon",
|
||||
"name",
|
||||
"me",
|
||||
"_state",
|
||||
)
|
||||
|
||||
def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
|
||||
def __init__(
|
||||
self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
self.me: ClientUser = me
|
||||
self._update_group(data)
|
||||
|
||||
def _update_group(self, data: GroupChannelPayload) -> None:
|
||||
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id')
|
||||
self._icon: Optional[str] = data.get('icon')
|
||||
self.name: Optional[str] = data.get('name')
|
||||
self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])]
|
||||
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id")
|
||||
self._icon: Optional[str] = data.get("icon")
|
||||
self.name: Optional[str] = data.get("name")
|
||||
self.recipients: List[User] = [
|
||||
self._state.store_user(u) for u in data.get("recipients", [])
|
||||
]
|
||||
|
||||
self.owner: Optional[BaseUser]
|
||||
if self.owner_id == self.me.id:
|
||||
@ -1920,12 +2057,12 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
return self.name
|
||||
|
||||
if len(self.recipients) == 0:
|
||||
return 'Unnamed'
|
||||
return "Unnamed"
|
||||
|
||||
return ', '.join(map(lambda x: x.name, self.recipients))
|
||||
return ", ".join(map(lambda x: x.name, self.recipients))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<GroupChannel id={self.id} name={self.name!r}>'
|
||||
return f"<GroupChannel id={self.id} name={self.name!r}>"
|
||||
|
||||
@property
|
||||
def type(self) -> ChannelType:
|
||||
@ -1937,7 +2074,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
|
||||
"""Optional[:class:`Asset`]: Returns the channel's icon asset if available."""
|
||||
if self._icon is None:
|
||||
return None
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path='channel')
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path="channel")
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
@ -2032,7 +2169,9 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
|
||||
The channel type associated with this partial messageable, if given.
|
||||
"""
|
||||
|
||||
def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None):
|
||||
def __init__(
|
||||
self, state: ConnectionState, id: int, type: Optional[ChannelType] = None
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self._channel: Object = Object(id=id)
|
||||
self.id: int = id
|
||||
@ -2093,13 +2232,21 @@ def _channel_factory(channel_type: int):
|
||||
|
||||
def _threaded_channel_factory(channel_type: int):
|
||||
cls, value = _channel_factory(channel_type)
|
||||
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
|
||||
if value in (
|
||||
ChannelType.private_thread,
|
||||
ChannelType.public_thread,
|
||||
ChannelType.news_thread,
|
||||
):
|
||||
return Thread, value
|
||||
return cls, value
|
||||
|
||||
|
||||
def _threaded_guild_channel_factory(channel_type: int):
|
||||
cls, value = _guild_channel_factory(channel_type)
|
||||
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
|
||||
if value in (
|
||||
ChannelType.private_thread,
|
||||
ChannelType.public_thread,
|
||||
ChannelType.news_thread,
|
||||
):
|
||||
return Thread, value
|
||||
return cls, value
|
||||
|
@ -29,7 +29,20 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -69,46 +82,49 @@ if TYPE_CHECKING:
|
||||
from .member import Member
|
||||
from .voice_client import VoiceProtocol
|
||||
|
||||
__all__ = (
|
||||
'Client',
|
||||
)
|
||||
__all__ = ("Client",)
|
||||
|
||||
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
|
||||
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
|
||||
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
|
||||
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
_log.info('Cleaning up after %d tasks.', len(tasks))
|
||||
_log.info("Cleaning up after %d tasks.", len(tasks))
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
||||
_log.info('All tasks finished cancelling.')
|
||||
_log.info("All tasks finished cancelling.")
|
||||
|
||||
for task in tasks:
|
||||
if task.cancelled():
|
||||
continue
|
||||
if task.exception() is not None:
|
||||
loop.call_exception_handler({
|
||||
'message': 'Unhandled exception during Client.run shutdown.',
|
||||
'exception': task.exception(),
|
||||
'task': task
|
||||
})
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "Unhandled exception during Client.run shutdown.",
|
||||
"exception": task.exception(),
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
try:
|
||||
_cancel_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
_log.info('Closing the event loop.')
|
||||
_log.info("Closing the event loop.")
|
||||
loop.close()
|
||||
|
||||
|
||||
class Client:
|
||||
r"""Represents a client connection that connects to Discord.
|
||||
This class is used to interact with the Discord WebSocket and API.
|
||||
@ -199,6 +215,7 @@ class Client:
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
The event loop that the client uses for asynchronous operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -210,26 +227,34 @@ class Client:
|
||||
|
||||
# self.ws is set in the connect method
|
||||
self.ws: DiscordWebSocket = None # type: ignore
|
||||
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
|
||||
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
|
||||
self.shard_id: Optional[int] = options.get('shard_id')
|
||||
self.shard_count: Optional[int] = options.get('shard_count')
|
||||
self.loop: asyncio.AbstractEventLoop = (
|
||||
asyncio.get_event_loop() if loop is None else loop
|
||||
)
|
||||
self._listeners: Dict[
|
||||
str, List[Tuple[asyncio.Future, Callable[..., bool]]]
|
||||
] = {}
|
||||
self.shard_id: Optional[int] = options.get("shard_id")
|
||||
self.shard_count: Optional[int] = options.get("shard_count")
|
||||
|
||||
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
|
||||
proxy: Optional[str] = options.pop('proxy', None)
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
|
||||
unsync_clock: bool = options.pop('assume_unsync_clock', True)
|
||||
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
|
||||
connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None)
|
||||
proxy: Optional[str] = options.pop("proxy", None)
|
||||
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None)
|
||||
unsync_clock: bool = options.pop("assume_unsync_clock", True)
|
||||
self.http: HTTPClient = HTTPClient(
|
||||
connector,
|
||||
proxy=proxy,
|
||||
proxy_auth=proxy_auth,
|
||||
unsync_clock=unsync_clock,
|
||||
loop=self.loop,
|
||||
)
|
||||
|
||||
self._handlers: Dict[str, Callable] = {
|
||||
'ready': self._handle_ready
|
||||
}
|
||||
self._handlers: Dict[str, Callable] = {"ready": self._handle_ready}
|
||||
|
||||
self._hooks: Dict[str, Callable] = {
|
||||
'before_identify': self._call_before_identify_hook
|
||||
"before_identify": self._call_before_identify_hook
|
||||
}
|
||||
|
||||
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
|
||||
self._enable_debug_events: bool = options.pop("enable_debug_events", False)
|
||||
self._connection: ConnectionState = self._get_state(**options)
|
||||
self._connection.shard_count = self.shard_count
|
||||
self._closed: bool = False
|
||||
@ -243,12 +268,20 @@ class Client:
|
||||
|
||||
# internals
|
||||
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
|
||||
def _get_websocket(
|
||||
self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None
|
||||
) -> DiscordWebSocket:
|
||||
return self.ws
|
||||
|
||||
def _get_state(self, **options: Any) -> ConnectionState:
|
||||
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
||||
hooks=self._hooks, http=self.http, loop=self.loop, **options)
|
||||
return ConnectionState(
|
||||
dispatch=self.dispatch,
|
||||
handlers=self._handlers,
|
||||
hooks=self._hooks,
|
||||
http=self.http,
|
||||
loop=self.loop,
|
||||
**options,
|
||||
)
|
||||
|
||||
def _handle_ready(self) -> None:
|
||||
self._ready.set()
|
||||
@ -260,7 +293,7 @@ class Client:
|
||||
This could be referred to as the Discord WebSocket protocol latency.
|
||||
"""
|
||||
ws = self.ws
|
||||
return float('nan') if not ws else ws.latency
|
||||
return float("nan") if not ws else ws.latency
|
||||
|
||||
def is_ws_ratelimited(self) -> bool:
|
||||
""":class:`bool`: Whether the websocket is currently rate limited.
|
||||
@ -331,7 +364,7 @@ class Client:
|
||||
If this is not passed via ``__init__`` then this is retrieved
|
||||
through the gateway when an event contains the data. Usually
|
||||
after :func:`~discord.on_connect` is called.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.application_id
|
||||
@ -348,7 +381,13 @@ class Client:
|
||||
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
|
||||
return self._ready.is_set()
|
||||
|
||||
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
async def _run_event(
|
||||
self,
|
||||
coro: Callable[..., Coroutine[Any, Any, Any]],
|
||||
event_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
await coro(*args, **kwargs)
|
||||
except asyncio.CancelledError:
|
||||
@ -359,14 +398,20 @@ class Client:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
|
||||
def _schedule_event(
|
||||
self,
|
||||
coro: Callable[..., Coroutine[Any, Any, Any]],
|
||||
event_name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> asyncio.Task:
|
||||
wrapped = self._run_event(coro, event_name, *args, **kwargs)
|
||||
# Schedules the task
|
||||
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
|
||||
return asyncio.create_task(wrapped, name=f"discord.py: {event_name}")
|
||||
|
||||
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
|
||||
_log.debug('Dispatching event %s', event)
|
||||
method = 'on_' + event
|
||||
_log.debug("Dispatching event %s", event)
|
||||
method = "on_" + event
|
||||
|
||||
listeners = self._listeners.get(event)
|
||||
if listeners:
|
||||
@ -413,18 +458,22 @@ class Client:
|
||||
overridden to have a different implementation.
|
||||
Check :func:`~discord.on_error` for more details.
|
||||
"""
|
||||
print(f'Ignoring exception in {event_method}', file=sys.stderr)
|
||||
print(f"Ignoring exception in {event_method}", file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
|
||||
# hooks
|
||||
|
||||
async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
|
||||
async def _call_before_identify_hook(
|
||||
self, shard_id: Optional[int], *, initial: bool = False
|
||||
) -> None:
|
||||
# This hook is an internal hook that actually calls the public one.
|
||||
# It allows the library to have its own hook without stepping on the
|
||||
# toes of those who need to override their own hook.
|
||||
await self.before_identify_hook(shard_id, initial=initial)
|
||||
|
||||
async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None:
|
||||
async def before_identify_hook(
|
||||
self, shard_id: Optional[int], *, initial: bool = False
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
A hook that is called before IDENTIFYing a session. This is useful
|
||||
@ -470,7 +519,7 @@ class Client:
|
||||
passing status code.
|
||||
"""
|
||||
|
||||
_log.info('logging in using static token')
|
||||
_log.info("logging in using static token")
|
||||
|
||||
data = await self.http.static_login(token.strip())
|
||||
self._connection.user = ClientUser(state=self._connection, data=data)
|
||||
@ -502,29 +551,35 @@ class Client:
|
||||
|
||||
backoff = ExponentialBackoff()
|
||||
ws_params = {
|
||||
'initial': True,
|
||||
'shard_id': self.shard_id,
|
||||
"initial": True,
|
||||
"shard_id": self.shard_id,
|
||||
}
|
||||
while not self.is_closed():
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, **ws_params)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=60.0)
|
||||
ws_params['initial'] = False
|
||||
ws_params["initial"] = False
|
||||
while True:
|
||||
await self.ws.poll_event()
|
||||
except ReconnectWebSocket as e:
|
||||
_log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
|
||||
_log.info("Got a request to %s the websocket.", e.op)
|
||||
self.dispatch("disconnect")
|
||||
ws_params.update(
|
||||
sequence=self.ws.sequence,
|
||||
resume=e.resume,
|
||||
session=self.ws.session_id,
|
||||
)
|
||||
continue
|
||||
except (OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError) as exc:
|
||||
except (
|
||||
OSError,
|
||||
HTTPException,
|
||||
GatewayNotFound,
|
||||
ConnectionClosed,
|
||||
aiohttp.ClientError,
|
||||
asyncio.TimeoutError,
|
||||
) as exc:
|
||||
|
||||
self.dispatch('disconnect')
|
||||
self.dispatch("disconnect")
|
||||
if not reconnect:
|
||||
await self.close()
|
||||
if isinstance(exc, ConnectionClosed) and exc.code == 1000:
|
||||
@ -537,7 +592,12 @@ class Client:
|
||||
|
||||
# If we get connection reset by peer then try to RESUME
|
||||
if isinstance(exc, OSError) and exc.errno in (54, 10054):
|
||||
ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id)
|
||||
ws_params.update(
|
||||
sequence=self.ws.sequence,
|
||||
initial=False,
|
||||
resume=True,
|
||||
session=self.ws.session_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# We should only get this when an unhandled close code happens,
|
||||
@ -557,7 +617,9 @@ class Client:
|
||||
# Always try to RESUME the connection
|
||||
# If the connection is not RESUME-able then the gateway will invalidate the session.
|
||||
# This is apparently what the official Discord client does.
|
||||
ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
|
||||
ws_params.update(
|
||||
sequence=self.ws.sequence, resume=True, session=self.ws.session_id
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""|coro|
|
||||
@ -654,10 +716,10 @@ class Client:
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
_log.info('Received signal to terminate bot and event loop.')
|
||||
_log.info("Received signal to terminate bot and event loop.")
|
||||
finally:
|
||||
future.remove_done_callback(stop_loop_on_completion)
|
||||
_log.info('Cleaning up tasks.')
|
||||
_log.info("Cleaning up tasks.")
|
||||
_cleanup_loop(loop)
|
||||
|
||||
if not future.cancelled():
|
||||
@ -686,10 +748,10 @@ class Client:
|
||||
self._connection._activity = None
|
||||
elif isinstance(value, BaseActivity):
|
||||
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
self._connection._activity = value.to_dict() # type: ignore
|
||||
else:
|
||||
raise TypeError('activity must derive from BaseActivity.')
|
||||
|
||||
raise TypeError("activity must derive from BaseActivity.")
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
""":class:`.Status`:
|
||||
@ -704,11 +766,11 @@ class Client:
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is Status.offline:
|
||||
self._connection._status = 'invisible'
|
||||
self._connection._status = "invisible"
|
||||
elif isinstance(value, Status):
|
||||
self._connection._status = str(value)
|
||||
else:
|
||||
raise TypeError('status must derive from Status.')
|
||||
raise TypeError("status must derive from Status.")
|
||||
|
||||
@property
|
||||
def allowed_mentions(self) -> Optional[AllowedMentions]:
|
||||
@ -723,7 +785,9 @@ class Client:
|
||||
if value is None or isinstance(value, AllowedMentions):
|
||||
self._connection.allowed_mentions = value
|
||||
else:
|
||||
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
|
||||
raise TypeError(
|
||||
f"allowed_mentions must be AllowedMentions not {value.__class__!r}"
|
||||
)
|
||||
|
||||
@property
|
||||
def intents(self) -> Intents:
|
||||
@ -740,7 +804,9 @@ class Client:
|
||||
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
|
||||
return list(self._connection._users.values())
|
||||
|
||||
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
|
||||
def get_channel(
|
||||
self, id: int, /
|
||||
) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]:
|
||||
"""Returns a channel or thread with the given ID.
|
||||
|
||||
Parameters
|
||||
@ -755,12 +821,14 @@ class Client:
|
||||
"""
|
||||
return self._connection.get_channel(id)
|
||||
|
||||
def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable:
|
||||
def get_partial_messageable(
|
||||
self, id: int, *, type: Optional[ChannelType] = None
|
||||
) -> PartialMessageable:
|
||||
"""Returns a partial messageable with the given channel ID.
|
||||
|
||||
This is useful if you have a channel_id but don't want to do an API call
|
||||
to send messages to it.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
@ -1001,8 +1069,10 @@ class Client:
|
||||
|
||||
future = self.loop.create_future()
|
||||
if check is None:
|
||||
|
||||
def _check(*args):
|
||||
return True
|
||||
|
||||
check = _check
|
||||
|
||||
ev = event.lower()
|
||||
@ -1040,10 +1110,10 @@ class Client:
|
||||
"""
|
||||
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('event registered must be a coroutine function')
|
||||
raise TypeError("event registered must be a coroutine function")
|
||||
|
||||
setattr(self, coro.__name__, coro)
|
||||
_log.debug('%s has successfully been registered as an event', coro.__name__)
|
||||
_log.debug("%s has successfully been registered as an event", coro.__name__)
|
||||
return coro
|
||||
|
||||
async def change_presence(
|
||||
@ -1082,10 +1152,10 @@ class Client:
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status_str = 'online'
|
||||
status_str = "online"
|
||||
status = Status.online
|
||||
elif status is Status.offline:
|
||||
status_str = 'invisible'
|
||||
status_str = "invisible"
|
||||
status = Status.offline
|
||||
else:
|
||||
status_str = str(status)
|
||||
@ -1111,7 +1181,7 @@ class Client:
|
||||
*,
|
||||
limit: Optional[int] = 100,
|
||||
before: SnowflakeTime = None,
|
||||
after: SnowflakeTime = None
|
||||
after: SnowflakeTime = None,
|
||||
) -> GuildIterator:
|
||||
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
|
||||
|
||||
@ -1191,7 +1261,7 @@ class Client:
|
||||
"""
|
||||
code = utils.resolve_template(code)
|
||||
data = await self.http.get_template(code)
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
return Template(data=data, state=self._connection) # type: ignore
|
||||
|
||||
async def fetch_guild(self, guild_id: int, /) -> Guild:
|
||||
"""|coro|
|
||||
@ -1277,7 +1347,9 @@ class Client:
|
||||
region_value = str(region)
|
||||
|
||||
if code:
|
||||
data = await self.http.create_from_template(code, name, region_value, icon_base64)
|
||||
data = await self.http.create_from_template(
|
||||
code, name, region_value, icon_base64
|
||||
)
|
||||
else:
|
||||
data = await self.http.create_guild(name, region_value, icon_base64)
|
||||
return Guild(data=data, state=self._connection)
|
||||
@ -1307,12 +1379,18 @@ class Client:
|
||||
The stage instance from the stage channel ID.
|
||||
"""
|
||||
data = await self.http.get_stage_instance(channel_id)
|
||||
guild = self.get_guild(int(data['guild_id']))
|
||||
guild = self.get_guild(int(data["guild_id"]))
|
||||
return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
# Invite management
|
||||
|
||||
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
|
||||
async def fetch_invite(
|
||||
self,
|
||||
url: Union[Invite, str],
|
||||
*,
|
||||
with_counts: bool = True,
|
||||
with_expiration: bool = True,
|
||||
) -> Invite:
|
||||
"""|coro|
|
||||
|
||||
Gets an :class:`.Invite` from a discord.gg URL or ID.
|
||||
@ -1351,7 +1429,9 @@ class Client:
|
||||
"""
|
||||
|
||||
invite_id = utils.resolve_invite(url)
|
||||
data = await self.http.get_invite(invite_id, with_counts=with_counts, with_expiration=with_expiration)
|
||||
data = await self.http.get_invite(
|
||||
invite_id, with_counts=with_counts, with_expiration=with_expiration
|
||||
)
|
||||
return Invite.from_incomplete(state=self._connection, data=data)
|
||||
|
||||
async def delete_invite(self, invite: Union[Invite, str]) -> None:
|
||||
@ -1428,8 +1508,8 @@ class Client:
|
||||
The bot's application information.
|
||||
"""
|
||||
data = await self.http.application_info()
|
||||
if 'rpc_origins' not in data:
|
||||
data['rpc_origins'] = None
|
||||
if "rpc_origins" not in data:
|
||||
data["rpc_origins"] = None
|
||||
return AppInfo(self._connection, data)
|
||||
|
||||
async def fetch_user(self, user_id: int, /) -> User:
|
||||
@ -1463,7 +1543,9 @@ class Client:
|
||||
data = await self.http.get_user(user_id)
|
||||
return User(state=self._connection, data=data)
|
||||
|
||||
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
async def fetch_channel(
|
||||
self, channel_id: int, /
|
||||
) -> Union[GuildChannel, PrivateChannel, Thread]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
|
||||
@ -1492,19 +1574,21 @@ class Client:
|
||||
"""
|
||||
data = await self.http.get_channel(channel_id)
|
||||
|
||||
factory, ch_type = _threaded_channel_factory(data['type'])
|
||||
factory, ch_type = _threaded_channel_factory(data["type"])
|
||||
if factory is None:
|
||||
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
|
||||
raise InvalidData(
|
||||
"Unknown channel type {type} for channel ID {id}.".format_map(data)
|
||||
)
|
||||
|
||||
if ch_type in (ChannelType.group, ChannelType.private):
|
||||
# the factory will be a DMChannel or GroupChannel here
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
|
||||
else:
|
||||
# the factory can't be a DMChannel or GroupChannel here
|
||||
guild_id = int(data['guild_id']) # type: ignore
|
||||
guild_id = int(data["guild_id"]) # type: ignore
|
||||
guild = self.get_guild(guild_id) or Object(id=guild_id)
|
||||
# GuildChannels expect a Guild, we may be passing an Object
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
|
||||
|
||||
return channel
|
||||
|
||||
@ -1530,7 +1614,9 @@ class Client:
|
||||
data = await self.http.get_webhook(webhook_id)
|
||||
return Webhook.from_state(data, state=self._connection)
|
||||
|
||||
async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]:
|
||||
async def fetch_sticker(
|
||||
self, sticker_id: int, /
|
||||
) -> Union[StandardSticker, GuildSticker]:
|
||||
"""|coro|
|
||||
|
||||
Retrieves a :class:`.Sticker` with the specified ID.
|
||||
@ -1550,8 +1636,8 @@ class Client:
|
||||
The sticker you requested.
|
||||
"""
|
||||
data = await self.http.get_sticker(sticker_id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
return cls(state=self._connection, data=data) # type: ignore
|
||||
cls, _ = _sticker_factory(data["type"]) # type: ignore
|
||||
return cls(state=self._connection, data=data) # type: ignore
|
||||
|
||||
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
|
||||
"""|coro|
|
||||
@ -1571,7 +1657,10 @@ class Client:
|
||||
All available premium sticker packs.
|
||||
"""
|
||||
data = await self.http.list_premium_sticker_packs()
|
||||
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
|
||||
return [
|
||||
StickerPack(state=self._connection, data=pack)
|
||||
for pack in data["sticker_packs"]
|
||||
]
|
||||
|
||||
async def create_dm(self, user: Snowflake) -> DMChannel:
|
||||
"""|coro|
|
||||
@ -1606,7 +1695,7 @@ class Client:
|
||||
|
||||
This method should be used for when a view is comprised of components
|
||||
that last longer than the lifecycle of the program.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Parameters
|
||||
@ -1628,17 +1717,19 @@ class Client:
|
||||
"""
|
||||
|
||||
if not isinstance(view, View):
|
||||
raise TypeError(f'expected an instance of View not {view.__class__!r}')
|
||||
raise TypeError(f"expected an instance of View not {view.__class__!r}")
|
||||
|
||||
if not view.is_persistent():
|
||||
raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout')
|
||||
raise ValueError(
|
||||
"View is not persistent. Items need to have a custom_id set and View must have no timeout"
|
||||
)
|
||||
|
||||
self._connection.store_view(view, message_id)
|
||||
|
||||
@property
|
||||
def persistent_views(self) -> Sequence[View]:
|
||||
"""Sequence[:class:`.View`]: A sequence of persistent views added to the client.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._connection.persistent_views
|
||||
|
@ -35,11 +35,11 @@ from typing import (
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'Colour',
|
||||
'Color',
|
||||
"Colour",
|
||||
"Color",
|
||||
)
|
||||
|
||||
CT = TypeVar('CT', bound='Colour')
|
||||
CT = TypeVar("CT", bound="Colour")
|
||||
|
||||
|
||||
class Colour:
|
||||
@ -76,16 +76,18 @@ class Colour:
|
||||
The raw integer colour value.
|
||||
"""
|
||||
|
||||
__slots__ = ('value',)
|
||||
__slots__ = ("value",)
|
||||
|
||||
def __init__(self, value: int):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
|
||||
raise TypeError(
|
||||
f"Expected int parameter, received {value.__class__.__name__} instead."
|
||||
)
|
||||
|
||||
self.value: int = value
|
||||
|
||||
def _get_byte(self, byte: int) -> int:
|
||||
return (self.value >> (8 * byte)) & 0xff
|
||||
return (self.value >> (8 * byte)) & 0xFF
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Colour) and self.value == other.value
|
||||
@ -94,13 +96,13 @@ class Colour:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'#{self.value:0>6x}'
|
||||
return f"#{self.value:0>6x}"
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Colour value={self.value}>'
|
||||
return f"<Colour value={self.value}>"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
@ -141,7 +143,11 @@ class Colour:
|
||||
return cls(0)
|
||||
|
||||
@classmethod
|
||||
def random(cls: Type[CT], *, seed: Optional[Union[int, str, float, bytes, bytearray]] = None) -> CT:
|
||||
def random(
|
||||
cls: Type[CT],
|
||||
*,
|
||||
seed: Optional[Union[int, str, float, bytes, bytearray]] = None,
|
||||
) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a random hue.
|
||||
|
||||
.. note::
|
||||
@ -164,12 +170,12 @@ class Colour:
|
||||
@classmethod
|
||||
def teal(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
|
||||
return cls(0x1abc9c)
|
||||
return cls(0x1ABC9C)
|
||||
|
||||
@classmethod
|
||||
def dark_teal(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
|
||||
return cls(0x11806a)
|
||||
return cls(0x11806A)
|
||||
|
||||
@classmethod
|
||||
def brand_green(cls: Type[CT]) -> CT:
|
||||
@ -182,17 +188,17 @@ class Colour:
|
||||
@classmethod
|
||||
def green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
|
||||
return cls(0x2ecc71)
|
||||
return cls(0x2ECC71)
|
||||
|
||||
@classmethod
|
||||
def dark_green(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
|
||||
return cls(0x1f8b4c)
|
||||
return cls(0x1F8B4C)
|
||||
|
||||
@classmethod
|
||||
def blue(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
|
||||
return cls(0x3498db)
|
||||
return cls(0x3498DB)
|
||||
|
||||
@classmethod
|
||||
def dark_blue(cls: Type[CT]) -> CT:
|
||||
@ -202,42 +208,42 @@ class Colour:
|
||||
@classmethod
|
||||
def purple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
|
||||
return cls(0x9b59b6)
|
||||
return cls(0x9B59B6)
|
||||
|
||||
@classmethod
|
||||
def dark_purple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
|
||||
return cls(0x71368a)
|
||||
return cls(0x71368A)
|
||||
|
||||
@classmethod
|
||||
def magenta(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
|
||||
return cls(0xe91e63)
|
||||
return cls(0xE91E63)
|
||||
|
||||
@classmethod
|
||||
def dark_magenta(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
|
||||
return cls(0xad1457)
|
||||
return cls(0xAD1457)
|
||||
|
||||
@classmethod
|
||||
def gold(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
|
||||
return cls(0xf1c40f)
|
||||
return cls(0xF1C40F)
|
||||
|
||||
@classmethod
|
||||
def dark_gold(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
|
||||
return cls(0xc27c0e)
|
||||
return cls(0xC27C0E)
|
||||
|
||||
@classmethod
|
||||
def orange(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
|
||||
return cls(0xe67e22)
|
||||
return cls(0xE67E22)
|
||||
|
||||
@classmethod
|
||||
def dark_orange(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
|
||||
return cls(0xa84300)
|
||||
return cls(0xA84300)
|
||||
|
||||
@classmethod
|
||||
def brand_red(cls: Type[CT]) -> CT:
|
||||
@ -250,45 +256,45 @@ class Colour:
|
||||
@classmethod
|
||||
def red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
|
||||
return cls(0xe74c3c)
|
||||
return cls(0xE74C3C)
|
||||
|
||||
@classmethod
|
||||
def dark_red(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
|
||||
return cls(0x992d22)
|
||||
return cls(0x992D22)
|
||||
|
||||
@classmethod
|
||||
def lighter_grey(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
|
||||
return cls(0x95a5a6)
|
||||
return cls(0x95A5A6)
|
||||
|
||||
lighter_gray = lighter_grey
|
||||
|
||||
@classmethod
|
||||
def dark_grey(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
|
||||
return cls(0x607d8b)
|
||||
return cls(0x607D8B)
|
||||
|
||||
dark_gray = dark_grey
|
||||
|
||||
@classmethod
|
||||
def light_grey(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
|
||||
return cls(0x979c9f)
|
||||
return cls(0x979C9F)
|
||||
|
||||
light_gray = light_grey
|
||||
|
||||
@classmethod
|
||||
def darker_grey(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
|
||||
return cls(0x546e7a)
|
||||
return cls(0x546E7A)
|
||||
|
||||
darker_gray = darker_grey
|
||||
|
||||
@classmethod
|
||||
def og_blurple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
|
||||
return cls(0x7289da)
|
||||
return cls(0x7289DA)
|
||||
|
||||
@classmethod
|
||||
def blurple(cls: Type[CT]) -> CT:
|
||||
@ -298,7 +304,7 @@ class Colour:
|
||||
@classmethod
|
||||
def greyple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
|
||||
return cls(0x99aab5)
|
||||
return cls(0x99AAB5)
|
||||
|
||||
@classmethod
|
||||
def dark_theme(cls: Type[CT]) -> CT:
|
||||
@ -324,7 +330,7 @@ class Colour:
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return cls(0xFEE75C)
|
||||
|
||||
|
||||
@classmethod
|
||||
def dark_blurple(cls: Type[CT]) -> CT:
|
||||
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
|
||||
|
@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from .enums import try_enum, ComponentType, ButtonStyle
|
||||
from .utils import get_slots, MISSING
|
||||
from .partial_emoji import PartialEmoji, _EmojiTag
|
||||
@ -41,14 +52,14 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Component',
|
||||
'ActionRow',
|
||||
'Button',
|
||||
'SelectMenu',
|
||||
'SelectOption',
|
||||
"Component",
|
||||
"ActionRow",
|
||||
"Button",
|
||||
"SelectMenu",
|
||||
"SelectOption",
|
||||
)
|
||||
|
||||
C = TypeVar('C', bound='Component')
|
||||
C = TypeVar("C", bound="Component")
|
||||
|
||||
|
||||
class Component:
|
||||
@ -70,14 +81,14 @@ class Component:
|
||||
The type of component.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = ('type',)
|
||||
__slots__: Tuple[str, ...] = ("type",)
|
||||
|
||||
__repr_info__: ClassVar[Tuple[str, ...]]
|
||||
type: ComponentType
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
|
||||
return f'<{self.__class__.__name__} {attrs}>'
|
||||
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__)
|
||||
return f"<{self.__class__.__name__} {attrs}>"
|
||||
|
||||
@classmethod
|
||||
def _raw_construct(cls: Type[C], **kwargs) -> C:
|
||||
@ -112,18 +123,20 @@ class ActionRow(Component):
|
||||
The children components that this holds, if any.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = ('children',)
|
||||
__slots__: Tuple[str, ...] = ("children",)
|
||||
|
||||
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
|
||||
|
||||
def __init__(self, data: ComponentPayload):
|
||||
self.type: ComponentType = try_enum(ComponentType, data['type'])
|
||||
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
|
||||
self.type: ComponentType = try_enum(ComponentType, data["type"])
|
||||
self.children: List[Component] = [
|
||||
_component_factory(d) for d in data.get("components", [])
|
||||
]
|
||||
|
||||
def to_dict(self) -> ActionRowPayload:
|
||||
return {
|
||||
'type': int(self.type),
|
||||
'components': [child.to_dict() for child in self.children],
|
||||
"type": int(self.type),
|
||||
"components": [child.to_dict() for child in self.children],
|
||||
} # type: ignore
|
||||
|
||||
|
||||
@ -157,44 +170,44 @@ class Button(Component):
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'style',
|
||||
'custom_id',
|
||||
'url',
|
||||
'disabled',
|
||||
'label',
|
||||
'emoji',
|
||||
"style",
|
||||
"custom_id",
|
||||
"url",
|
||||
"disabled",
|
||||
"label",
|
||||
"emoji",
|
||||
)
|
||||
|
||||
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
|
||||
|
||||
def __init__(self, data: ButtonComponentPayload):
|
||||
self.type: ComponentType = try_enum(ComponentType, data['type'])
|
||||
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
|
||||
self.custom_id: Optional[str] = data.get('custom_id')
|
||||
self.url: Optional[str] = data.get('url')
|
||||
self.disabled: bool = data.get('disabled', False)
|
||||
self.label: Optional[str] = data.get('label')
|
||||
self.type: ComponentType = try_enum(ComponentType, data["type"])
|
||||
self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
|
||||
self.custom_id: Optional[str] = data.get("custom_id")
|
||||
self.url: Optional[str] = data.get("url")
|
||||
self.disabled: bool = data.get("disabled", False)
|
||||
self.label: Optional[str] = data.get("label")
|
||||
self.emoji: Optional[PartialEmoji]
|
||||
try:
|
||||
self.emoji = PartialEmoji.from_dict(data['emoji'])
|
||||
self.emoji = PartialEmoji.from_dict(data["emoji"])
|
||||
except KeyError:
|
||||
self.emoji = None
|
||||
|
||||
def to_dict(self) -> ButtonComponentPayload:
|
||||
payload = {
|
||||
'type': 2,
|
||||
'style': int(self.style),
|
||||
'label': self.label,
|
||||
'disabled': self.disabled,
|
||||
"type": 2,
|
||||
"style": int(self.style),
|
||||
"label": self.label,
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
if self.custom_id:
|
||||
payload['custom_id'] = self.custom_id
|
||||
payload["custom_id"] = self.custom_id
|
||||
|
||||
if self.url:
|
||||
payload['url'] = self.url
|
||||
payload["url"] = self.url
|
||||
|
||||
if self.emoji:
|
||||
payload['emoji'] = self.emoji.to_dict()
|
||||
payload["emoji"] = self.emoji.to_dict()
|
||||
|
||||
return payload # type: ignore
|
||||
|
||||
@ -231,37 +244,39 @@ class SelectMenu(Component):
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'custom_id',
|
||||
'placeholder',
|
||||
'min_values',
|
||||
'max_values',
|
||||
'options',
|
||||
'disabled',
|
||||
"custom_id",
|
||||
"placeholder",
|
||||
"min_values",
|
||||
"max_values",
|
||||
"options",
|
||||
"disabled",
|
||||
)
|
||||
|
||||
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
|
||||
|
||||
def __init__(self, data: SelectMenuPayload):
|
||||
self.type = ComponentType.select
|
||||
self.custom_id: str = data['custom_id']
|
||||
self.placeholder: Optional[str] = data.get('placeholder')
|
||||
self.min_values: int = data.get('min_values', 1)
|
||||
self.max_values: int = data.get('max_values', 1)
|
||||
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
|
||||
self.disabled: bool = data.get('disabled', False)
|
||||
self.custom_id: str = data["custom_id"]
|
||||
self.placeholder: Optional[str] = data.get("placeholder")
|
||||
self.min_values: int = data.get("min_values", 1)
|
||||
self.max_values: int = data.get("max_values", 1)
|
||||
self.options: List[SelectOption] = [
|
||||
SelectOption.from_dict(option) for option in data.get("options", [])
|
||||
]
|
||||
self.disabled: bool = data.get("disabled", False)
|
||||
|
||||
def to_dict(self) -> SelectMenuPayload:
|
||||
payload: SelectMenuPayload = {
|
||||
'type': self.type.value,
|
||||
'custom_id': self.custom_id,
|
||||
'min_values': self.min_values,
|
||||
'max_values': self.max_values,
|
||||
'options': [op.to_dict() for op in self.options],
|
||||
'disabled': self.disabled,
|
||||
"type": self.type.value,
|
||||
"custom_id": self.custom_id,
|
||||
"min_values": self.min_values,
|
||||
"max_values": self.max_values,
|
||||
"options": [op.to_dict() for op in self.options],
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
|
||||
if self.placeholder:
|
||||
payload['placeholder'] = self.placeholder
|
||||
payload["placeholder"] = self.placeholder
|
||||
|
||||
return payload
|
||||
|
||||
@ -292,11 +307,11 @@ class SelectOption:
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'label',
|
||||
'value',
|
||||
'description',
|
||||
'emoji',
|
||||
'default',
|
||||
"label",
|
||||
"value",
|
||||
"description",
|
||||
"emoji",
|
||||
"default",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -318,60 +333,62 @@ class SelectOption:
|
||||
elif isinstance(emoji, _EmojiTag):
|
||||
emoji = emoji._to_partial()
|
||||
else:
|
||||
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
|
||||
raise TypeError(
|
||||
f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}"
|
||||
)
|
||||
|
||||
self.emoji = emoji
|
||||
self.default = default
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} '
|
||||
f'emoji={self.emoji!r} default={self.default!r}>'
|
||||
f"<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} "
|
||||
f"emoji={self.emoji!r} default={self.default!r}>"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.emoji:
|
||||
base = f'{self.emoji} {self.label}'
|
||||
base = f"{self.emoji} {self.label}"
|
||||
else:
|
||||
base = self.label
|
||||
|
||||
if self.description:
|
||||
return f'{base}\n{self.description}'
|
||||
return f"{base}\n{self.description}"
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
|
||||
try:
|
||||
emoji = PartialEmoji.from_dict(data['emoji'])
|
||||
emoji = PartialEmoji.from_dict(data["emoji"])
|
||||
except KeyError:
|
||||
emoji = None
|
||||
|
||||
return cls(
|
||||
label=data['label'],
|
||||
value=data['value'],
|
||||
description=data.get('description'),
|
||||
label=data["label"],
|
||||
value=data["value"],
|
||||
description=data.get("description"),
|
||||
emoji=emoji,
|
||||
default=data.get('default', False),
|
||||
default=data.get("default", False),
|
||||
)
|
||||
|
||||
def to_dict(self) -> SelectOptionPayload:
|
||||
payload: SelectOptionPayload = {
|
||||
'label': self.label,
|
||||
'value': self.value,
|
||||
'default': self.default,
|
||||
"label": self.label,
|
||||
"value": self.value,
|
||||
"default": self.default,
|
||||
}
|
||||
|
||||
if self.emoji:
|
||||
payload['emoji'] = self.emoji.to_dict() # type: ignore
|
||||
payload["emoji"] = self.emoji.to_dict() # type: ignore
|
||||
|
||||
if self.description:
|
||||
payload['description'] = self.description
|
||||
payload["description"] = self.description
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _component_factory(data: ComponentPayload) -> Component:
|
||||
component_type = data['type']
|
||||
component_type = data["type"]
|
||||
if component_type == 1:
|
||||
return ActionRow(data)
|
||||
elif component_type == 2:
|
||||
|
@ -32,11 +32,10 @@ if TYPE_CHECKING:
|
||||
|
||||
from types import TracebackType
|
||||
|
||||
TypingT = TypeVar('TypingT', bound='Typing')
|
||||
TypingT = TypeVar("TypingT", bound="Typing")
|
||||
|
||||
__all__ = ("Typing",)
|
||||
|
||||
__all__ = (
|
||||
'Typing',
|
||||
)
|
||||
|
||||
def _typing_done_callback(fut: asyncio.Future) -> None:
|
||||
# just retrieve any exception and call it a day
|
||||
@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Typing:
|
||||
def __init__(self, messageable: Messageable) -> None:
|
||||
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
|
||||
@ -67,7 +67,8 @@ class Typing:
|
||||
self.task.add_done_callback(_typing_done_callback)
|
||||
return self
|
||||
|
||||
def __exit__(self,
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
@ -79,7 +80,8 @@ class Typing:
|
||||
await channel._state.http.send_typing(channel.id)
|
||||
return self.__enter__()
|
||||
|
||||
async def __aexit__(self,
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
|
@ -25,14 +25,23 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Final,
|
||||
List,
|
||||
Mapping,
|
||||
Protocol,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from . import utils
|
||||
from .colour import Colour
|
||||
|
||||
__all__ = (
|
||||
'Embed',
|
||||
)
|
||||
__all__ = ("Embed",)
|
||||
|
||||
|
||||
class _EmptyEmbed:
|
||||
@ -40,7 +49,7 @@ class _EmptyEmbed:
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'Embed.Empty'
|
||||
return "Embed.Empty"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 0
|
||||
@ -57,51 +66,47 @@ class EmbedProxy:
|
||||
return len(self.__dict__)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
|
||||
return f'EmbedProxy({inner})'
|
||||
inner = ", ".join(
|
||||
(f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_"))
|
||||
)
|
||||
return f"EmbedProxy({inner})"
|
||||
|
||||
def __getattr__(self, attr: str) -> _EmptyEmbed:
|
||||
return EmptyEmbed
|
||||
|
||||
|
||||
E = TypeVar('E', bound='Embed')
|
||||
E = TypeVar("E", bound="Embed")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.types.embed import Embed as EmbedData, EmbedType
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
MaybeEmpty = Union[T, _EmptyEmbed]
|
||||
|
||||
|
||||
class _EmbedFooterProxy(Protocol):
|
||||
text: MaybeEmpty[str]
|
||||
icon_url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedFieldProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
value: MaybeEmpty[str]
|
||||
inline: bool
|
||||
|
||||
|
||||
class _EmbedMediaProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
proxy_url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedVideoProxy(Protocol):
|
||||
url: MaybeEmpty[str]
|
||||
height: MaybeEmpty[int]
|
||||
width: MaybeEmpty[int]
|
||||
|
||||
|
||||
class _EmbedProviderProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
|
||||
|
||||
class _EmbedAuthorProxy(Protocol):
|
||||
name: MaybeEmpty[str]
|
||||
url: MaybeEmpty[str]
|
||||
@ -163,33 +168,33 @@ class Embed:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'title',
|
||||
'url',
|
||||
'type',
|
||||
'_timestamp',
|
||||
'_colour',
|
||||
'_footer',
|
||||
'_image',
|
||||
'_thumbnail',
|
||||
'_video',
|
||||
'_provider',
|
||||
'_author',
|
||||
'_fields',
|
||||
'description',
|
||||
"title",
|
||||
"url",
|
||||
"type",
|
||||
"_timestamp",
|
||||
"_colour",
|
||||
"_footer",
|
||||
"_image",
|
||||
"_thumbnail",
|
||||
"_video",
|
||||
"_provider",
|
||||
"_author",
|
||||
"_fields",
|
||||
"description",
|
||||
)
|
||||
|
||||
Empty: Final = EmptyEmbed
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = 'rich',
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
self,
|
||||
*,
|
||||
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
|
||||
title: MaybeEmpty[Any] = EmptyEmbed,
|
||||
type: EmbedType = "rich",
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
description: MaybeEmpty[Any] = EmptyEmbed,
|
||||
timestamp: datetime.datetime = None,
|
||||
):
|
||||
|
||||
self.colour = colour if colour is not EmptyEmbed else color
|
||||
@ -231,10 +236,10 @@ class Embed:
|
||||
|
||||
# fill in the basic fields
|
||||
|
||||
self.title = data.get('title', EmptyEmbed)
|
||||
self.type = data.get('type', EmptyEmbed)
|
||||
self.description = data.get('description', EmptyEmbed)
|
||||
self.url = data.get('url', EmptyEmbed)
|
||||
self.title = data.get("title", EmptyEmbed)
|
||||
self.type = data.get("type", EmptyEmbed)
|
||||
self.description = data.get("description", EmptyEmbed)
|
||||
self.url = data.get("url", EmptyEmbed)
|
||||
|
||||
if self.title is not EmptyEmbed:
|
||||
self.title = str(self.title)
|
||||
@ -248,22 +253,30 @@ class Embed:
|
||||
# try to fill in the more rich fields
|
||||
|
||||
try:
|
||||
self._colour = Colour(value=data['color'])
|
||||
self._colour = Colour(value=data["color"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._timestamp = utils.parse_time(data['timestamp'])
|
||||
self._timestamp = utils.parse_time(data["timestamp"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'):
|
||||
for attr in (
|
||||
"thumbnail",
|
||||
"video",
|
||||
"provider",
|
||||
"author",
|
||||
"fields",
|
||||
"image",
|
||||
"footer",
|
||||
):
|
||||
try:
|
||||
value = data[attr]
|
||||
except KeyError:
|
||||
continue
|
||||
else:
|
||||
setattr(self, '_' + attr, value)
|
||||
setattr(self, "_" + attr, value)
|
||||
|
||||
return self
|
||||
|
||||
@ -273,11 +286,11 @@ class Embed:
|
||||
|
||||
def __len__(self) -> int:
|
||||
total = len(self.title) + len(self.description)
|
||||
for field in getattr(self, '_fields', []):
|
||||
total += len(field['name']) + len(field['value'])
|
||||
for field in getattr(self, "_fields", []):
|
||||
total += len(field["name"]) + len(field["value"])
|
||||
|
||||
try:
|
||||
footer_text = self._footer['text']
|
||||
footer_text = self._footer["text"]
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
@ -288,7 +301,7 @@ class Embed:
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
total += len(author['name'])
|
||||
total += len(author["name"])
|
||||
|
||||
return total
|
||||
|
||||
@ -312,7 +325,7 @@ class Embed:
|
||||
|
||||
@property
|
||||
def colour(self) -> MaybeEmpty[Colour]:
|
||||
return getattr(self, '_colour', EmptyEmbed)
|
||||
return getattr(self, "_colour", EmptyEmbed)
|
||||
|
||||
@colour.setter
|
||||
def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore
|
||||
@ -321,13 +334,15 @@ class Embed:
|
||||
elif isinstance(value, int):
|
||||
self._colour = Colour(value=value)
|
||||
else:
|
||||
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.')
|
||||
raise TypeError(
|
||||
f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead."
|
||||
)
|
||||
|
||||
color = colour
|
||||
|
||||
@property
|
||||
def timestamp(self) -> MaybeEmpty[datetime.datetime]:
|
||||
return getattr(self, '_timestamp', EmptyEmbed)
|
||||
return getattr(self, "_timestamp", EmptyEmbed)
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
|
||||
@ -338,7 +353,9 @@ class Embed:
|
||||
elif isinstance(value, _EmptyEmbed):
|
||||
self._timestamp = value
|
||||
else:
|
||||
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead")
|
||||
raise TypeError(
|
||||
f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead"
|
||||
)
|
||||
|
||||
@property
|
||||
def footer(self) -> _EmbedFooterProxy:
|
||||
@ -348,9 +365,14 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_footer", {})) # type: ignore
|
||||
|
||||
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
|
||||
def set_footer(
|
||||
self: E,
|
||||
*,
|
||||
text: MaybeEmpty[Any] = EmptyEmbed,
|
||||
icon_url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
) -> E:
|
||||
"""Sets the footer for the embed content.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@ -366,10 +388,10 @@ class Embed:
|
||||
|
||||
self._footer = {}
|
||||
if text is not EmptyEmbed:
|
||||
self._footer['text'] = str(text)
|
||||
self._footer["text"] = str(text)
|
||||
|
||||
if icon_url is not EmptyEmbed:
|
||||
self._footer['icon_url'] = str(icon_url)
|
||||
self._footer["icon_url"] = str(icon_url)
|
||||
|
||||
return self
|
||||
|
||||
@ -401,12 +423,12 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_image", {})) # type: ignore
|
||||
|
||||
@image.setter
|
||||
def image(self: E, *, url: Any):
|
||||
self._image = {
|
||||
'url': str(url),
|
||||
"url": str(url),
|
||||
}
|
||||
|
||||
@image.deleter
|
||||
@ -451,15 +473,14 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore
|
||||
|
||||
@thumbnail.setter
|
||||
def thumbnail(self: E, *, url: Any):
|
||||
"""Sets the thumbnail for the embed content.
|
||||
"""
|
||||
"""Sets the thumbnail for the embed content."""
|
||||
|
||||
self._thumbnail = {
|
||||
'url': str(url),
|
||||
"url": str(url),
|
||||
}
|
||||
|
||||
return
|
||||
@ -504,7 +525,7 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_video", {})) # type: ignore
|
||||
|
||||
@property
|
||||
def provider(self) -> _EmbedProviderProxy:
|
||||
@ -514,7 +535,7 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_provider", {})) # type: ignore
|
||||
|
||||
@property
|
||||
def author(self) -> _EmbedAuthorProxy:
|
||||
@ -524,9 +545,15 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return EmbedProxy(getattr(self, '_author', {})) # type: ignore
|
||||
return EmbedProxy(getattr(self, "_author", {})) # type: ignore
|
||||
|
||||
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
|
||||
def set_author(
|
||||
self: E,
|
||||
*,
|
||||
name: Any,
|
||||
url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
icon_url: MaybeEmpty[Any] = EmptyEmbed,
|
||||
) -> E:
|
||||
"""Sets the author for the embed content.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@ -543,14 +570,14 @@ class Embed:
|
||||
"""
|
||||
|
||||
self._author = {
|
||||
'name': str(name),
|
||||
"name": str(name),
|
||||
}
|
||||
|
||||
if url is not EmptyEmbed:
|
||||
self._author['url'] = str(url)
|
||||
self._author["url"] = str(url)
|
||||
|
||||
if icon_url is not EmptyEmbed:
|
||||
self._author['icon_url'] = str(icon_url)
|
||||
self._author["icon_url"] = str(icon_url)
|
||||
|
||||
return self
|
||||
|
||||
@ -577,7 +604,7 @@ class Embed:
|
||||
|
||||
If the attribute has no value then :attr:`Empty` is returned.
|
||||
"""
|
||||
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore
|
||||
return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore
|
||||
|
||||
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E:
|
||||
"""Adds a field to the embed object.
|
||||
@ -596,9 +623,9 @@ class Embed:
|
||||
"""
|
||||
|
||||
field = {
|
||||
'inline': inline,
|
||||
'name': str(name),
|
||||
'value': str(value),
|
||||
"inline": inline,
|
||||
"name": str(name),
|
||||
"value": str(value),
|
||||
}
|
||||
|
||||
try:
|
||||
@ -608,7 +635,9 @@ class Embed:
|
||||
|
||||
return self
|
||||
|
||||
def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E:
|
||||
def insert_field_at(
|
||||
self: E, index: int, *, name: Any, value: Any, inline: bool = True
|
||||
) -> E:
|
||||
"""Inserts a field before a specified index to the embed.
|
||||
|
||||
This function returns the class instance to allow for fluent-style
|
||||
@ -629,9 +658,9 @@ class Embed:
|
||||
"""
|
||||
|
||||
field = {
|
||||
'inline': inline,
|
||||
'name': str(name),
|
||||
'value': str(value),
|
||||
"inline": inline,
|
||||
"name": str(name),
|
||||
"value": str(value),
|
||||
}
|
||||
|
||||
try:
|
||||
@ -669,7 +698,9 @@ class Embed:
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
|
||||
def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E:
|
||||
def set_field_at(
|
||||
self: E, index: int, *, name: Any, value: Any, inline: bool = True
|
||||
) -> E:
|
||||
"""Modifies a field to the embed object.
|
||||
|
||||
The index must point to a valid pre-existing field.
|
||||
@ -697,11 +728,11 @@ class Embed:
|
||||
try:
|
||||
field = self._fields[index]
|
||||
except (TypeError, IndexError, AttributeError):
|
||||
raise IndexError('field index out of range')
|
||||
raise IndexError("field index out of range")
|
||||
|
||||
field['name'] = str(name)
|
||||
field['value'] = str(value)
|
||||
field['inline'] = inline
|
||||
field["name"] = str(name)
|
||||
field["value"] = str(value)
|
||||
field["inline"] = inline
|
||||
return self
|
||||
|
||||
def to_dict(self) -> EmbedData:
|
||||
@ -719,35 +750,39 @@ class Embed:
|
||||
# deal with basic convenience wrappers
|
||||
|
||||
try:
|
||||
colour = result.pop('colour')
|
||||
colour = result.pop("colour")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if colour:
|
||||
result['color'] = colour.value
|
||||
result["color"] = colour.value
|
||||
|
||||
try:
|
||||
timestamp = result.pop('timestamp')
|
||||
timestamp = result.pop("timestamp")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if timestamp:
|
||||
if timestamp.tzinfo:
|
||||
result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
|
||||
result["timestamp"] = timestamp.astimezone(
|
||||
tz=datetime.timezone.utc
|
||||
).isoformat()
|
||||
else:
|
||||
result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
|
||||
result["timestamp"] = timestamp.replace(
|
||||
tzinfo=datetime.timezone.utc
|
||||
).isoformat()
|
||||
|
||||
# add in the non raw attribute ones
|
||||
if self.type:
|
||||
result['type'] = self.type
|
||||
result["type"] = self.type
|
||||
|
||||
if self.description:
|
||||
result['description'] = self.description
|
||||
result["description"] = self.description
|
||||
|
||||
if self.url:
|
||||
result['url'] = self.url
|
||||
result["url"] = self.url
|
||||
|
||||
if self.title:
|
||||
result['title'] = self.title
|
||||
result["title"] = self.title
|
||||
|
||||
return result # type: ignore
|
||||
|
@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING
|
||||
from .partial_emoji import _EmojiTag, PartialEmoji
|
||||
from .user import User
|
||||
|
||||
__all__ = (
|
||||
'Emoji',
|
||||
)
|
||||
__all__ = ("Emoji",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.emoji import Emoji as EmojiPayload
|
||||
@ -98,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'require_colons',
|
||||
'animated',
|
||||
'managed',
|
||||
'id',
|
||||
'name',
|
||||
'_roles',
|
||||
'guild_id',
|
||||
'_state',
|
||||
'user',
|
||||
'available',
|
||||
"require_colons",
|
||||
"animated",
|
||||
"managed",
|
||||
"id",
|
||||
"name",
|
||||
"_roles",
|
||||
"guild_id",
|
||||
"_state",
|
||||
"user",
|
||||
"available",
|
||||
)
|
||||
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
|
||||
@ -116,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, emoji: EmojiPayload):
|
||||
self.require_colons: bool = emoji.get('require_colons', False)
|
||||
self.managed: bool = emoji.get('managed', False)
|
||||
self.id: int = int(emoji['id']) # type: ignore
|
||||
self.name: str = emoji['name'] # type: ignore
|
||||
self.animated: bool = emoji.get('animated', False)
|
||||
self.available: bool = emoji.get('available', True)
|
||||
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', [])))
|
||||
user = emoji.get('user')
|
||||
self.require_colons: bool = emoji.get("require_colons", False)
|
||||
self.managed: bool = emoji.get("managed", False)
|
||||
self.id: int = int(emoji["id"]) # type: ignore
|
||||
self.name: str = emoji["name"] # type: ignore
|
||||
self.animated: bool = emoji.get("animated", False)
|
||||
self.available: bool = emoji.get("available", True)
|
||||
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", [])))
|
||||
user = emoji.get("user")
|
||||
self.user: Optional[User] = User(state=self._state, data=user) if user else None
|
||||
|
||||
def _to_partial(self) -> PartialEmoji:
|
||||
@ -131,21 +129,21 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, Any]]:
|
||||
for attr in self.__slots__:
|
||||
if attr[0] != '_':
|
||||
if attr[0] != "_":
|
||||
value = getattr(self, attr, None)
|
||||
if value is not None:
|
||||
yield (attr, value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.animated:
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
return f"<a:{self.name}:{self.id}>"
|
||||
return f"<:{self.name}:{self.id}>"
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
|
||||
return f"<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, _EmojiTag) and self.id == other.id
|
||||
@ -164,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
@property
|
||||
def url(self) -> str:
|
||||
""":class:`str`: Returns the URL of the emoji."""
|
||||
fmt = 'gif' if self.animated else 'png'
|
||||
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
|
||||
fmt = "gif" if self.animated else "png"
|
||||
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
|
||||
|
||||
@property
|
||||
def roles(self) -> List[Role]:
|
||||
@ -217,9 +215,17 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
An error occurred deleting the emoji.
|
||||
"""
|
||||
|
||||
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
|
||||
await self._state.http.delete_custom_emoji(
|
||||
self.guild.id, self.id, reason=reason
|
||||
)
|
||||
|
||||
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
name: str = MISSING,
|
||||
roles: List[Snowflake] = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> Emoji:
|
||||
r"""|coro|
|
||||
|
||||
Edits the custom emoji.
|
||||
@ -254,9 +260,11 @@ class Emoji(_EmojiTag, AssetMixin):
|
||||
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
payload["name"] = name
|
||||
if roles is not MISSING:
|
||||
payload['roles'] = [role.id for role in roles]
|
||||
payload["roles"] = [role.id for role in roles]
|
||||
|
||||
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
|
||||
data = await self._state.http.edit_custom_emoji(
|
||||
self.guild.id, self.id, payload=payload, reason=reason
|
||||
)
|
||||
return Emoji(guild=self.guild, data=data, state=self._state)
|
||||
|
197
discord/enums.py
197
discord/enums.py
@ -27,50 +27,65 @@ from collections import namedtuple
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
__all__ = (
|
||||
'Enum',
|
||||
'ChannelType',
|
||||
'MessageType',
|
||||
'VoiceRegion',
|
||||
'SpeakingState',
|
||||
'VerificationLevel',
|
||||
'ContentFilter',
|
||||
'Status',
|
||||
'DefaultAvatar',
|
||||
'AuditLogAction',
|
||||
'AuditLogActionCategory',
|
||||
'UserFlags',
|
||||
'ActivityType',
|
||||
'NotificationLevel',
|
||||
'TeamMembershipState',
|
||||
'WebhookType',
|
||||
'ExpireBehaviour',
|
||||
'ExpireBehavior',
|
||||
'StickerType',
|
||||
'StickerFormatType',
|
||||
'InviteTarget',
|
||||
'VideoQualityMode',
|
||||
'ComponentType',
|
||||
'ButtonStyle',
|
||||
'StagePrivacyLevel',
|
||||
'InteractionType',
|
||||
'InteractionResponseType',
|
||||
'NSFWLevel',
|
||||
"Enum",
|
||||
"ChannelType",
|
||||
"MessageType",
|
||||
"VoiceRegion",
|
||||
"SpeakingState",
|
||||
"VerificationLevel",
|
||||
"ContentFilter",
|
||||
"Status",
|
||||
"DefaultAvatar",
|
||||
"AuditLogAction",
|
||||
"AuditLogActionCategory",
|
||||
"UserFlags",
|
||||
"ActivityType",
|
||||
"NotificationLevel",
|
||||
"TeamMembershipState",
|
||||
"WebhookType",
|
||||
"ExpireBehaviour",
|
||||
"ExpireBehavior",
|
||||
"StickerType",
|
||||
"StickerFormatType",
|
||||
"InviteTarget",
|
||||
"VideoQualityMode",
|
||||
"ComponentType",
|
||||
"ButtonStyle",
|
||||
"StagePrivacyLevel",
|
||||
"InteractionType",
|
||||
"InteractionResponseType",
|
||||
"NSFWLevel",
|
||||
)
|
||||
|
||||
|
||||
def _create_value_cls(name, comparable):
|
||||
cls = namedtuple('_EnumValue_' + name, 'name value')
|
||||
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
|
||||
cls.__str__ = lambda self: f'{name}.{self.name}'
|
||||
cls = namedtuple("_EnumValue_" + name, "name value")
|
||||
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
|
||||
cls.__str__ = lambda self: f"{name}.{self.name}"
|
||||
if comparable:
|
||||
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
|
||||
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
|
||||
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
|
||||
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
|
||||
cls.__le__ = (
|
||||
lambda self, other: isinstance(other, self.__class__)
|
||||
and self.value <= other.value
|
||||
)
|
||||
cls.__ge__ = (
|
||||
lambda self, other: isinstance(other, self.__class__)
|
||||
and self.value >= other.value
|
||||
)
|
||||
cls.__lt__ = (
|
||||
lambda self, other: isinstance(other, self.__class__)
|
||||
and self.value < other.value
|
||||
)
|
||||
cls.__gt__ = (
|
||||
lambda self, other: isinstance(other, self.__class__)
|
||||
and self.value > other.value
|
||||
)
|
||||
return cls
|
||||
|
||||
|
||||
def _is_descriptor(obj):
|
||||
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
|
||||
return (
|
||||
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
|
||||
)
|
||||
|
||||
|
||||
class EnumMeta(type):
|
||||
@ -88,7 +103,7 @@ class EnumMeta(type):
|
||||
value_cls = _create_value_cls(name, comparable)
|
||||
for key, value in list(attrs.items()):
|
||||
is_descriptor = _is_descriptor(value)
|
||||
if key[0] == '_' and not is_descriptor:
|
||||
if key[0] == "_" and not is_descriptor:
|
||||
continue
|
||||
|
||||
# Special case classmethod to just pass through
|
||||
@ -110,10 +125,10 @@ class EnumMeta(type):
|
||||
member_mapping[key] = new_value
|
||||
attrs[key] = new_value
|
||||
|
||||
attrs['_enum_value_map_'] = value_mapping
|
||||
attrs['_enum_member_map_'] = member_mapping
|
||||
attrs['_enum_member_names_'] = member_names
|
||||
attrs['_enum_value_cls_'] = value_cls
|
||||
attrs["_enum_value_map_"] = value_mapping
|
||||
attrs["_enum_member_map_"] = member_mapping
|
||||
attrs["_enum_member_names_"] = member_names
|
||||
attrs["_enum_value_cls_"] = value_cls
|
||||
actual_cls = super().__new__(cls, name, bases, attrs)
|
||||
value_cls._actual_enum_cls_ = actual_cls # type: ignore
|
||||
return actual_cls
|
||||
@ -122,13 +137,15 @@ class EnumMeta(type):
|
||||
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
|
||||
|
||||
def __reversed__(cls):
|
||||
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
|
||||
return (
|
||||
cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)
|
||||
)
|
||||
|
||||
def __len__(cls):
|
||||
return len(cls._enum_member_names_)
|
||||
|
||||
def __repr__(cls):
|
||||
return f'<enum {cls.__name__}>'
|
||||
return f"<enum {cls.__name__}>"
|
||||
|
||||
@property
|
||||
def __members__(cls):
|
||||
@ -144,10 +161,10 @@ class EnumMeta(type):
|
||||
return cls._enum_member_map_[key]
|
||||
|
||||
def __setattr__(cls, name, value):
|
||||
raise TypeError('Enums are immutable.')
|
||||
raise TypeError("Enums are immutable.")
|
||||
|
||||
def __delattr__(cls, attr):
|
||||
raise TypeError('Enums are immutable')
|
||||
raise TypeError("Enums are immutable")
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
# isinstance(x, Y)
|
||||
@ -215,29 +232,29 @@ class MessageType(Enum):
|
||||
|
||||
|
||||
class VoiceRegion(Enum):
|
||||
us_west = 'us-west'
|
||||
us_east = 'us-east'
|
||||
us_south = 'us-south'
|
||||
us_central = 'us-central'
|
||||
eu_west = 'eu-west'
|
||||
eu_central = 'eu-central'
|
||||
singapore = 'singapore'
|
||||
london = 'london'
|
||||
sydney = 'sydney'
|
||||
amsterdam = 'amsterdam'
|
||||
frankfurt = 'frankfurt'
|
||||
brazil = 'brazil'
|
||||
hongkong = 'hongkong'
|
||||
russia = 'russia'
|
||||
japan = 'japan'
|
||||
southafrica = 'southafrica'
|
||||
south_korea = 'south-korea'
|
||||
india = 'india'
|
||||
europe = 'europe'
|
||||
dubai = 'dubai'
|
||||
vip_us_east = 'vip-us-east'
|
||||
vip_us_west = 'vip-us-west'
|
||||
vip_amsterdam = 'vip-amsterdam'
|
||||
us_west = "us-west"
|
||||
us_east = "us-east"
|
||||
us_south = "us-south"
|
||||
us_central = "us-central"
|
||||
eu_west = "eu-west"
|
||||
eu_central = "eu-central"
|
||||
singapore = "singapore"
|
||||
london = "london"
|
||||
sydney = "sydney"
|
||||
amsterdam = "amsterdam"
|
||||
frankfurt = "frankfurt"
|
||||
brazil = "brazil"
|
||||
hongkong = "hongkong"
|
||||
russia = "russia"
|
||||
japan = "japan"
|
||||
southafrica = "southafrica"
|
||||
south_korea = "south-korea"
|
||||
india = "india"
|
||||
europe = "europe"
|
||||
dubai = "dubai"
|
||||
vip_us_east = "vip-us-east"
|
||||
vip_us_west = "vip-us-west"
|
||||
vip_amsterdam = "vip-amsterdam"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
@ -277,12 +294,12 @@ class ContentFilter(Enum, comparable=True):
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
online = 'online'
|
||||
offline = 'offline'
|
||||
idle = 'idle'
|
||||
dnd = 'dnd'
|
||||
do_not_disturb = 'dnd'
|
||||
invisible = 'invisible'
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
idle = "idle"
|
||||
dnd = "dnd"
|
||||
do_not_disturb = "dnd"
|
||||
invisible = "invisible"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
@ -415,33 +432,33 @@ class AuditLogAction(Enum):
|
||||
def target_type(self) -> Optional[str]:
|
||||
v = self.value
|
||||
if v == -1:
|
||||
return 'all'
|
||||
return "all"
|
||||
elif v < 10:
|
||||
return 'guild'
|
||||
return "guild"
|
||||
elif v < 20:
|
||||
return 'channel'
|
||||
return "channel"
|
||||
elif v < 30:
|
||||
return 'user'
|
||||
return "user"
|
||||
elif v < 40:
|
||||
return 'role'
|
||||
return "role"
|
||||
elif v < 50:
|
||||
return 'invite'
|
||||
return "invite"
|
||||
elif v < 60:
|
||||
return 'webhook'
|
||||
return "webhook"
|
||||
elif v < 70:
|
||||
return 'emoji'
|
||||
return "emoji"
|
||||
elif v == 73:
|
||||
return 'channel'
|
||||
return "channel"
|
||||
elif v < 80:
|
||||
return 'message'
|
||||
return "message"
|
||||
elif v < 83:
|
||||
return 'integration'
|
||||
return "integration"
|
||||
elif v < 90:
|
||||
return 'stage_instance'
|
||||
return "stage_instance"
|
||||
elif v < 93:
|
||||
return 'sticker'
|
||||
return "sticker"
|
||||
elif v < 113:
|
||||
return 'thread'
|
||||
return "thread"
|
||||
|
||||
|
||||
class UserFlags(Enum):
|
||||
@ -589,12 +606,12 @@ class NSFWLevel(Enum, comparable=True):
|
||||
age_restricted = 3
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def create_unknown_value(cls: Type[T], val: Any) -> T:
|
||||
value_cls = cls._enum_value_cls_ # type: ignore
|
||||
name = f'unknown_{val}'
|
||||
name = f"unknown_{val}"
|
||||
return value_cls(name=name, value=val)
|
||||
|
||||
|
||||
|
@ -38,20 +38,20 @@ if TYPE_CHECKING:
|
||||
from .interactions import Interaction
|
||||
|
||||
__all__ = (
|
||||
'DiscordException',
|
||||
'ClientException',
|
||||
'NoMoreItems',
|
||||
'GatewayNotFound',
|
||||
'HTTPException',
|
||||
'Forbidden',
|
||||
'NotFound',
|
||||
'DiscordServerError',
|
||||
'InvalidData',
|
||||
'InvalidArgument',
|
||||
'LoginFailure',
|
||||
'ConnectionClosed',
|
||||
'PrivilegedIntentsRequired',
|
||||
'InteractionResponded',
|
||||
"DiscordException",
|
||||
"ClientException",
|
||||
"NoMoreItems",
|
||||
"GatewayNotFound",
|
||||
"HTTPException",
|
||||
"Forbidden",
|
||||
"NotFound",
|
||||
"DiscordServerError",
|
||||
"InvalidData",
|
||||
"InvalidArgument",
|
||||
"LoginFailure",
|
||||
"ConnectionClosed",
|
||||
"PrivilegedIntentsRequired",
|
||||
"InteractionResponded",
|
||||
)
|
||||
|
||||
|
||||
@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException):
|
||||
"""An exception that is raised when the gateway for Discord could not be found"""
|
||||
|
||||
def __init__(self):
|
||||
message = 'The gateway to connect to discord was not found.'
|
||||
message = "The gateway to connect to discord was not found."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]:
|
||||
def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]:
|
||||
items: List[Tuple[str, str]] = []
|
||||
for k, v in d.items():
|
||||
new_key = key + '.' + k if key else k
|
||||
new_key = key + "." + k if key else k
|
||||
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
_errors: List[Dict[str, Any]] = v['_errors']
|
||||
_errors: List[Dict[str, Any]] = v["_errors"]
|
||||
except KeyError:
|
||||
items.extend(_flatten_error_dict(v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, ' '.join(x.get('message', '') for x in _errors)))
|
||||
items.append((new_key, " ".join(x.get("message", "") for x in _errors)))
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
|
||||
@ -123,28 +123,30 @@ class HTTPException(DiscordException):
|
||||
The Discord specific error code for the failure.
|
||||
"""
|
||||
|
||||
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
|
||||
def __init__(
|
||||
self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]
|
||||
):
|
||||
self.response: _ResponseType = response
|
||||
self.status: int = response.status # type: ignore
|
||||
self.code: int
|
||||
self.text: str
|
||||
if isinstance(message, dict):
|
||||
self.code = message.get('code', 0)
|
||||
base = message.get('message', '')
|
||||
errors = message.get('errors')
|
||||
self.code = message.get("code", 0)
|
||||
base = message.get("message", "")
|
||||
errors = message.get("errors")
|
||||
if errors:
|
||||
errors = _flatten_error_dict(errors)
|
||||
helpful = '\n'.join('In %s: %s' % t for t in errors.items())
|
||||
self.text = base + '\n' + helpful
|
||||
helpful = "\n".join("In %s: %s" % t for t in errors.items())
|
||||
self.text = base + "\n" + helpful
|
||||
else:
|
||||
self.text = base
|
||||
else:
|
||||
self.text = message or ''
|
||||
self.text = message or ""
|
||||
self.code = 0
|
||||
|
||||
fmt = '{0.status} {0.reason} (error code: {1})'
|
||||
fmt = "{0.status} {0.reason} (error code: {1})"
|
||||
if len(self.text):
|
||||
fmt += ': {2}'
|
||||
fmt += ": {2}"
|
||||
|
||||
super().__init__(fmt.format(self.response, self.code, self.text))
|
||||
|
||||
@ -221,14 +223,20 @@ class ConnectionClosed(ClientException):
|
||||
The shard ID that got closed if applicable.
|
||||
"""
|
||||
|
||||
def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
socket: ClientWebSocketResponse,
|
||||
*,
|
||||
shard_id: Optional[int],
|
||||
code: Optional[int] = None,
|
||||
):
|
||||
# This exception is just the same exception except
|
||||
# reconfigured to subclass ClientException for users
|
||||
self.code: int = code or socket.close_code or -1
|
||||
# aiohttp doesn't seem to consistently provide close reason
|
||||
self.reason: str = ''
|
||||
self.reason: str = ""
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}')
|
||||
super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}")
|
||||
|
||||
|
||||
class PrivilegedIntentsRequired(ClientException):
|
||||
@ -250,10 +258,10 @@ class PrivilegedIntentsRequired(ClientException):
|
||||
def __init__(self, shard_id: Optional[int]):
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
msg = (
|
||||
'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the '
|
||||
'developer portal. It is recommended to go to https://discord.com/developers/applications/ '
|
||||
'and explicitly enable the privileged intents within your application\'s page. If this is not '
|
||||
'possible, then consider disabling the privileged intents instead.'
|
||||
"Shard ID %s is requesting privileged intents that have not been explicitly enabled in the "
|
||||
"developer portal. It is recommended to go to https://discord.com/developers/applications/ "
|
||||
"and explicitly enable the privileged intents within your application's page. If this is not "
|
||||
"possible, then consider disabling the privileged intents instead."
|
||||
)
|
||||
super().__init__(msg % shard_id)
|
||||
|
||||
@ -274,4 +282,4 @@ class InteractionResponded(ClientException):
|
||||
|
||||
def __init__(self, interaction: Interaction):
|
||||
self.interaction: Interaction = interaction
|
||||
super().__init__('This interaction has already been responded to before')
|
||||
super().__init__("This interaction has already been responded to before")
|
||||
|
@ -31,15 +31,23 @@ if TYPE_CHECKING:
|
||||
from .cog import Cog
|
||||
from .errors import CommandError
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
Coro = Coroutine[Any, Any, T]
|
||||
MaybeCoro = Union[T, Coro[T]]
|
||||
CoroFunc = Callable[..., Coro[Any]]
|
||||
|
||||
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
|
||||
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
|
||||
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
|
||||
Check = Union[
|
||||
Callable[["Cog", "Context[Any]"], MaybeCoro[bool]],
|
||||
Callable[["Context[Any]"], MaybeCoro[bool]],
|
||||
]
|
||||
Hook = Union[
|
||||
Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]
|
||||
]
|
||||
Error = Union[
|
||||
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]],
|
||||
Callable[["Context[Any]", "CommandError"], Coro[Any]],
|
||||
]
|
||||
|
||||
|
||||
# This is merely a tag type to avoid circular import issues.
|
||||
|
@ -33,7 +33,18 @@ import importlib.util
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import discord
|
||||
|
||||
@ -54,17 +65,18 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'when_mentioned',
|
||||
'when_mentioned_or',
|
||||
'Bot',
|
||||
'AutoShardedBot',
|
||||
"when_mentioned",
|
||||
"when_mentioned_or",
|
||||
"Bot",
|
||||
"AutoShardedBot",
|
||||
)
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
T = TypeVar('T')
|
||||
CFT = TypeVar('CFT', bound='CoroFunc')
|
||||
CXT = TypeVar('CXT', bound='Context')
|
||||
T = TypeVar("T")
|
||||
CFT = TypeVar("CFT", bound="CoroFunc")
|
||||
CXT = TypeVar("CXT", bound="Context")
|
||||
|
||||
|
||||
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
"""A callable that implements a command prefix equivalent to being mentioned.
|
||||
@ -72,9 +84,12 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
"""
|
||||
# bot.user will never be None when this is called
|
||||
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
|
||||
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
|
||||
|
||||
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
|
||||
def when_mentioned_or(
|
||||
*prefixes: str,
|
||||
) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
|
||||
"""A callable that implements when mentioned or other prefixes provided.
|
||||
|
||||
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
|
||||
@ -103,6 +118,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
|
||||
----------
|
||||
:func:`.when_mentioned`
|
||||
"""
|
||||
|
||||
def inner(bot, msg):
|
||||
r = list(prefixes)
|
||||
r = when_mentioned(bot, msg) + r
|
||||
@ -110,17 +126,29 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _is_submodule(parent: str, child: str) -> bool:
|
||||
return parent == child or child.startswith(parent + ".")
|
||||
|
||||
|
||||
class _DefaultRepr:
|
||||
def __repr__(self):
|
||||
return '<default-help-command>'
|
||||
return "<default-help-command>"
|
||||
|
||||
|
||||
_default = _DefaultRepr()
|
||||
|
||||
|
||||
class BotBase(GroupMixin):
|
||||
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
|
||||
def __init__(
|
||||
self,
|
||||
command_prefix,
|
||||
help_command=_default,
|
||||
description=None,
|
||||
*,
|
||||
intents: discord.Intents,
|
||||
**options,
|
||||
):
|
||||
super().__init__(**options, intents=intents)
|
||||
self.command_prefix = command_prefix
|
||||
self.extra_events: Dict[str, List[CoroFunc]] = {}
|
||||
@ -131,16 +159,20 @@ class BotBase(GroupMixin):
|
||||
self._before_invoke = None
|
||||
self._after_invoke = None
|
||||
self._help_command = None
|
||||
self.description = inspect.cleandoc(description) if description else ''
|
||||
self.owner_id = options.get('owner_id')
|
||||
self.owner_ids = options.get('owner_ids', set())
|
||||
self.strip_after_prefix = options.get('strip_after_prefix', False)
|
||||
self.description = inspect.cleandoc(description) if description else ""
|
||||
self.owner_id = options.get("owner_id")
|
||||
self.owner_ids = options.get("owner_ids", set())
|
||||
self.strip_after_prefix = options.get("strip_after_prefix", False)
|
||||
|
||||
if self.owner_id and self.owner_ids:
|
||||
raise TypeError('Both owner_id and owner_ids are set.')
|
||||
raise TypeError("Both owner_id and owner_ids are set.")
|
||||
|
||||
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
|
||||
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
|
||||
if self.owner_ids and not isinstance(
|
||||
self.owner_ids, collections.abc.Collection
|
||||
):
|
||||
raise TypeError(
|
||||
f"owner_ids must be a collection not {self.owner_ids.__class__!r}"
|
||||
)
|
||||
|
||||
if help_command is _default:
|
||||
self.help_command = DefaultHelpCommand()
|
||||
@ -152,7 +184,7 @@ class BotBase(GroupMixin):
|
||||
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
# super() will resolve to Client
|
||||
super().dispatch(event_name, *args, **kwargs) # type: ignore
|
||||
ev = 'on_' + event_name
|
||||
ev = "on_" + event_name
|
||||
for event in self.extra_events.get(ev, []):
|
||||
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
|
||||
|
||||
@ -172,7 +204,9 @@ class BotBase(GroupMixin):
|
||||
|
||||
await super().close() # type: ignore
|
||||
|
||||
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
|
||||
async def on_command_error(
|
||||
self, context: Context, exception: errors.CommandError
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
The default command error handler provided by the bot.
|
||||
@ -182,7 +216,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
This only fires if you do not specify any listeners for command error.
|
||||
"""
|
||||
if self.extra_events.get('on_command_error', None):
|
||||
if self.extra_events.get("on_command_error", None):
|
||||
return
|
||||
|
||||
command = context.command
|
||||
@ -193,8 +227,10 @@ class BotBase(GroupMixin):
|
||||
if cog and cog.has_error_handler():
|
||||
return
|
||||
|
||||
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
|
||||
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
|
||||
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
|
||||
traceback.print_exception(
|
||||
type(exception), exception, exception.__traceback__, file=sys.stderr
|
||||
)
|
||||
|
||||
# global check registration
|
||||
|
||||
@ -380,7 +416,7 @@ class BotBase(GroupMixin):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The pre-invoke hook must be a coroutine.')
|
||||
raise TypeError("The pre-invoke hook must be a coroutine.")
|
||||
|
||||
self._before_invoke = coro
|
||||
return coro
|
||||
@ -413,7 +449,7 @@ class BotBase(GroupMixin):
|
||||
The coroutine passed is not actually a coroutine.
|
||||
"""
|
||||
if not asyncio.iscoroutinefunction(coro):
|
||||
raise TypeError('The post-invoke hook must be a coroutine.')
|
||||
raise TypeError("The post-invoke hook must be a coroutine.")
|
||||
|
||||
self._after_invoke = coro
|
||||
return coro
|
||||
@ -445,7 +481,7 @@ class BotBase(GroupMixin):
|
||||
name = func.__name__ if name is MISSING else name
|
||||
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
raise TypeError('Listeners must be coroutines')
|
||||
raise TypeError("Listeners must be coroutines")
|
||||
|
||||
if name in self.extra_events:
|
||||
self.extra_events[name].append(func)
|
||||
@ -541,14 +577,14 @@ class BotBase(GroupMixin):
|
||||
"""
|
||||
|
||||
if not isinstance(cog, Cog):
|
||||
raise TypeError('cogs must derive from Cog')
|
||||
raise TypeError("cogs must derive from Cog")
|
||||
|
||||
cog_name = cog.__cog_name__
|
||||
existing = self.__cogs.get(cog_name)
|
||||
|
||||
if existing is not None:
|
||||
if not override:
|
||||
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
|
||||
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
|
||||
self.remove_cog(cog_name)
|
||||
|
||||
cog = cog._inject(self)
|
||||
@ -628,7 +664,9 @@ class BotBase(GroupMixin):
|
||||
for event_list in self.extra_events.copy().values():
|
||||
remove = []
|
||||
for index, event in enumerate(event_list):
|
||||
if event.__module__ is not None and _is_submodule(name, event.__module__):
|
||||
if event.__module__ is not None and _is_submodule(
|
||||
name, event.__module__
|
||||
):
|
||||
remove.append(index)
|
||||
|
||||
for index in reversed(remove):
|
||||
@ -636,7 +674,7 @@ class BotBase(GroupMixin):
|
||||
|
||||
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
|
||||
try:
|
||||
func = getattr(lib, 'teardown')
|
||||
func = getattr(lib, "teardown")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
@ -652,7 +690,9 @@ class BotBase(GroupMixin):
|
||||
if _is_submodule(name, module):
|
||||
del sys.modules[module]
|
||||
|
||||
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
|
||||
def _load_from_module_spec(
|
||||
self, spec: importlib.machinery.ModuleSpec, key: str
|
||||
) -> None:
|
||||
# precondition: key not in self.__extensions
|
||||
lib = importlib.util.module_from_spec(spec)
|
||||
sys.modules[key] = lib
|
||||
@ -663,7 +703,7 @@ class BotBase(GroupMixin):
|
||||
raise errors.ExtensionFailed(key, e) from e
|
||||
|
||||
try:
|
||||
setup = getattr(lib, 'setup')
|
||||
setup = getattr(lib, "setup")
|
||||
except AttributeError:
|
||||
del sys.modules[key]
|
||||
raise errors.NoEntryPointError(key)
|
||||
@ -850,7 +890,7 @@ class BotBase(GroupMixin):
|
||||
def help_command(self, value: Optional[HelpCommand]) -> None:
|
||||
if value is not None:
|
||||
if not isinstance(value, HelpCommand):
|
||||
raise TypeError('help_command must be a subclass of HelpCommand')
|
||||
raise TypeError("help_command must be a subclass of HelpCommand")
|
||||
if self._help_command is not None:
|
||||
self._help_command._remove_from_bot(self)
|
||||
self._help_command = value
|
||||
@ -893,11 +933,15 @@ class BotBase(GroupMixin):
|
||||
if isinstance(ret, collections.abc.Iterable):
|
||||
raise
|
||||
|
||||
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
|
||||
f"returning either of these, not {ret.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"command_prefix must be plain string, iterable of strings, or callable "
|
||||
f"returning either of these, not {ret.__class__.__name__}"
|
||||
)
|
||||
|
||||
if not ret:
|
||||
raise ValueError("Iterable command_prefix must contain at least one prefix")
|
||||
raise ValueError(
|
||||
"Iterable command_prefix must contain at least one prefix"
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@ -954,14 +998,18 @@ class BotBase(GroupMixin):
|
||||
|
||||
except TypeError:
|
||||
if not isinstance(prefix, list):
|
||||
raise TypeError("get_prefix must return either a string or a list of string, "
|
||||
f"not {prefix.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"get_prefix must return either a string or a list of string, "
|
||||
f"not {prefix.__class__.__name__}"
|
||||
)
|
||||
|
||||
# It's possible a bad command_prefix got us here.
|
||||
for value in prefix:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
|
||||
f"contain only strings, not {value.__class__.__name__}")
|
||||
raise TypeError(
|
||||
"Iterable command_prefix or list returned from get_prefix must "
|
||||
f"contain only strings, not {value.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Getting here shouldn't happen
|
||||
raise
|
||||
@ -988,19 +1036,19 @@ class BotBase(GroupMixin):
|
||||
The invocation context to invoke.
|
||||
"""
|
||||
if ctx.command is not None:
|
||||
self.dispatch('command', ctx)
|
||||
self.dispatch("command", ctx)
|
||||
try:
|
||||
if await self.can_run(ctx, call_once=True):
|
||||
await ctx.command.invoke(ctx)
|
||||
else:
|
||||
raise errors.CheckFailure('The global check once functions failed.')
|
||||
raise errors.CheckFailure("The global check once functions failed.")
|
||||
except errors.CommandError as exc:
|
||||
await ctx.command.dispatch_error(ctx, exc)
|
||||
else:
|
||||
self.dispatch('command_completion', ctx)
|
||||
self.dispatch("command_completion", ctx)
|
||||
elif ctx.invoked_with:
|
||||
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
|
||||
self.dispatch('command_error', ctx, exc)
|
||||
self.dispatch("command_error", ctx, exc)
|
||||
|
||||
async def process_commands(self, message: Message) -> None:
|
||||
"""|coro|
|
||||
@ -1033,6 +1081,7 @@ class BotBase(GroupMixin):
|
||||
async def on_message(self, message):
|
||||
await self.process_commands(message)
|
||||
|
||||
|
||||
class Bot(BotBase, discord.Client):
|
||||
"""Represents a discord bot.
|
||||
|
||||
@ -1103,10 +1152,13 @@ class Bot(BotBase, discord.Client):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AutoShardedBot(BotBase, discord.AutoShardedClient):
|
||||
"""This is similar to :class:`.Bot` except that it is inherited from
|
||||
:class:`discord.AutoShardedClient` instead.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
@ -26,7 +26,19 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import discord.utils
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Type,
|
||||
)
|
||||
|
||||
from ._types import _BaseCommand
|
||||
|
||||
@ -36,15 +48,16 @@ if TYPE_CHECKING:
|
||||
from .core import Command
|
||||
|
||||
__all__ = (
|
||||
'CogMeta',
|
||||
'Cog',
|
||||
"CogMeta",
|
||||
"Cog",
|
||||
)
|
||||
|
||||
CogT = TypeVar('CogT', bound='Cog')
|
||||
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
|
||||
CogT = TypeVar("CogT", bound="Cog")
|
||||
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
class CogMeta(type):
|
||||
"""A metaclass for defining a cog.
|
||||
|
||||
@ -104,6 +117,7 @@ class CogMeta(type):
|
||||
async def bar(self, ctx):
|
||||
pass # hidden -> False
|
||||
"""
|
||||
|
||||
__cog_name__: str
|
||||
__cog_settings__: Dict[str, Any]
|
||||
__cog_commands__: List[Command]
|
||||
@ -111,17 +125,17 @@ class CogMeta(type):
|
||||
|
||||
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
|
||||
name, bases, attrs = args
|
||||
attrs['__cog_name__'] = kwargs.pop('name', name)
|
||||
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
|
||||
attrs["__cog_name__"] = kwargs.pop("name", name)
|
||||
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
|
||||
|
||||
description = kwargs.pop('description', None)
|
||||
description = kwargs.pop("description", None)
|
||||
if description is None:
|
||||
description = inspect.cleandoc(attrs.get('__doc__', ''))
|
||||
attrs['__cog_description__'] = description
|
||||
description = inspect.cleandoc(attrs.get("__doc__", ""))
|
||||
attrs["__cog_description__"] = description
|
||||
|
||||
commands = {}
|
||||
listeners = {}
|
||||
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
|
||||
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
|
||||
|
||||
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
for base in reversed(new_cls.__mro__):
|
||||
@ -136,21 +150,25 @@ class CogMeta(type):
|
||||
value = value.__func__
|
||||
if isinstance(value, _BaseCommand):
|
||||
if is_static_method:
|
||||
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
|
||||
if elem.startswith(('cog_', 'bot_')):
|
||||
raise TypeError(
|
||||
f"Command in method {base}.{elem!r} must not be staticmethod."
|
||||
)
|
||||
if elem.startswith(("cog_", "bot_")):
|
||||
raise TypeError(no_bot_cog.format(base, elem))
|
||||
commands[elem] = value
|
||||
elif inspect.iscoroutinefunction(value):
|
||||
try:
|
||||
getattr(value, '__cog_listener__')
|
||||
getattr(value, "__cog_listener__")
|
||||
except AttributeError:
|
||||
continue
|
||||
else:
|
||||
if elem.startswith(('cog_', 'bot_')):
|
||||
if elem.startswith(("cog_", "bot_")):
|
||||
raise TypeError(no_bot_cog.format(base, elem))
|
||||
listeners[elem] = value
|
||||
|
||||
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
|
||||
new_cls.__cog_commands__ = list(
|
||||
commands.values()
|
||||
) # this will be copied in Cog.__new__
|
||||
|
||||
listeners_as_list = []
|
||||
for listener in listeners.values():
|
||||
@ -169,10 +187,12 @@ class CogMeta(type):
|
||||
def qualified_name(cls) -> str:
|
||||
return cls.__cog_name__
|
||||
|
||||
|
||||
def _cog_special_method(func: FuncT) -> FuncT:
|
||||
func.__cog_special_method__ = None
|
||||
return func
|
||||
|
||||
|
||||
class Cog(metaclass=CogMeta):
|
||||
"""The base class that all cogs must inherit from.
|
||||
|
||||
@ -183,6 +203,7 @@ class Cog(metaclass=CogMeta):
|
||||
When inheriting from this class, the options shown in :class:`CogMeta`
|
||||
are equally valid here.
|
||||
"""
|
||||
|
||||
__cog_name__: ClassVar[str]
|
||||
__cog_settings__: ClassVar[Dict[str, Any]]
|
||||
__cog_commands__: ClassVar[List[Command]]
|
||||
@ -199,10 +220,7 @@ class Cog(metaclass=CogMeta):
|
||||
# r.e type ignore, type-checker complains about overriding a ClassVar
|
||||
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
|
||||
|
||||
lookup = {
|
||||
cmd.qualified_name: cmd
|
||||
for cmd in self.__cog_commands__
|
||||
}
|
||||
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
|
||||
|
||||
# Update the Command instances dynamically as well
|
||||
for command in self.__cog_commands__:
|
||||
@ -255,6 +273,7 @@ class Cog(metaclass=CogMeta):
|
||||
A command or group from the cog.
|
||||
"""
|
||||
from .core import GroupMixin
|
||||
|
||||
for command in self.__cog_commands__:
|
||||
if command.parent is None:
|
||||
yield command
|
||||
@ -269,12 +288,15 @@ class Cog(metaclass=CogMeta):
|
||||
List[Tuple[:class:`str`, :ref:`coroutine <coroutine>`]]
|
||||
The listeners defined in this cog.
|
||||
"""
|
||||
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
|
||||
return [
|
||||
(name, getattr(self, method_name))
|
||||
for name, method_name in self.__cog_listeners__
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
|
||||
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
|
||||
return getattr(method.__func__, '__cog_special_method__', method)
|
||||
return getattr(method.__func__, "__cog_special_method__", method)
|
||||
|
||||
@classmethod
|
||||
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
|
||||
@ -296,14 +318,16 @@ class Cog(metaclass=CogMeta):
|
||||
"""
|
||||
|
||||
if name is not MISSING and not isinstance(name, str):
|
||||
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
|
||||
raise TypeError(
|
||||
f"Cog.listener expected str but received {name.__class__.__name__!r} instead."
|
||||
)
|
||||
|
||||
def decorator(func: FuncT) -> FuncT:
|
||||
actual = func
|
||||
if isinstance(actual, staticmethod):
|
||||
actual = actual.__func__
|
||||
if not inspect.iscoroutinefunction(actual):
|
||||
raise TypeError('Listener function must be a coroutine function.')
|
||||
raise TypeError("Listener function must be a coroutine function.")
|
||||
actual.__cog_listener__ = True
|
||||
to_assign = name or actual.__name__
|
||||
try:
|
||||
@ -315,6 +339,7 @@ class Cog(metaclass=CogMeta):
|
||||
# to pick it up but the metaclass unfurls the function and
|
||||
# thus the assignments need to be on the actual function
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def has_error_handler(self) -> bool:
|
||||
@ -322,7 +347,7 @@ class Cog(metaclass=CogMeta):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
|
||||
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
|
||||
|
||||
@_cog_special_method
|
||||
def cog_unload(self) -> None:
|
||||
|
@ -49,21 +49,19 @@ if TYPE_CHECKING:
|
||||
from .help import HelpCommand
|
||||
from .view import StringView
|
||||
|
||||
__all__ = (
|
||||
'Context',
|
||||
)
|
||||
__all__ = ("Context",)
|
||||
|
||||
MISSING: Any = discord.utils.MISSING
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar('CogT', bound="Cog")
|
||||
T = TypeVar("T")
|
||||
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
|
||||
CogT = TypeVar("CogT", bound="Cog")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
P = ParamSpec("P")
|
||||
else:
|
||||
P = TypeVar('P')
|
||||
P = TypeVar("P")
|
||||
|
||||
|
||||
class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
@ -122,7 +120,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
or invoked.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: Message,
|
||||
bot: BotT,
|
||||
@ -153,7 +152,9 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
self.current_parameter: Optional[inspect.Parameter] = current_parameter
|
||||
self._state: ConnectionState = self.message._state
|
||||
|
||||
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
async def invoke(
|
||||
self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs
|
||||
) -> T:
|
||||
r"""|coro|
|
||||
|
||||
Calls a command with the arguments given.
|
||||
@ -219,7 +220,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
cmd = self.command
|
||||
view = self.view
|
||||
if cmd is None:
|
||||
raise ValueError('This context is not valid.')
|
||||
raise ValueError("This context is not valid.")
|
||||
|
||||
# some state to revert to when we're done
|
||||
index, previous = view.index, view.previous
|
||||
@ -230,10 +231,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
|
||||
if restart:
|
||||
to_call = cmd.root_parent or cmd
|
||||
view.index = len(self.prefix or '')
|
||||
view.index = len(self.prefix or "")
|
||||
view.previous = 0
|
||||
self.invoked_parents = []
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
self.invoked_with = view.get_word() # advance to get the root command
|
||||
else:
|
||||
to_call = cmd
|
||||
|
||||
@ -263,7 +264,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if self.prefix is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
user = self.me
|
||||
# this breaks if the prefix mention is not the bot itself but I
|
||||
@ -271,7 +272,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
# for this common use case rather than waste performance for the
|
||||
# odd one.
|
||||
pattern = re.compile(r"<@!?%s>" % user.id)
|
||||
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
|
||||
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
|
||||
|
||||
@property
|
||||
def cog(self) -> Optional[Cog]:
|
||||
@ -381,7 +382,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
|
||||
await cmd.prepare_help_command(self, entity.qualified_name)
|
||||
|
||||
try:
|
||||
if hasattr(entity, '__cog_commands__'):
|
||||
if hasattr(entity, "__cog_commands__"):
|
||||
injected = wrap_callback(cmd.send_cog_help)
|
||||
return await injected(entity)
|
||||
elif isinstance(entity, Group):
|
||||
|
@ -52,32 +52,32 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Converter',
|
||||
'ObjectConverter',
|
||||
'MemberConverter',
|
||||
'UserConverter',
|
||||
'MessageConverter',
|
||||
'PartialMessageConverter',
|
||||
'TextChannelConverter',
|
||||
'InviteConverter',
|
||||
'GuildConverter',
|
||||
'RoleConverter',
|
||||
'GameConverter',
|
||||
'ColourConverter',
|
||||
'ColorConverter',
|
||||
'VoiceChannelConverter',
|
||||
'StageChannelConverter',
|
||||
'EmojiConverter',
|
||||
'PartialEmojiConverter',
|
||||
'CategoryChannelConverter',
|
||||
'IDConverter',
|
||||
'StoreChannelConverter',
|
||||
'ThreadConverter',
|
||||
'GuildChannelConverter',
|
||||
'GuildStickerConverter',
|
||||
'clean_content',
|
||||
'Greedy',
|
||||
'run_converters',
|
||||
"Converter",
|
||||
"ObjectConverter",
|
||||
"MemberConverter",
|
||||
"UserConverter",
|
||||
"MessageConverter",
|
||||
"PartialMessageConverter",
|
||||
"TextChannelConverter",
|
||||
"InviteConverter",
|
||||
"GuildConverter",
|
||||
"RoleConverter",
|
||||
"GameConverter",
|
||||
"ColourConverter",
|
||||
"ColorConverter",
|
||||
"VoiceChannelConverter",
|
||||
"StageChannelConverter",
|
||||
"EmojiConverter",
|
||||
"PartialEmojiConverter",
|
||||
"CategoryChannelConverter",
|
||||
"IDConverter",
|
||||
"StoreChannelConverter",
|
||||
"ThreadConverter",
|
||||
"GuildChannelConverter",
|
||||
"GuildStickerConverter",
|
||||
"clean_content",
|
||||
"Greedy",
|
||||
"run_converters",
|
||||
)
|
||||
|
||||
|
||||
@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument):
|
||||
|
||||
|
||||
_utils_get = discord.utils.get
|
||||
T = TypeVar('T')
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
||||
TT = TypeVar('TT', bound=discord.Thread)
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
|
||||
TT = TypeVar("TT", bound=discord.Thread)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -132,10 +132,10 @@ class Converter(Protocol[T_co]):
|
||||
:exc:`.BadArgument`
|
||||
The converter failed to convert the argument.
|
||||
"""
|
||||
raise NotImplementedError('Derived classes need to implement this.')
|
||||
raise NotImplementedError("Derived classes need to implement this.")
|
||||
|
||||
|
||||
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
|
||||
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
|
||||
|
||||
|
||||
class IDConverter(Converter[T_co]):
|
||||
@ -158,7 +158,9 @@ class ObjectConverter(IDConverter[discord.Object]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Object:
|
||||
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(
|
||||
r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument
|
||||
)
|
||||
|
||||
if match is None:
|
||||
raise ObjectNotFound(argument)
|
||||
@ -192,13 +194,17 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
|
||||
async def query_member_named(self, guild, argument):
|
||||
cache = guild._state.member_cache_flags.joined
|
||||
if len(argument) > 5 and argument[-5] == '#':
|
||||
username, _, discriminator = argument.rpartition('#')
|
||||
if len(argument) > 5 and argument[-5] == "#":
|
||||
username, _, discriminator = argument.rpartition("#")
|
||||
members = await guild.query_members(username, limit=100, cache=cache)
|
||||
return discord.utils.get(members, name=username, discriminator=discriminator)
|
||||
return discord.utils.get(
|
||||
members, name=username, discriminator=discriminator
|
||||
)
|
||||
else:
|
||||
members = await guild.query_members(argument, limit=100, cache=cache)
|
||||
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
|
||||
return discord.utils.find(
|
||||
lambda m: m.name == argument or m.nick == argument, members
|
||||
)
|
||||
|
||||
async def query_member_by_id(self, bot, guild, user_id):
|
||||
ws = bot._get_websocket(shard_id=guild.shard_id)
|
||||
@ -223,7 +229,9 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Member:
|
||||
bot = ctx.bot
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(
|
||||
r"<@!?([0-9]{15,20})>$", argument
|
||||
)
|
||||
guild = ctx.guild
|
||||
result = None
|
||||
user_id = None
|
||||
@ -232,13 +240,15 @@ class MemberConverter(IDConverter[discord.Member]):
|
||||
if guild:
|
||||
result = guild.get_member_named(argument)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_member_named', argument)
|
||||
result = _get_from_guilds(bot, "get_member_named", argument)
|
||||
else:
|
||||
user_id = int(match.group(1))
|
||||
if guild:
|
||||
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
|
||||
result = guild.get_member(user_id) or _utils_get(
|
||||
ctx.message.mentions, id=user_id
|
||||
)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_member', user_id)
|
||||
result = _get_from_guilds(bot, "get_member", user_id)
|
||||
|
||||
if result is None:
|
||||
if guild is None:
|
||||
@ -276,13 +286,17 @@ class UserConverter(IDConverter[discord.User]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.User:
|
||||
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(
|
||||
r"<@!?([0-9]{15,20})>$", argument
|
||||
)
|
||||
result = None
|
||||
state = ctx._state
|
||||
|
||||
if match is not None:
|
||||
user_id = int(match.group(1))
|
||||
result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id)
|
||||
result = ctx.bot.get_user(user_id) or _utils_get(
|
||||
ctx.message.mentions, id=user_id
|
||||
)
|
||||
if result is None:
|
||||
try:
|
||||
result = await ctx.bot.fetch_user(user_id)
|
||||
@ -294,12 +308,12 @@ class UserConverter(IDConverter[discord.User]):
|
||||
arg = argument
|
||||
|
||||
# Remove the '@' character if this is the first character from the argument
|
||||
if arg[0] == '@':
|
||||
if arg[0] == "@":
|
||||
# Remove first character
|
||||
arg = arg[1:]
|
||||
|
||||
# check for discriminator if it exists,
|
||||
if len(arg) > 5 and arg[-5] == '#':
|
||||
if len(arg) > 5 and arg[-5] == "#":
|
||||
discrim = arg[-4:]
|
||||
name = arg[:-5]
|
||||
predicate = lambda u: u.name == name and u.discriminator == discrim
|
||||
@ -330,29 +344,33 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
|
||||
|
||||
@staticmethod
|
||||
def _get_id_matches(ctx, argument):
|
||||
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
|
||||
id_regex = re.compile(
|
||||
r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$"
|
||||
)
|
||||
link_regex = re.compile(
|
||||
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
|
||||
r'(?P<guild_id>[0-9]{15,20}|@me)'
|
||||
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
|
||||
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
|
||||
r"(?P<guild_id>[0-9]{15,20}|@me)"
|
||||
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
|
||||
)
|
||||
match = id_regex.match(argument) or link_regex.match(argument)
|
||||
if not match:
|
||||
raise MessageNotFound(argument)
|
||||
data = match.groupdict()
|
||||
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
|
||||
message_id = int(data['message_id'])
|
||||
guild_id = data.get('guild_id')
|
||||
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
|
||||
message_id = int(data["message_id"])
|
||||
guild_id = data.get("guild_id")
|
||||
if guild_id is None:
|
||||
guild_id = ctx.guild and ctx.guild.id
|
||||
elif guild_id == '@me':
|
||||
elif guild_id == "@me":
|
||||
guild_id = None
|
||||
else:
|
||||
guild_id = int(guild_id)
|
||||
return guild_id, message_id, channel_id
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
|
||||
def _resolve_channel(
|
||||
ctx, guild_id, channel_id
|
||||
) -> Optional[PartialMessageableChannel]:
|
||||
if guild_id is not None:
|
||||
guild = ctx.bot.get_guild(guild_id)
|
||||
if guild is not None and channel_id is not None:
|
||||
@ -386,7 +404,9 @@ class MessageConverter(IDConverter[discord.Message]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Message:
|
||||
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
|
||||
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(
|
||||
ctx, argument
|
||||
)
|
||||
message = ctx.bot._connection._get_message(message_id)
|
||||
if message:
|
||||
return message
|
||||
@ -417,13 +437,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
|
||||
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
|
||||
return self._resolve_channel(
|
||||
ctx, argument, "channels", discord.abc.GuildChannel
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
|
||||
def _resolve_channel(
|
||||
ctx: Context, argument: str, attribute: str, type: Type[CT]
|
||||
) -> CT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
match = IDConverter._get_id_match(argument) or re.match(
|
||||
r"<#([0-9]{15,20})>$", argument
|
||||
)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
@ -443,7 +469,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
if guild:
|
||||
result = guild.get_channel(channel_id)
|
||||
else:
|
||||
result = _get_from_guilds(bot, 'get_channel', channel_id)
|
||||
result = _get_from_guilds(bot, "get_channel", channel_id)
|
||||
|
||||
if not isinstance(result, type):
|
||||
raise ChannelNotFound(argument)
|
||||
@ -451,10 +477,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
|
||||
def _resolve_thread(
|
||||
ctx: Context, argument: str, attribute: str, type: Type[TT]
|
||||
) -> TT:
|
||||
bot = ctx.bot
|
||||
|
||||
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
|
||||
match = IDConverter._get_id_match(argument) or re.match(
|
||||
r"<#([0-9]{15,20})>$", argument
|
||||
)
|
||||
result = None
|
||||
guild = ctx.guild
|
||||
|
||||
@ -491,7 +521,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
|
||||
return GuildChannelConverter._resolve_channel(
|
||||
ctx, argument, "text_channels", discord.TextChannel
|
||||
)
|
||||
|
||||
|
||||
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
@ -511,7 +543,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
|
||||
return GuildChannelConverter._resolve_channel(
|
||||
ctx, argument, "voice_channels", discord.VoiceChannel
|
||||
)
|
||||
|
||||
|
||||
class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
@ -530,7 +564,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
|
||||
return GuildChannelConverter._resolve_channel(
|
||||
ctx, argument, "stage_channels", discord.StageChannel
|
||||
)
|
||||
|
||||
|
||||
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
@ -550,7 +586,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
|
||||
return GuildChannelConverter._resolve_channel(
|
||||
ctx, argument, "categories", discord.CategoryChannel
|
||||
)
|
||||
|
||||
|
||||
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
@ -569,7 +607,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
|
||||
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
|
||||
return GuildChannelConverter._resolve_channel(
|
||||
ctx, argument, "channels", discord.StoreChannel
|
||||
)
|
||||
|
||||
|
||||
class ThreadConverter(IDConverter[discord.Thread]):
|
||||
@ -587,7 +627,9 @@ class ThreadConverter(IDConverter[discord.Thread]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
|
||||
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
|
||||
return GuildChannelConverter._resolve_thread(
|
||||
ctx, argument, "threads", discord.Thread
|
||||
)
|
||||
|
||||
|
||||
class ColourConverter(Converter[discord.Colour]):
|
||||
@ -616,10 +658,12 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
Added support for ``rgb`` function and 3-digit hex shortcuts
|
||||
"""
|
||||
|
||||
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
|
||||
RGB_REGEX = re.compile(
|
||||
r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)"
|
||||
)
|
||||
|
||||
def parse_hex_number(self, argument):
|
||||
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
|
||||
try:
|
||||
value = int(arg, base=16)
|
||||
if not (0 <= value <= 0xFFFFFF):
|
||||
@ -630,7 +674,7 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
return discord.Color(value=value)
|
||||
|
||||
def parse_rgb_number(self, argument, number):
|
||||
if number[-1] == '%':
|
||||
if number[-1] == "%":
|
||||
value = int(number[:-1])
|
||||
if not (0 <= value <= 100):
|
||||
raise BadColourArgument(argument)
|
||||
@ -646,29 +690,29 @@ class ColourConverter(Converter[discord.Colour]):
|
||||
if match is None:
|
||||
raise BadColourArgument(argument)
|
||||
|
||||
red = self.parse_rgb_number(argument, match.group('r'))
|
||||
green = self.parse_rgb_number(argument, match.group('g'))
|
||||
blue = self.parse_rgb_number(argument, match.group('b'))
|
||||
red = self.parse_rgb_number(argument, match.group("r"))
|
||||
green = self.parse_rgb_number(argument, match.group("g"))
|
||||
blue = self.parse_rgb_number(argument, match.group("b"))
|
||||
return discord.Color.from_rgb(red, green, blue)
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
|
||||
if argument[0] == '#':
|
||||
if argument[0] == "#":
|
||||
return self.parse_hex_number(argument[1:])
|
||||
|
||||
if argument[0:2] == '0x':
|
||||
if argument[0:2] == "0x":
|
||||
rest = argument[2:]
|
||||
# Legacy backwards compatible syntax
|
||||
if rest.startswith('#'):
|
||||
if rest.startswith("#"):
|
||||
return self.parse_hex_number(rest[1:])
|
||||
return self.parse_hex_number(rest)
|
||||
|
||||
arg = argument.lower()
|
||||
if arg[0:3] == 'rgb':
|
||||
if arg[0:3] == "rgb":
|
||||
return self.parse_rgb(arg)
|
||||
|
||||
arg = arg.replace(' ', '_')
|
||||
arg = arg.replace(" ", "_")
|
||||
method = getattr(discord.Colour, arg, None)
|
||||
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
|
||||
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
|
||||
raise BadColourArgument(arg)
|
||||
return method()
|
||||
|
||||
@ -697,7 +741,9 @@ class RoleConverter(IDConverter[discord.Role]):
|
||||
if not guild:
|
||||
raise NoPrivateMessage()
|
||||
|
||||
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(
|
||||
r"<@&([0-9]{15,20})>$", argument
|
||||
)
|
||||
if match:
|
||||
result = guild.get_role(int(match.group(1)))
|
||||
else:
|
||||
@ -776,7 +822,9 @@ class EmojiConverter(IDConverter[discord.Emoji]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
|
||||
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
|
||||
match = self._get_id_match(argument) or re.match(
|
||||
r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument
|
||||
)
|
||||
result = None
|
||||
bot = ctx.bot
|
||||
guild = ctx.guild
|
||||
@ -810,7 +858,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
"""
|
||||
|
||||
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
|
||||
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
|
||||
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
|
||||
|
||||
if match:
|
||||
emoji_animated = bool(match.group(1))
|
||||
@ -818,7 +866,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
|
||||
emoji_id = int(match.group(3))
|
||||
|
||||
return discord.PartialEmoji.with_state(
|
||||
ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id
|
||||
ctx.bot._connection,
|
||||
animated=emoji_animated,
|
||||
name=emoji_name,
|
||||
id=emoji_id,
|
||||
)
|
||||
|
||||
raise PartialEmojiConversionFailure(argument)
|
||||
@ -903,37 +954,41 @@ class clean_content(Converter[str]):
|
||||
|
||||
def resolve_member(id: int) -> str:
|
||||
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
|
||||
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
|
||||
return (
|
||||
f"@{m.display_name if self.use_nicknames else m.name}"
|
||||
if m
|
||||
else "@deleted-user"
|
||||
)
|
||||
|
||||
def resolve_role(id: int) -> str:
|
||||
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
|
||||
return f'@{r.name}' if r else '@deleted-role'
|
||||
return f"@{r.name}" if r else "@deleted-role"
|
||||
|
||||
else:
|
||||
|
||||
def resolve_member(id: int) -> str:
|
||||
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
|
||||
return f'@{m.name}' if m else '@deleted-user'
|
||||
return f"@{m.name}" if m else "@deleted-user"
|
||||
|
||||
def resolve_role(id: int) -> str:
|
||||
return '@deleted-role'
|
||||
return "@deleted-role"
|
||||
|
||||
if self.fix_channel_mentions and ctx.guild:
|
||||
|
||||
def resolve_channel(id: int) -> str:
|
||||
c = ctx.guild.get_channel(id)
|
||||
return f'#{c.name}' if c else '#deleted-channel'
|
||||
return f"#{c.name}" if c else "#deleted-channel"
|
||||
|
||||
else:
|
||||
|
||||
def resolve_channel(id: int) -> str:
|
||||
return f'<#{id}>'
|
||||
return f"<#{id}>"
|
||||
|
||||
transforms = {
|
||||
'@': resolve_member,
|
||||
'@!': resolve_member,
|
||||
'#': resolve_channel,
|
||||
'@&': resolve_role,
|
||||
"@": resolve_member,
|
||||
"@!": resolve_member,
|
||||
"#": resolve_channel,
|
||||
"@&": resolve_role,
|
||||
}
|
||||
|
||||
def repl(match: re.Match) -> str:
|
||||
@ -942,7 +997,7 @@ class clean_content(Converter[str]):
|
||||
transformed = transforms[type](id)
|
||||
return transformed
|
||||
|
||||
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
|
||||
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
|
||||
if self.escape_markdown:
|
||||
result = discord.utils.escape_markdown(result)
|
||||
elif self.remove_markdown:
|
||||
@ -974,42 +1029,46 @@ class Greedy(List[T]):
|
||||
For more information, check :ref:`ext_commands_special_converters`.
|
||||
"""
|
||||
|
||||
__slots__ = ('converter',)
|
||||
__slots__ = ("converter",)
|
||||
|
||||
def __init__(self, *, converter: T):
|
||||
self.converter = converter
|
||||
|
||||
def __repr__(self):
|
||||
converter = getattr(self.converter, '__name__', repr(self.converter))
|
||||
return f'Greedy[{converter}]'
|
||||
converter = getattr(self.converter, "__name__", repr(self.converter))
|
||||
return f"Greedy[{converter}]"
|
||||
|
||||
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
if len(params) != 1:
|
||||
raise TypeError('Greedy[...] only takes a single argument')
|
||||
raise TypeError("Greedy[...] only takes a single argument")
|
||||
converter = params[0]
|
||||
|
||||
origin = getattr(converter, '__origin__', None)
|
||||
args = getattr(converter, '__args__', ())
|
||||
origin = getattr(converter, "__origin__", None)
|
||||
args = getattr(converter, "__args__", ())
|
||||
|
||||
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
|
||||
raise TypeError('Greedy[...] expects a type or a Converter instance.')
|
||||
if not (
|
||||
callable(converter)
|
||||
or isinstance(converter, Converter)
|
||||
or origin is not None
|
||||
):
|
||||
raise TypeError("Greedy[...] expects a type or a Converter instance.")
|
||||
|
||||
if converter in (str, type(None)) or origin is Greedy:
|
||||
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
|
||||
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
|
||||
|
||||
if origin is Union and type(None) in args:
|
||||
raise TypeError(f'Greedy[{converter!r}] is invalid.')
|
||||
raise TypeError(f"Greedy[{converter!r}] is invalid.")
|
||||
|
||||
return cls(converter=converter)
|
||||
|
||||
|
||||
def _convert_to_bool(argument: str) -> bool:
|
||||
lowered = argument.lower()
|
||||
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
|
||||
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
|
||||
return True
|
||||
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
|
||||
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
|
||||
return False
|
||||
else:
|
||||
raise BadBoolArgument(lowered)
|
||||
@ -1056,7 +1115,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
|
||||
}
|
||||
|
||||
|
||||
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def _actual_conversion(
|
||||
ctx: Context, converter, argument: str, param: inspect.Parameter
|
||||
):
|
||||
if converter is bool:
|
||||
return _convert_to_bool(argument)
|
||||
|
||||
@ -1065,7 +1126,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
|
||||
if module is not None and (
|
||||
module.startswith("discord.") and not module.endswith("converter")
|
||||
):
|
||||
converter = CONVERTER_MAPPING.get(converter, converter)
|
||||
|
||||
try:
|
||||
@ -1091,10 +1154,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
|
||||
except AttributeError:
|
||||
name = converter.__class__.__name__
|
||||
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
raise BadArgument(
|
||||
f'Converting to "{name}" failed for parameter "{param.name}".'
|
||||
) from exc
|
||||
|
||||
|
||||
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
|
||||
async def run_converters(
|
||||
ctx: Context, converter, argument: str, param: inspect.Parameter
|
||||
):
|
||||
"""|coro|
|
||||
|
||||
Runs converters for a given converter, argument, and parameter.
|
||||
@ -1124,7 +1191,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
|
||||
Any
|
||||
The resulting conversion.
|
||||
"""
|
||||
origin = getattr(converter, '__origin__', None)
|
||||
origin = getattr(converter, "__origin__", None)
|
||||
|
||||
if origin is Union:
|
||||
errors = []
|
||||
|
@ -38,24 +38,25 @@ if TYPE_CHECKING:
|
||||
from ...message import Message
|
||||
|
||||
__all__ = (
|
||||
'BucketType',
|
||||
'Cooldown',
|
||||
'CooldownMapping',
|
||||
'DynamicCooldownMapping',
|
||||
'MaxConcurrency',
|
||||
"BucketType",
|
||||
"Cooldown",
|
||||
"CooldownMapping",
|
||||
"DynamicCooldownMapping",
|
||||
"MaxConcurrency",
|
||||
)
|
||||
|
||||
C = TypeVar('C', bound='CooldownMapping')
|
||||
MC = TypeVar('MC', bound='MaxConcurrency')
|
||||
C = TypeVar("C", bound="CooldownMapping")
|
||||
MC = TypeVar("MC", bound="MaxConcurrency")
|
||||
|
||||
|
||||
class BucketType(Enum):
|
||||
default = 0
|
||||
user = 1
|
||||
guild = 2
|
||||
channel = 3
|
||||
member = 4
|
||||
default = 0
|
||||
user = 1
|
||||
guild = 2
|
||||
channel = 3
|
||||
member = 4
|
||||
category = 5
|
||||
role = 6
|
||||
role = 6
|
||||
|
||||
def get_key(self, msg: Message) -> Any:
|
||||
if self is BucketType.user:
|
||||
@ -90,7 +91,7 @@ class Cooldown:
|
||||
The length of the cooldown period in seconds.
|
||||
"""
|
||||
|
||||
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
|
||||
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
|
||||
|
||||
def __init__(self, rate: float, per: float) -> None:
|
||||
self.rate: int = int(rate)
|
||||
@ -190,7 +191,8 @@ class Cooldown:
|
||||
return Cooldown(self.rate, self.per)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
|
||||
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
|
||||
|
||||
|
||||
class CooldownMapping:
|
||||
def __init__(
|
||||
@ -199,7 +201,7 @@ class CooldownMapping:
|
||||
type: Callable[[Message], Any],
|
||||
) -> None:
|
||||
if not callable(type):
|
||||
raise TypeError('Cooldown type must be a BucketType or callable')
|
||||
raise TypeError("Cooldown type must be a BucketType or callable")
|
||||
|
||||
self._cache: Dict[Any, Cooldown] = {}
|
||||
self._cooldown: Optional[Cooldown] = original
|
||||
@ -252,16 +254,16 @@ class CooldownMapping:
|
||||
|
||||
return bucket
|
||||
|
||||
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
|
||||
def update_rate_limit(
|
||||
self, message: Message, current: Optional[float] = None
|
||||
) -> Optional[float]:
|
||||
bucket = self.get_bucket(message, current)
|
||||
return bucket.update_rate_limit(current)
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
|
||||
class DynamicCooldownMapping(CooldownMapping):
|
||||
def __init__(
|
||||
self,
|
||||
factory: Callable[[Message], Cooldown],
|
||||
type: Callable[[Message], Any]
|
||||
self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]
|
||||
) -> None:
|
||||
super().__init__(None, type)
|
||||
self._factory: Callable[[Message], Cooldown] = factory
|
||||
@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping):
|
||||
def create_bucket(self, message: Message) -> Cooldown:
|
||||
return self._factory(message)
|
||||
|
||||
|
||||
class _Semaphore:
|
||||
"""This class is a version of a semaphore.
|
||||
|
||||
@ -291,7 +294,7 @@ class _Semaphore:
|
||||
overkill for what is basically a counter.
|
||||
"""
|
||||
|
||||
__slots__ = ('value', 'loop', '_waiters')
|
||||
__slots__ = ("value", "loop", "_waiters")
|
||||
|
||||
def __init__(self, number: int) -> None:
|
||||
self.value: int = number
|
||||
@ -299,7 +302,7 @@ class _Semaphore:
|
||||
self._waiters: Deque[asyncio.Future] = deque()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
|
||||
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self.value == 0
|
||||
@ -337,8 +340,9 @@ class _Semaphore:
|
||||
self.value += 1
|
||||
self.wake_up()
|
||||
|
||||
|
||||
class MaxConcurrency:
|
||||
__slots__ = ('number', 'per', 'wait', '_mapping')
|
||||
__slots__ = ("number", "per", "wait", "_mapping")
|
||||
|
||||
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
|
||||
self._mapping: Dict[Any, _Semaphore] = {}
|
||||
@ -347,16 +351,20 @@ class MaxConcurrency:
|
||||
self.wait: bool = wait
|
||||
|
||||
if number <= 0:
|
||||
raise ValueError('max_concurrency \'number\' cannot be less than 1')
|
||||
raise ValueError("max_concurrency 'number' cannot be less than 1")
|
||||
|
||||
if not isinstance(per, BucketType):
|
||||
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
|
||||
raise TypeError(
|
||||
f"max_concurrency 'per' must be of type BucketType not {type(per)!r}"
|
||||
)
|
||||
|
||||
def copy(self: MC) -> MC:
|
||||
return self.__class__(self.number, per=self.per, wait=self.wait)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
|
||||
return (
|
||||
f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
|
||||
)
|
||||
|
||||
def get_key(self, message: Message) -> Any:
|
||||
return self.per.get_key(message)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -41,65 +41,66 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'CommandError',
|
||||
'MissingRequiredArgument',
|
||||
'BadArgument',
|
||||
'PrivateMessageOnly',
|
||||
'NoPrivateMessage',
|
||||
'CheckFailure',
|
||||
'CheckAnyFailure',
|
||||
'CommandNotFound',
|
||||
'DisabledCommand',
|
||||
'CommandInvokeError',
|
||||
'TooManyArguments',
|
||||
'UserInputError',
|
||||
'CommandOnCooldown',
|
||||
'MaxConcurrencyReached',
|
||||
'NotOwner',
|
||||
'MessageNotFound',
|
||||
'ObjectNotFound',
|
||||
'MemberNotFound',
|
||||
'GuildNotFound',
|
||||
'UserNotFound',
|
||||
'ChannelNotFound',
|
||||
'ThreadNotFound',
|
||||
'ChannelNotReadable',
|
||||
'BadColourArgument',
|
||||
'BadColorArgument',
|
||||
'RoleNotFound',
|
||||
'BadInviteArgument',
|
||||
'EmojiNotFound',
|
||||
'GuildStickerNotFound',
|
||||
'PartialEmojiConversionFailure',
|
||||
'BadBoolArgument',
|
||||
'MissingRole',
|
||||
'BotMissingRole',
|
||||
'MissingAnyRole',
|
||||
'BotMissingAnyRole',
|
||||
'MissingPermissions',
|
||||
'BotMissingPermissions',
|
||||
'NSFWChannelRequired',
|
||||
'ConversionError',
|
||||
'BadUnionArgument',
|
||||
'BadLiteralArgument',
|
||||
'ArgumentParsingError',
|
||||
'UnexpectedQuoteError',
|
||||
'InvalidEndOfQuotedStringError',
|
||||
'ExpectedClosingQuoteError',
|
||||
'ExtensionError',
|
||||
'ExtensionAlreadyLoaded',
|
||||
'ExtensionNotLoaded',
|
||||
'NoEntryPointError',
|
||||
'ExtensionFailed',
|
||||
'ExtensionNotFound',
|
||||
'CommandRegistrationError',
|
||||
'FlagError',
|
||||
'BadFlagArgument',
|
||||
'MissingFlagArgument',
|
||||
'TooManyFlags',
|
||||
'MissingRequiredFlag',
|
||||
"CommandError",
|
||||
"MissingRequiredArgument",
|
||||
"BadArgument",
|
||||
"PrivateMessageOnly",
|
||||
"NoPrivateMessage",
|
||||
"CheckFailure",
|
||||
"CheckAnyFailure",
|
||||
"CommandNotFound",
|
||||
"DisabledCommand",
|
||||
"CommandInvokeError",
|
||||
"TooManyArguments",
|
||||
"UserInputError",
|
||||
"CommandOnCooldown",
|
||||
"MaxConcurrencyReached",
|
||||
"NotOwner",
|
||||
"MessageNotFound",
|
||||
"ObjectNotFound",
|
||||
"MemberNotFound",
|
||||
"GuildNotFound",
|
||||
"UserNotFound",
|
||||
"ChannelNotFound",
|
||||
"ThreadNotFound",
|
||||
"ChannelNotReadable",
|
||||
"BadColourArgument",
|
||||
"BadColorArgument",
|
||||
"RoleNotFound",
|
||||
"BadInviteArgument",
|
||||
"EmojiNotFound",
|
||||
"GuildStickerNotFound",
|
||||
"PartialEmojiConversionFailure",
|
||||
"BadBoolArgument",
|
||||
"MissingRole",
|
||||
"BotMissingRole",
|
||||
"MissingAnyRole",
|
||||
"BotMissingAnyRole",
|
||||
"MissingPermissions",
|
||||
"BotMissingPermissions",
|
||||
"NSFWChannelRequired",
|
||||
"ConversionError",
|
||||
"BadUnionArgument",
|
||||
"BadLiteralArgument",
|
||||
"ArgumentParsingError",
|
||||
"UnexpectedQuoteError",
|
||||
"InvalidEndOfQuotedStringError",
|
||||
"ExpectedClosingQuoteError",
|
||||
"ExtensionError",
|
||||
"ExtensionAlreadyLoaded",
|
||||
"ExtensionNotLoaded",
|
||||
"NoEntryPointError",
|
||||
"ExtensionFailed",
|
||||
"ExtensionNotFound",
|
||||
"CommandRegistrationError",
|
||||
"FlagError",
|
||||
"BadFlagArgument",
|
||||
"MissingFlagArgument",
|
||||
"TooManyFlags",
|
||||
"MissingRequiredFlag",
|
||||
)
|
||||
|
||||
|
||||
class CommandError(DiscordException):
|
||||
r"""The base exception type for all command related errors.
|
||||
|
||||
@ -109,14 +110,18 @@ class CommandError(DiscordException):
|
||||
in a special way as they are caught and passed into a special event
|
||||
from :class:`.Bot`\, :func:`.on_command_error`.
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
|
||||
if message is not None:
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
m = message.replace("@everyone", "@\u200beveryone").replace(
|
||||
"@here", "@\u200bhere"
|
||||
)
|
||||
super().__init__(m, *args)
|
||||
else:
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class ConversionError(CommandError):
|
||||
"""Exception raised when a Converter class raises non-CommandError.
|
||||
|
||||
@ -130,18 +135,22 @@ class ConversionError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, converter: Converter, original: Exception) -> None:
|
||||
self.converter: Converter = converter
|
||||
self.original: Exception = original
|
||||
|
||||
|
||||
class UserInputError(CommandError):
|
||||
"""The base exception type for errors that involve errors
|
||||
regarding user input.
|
||||
|
||||
This inherits from :exc:`CommandError`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandNotFound(CommandError):
|
||||
"""Exception raised when a command is attempted to be invoked
|
||||
but no command under that name is found.
|
||||
@ -151,8 +160,10 @@ class CommandNotFound(CommandError):
|
||||
|
||||
This inherits from :exc:`CommandError`.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MissingRequiredArgument(UserInputError):
|
||||
"""Exception raised when parsing a command and a parameter
|
||||
that is required is not encountered.
|
||||
@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError):
|
||||
param: :class:`inspect.Parameter`
|
||||
The argument that is missing.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter) -> None:
|
||||
self.param: Parameter = param
|
||||
super().__init__(f'{param.name} is a required argument that is missing.')
|
||||
super().__init__(f"{param.name} is a required argument that is missing.")
|
||||
|
||||
|
||||
class TooManyArguments(UserInputError):
|
||||
"""Exception raised when the command was passed too many arguments and its
|
||||
@ -174,23 +187,29 @@ class TooManyArguments(UserInputError):
|
||||
|
||||
This inherits from :exc:`UserInputError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BadArgument(UserInputError):
|
||||
"""Exception raised when a parsing or conversion failure is encountered
|
||||
on an argument to pass into a command.
|
||||
|
||||
This inherits from :exc:`UserInputError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckFailure(CommandError):
|
||||
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
|
||||
|
||||
This inherits from :exc:`CommandError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CheckAnyFailure(CheckFailure):
|
||||
"""Exception raised when all predicates in :func:`check_any` fail.
|
||||
|
||||
@ -206,10 +225,13 @@ class CheckAnyFailure(CheckFailure):
|
||||
A list of check predicates that failed.
|
||||
"""
|
||||
|
||||
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
|
||||
def __init__(
|
||||
self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]
|
||||
) -> None:
|
||||
self.checks: List[CheckFailure] = checks
|
||||
self.errors: List[Callable[[Context], bool]] = errors
|
||||
super().__init__('You do not have permission to run this command.')
|
||||
super().__init__("You do not have permission to run this command.")
|
||||
|
||||
|
||||
class PrivateMessageOnly(CheckFailure):
|
||||
"""Exception raised when an operation does not work outside of private
|
||||
@ -217,8 +239,12 @@ class PrivateMessageOnly(CheckFailure):
|
||||
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command can only be used in private messages.')
|
||||
super().__init__(
|
||||
message or "This command can only be used in private messages."
|
||||
)
|
||||
|
||||
|
||||
class NoPrivateMessage(CheckFailure):
|
||||
"""Exception raised when an operation does not work in private message
|
||||
@ -228,15 +254,18 @@ class NoPrivateMessage(CheckFailure):
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
super().__init__(message or 'This command cannot be used in private messages.')
|
||||
super().__init__(message or "This command cannot be used in private messages.")
|
||||
|
||||
|
||||
class NotOwner(CheckFailure):
|
||||
"""Exception raised when the message author is not the owner of the bot.
|
||||
|
||||
This inherits from :exc:`CheckFailure`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ObjectNotFound(BadArgument):
|
||||
"""Exception raised when the argument provided did not match the format
|
||||
of an ID or a mention.
|
||||
@ -250,9 +279,11 @@ class ObjectNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The argument supplied by the caller that was not matched
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
|
||||
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
|
||||
|
||||
|
||||
class MemberNotFound(BadArgument):
|
||||
"""Exception raised when the member provided was not found in the bot's
|
||||
@ -267,10 +298,12 @@ class MemberNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The member supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Member "{argument}" not found.')
|
||||
|
||||
|
||||
class GuildNotFound(BadArgument):
|
||||
"""Exception raised when the guild provided was not found in the bot's cache.
|
||||
|
||||
@ -283,10 +316,12 @@ class GuildNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The guild supplied by the called that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Guild "{argument}" not found.')
|
||||
|
||||
|
||||
class UserNotFound(BadArgument):
|
||||
"""Exception raised when the user provided was not found in the bot's
|
||||
cache.
|
||||
@ -300,10 +335,12 @@ class UserNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The user supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'User "{argument}" not found.')
|
||||
|
||||
|
||||
class MessageNotFound(BadArgument):
|
||||
"""Exception raised when the message provided was not found in the channel.
|
||||
|
||||
@ -316,10 +353,12 @@ class MessageNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The message supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Message "{argument}" not found.')
|
||||
|
||||
|
||||
class ChannelNotReadable(BadArgument):
|
||||
"""Exception raised when the bot does not have permission to read messages
|
||||
in the channel.
|
||||
@ -333,10 +372,12 @@ class ChannelNotReadable(BadArgument):
|
||||
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel supplied by the caller that was not readable
|
||||
"""
|
||||
|
||||
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
|
||||
self.argument: Union[GuildChannel, Thread] = argument
|
||||
super().__init__(f"Can't read messages in {argument.mention}.")
|
||||
|
||||
|
||||
class ChannelNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the channel.
|
||||
|
||||
@ -349,10 +390,12 @@ class ChannelNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The channel supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Channel "{argument}" not found.')
|
||||
|
||||
|
||||
class ThreadNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the thread.
|
||||
|
||||
@ -365,10 +408,12 @@ class ThreadNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The thread supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Thread "{argument}" not found.')
|
||||
|
||||
|
||||
class BadColourArgument(BadArgument):
|
||||
"""Exception raised when the colour is not valid.
|
||||
|
||||
@ -381,12 +426,15 @@ class BadColourArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The colour supplied by the caller that was not valid
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Colour "{argument}" is invalid.')
|
||||
|
||||
|
||||
BadColorArgument = BadColourArgument
|
||||
|
||||
|
||||
class RoleNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the role.
|
||||
|
||||
@ -399,10 +447,12 @@ class RoleNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The role supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Role "{argument}" not found.')
|
||||
|
||||
|
||||
class BadInviteArgument(BadArgument):
|
||||
"""Exception raised when the invite is invalid or expired.
|
||||
|
||||
@ -410,10 +460,12 @@ class BadInviteArgument(BadArgument):
|
||||
|
||||
.. versionadded:: 1.5
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Invite "{argument}" is invalid or expired.')
|
||||
|
||||
|
||||
class EmojiNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the emoji.
|
||||
|
||||
@ -426,10 +478,12 @@ class EmojiNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Emoji "{argument}" not found.')
|
||||
|
||||
|
||||
class PartialEmojiConversionFailure(BadArgument):
|
||||
"""Exception raised when the emoji provided does not match the correct
|
||||
format.
|
||||
@ -443,10 +497,12 @@ class PartialEmojiConversionFailure(BadArgument):
|
||||
argument: :class:`str`
|
||||
The emoji supplied by the caller that did not match the regex
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
|
||||
|
||||
|
||||
class GuildStickerNotFound(BadArgument):
|
||||
"""Exception raised when the bot can not find the sticker.
|
||||
|
||||
@ -459,10 +515,12 @@ class GuildStickerNotFound(BadArgument):
|
||||
argument: :class:`str`
|
||||
The sticker supplied by the caller that was not found
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'Sticker "{argument}" not found.')
|
||||
|
||||
|
||||
class BadBoolArgument(BadArgument):
|
||||
"""Exception raised when a boolean argument was not convertable.
|
||||
|
||||
@ -475,17 +533,21 @@ class BadBoolArgument(BadArgument):
|
||||
argument: :class:`str`
|
||||
The boolean argument supplied by the caller that is not in the predefined list
|
||||
"""
|
||||
|
||||
def __init__(self, argument: str) -> None:
|
||||
self.argument: str = argument
|
||||
super().__init__(f'{argument} is not a recognised boolean option')
|
||||
super().__init__(f"{argument} is not a recognised boolean option")
|
||||
|
||||
|
||||
class DisabledCommand(CommandError):
|
||||
"""Exception raised when the command being invoked is disabled.
|
||||
|
||||
This inherits from :exc:`CommandError`
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandInvokeError(CommandError):
|
||||
"""Exception raised when the command being invoked raised an exception.
|
||||
|
||||
@ -497,9 +559,11 @@ class CommandInvokeError(CommandError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, e: Exception) -> None:
|
||||
self.original: Exception = e
|
||||
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
|
||||
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
class CommandOnCooldown(CommandError):
|
||||
"""Exception raised when the command being invoked is on cooldown.
|
||||
@ -516,11 +580,15 @@ class CommandOnCooldown(CommandError):
|
||||
retry_after: :class:`float`
|
||||
The amount of seconds to wait before you can retry again.
|
||||
"""
|
||||
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
|
||||
|
||||
def __init__(
|
||||
self, cooldown: Cooldown, retry_after: float, type: BucketType
|
||||
) -> None:
|
||||
self.cooldown: Cooldown = cooldown
|
||||
self.retry_after: float = retry_after
|
||||
self.type: BucketType = type
|
||||
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
|
||||
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
|
||||
|
||||
|
||||
class MaxConcurrencyReached(CommandError):
|
||||
"""Exception raised when the command being invoked has reached its maximum concurrency.
|
||||
@ -539,10 +607,13 @@ class MaxConcurrencyReached(CommandError):
|
||||
self.number: int = number
|
||||
self.per: BucketType = per
|
||||
name = per.name
|
||||
suffix = 'per %s' % name if per.name != 'default' else 'globally'
|
||||
plural = '%s times %s' if number > 1 else '%s time %s'
|
||||
suffix = "per %s" % name if per.name != "default" else "globally"
|
||||
plural = "%s times %s" if number > 1 else "%s time %s"
|
||||
fmt = plural % (number, suffix)
|
||||
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
|
||||
super().__init__(
|
||||
f"Too many people are using this command. It can only be used {fmt} concurrently."
|
||||
)
|
||||
|
||||
|
||||
class MissingRole(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks a role to run a command.
|
||||
@ -557,11 +628,13 @@ class MissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Role {missing_role!r} is required to run this command.'
|
||||
message = f"Role {missing_role!r} is required to run this command."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BotMissingRole(CheckFailure):
|
||||
"""Exception raised when the bot's member lacks a role to run a command.
|
||||
|
||||
@ -575,11 +648,13 @@ class BotMissingRole(CheckFailure):
|
||||
The required role that is missing.
|
||||
This is the parameter passed to :func:`~.commands.has_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_role: Snowflake) -> None:
|
||||
self.missing_role: Snowflake = missing_role
|
||||
message = f'Bot requires the role {missing_role!r} to run this command'
|
||||
message = f"Bot requires the role {missing_role!r} to run this command"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MissingAnyRole(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks any of
|
||||
the roles specified to run a command.
|
||||
@ -594,15 +669,16 @@ class MissingAnyRole(CheckFailure):
|
||||
The roles that the invoker is missing.
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' or '.join(missing)
|
||||
fmt = " or ".join(missing)
|
||||
|
||||
message = f"You are missing at least one of the required roles: {fmt}"
|
||||
super().__init__(message)
|
||||
@ -623,19 +699,21 @@ class BotMissingAnyRole(CheckFailure):
|
||||
These are the parameters passed to :func:`~.commands.has_any_role`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, missing_roles: SnowflakeList) -> None:
|
||||
self.missing_roles: SnowflakeList = missing_roles
|
||||
|
||||
missing = [f"'{role}'" for role in missing_roles]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' or '.join(missing)
|
||||
fmt = " or ".join(missing)
|
||||
|
||||
message = f"Bot is missing at least one of the required roles: {fmt}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class NSFWChannelRequired(CheckFailure):
|
||||
"""Exception raised when a channel does not have the required NSFW setting.
|
||||
|
||||
@ -648,9 +726,13 @@ class NSFWChannelRequired(CheckFailure):
|
||||
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
|
||||
The channel that does not have NSFW enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
|
||||
self.channel: Union[GuildChannel, Thread] = channel
|
||||
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
|
||||
super().__init__(
|
||||
f"Channel '{channel}' needs to be NSFW for this command to work."
|
||||
)
|
||||
|
||||
|
||||
class MissingPermissions(CheckFailure):
|
||||
"""Exception raised when the command invoker lacks permissions to run a
|
||||
@ -663,18 +745,23 @@ class MissingPermissions(CheckFailure):
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
missing = [
|
||||
perm.replace("_", " ").replace("guild", "server").title()
|
||||
for perm in missing_permissions
|
||||
]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' and '.join(missing)
|
||||
message = f'You are missing {fmt} permission(s) to run this command.'
|
||||
fmt = " and ".join(missing)
|
||||
message = f"You are missing {fmt} permission(s) to run this command."
|
||||
super().__init__(message, *args)
|
||||
|
||||
|
||||
class BotMissingPermissions(CheckFailure):
|
||||
"""Exception raised when the bot's member lacks permissions to run a
|
||||
command.
|
||||
@ -686,18 +773,23 @@ class BotMissingPermissions(CheckFailure):
|
||||
missing_permissions: List[:class:`str`]
|
||||
The required permissions that are missing.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
|
||||
self.missing_permissions: List[str] = missing_permissions
|
||||
|
||||
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
|
||||
missing = [
|
||||
perm.replace("_", " ").replace("guild", "server").title()
|
||||
for perm in missing_permissions
|
||||
]
|
||||
|
||||
if len(missing) > 2:
|
||||
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
|
||||
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
|
||||
else:
|
||||
fmt = ' and '.join(missing)
|
||||
message = f'Bot requires {fmt} permission(s) to run this command.'
|
||||
fmt = " and ".join(missing)
|
||||
message = f"Bot requires {fmt} permission(s) to run this command."
|
||||
super().__init__(message, *args)
|
||||
|
||||
|
||||
class BadUnionArgument(UserInputError):
|
||||
"""Exception raised when a :data:`typing.Union` converter fails for all
|
||||
its associated types.
|
||||
@ -713,7 +805,10 @@ class BadUnionArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
|
||||
|
||||
def __init__(
|
||||
self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]
|
||||
) -> None:
|
||||
self.param: Parameter = param
|
||||
self.converters: Tuple[Type, ...] = converters
|
||||
self.errors: List[CommandError] = errors
|
||||
@ -722,18 +817,19 @@ class BadUnionArgument(UserInputError):
|
||||
try:
|
||||
return x.__name__
|
||||
except AttributeError:
|
||||
if hasattr(x, '__origin__'):
|
||||
if hasattr(x, "__origin__"):
|
||||
return repr(x)
|
||||
return x.__class__.__name__
|
||||
|
||||
to_string = [_get_name(x) for x in converters]
|
||||
if len(to_string) > 2:
|
||||
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
|
||||
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
|
||||
else:
|
||||
fmt = ' or '.join(to_string)
|
||||
fmt = " or ".join(to_string)
|
||||
|
||||
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
|
||||
|
||||
|
||||
class BadLiteralArgument(UserInputError):
|
||||
"""Exception raised when a :data:`typing.Literal` converter fails for all
|
||||
its associated values.
|
||||
@ -751,19 +847,23 @@ class BadLiteralArgument(UserInputError):
|
||||
errors: List[:class:`CommandError`]
|
||||
A list of errors that were caught from failing the conversion.
|
||||
"""
|
||||
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
|
||||
|
||||
def __init__(
|
||||
self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]
|
||||
) -> None:
|
||||
self.param: Parameter = param
|
||||
self.literals: Tuple[Any, ...] = literals
|
||||
self.errors: List[CommandError] = errors
|
||||
|
||||
to_string = [repr(l) for l in literals]
|
||||
if len(to_string) > 2:
|
||||
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
|
||||
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
|
||||
else:
|
||||
fmt = ' or '.join(to_string)
|
||||
fmt = " or ".join(to_string)
|
||||
|
||||
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
|
||||
|
||||
|
||||
class ArgumentParsingError(UserInputError):
|
||||
"""An exception raised when the parser fails to parse a user's input.
|
||||
|
||||
@ -772,8 +872,10 @@ class ArgumentParsingError(UserInputError):
|
||||
There are child classes that implement more granular parsing errors for
|
||||
i18n purposes.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedQuoteError(ArgumentParsingError):
|
||||
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
|
||||
|
||||
@ -784,9 +886,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
|
||||
quote: :class:`str`
|
||||
The quote mark that was found inside the non-quoted string.
|
||||
"""
|
||||
|
||||
def __init__(self, quote: str) -> None:
|
||||
self.quote: str = quote
|
||||
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
|
||||
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
|
||||
|
||||
|
||||
class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
"""An exception raised when a space is expected after the closing quote in a string
|
||||
@ -799,9 +903,13 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
|
||||
char: :class:`str`
|
||||
The character found instead of the expected string.
|
||||
"""
|
||||
|
||||
def __init__(self, char: str) -> None:
|
||||
self.char: str = char
|
||||
super().__init__(f'Expected space after closing quotation but received {char!r}')
|
||||
super().__init__(
|
||||
f"Expected space after closing quotation but received {char!r}"
|
||||
)
|
||||
|
||||
|
||||
class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
"""An exception raised when a quote character is expected but not found.
|
||||
@ -816,7 +924,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
|
||||
|
||||
def __init__(self, close_quote: str) -> None:
|
||||
self.close_quote: str = close_quote
|
||||
super().__init__(f'Expected closing {close_quote}.')
|
||||
super().__init__(f"Expected closing {close_quote}.")
|
||||
|
||||
|
||||
class ExtensionError(DiscordException):
|
||||
"""Base exception for extension related errors.
|
||||
@ -828,37 +937,47 @@ class ExtensionError(DiscordException):
|
||||
name: :class:`str`
|
||||
The extension that had an error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
|
||||
self.name: str = name
|
||||
message = message or f'Extension {name!r} had an error.'
|
||||
message = message or f"Extension {name!r} had an error."
|
||||
# clean-up @everyone and @here mentions
|
||||
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
|
||||
m = message.replace("@everyone", "@\u200beveryone").replace(
|
||||
"@here", "@\u200bhere"
|
||||
)
|
||||
super().__init__(m, *args)
|
||||
|
||||
|
||||
class ExtensionAlreadyLoaded(ExtensionError):
|
||||
"""An exception raised when an extension has already been loaded.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} is already loaded.', name=name)
|
||||
super().__init__(f"Extension {name!r} is already loaded.", name=name)
|
||||
|
||||
|
||||
class ExtensionNotLoaded(ExtensionError):
|
||||
"""An exception raised when an extension was not loaded.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
|
||||
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
|
||||
|
||||
|
||||
class NoEntryPointError(ExtensionError):
|
||||
"""An exception raised when an extension does not have a ``setup`` entry point function.
|
||||
|
||||
This inherits from :exc:`ExtensionError`
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
|
||||
|
||||
|
||||
class ExtensionFailed(ExtensionError):
|
||||
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
|
||||
|
||||
@ -872,11 +991,13 @@ class ExtensionFailed(ExtensionError):
|
||||
The original exception that was raised. You can also get this via
|
||||
the ``__cause__`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, original: Exception) -> None:
|
||||
self.original: Exception = original
|
||||
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
|
||||
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
|
||||
class ExtensionNotFound(ExtensionError):
|
||||
"""An exception raised when an extension is not found.
|
||||
|
||||
@ -890,10 +1011,12 @@ class ExtensionNotFound(ExtensionError):
|
||||
name: :class:`str`
|
||||
The extension that had the error.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
msg = f'Extension {name!r} could not be loaded.'
|
||||
msg = f"Extension {name!r} could not be loaded."
|
||||
super().__init__(msg, name=name)
|
||||
|
||||
|
||||
class CommandRegistrationError(ClientException):
|
||||
"""An exception raised when the command can't be added
|
||||
because the name is already taken by a different command.
|
||||
@ -909,11 +1032,13 @@ class CommandRegistrationError(ClientException):
|
||||
alias_conflict: :class:`bool`
|
||||
Whether the name that conflicts is an alias of the command we try to add.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
|
||||
self.name: str = name
|
||||
self.alias_conflict: bool = alias_conflict
|
||||
type_ = 'alias' if alias_conflict else 'command'
|
||||
super().__init__(f'The {type_} {name} is already an existing command or alias.')
|
||||
type_ = "alias" if alias_conflict else "command"
|
||||
super().__init__(f"The {type_} {name} is already an existing command or alias.")
|
||||
|
||||
|
||||
class FlagError(BadArgument):
|
||||
"""The base exception type for all flag parsing related errors.
|
||||
@ -922,8 +1047,10 @@ class FlagError(BadArgument):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TooManyFlags(FlagError):
|
||||
"""An exception raised when a flag has received too many values.
|
||||
|
||||
@ -938,10 +1065,14 @@ class TooManyFlags(FlagError):
|
||||
values: List[:class:`str`]
|
||||
The values that were passed.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag, values: List[str]) -> None:
|
||||
self.flag: Flag = flag
|
||||
self.values: List[str] = values
|
||||
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
|
||||
super().__init__(
|
||||
f"Too many flag values, expected {flag.max_args} but received {len(values)}."
|
||||
)
|
||||
|
||||
|
||||
class BadFlagArgument(FlagError):
|
||||
"""An exception raised when a flag failed to convert a value.
|
||||
@ -955,6 +1086,7 @@ class BadFlagArgument(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that failed to convert.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
try:
|
||||
@ -962,7 +1094,8 @@ class BadFlagArgument(FlagError):
|
||||
except AttributeError:
|
||||
name = flag.annotation.__class__.__name__
|
||||
|
||||
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
|
||||
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
|
||||
|
||||
|
||||
class MissingRequiredFlag(FlagError):
|
||||
"""An exception raised when a required flag was not given.
|
||||
@ -976,9 +1109,11 @@ class MissingRequiredFlag(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The required flag that was not found.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} is required and missing')
|
||||
super().__init__(f"Flag {flag.name!r} is required and missing")
|
||||
|
||||
|
||||
class MissingFlagArgument(FlagError):
|
||||
"""An exception raised when a flag did not get a value.
|
||||
@ -992,6 +1127,7 @@ class MissingFlagArgument(FlagError):
|
||||
flag: :class:`~discord.ext.commands.Flag`
|
||||
The flag that did not get a value.
|
||||
"""
|
||||
|
||||
def __init__(self, flag: Flag) -> None:
|
||||
self.flag: Flag = flag
|
||||
super().__init__(f'Flag {flag.name!r} does not have an argument')
|
||||
super().__init__(f"Flag {flag.name!r} does not have an argument")
|
||||
|
@ -59,9 +59,9 @@ import sys
|
||||
import re
|
||||
|
||||
__all__ = (
|
||||
'Flag',
|
||||
'flag',
|
||||
'FlagConverter',
|
||||
"Flag",
|
||||
"flag",
|
||||
"FlagConverter",
|
||||
)
|
||||
|
||||
|
||||
@ -143,25 +143,35 @@ def flag(
|
||||
Whether multiple given values overrides the previous value. The default
|
||||
value depends on the annotation given.
|
||||
"""
|
||||
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
|
||||
return Flag(
|
||||
name=name,
|
||||
aliases=aliases,
|
||||
default=default,
|
||||
max_args=max_args,
|
||||
override=override,
|
||||
)
|
||||
|
||||
|
||||
def validate_flag_name(name: str, forbidden: Set[str]):
|
||||
if not name:
|
||||
raise ValueError('flag names should not be empty')
|
||||
raise ValueError("flag names should not be empty")
|
||||
|
||||
for ch in name:
|
||||
if ch.isspace():
|
||||
raise ValueError(f'flag name {name!r} cannot have spaces')
|
||||
if ch == '\\':
|
||||
raise ValueError(f'flag name {name!r} cannot have backslashes')
|
||||
raise ValueError(f"flag name {name!r} cannot have spaces")
|
||||
if ch == "\\":
|
||||
raise ValueError(f"flag name {name!r} cannot have backslashes")
|
||||
if ch in forbidden:
|
||||
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
|
||||
raise ValueError(
|
||||
f"flag name {name!r} cannot have any of {forbidden!r} within them"
|
||||
)
|
||||
|
||||
|
||||
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
||||
annotations = namespace.get('__annotations__', {})
|
||||
case_insensitive = namespace['__commands_flag_case_insensitive__']
|
||||
def get_flags(
|
||||
namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]
|
||||
) -> Dict[str, Flag]:
|
||||
annotations = namespace.get("__annotations__", {})
|
||||
case_insensitive = namespace["__commands_flag_case_insensitive__"]
|
||||
flags: Dict[str, Flag] = {}
|
||||
cache: Dict[str, Any] = {}
|
||||
names: Set[str] = set()
|
||||
@ -176,9 +186,15 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
if flag.name is MISSING:
|
||||
flag.name = name
|
||||
|
||||
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
||||
annotation = flag.annotation = resolve_annotation(
|
||||
flag.annotation, globals, locals, cache
|
||||
)
|
||||
|
||||
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
|
||||
if (
|
||||
flag.default is MISSING
|
||||
and hasattr(annotation, "__commands_is_flag__")
|
||||
and annotation._can_be_constructible()
|
||||
):
|
||||
flag.default = annotation._construct_default
|
||||
|
||||
if flag.aliases is MISSING:
|
||||
@ -229,7 +245,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
if flag.max_args is MISSING:
|
||||
flag.max_args = 1
|
||||
else:
|
||||
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
|
||||
raise TypeError(
|
||||
f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag"
|
||||
)
|
||||
|
||||
if flag.override is MISSING:
|
||||
flag.override = False
|
||||
@ -237,7 +255,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
# Validate flag names are unique
|
||||
name = flag.name.casefold() if case_insensitive else flag.name
|
||||
if name in names:
|
||||
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
|
||||
raise TypeError(
|
||||
f"{flag.name!r} flag conflicts with previous flag or alias."
|
||||
)
|
||||
else:
|
||||
names.add(name)
|
||||
|
||||
@ -245,7 +265,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
# Validate alias is unique
|
||||
alias = alias.casefold() if case_insensitive else alias
|
||||
if alias in names:
|
||||
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
|
||||
raise TypeError(
|
||||
f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias."
|
||||
)
|
||||
else:
|
||||
names.add(alias)
|
||||
|
||||
@ -274,10 +296,10 @@ class FlagsMeta(type):
|
||||
delimiter: str = MISSING,
|
||||
prefix: str = MISSING,
|
||||
):
|
||||
attrs['__commands_is_flag__'] = True
|
||||
attrs["__commands_is_flag__"] = True
|
||||
|
||||
try:
|
||||
global_ns = sys.modules[attrs['__module__']].__dict__
|
||||
global_ns = sys.modules[attrs["__module__"]].__dict__
|
||||
except KeyError:
|
||||
global_ns = {}
|
||||
|
||||
@ -296,26 +318,32 @@ class FlagsMeta(type):
|
||||
flags: Dict[str, Flag] = {}
|
||||
aliases: Dict[str, str] = {}
|
||||
for base in reversed(bases):
|
||||
if base.__dict__.get('__commands_is_flag__', False):
|
||||
flags.update(base.__dict__['__commands_flags__'])
|
||||
aliases.update(base.__dict__['__commands_flag_aliases__'])
|
||||
if base.__dict__.get("__commands_is_flag__", False):
|
||||
flags.update(base.__dict__["__commands_flags__"])
|
||||
aliases.update(base.__dict__["__commands_flag_aliases__"])
|
||||
if case_insensitive is MISSING:
|
||||
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
|
||||
attrs["__commands_flag_case_insensitive__"] = base.__dict__[
|
||||
"__commands_flag_case_insensitive__"
|
||||
]
|
||||
if delimiter is MISSING:
|
||||
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
|
||||
attrs["__commands_flag_delimiter__"] = base.__dict__[
|
||||
"__commands_flag_delimiter__"
|
||||
]
|
||||
if prefix is MISSING:
|
||||
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
|
||||
attrs["__commands_flag_prefix__"] = base.__dict__[
|
||||
"__commands_flag_prefix__"
|
||||
]
|
||||
|
||||
if case_insensitive is not MISSING:
|
||||
attrs['__commands_flag_case_insensitive__'] = case_insensitive
|
||||
attrs["__commands_flag_case_insensitive__"] = case_insensitive
|
||||
if delimiter is not MISSING:
|
||||
attrs['__commands_flag_delimiter__'] = delimiter
|
||||
attrs["__commands_flag_delimiter__"] = delimiter
|
||||
if prefix is not MISSING:
|
||||
attrs['__commands_flag_prefix__'] = prefix
|
||||
attrs["__commands_flag_prefix__"] = prefix
|
||||
|
||||
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
|
||||
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
|
||||
prefix = attrs.setdefault('__commands_flag_prefix__', '')
|
||||
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
|
||||
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
|
||||
prefix = attrs.setdefault("__commands_flag_prefix__", "")
|
||||
|
||||
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
|
||||
flags[flag_name] = flag
|
||||
@ -330,23 +358,30 @@ class FlagsMeta(type):
|
||||
regex_flags = 0
|
||||
if case_insensitive:
|
||||
flags = {key.casefold(): value for key, value in flags.items()}
|
||||
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
|
||||
aliases = {
|
||||
key.casefold(): value.casefold() for key, value in aliases.items()
|
||||
}
|
||||
regex_flags = re.IGNORECASE
|
||||
|
||||
keys = list(re.escape(k) for k in flags)
|
||||
keys.extend(re.escape(a) for a in aliases)
|
||||
keys = sorted(keys, key=lambda t: len(t), reverse=True)
|
||||
|
||||
joined = '|'.join(keys)
|
||||
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
||||
attrs['__commands_flag_regex__'] = pattern
|
||||
attrs['__commands_flags__'] = flags
|
||||
attrs['__commands_flag_aliases__'] = aliases
|
||||
joined = "|".join(keys)
|
||||
pattern = re.compile(
|
||||
f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})",
|
||||
regex_flags,
|
||||
)
|
||||
attrs["__commands_flag_regex__"] = pattern
|
||||
attrs["__commands_flags__"] = flags
|
||||
attrs["__commands_flag_aliases__"] = aliases
|
||||
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
|
||||
async def tuple_convert_all(
|
||||
ctx: Context, argument: str, flag: Flag, converter: Any
|
||||
) -> Tuple[Any, ...]:
|
||||
view = StringView(argument)
|
||||
results = []
|
||||
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
||||
@ -371,7 +406,9 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
|
||||
return tuple(results)
|
||||
|
||||
|
||||
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
|
||||
async def tuple_convert_flag(
|
||||
ctx: Context, argument: str, flag: Flag, converters: Any
|
||||
) -> Tuple[Any, ...]:
|
||||
view = StringView(argument)
|
||||
results = []
|
||||
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
||||
@ -409,9 +446,13 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
|
||||
else:
|
||||
if origin is tuple:
|
||||
if annotation.__args__[-1] is Ellipsis:
|
||||
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
|
||||
return await tuple_convert_all(
|
||||
ctx, argument, flag, annotation.__args__[0]
|
||||
)
|
||||
else:
|
||||
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
|
||||
return await tuple_convert_flag(
|
||||
ctx, argument, flag, annotation.__args__
|
||||
)
|
||||
elif origin is list:
|
||||
# typing.List[x]
|
||||
annotation = annotation.__args__[0]
|
||||
@ -432,7 +473,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
|
||||
raise BadFlagArgument(flag) from e
|
||||
|
||||
|
||||
F = TypeVar('F', bound='FlagConverter')
|
||||
F = TypeVar("F", bound="FlagConverter")
|
||||
|
||||
|
||||
class FlagConverter(metaclass=FlagsMeta):
|
||||
@ -493,8 +534,13 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
return self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
|
||||
return f'<{self.__class__.__name__} {pairs}>'
|
||||
pairs = " ".join(
|
||||
[
|
||||
f"{flag.attribute}={getattr(self, flag.attribute)!r}"
|
||||
for flag in self.get_flags().values()
|
||||
]
|
||||
)
|
||||
return f"<{self.__class__.__name__} {pairs}>"
|
||||
|
||||
@classmethod
|
||||
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
||||
@ -507,7 +553,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
case_insensitive = cls.__commands_flag_case_insensitive__
|
||||
for match in cls.__commands_flag_regex__.finditer(argument):
|
||||
begin, end = match.span(0)
|
||||
key = match.group('flag')
|
||||
key = match.group("flag")
|
||||
if case_insensitive:
|
||||
key = key.casefold()
|
||||
|
||||
|
@ -39,10 +39,10 @@ if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
__all__ = (
|
||||
'Paginator',
|
||||
'HelpCommand',
|
||||
'DefaultHelpCommand',
|
||||
'MinimalHelpCommand',
|
||||
"Paginator",
|
||||
"HelpCommand",
|
||||
"DefaultHelpCommand",
|
||||
"MinimalHelpCommand",
|
||||
)
|
||||
|
||||
# help -> shows info of bot on top/bottom and lists subcommands
|
||||
@ -89,7 +89,7 @@ class Paginator:
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
|
||||
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.max_size = max_size
|
||||
@ -118,7 +118,7 @@ class Paginator:
|
||||
def _linesep_len(self):
|
||||
return len(self.linesep)
|
||||
|
||||
def add_line(self, line='', *, empty=False):
|
||||
def add_line(self, line="", *, empty=False):
|
||||
"""Adds a line to the current page.
|
||||
|
||||
If the line exceeds the :attr:`max_size` then an exception
|
||||
@ -136,18 +136,23 @@ class Paginator:
|
||||
RuntimeError
|
||||
The line was too big for the current :attr:`max_size`.
|
||||
"""
|
||||
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
|
||||
max_page_size = (
|
||||
self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
|
||||
)
|
||||
if len(line) > max_page_size:
|
||||
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
|
||||
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
|
||||
|
||||
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
|
||||
if (
|
||||
self._count + len(line) + self._linesep_len
|
||||
> self.max_size - self._suffix_len
|
||||
):
|
||||
self.close_page()
|
||||
|
||||
self._count += len(line) + self._linesep_len
|
||||
self._current_page.append(line)
|
||||
|
||||
if empty:
|
||||
self._current_page.append('')
|
||||
self._current_page.append("")
|
||||
self._count += self._linesep_len
|
||||
|
||||
def close_page(self):
|
||||
@ -176,7 +181,7 @@ class Paginator:
|
||||
return self._pages
|
||||
|
||||
def __repr__(self):
|
||||
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
|
||||
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
|
||||
return fmt.format(self)
|
||||
|
||||
|
||||
@ -197,7 +202,7 @@ class _HelpCommandImpl(Command):
|
||||
self.callback = injected.command_callback
|
||||
|
||||
on_error = injected.on_help_command_error
|
||||
if not hasattr(on_error, '__help_command_not_overriden__'):
|
||||
if not hasattr(on_error, "__help_command_not_overriden__"):
|
||||
if self.cog is not None:
|
||||
self.on_error = self._on_error_cog_implementation
|
||||
else:
|
||||
@ -224,7 +229,7 @@ class _HelpCommandImpl(Command):
|
||||
try:
|
||||
del result[next(iter(result))]
|
||||
except StopIteration:
|
||||
raise ValueError('Missing context parameter') from None
|
||||
raise ValueError("Missing context parameter") from None
|
||||
else:
|
||||
return result
|
||||
|
||||
@ -296,13 +301,13 @@ class HelpCommand:
|
||||
"""
|
||||
|
||||
MENTION_TRANSFORMS = {
|
||||
'@everyone': '@\u200beveryone',
|
||||
'@here': '@\u200bhere',
|
||||
r'<@!?[0-9]{17,22}>': '@deleted-user',
|
||||
r'<@&[0-9]{17,22}>': '@deleted-role',
|
||||
"@everyone": "@\u200beveryone",
|
||||
"@here": "@\u200bhere",
|
||||
r"<@!?[0-9]{17,22}>": "@deleted-user",
|
||||
r"<@&[0-9]{17,22}>": "@deleted-role",
|
||||
}
|
||||
|
||||
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
|
||||
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# To prevent race conditions of a single instance while also allowing
|
||||
@ -321,11 +326,11 @@ class HelpCommand:
|
||||
return self
|
||||
|
||||
def __init__(self, **options):
|
||||
self.show_hidden = options.pop('show_hidden', False)
|
||||
self.verify_checks = options.pop('verify_checks', True)
|
||||
self.command_attrs = attrs = options.pop('command_attrs', {})
|
||||
attrs.setdefault('name', 'help')
|
||||
attrs.setdefault('help', 'Shows this message')
|
||||
self.show_hidden = options.pop("show_hidden", False)
|
||||
self.verify_checks = options.pop("verify_checks", True)
|
||||
self.command_attrs = attrs = options.pop("command_attrs", {})
|
||||
attrs.setdefault("name", "help")
|
||||
attrs.setdefault("help", "Shows this message")
|
||||
self.context: Context = discord.utils.MISSING
|
||||
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
|
||||
|
||||
@ -398,7 +403,11 @@ class HelpCommand:
|
||||
"""
|
||||
command_name = self._command_impl.name
|
||||
ctx = self.context
|
||||
if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name:
|
||||
if (
|
||||
ctx is None
|
||||
or ctx.command is None
|
||||
or ctx.command.qualified_name != command_name
|
||||
):
|
||||
return command_name
|
||||
return ctx.invoked_with
|
||||
|
||||
@ -422,20 +431,20 @@ class HelpCommand:
|
||||
if not parent.signature or parent.invoke_without_command:
|
||||
entries.append(parent.name)
|
||||
else:
|
||||
entries.append(parent.name + ' ' + parent.signature)
|
||||
entries.append(parent.name + " " + parent.signature)
|
||||
parent = parent.parent
|
||||
parent_sig = ' '.join(reversed(entries))
|
||||
parent_sig = " ".join(reversed(entries))
|
||||
|
||||
if len(command.aliases) > 0:
|
||||
aliases = '|'.join(command.aliases)
|
||||
fmt = f'[{command.name}|{aliases}]'
|
||||
aliases = "|".join(command.aliases)
|
||||
fmt = f"[{command.name}|{aliases}]"
|
||||
if parent_sig:
|
||||
fmt = parent_sig + ' ' + fmt
|
||||
fmt = parent_sig + " " + fmt
|
||||
alias = fmt
|
||||
else:
|
||||
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
|
||||
alias = command.name if not parent_sig else parent_sig + " " + command.name
|
||||
|
||||
return f'{self.context.clean_prefix}{alias} {command.signature}'
|
||||
return f"{self.context.clean_prefix}{alias} {command.signature}"
|
||||
|
||||
def remove_mentions(self, string):
|
||||
"""Removes mentions from the string to prevent abuse.
|
||||
@ -449,7 +458,7 @@ class HelpCommand:
|
||||
"""
|
||||
|
||||
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
|
||||
return transforms.get(obj.group(0), '@invalid')
|
||||
return transforms.get(obj.group(0), "@invalid")
|
||||
|
||||
return self.MENTION_PATTERN.sub(replace, string)
|
||||
|
||||
@ -527,7 +536,9 @@ class HelpCommand:
|
||||
The string to use when the command did not have the subcommand requested.
|
||||
"""
|
||||
if isinstance(command, Group) and len(command.all_commands) > 0:
|
||||
return f'Command "{command.qualified_name}" has no subcommand named {string}'
|
||||
return (
|
||||
f'Command "{command.qualified_name}" has no subcommand named {string}'
|
||||
)
|
||||
return f'Command "{command.qualified_name}" has no subcommands.'
|
||||
|
||||
async def filter_commands(self, commands, *, sort=False, key=None):
|
||||
@ -558,7 +569,9 @@ class HelpCommand:
|
||||
if sort and key is None:
|
||||
key = lambda c: c.name
|
||||
|
||||
iterator = commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
|
||||
iterator = (
|
||||
commands if self.show_hidden else filter(lambda c: not c.hidden, commands)
|
||||
)
|
||||
|
||||
if self.verify_checks is False:
|
||||
# if we do not need to verify the checks then we can just
|
||||
@ -846,21 +859,27 @@ class HelpCommand:
|
||||
# Since we want to have detailed errors when someone
|
||||
# passes an invalid subcommand, we need to walk through
|
||||
# the command group chain ourselves.
|
||||
keys = command.split(' ')
|
||||
keys = command.split(" ")
|
||||
cmd = bot.all_commands.get(keys[0])
|
||||
if cmd is None:
|
||||
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
|
||||
string = await maybe_coro(
|
||||
self.command_not_found, self.remove_mentions(keys[0])
|
||||
)
|
||||
return await self.send_error_message(string)
|
||||
|
||||
for key in keys[1:]:
|
||||
try:
|
||||
found = cmd.all_commands.get(key)
|
||||
except AttributeError:
|
||||
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
|
||||
string = await maybe_coro(
|
||||
self.subcommand_not_found, cmd, self.remove_mentions(key)
|
||||
)
|
||||
return await self.send_error_message(string)
|
||||
else:
|
||||
if found is None:
|
||||
string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key))
|
||||
string = await maybe_coro(
|
||||
self.subcommand_not_found, cmd, self.remove_mentions(key)
|
||||
)
|
||||
return await self.send_error_message(string)
|
||||
cmd = found
|
||||
|
||||
@ -907,14 +926,14 @@ class DefaultHelpCommand(HelpCommand):
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.width = options.pop('width', 80)
|
||||
self.indent = options.pop('indent', 2)
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
self.width = options.pop("width", 80)
|
||||
self.indent = options.pop("indent", 2)
|
||||
self.sort_commands = options.pop("sort_commands", True)
|
||||
self.dm_help = options.pop("dm_help", False)
|
||||
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
|
||||
self.commands_heading = options.pop("commands_heading", "Commands:")
|
||||
self.no_category = options.pop("no_category", "No Category")
|
||||
self.paginator = options.pop("paginator", None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator()
|
||||
@ -924,7 +943,7 @@ class DefaultHelpCommand(HelpCommand):
|
||||
def shorten_text(self, text):
|
||||
""":class:`str`: Shortens text to fit into the :attr:`width`."""
|
||||
if len(text) > self.width:
|
||||
return text[:self.width - 3].rstrip() + '...'
|
||||
return text[: self.width - 3].rstrip() + "..."
|
||||
return text
|
||||
|
||||
def get_ending_note(self):
|
||||
@ -1021,11 +1040,11 @@ class DefaultHelpCommand(HelpCommand):
|
||||
# <description> portion
|
||||
self.paginator.add_line(bot.description, empty=True)
|
||||
|
||||
no_category = f'\u200b{self.no_category}:'
|
||||
no_category = f"\u200b{self.no_category}:"
|
||||
|
||||
def get_category(command, *, no_category=no_category):
|
||||
cog = command.cog
|
||||
return cog.qualified_name + ':' if cog is not None else no_category
|
||||
return cog.qualified_name + ":" if cog is not None else no_category
|
||||
|
||||
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
|
||||
max_size = self.get_max_size(filtered)
|
||||
@ -1033,7 +1052,11 @@ class DefaultHelpCommand(HelpCommand):
|
||||
|
||||
# Now we can add the commands to the page.
|
||||
for category, commands in to_iterate:
|
||||
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
|
||||
commands = (
|
||||
sorted(commands, key=lambda c: c.name)
|
||||
if self.sort_commands
|
||||
else list(commands)
|
||||
)
|
||||
self.add_indented_commands(commands, heading=category, max_size=max_size)
|
||||
|
||||
note = self.get_ending_note()
|
||||
@ -1066,7 +1089,9 @@ class DefaultHelpCommand(HelpCommand):
|
||||
if cog.description:
|
||||
self.paginator.add_line(cog.description, empty=True)
|
||||
|
||||
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
|
||||
filtered = await self.filter_commands(
|
||||
cog.get_commands(), sort=self.sort_commands
|
||||
)
|
||||
self.add_indented_commands(filtered, heading=self.commands_heading)
|
||||
|
||||
note = self.get_ending_note()
|
||||
@ -1110,13 +1135,13 @@ class MinimalHelpCommand(HelpCommand):
|
||||
"""
|
||||
|
||||
def __init__(self, **options):
|
||||
self.sort_commands = options.pop('sort_commands', True)
|
||||
self.commands_heading = options.pop('commands_heading', "Commands")
|
||||
self.dm_help = options.pop('dm_help', False)
|
||||
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
|
||||
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
|
||||
self.no_category = options.pop('no_category', 'No Category')
|
||||
self.paginator = options.pop('paginator', None)
|
||||
self.sort_commands = options.pop("sort_commands", True)
|
||||
self.commands_heading = options.pop("commands_heading", "Commands")
|
||||
self.dm_help = options.pop("dm_help", False)
|
||||
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
|
||||
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
|
||||
self.no_category = options.pop("no_category", "No Category")
|
||||
self.paginator = options.pop("paginator", None)
|
||||
|
||||
if self.paginator is None:
|
||||
self.paginator = Paginator(suffix=None, prefix=None)
|
||||
@ -1149,7 +1174,9 @@ class MinimalHelpCommand(HelpCommand):
|
||||
)
|
||||
|
||||
def get_command_signature(self, command):
|
||||
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
|
||||
return (
|
||||
f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
|
||||
)
|
||||
|
||||
def get_ending_note(self):
|
||||
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
|
||||
@ -1180,8 +1207,8 @@ class MinimalHelpCommand(HelpCommand):
|
||||
"""
|
||||
if commands:
|
||||
# U+2002 Middle Dot
|
||||
joined = '\u2002'.join(c.name for c in commands)
|
||||
self.paginator.add_line(f'__**{heading}**__')
|
||||
joined = "\u2002".join(c.name for c in commands)
|
||||
self.paginator.add_line(f"__**{heading}**__")
|
||||
self.paginator.add_line(joined)
|
||||
|
||||
def add_subcommand_formatting(self, command):
|
||||
@ -1197,8 +1224,12 @@ class MinimalHelpCommand(HelpCommand):
|
||||
command: :class:`Command`
|
||||
The command to show information of.
|
||||
"""
|
||||
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
|
||||
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
|
||||
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
|
||||
self.paginator.add_line(
|
||||
fmt.format(
|
||||
self.context.clean_prefix, command.qualified_name, command.short_doc
|
||||
)
|
||||
)
|
||||
|
||||
def add_aliases_formatting(self, aliases):
|
||||
"""Adds the formatting information on a command's aliases.
|
||||
@ -1215,7 +1246,9 @@ class MinimalHelpCommand(HelpCommand):
|
||||
aliases: Sequence[:class:`str`]
|
||||
A list of aliases to format.
|
||||
"""
|
||||
self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True)
|
||||
self.paginator.add_line(
|
||||
f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True
|
||||
)
|
||||
|
||||
def add_command_formatting(self, command):
|
||||
"""A utility function to format commands and groups.
|
||||
@ -1268,7 +1301,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
if note:
|
||||
self.paginator.add_line(note, empty=True)
|
||||
|
||||
no_category = f'\u200b{self.no_category}'
|
||||
no_category = f"\u200b{self.no_category}"
|
||||
|
||||
def get_category(command, *, no_category=no_category):
|
||||
cog = command.cog
|
||||
@ -1278,7 +1311,11 @@ class MinimalHelpCommand(HelpCommand):
|
||||
to_iterate = itertools.groupby(filtered, key=get_category)
|
||||
|
||||
for category, commands in to_iterate:
|
||||
commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands)
|
||||
commands = (
|
||||
sorted(commands, key=lambda c: c.name)
|
||||
if self.sort_commands
|
||||
else list(commands)
|
||||
)
|
||||
self.add_bot_commands_formatting(commands, category)
|
||||
|
||||
note = self.get_ending_note()
|
||||
@ -1300,9 +1337,11 @@ class MinimalHelpCommand(HelpCommand):
|
||||
if cog.description:
|
||||
self.paginator.add_line(cog.description, empty=True)
|
||||
|
||||
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
|
||||
filtered = await self.filter_commands(
|
||||
cog.get_commands(), sort=self.sort_commands
|
||||
)
|
||||
if filtered:
|
||||
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
|
||||
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
|
||||
for command in filtered:
|
||||
self.add_subcommand_formatting(command)
|
||||
|
||||
@ -1322,7 +1361,7 @@ class MinimalHelpCommand(HelpCommand):
|
||||
if note:
|
||||
self.paginator.add_line(note, empty=True)
|
||||
|
||||
self.paginator.add_line(f'**{self.commands_heading}**')
|
||||
self.paginator.add_line(f"**{self.commands_heading}**")
|
||||
for command in filtered:
|
||||
self.add_subcommand_formatting(command)
|
||||
|
||||
|
@ -22,7 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
|
||||
from .errors import (
|
||||
UnexpectedQuoteError,
|
||||
InvalidEndOfQuotedStringError,
|
||||
ExpectedClosingQuoteError,
|
||||
)
|
||||
|
||||
# map from opening quotes to closing quotes
|
||||
_quotes = {
|
||||
@ -46,6 +50,7 @@ _quotes = {
|
||||
}
|
||||
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
|
||||
|
||||
|
||||
class StringView:
|
||||
def __init__(self, buffer):
|
||||
self.index = 0
|
||||
@ -81,20 +86,20 @@ class StringView:
|
||||
|
||||
def skip_string(self, string):
|
||||
strlen = len(string)
|
||||
if self.buffer[self.index:self.index + strlen] == string:
|
||||
if self.buffer[self.index : self.index + strlen] == string:
|
||||
self.previous = self.index
|
||||
self.index += strlen
|
||||
return True
|
||||
return False
|
||||
|
||||
def read_rest(self):
|
||||
result = self.buffer[self.index:]
|
||||
result = self.buffer[self.index :]
|
||||
self.previous = self.index
|
||||
self.index = self.end
|
||||
return result
|
||||
|
||||
def read(self, n):
|
||||
result = self.buffer[self.index:self.index + n]
|
||||
result = self.buffer[self.index : self.index + n]
|
||||
self.previous = self.index
|
||||
self.index += n
|
||||
return result
|
||||
@ -120,7 +125,7 @@ class StringView:
|
||||
except IndexError:
|
||||
break
|
||||
self.previous = self.index
|
||||
result = self.buffer[self.index:self.index + pos]
|
||||
result = self.buffer[self.index : self.index + pos]
|
||||
self.index += pos
|
||||
return result
|
||||
|
||||
@ -144,11 +149,11 @@ class StringView:
|
||||
if is_quoted:
|
||||
# unexpected EOF
|
||||
raise ExpectedClosingQuoteError(close_quote)
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
# currently we accept strings in the format of "hello world"
|
||||
# to embed a quote inside the string you must escape it: "a \"world\""
|
||||
if current == '\\':
|
||||
if current == "\\":
|
||||
next_char = self.get()
|
||||
if not next_char:
|
||||
# string ends with \ and no character after it
|
||||
@ -156,7 +161,7 @@ class StringView:
|
||||
# if we're quoted then we're expecting a closing quote
|
||||
raise ExpectedClosingQuoteError(close_quote)
|
||||
# if we aren't then we just let it through
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
if next_char in _escaped_quotes:
|
||||
# escaped quote
|
||||
@ -179,14 +184,13 @@ class StringView:
|
||||
raise InvalidEndOfQuotedStringError(next_char)
|
||||
|
||||
# we're quoted so it's okay
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
if current.isspace() and not is_quoted:
|
||||
# end of word found
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
result.append(current)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
|
||||
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"
|
||||
|
@ -48,21 +48,21 @@ from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
|
||||
__all__ = (
|
||||
'loop',
|
||||
)
|
||||
__all__ = ("loop",)
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
_func = Callable[..., Awaitable[Any]]
|
||||
LF = TypeVar('LF', bound=_func)
|
||||
FT = TypeVar('FT', bound=_func)
|
||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
LF = TypeVar("LF", bound=_func)
|
||||
FT = TypeVar("FT", bound=_func)
|
||||
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
|
||||
|
||||
class SleepHandle:
|
||||
__slots__ = ('future', 'loop', 'handle')
|
||||
__slots__ = ("future", "loop", "handle")
|
||||
|
||||
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
def __init__(
|
||||
self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self.loop = loop
|
||||
self.future = future = loop.create_future()
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
@ -124,7 +124,7 @@ class Loop(Generic[LF]):
|
||||
self._stop_next_iteration = False
|
||||
|
||||
if self.count is not None and self.count <= 0:
|
||||
raise ValueError('count must be greater than 0 or None.')
|
||||
raise ValueError("count must be greater than 0 or None.")
|
||||
|
||||
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
||||
self._last_iteration_failed = False
|
||||
@ -132,10 +132,12 @@ class Loop(Generic[LF]):
|
||||
self._next_iteration = None
|
||||
|
||||
if not inspect.iscoroutinefunction(self.coro):
|
||||
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
|
||||
raise TypeError(
|
||||
f"Expected coroutine function, not {type(self.coro).__name__!r}."
|
||||
)
|
||||
|
||||
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||
coro = getattr(self, '_' + name)
|
||||
coro = getattr(self, "_" + name)
|
||||
if coro is None:
|
||||
return
|
||||
|
||||
@ -150,7 +152,7 @@ class Loop(Generic[LF]):
|
||||
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
await self._call_loop_function('before_loop')
|
||||
await self._call_loop_function("before_loop")
|
||||
self._last_iteration_failed = False
|
||||
if self._time is not MISSING:
|
||||
# the time index should be prepared every time the internal loop is started
|
||||
@ -193,10 +195,10 @@ class Loop(Generic[LF]):
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._has_failed = True
|
||||
await self._call_loop_function('error', exc)
|
||||
await self._call_loop_function("error", exc)
|
||||
raise exc
|
||||
finally:
|
||||
await self._call_loop_function('after_loop')
|
||||
await self._call_loop_function("after_loop")
|
||||
self._handle.cancel()
|
||||
self._is_being_cancelled = False
|
||||
self._current_loop = 0
|
||||
@ -323,7 +325,7 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if self._task is not MISSING and not self._task.done():
|
||||
raise RuntimeError('Task is already launched and is not completed.')
|
||||
raise RuntimeError("Task is already launched and is not completed.")
|
||||
|
||||
if self._injected is not None:
|
||||
args = (self._injected, *args)
|
||||
@ -356,7 +358,9 @@ class Loop(Generic[LF]):
|
||||
self._stop_next_iteration = True
|
||||
|
||||
def _can_be_cancelled(self) -> bool:
|
||||
return bool(not self._is_being_cancelled and self._task and not self._task.done())
|
||||
return bool(
|
||||
not self._is_being_cancelled and self._task and not self._task.done()
|
||||
)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancels the internal task, if it is running."""
|
||||
@ -379,7 +383,9 @@ class Loop(Generic[LF]):
|
||||
The keyword arguments to use.
|
||||
"""
|
||||
|
||||
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
|
||||
def restart_when_over(
|
||||
fut: Any, *, args: Any = args, kwargs: Any = kwargs
|
||||
) -> None:
|
||||
self._task.remove_done_callback(restart_when_over)
|
||||
self.start(*args, **kwargs)
|
||||
|
||||
@ -410,9 +416,9 @@ class Loop(Generic[LF]):
|
||||
|
||||
for exc in exceptions:
|
||||
if not inspect.isclass(exc):
|
||||
raise TypeError(f'{exc!r} must be a class.')
|
||||
raise TypeError(f"{exc!r} must be a class.")
|
||||
if not issubclass(exc, BaseException):
|
||||
raise TypeError(f'{exc!r} must inherit from BaseException.')
|
||||
raise TypeError(f"{exc!r} must inherit from BaseException.")
|
||||
|
||||
self._valid_exception = (*self._valid_exception, *exceptions)
|
||||
|
||||
@ -439,7 +445,9 @@ class Loop(Generic[LF]):
|
||||
Whether all exceptions were successfully removed.
|
||||
"""
|
||||
old_length = len(self._valid_exception)
|
||||
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
||||
self._valid_exception = tuple(
|
||||
x for x in self._valid_exception if x not in exceptions
|
||||
)
|
||||
return len(self._valid_exception) == old_length - len(exceptions)
|
||||
|
||||
def get_task(self) -> Optional[asyncio.Task[None]]:
|
||||
@ -466,8 +474,13 @@ class Loop(Generic[LF]):
|
||||
|
||||
async def _error(self, *args: Any) -> None:
|
||||
exception: Exception = args[-1]
|
||||
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
|
||||
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
|
||||
print(
|
||||
f"Unhandled exception in internal background task {self.coro.__name__!r}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
traceback.print_exception(
|
||||
type(exception), exception, exception.__traceback__, file=sys.stderr
|
||||
)
|
||||
|
||||
def before_loop(self, coro: FT) -> FT:
|
||||
"""A decorator that registers a coroutine to be called before the loop starts running.
|
||||
@ -489,7 +502,9 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(
|
||||
f"Expected coroutine function, received {coro.__class__.__name__!r}."
|
||||
)
|
||||
|
||||
self._before_loop = coro
|
||||
return coro
|
||||
@ -517,7 +532,9 @@ class Loop(Generic[LF]):
|
||||
"""
|
||||
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(
|
||||
f"Expected coroutine function, received {coro.__class__.__name__!r}."
|
||||
)
|
||||
|
||||
self._after_loop = coro
|
||||
return coro
|
||||
@ -543,7 +560,9 @@ class Loop(Generic[LF]):
|
||||
The function was not a coroutine.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
raise TypeError(
|
||||
f"Expected coroutine function, received {coro.__class__.__name__!r}."
|
||||
)
|
||||
|
||||
self._error = coro # type: ignore
|
||||
return coro
|
||||
@ -557,14 +576,18 @@ class Loop(Generic[LF]):
|
||||
if self._current_loop == 0:
|
||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
||||
return datetime.datetime.combine(
|
||||
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
+ datetime.timedelta(days=1),
|
||||
self._time[0],
|
||||
)
|
||||
|
||||
next_time = self._time[self._time_index]
|
||||
|
||||
if self._current_loop == 0:
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||
return datetime.datetime.combine(
|
||||
datetime.datetime.now(datetime.timezone.utc), next_time
|
||||
)
|
||||
|
||||
next_date = self._last_iteration
|
||||
if self._time_index == 0:
|
||||
@ -580,7 +603,9 @@ class Loop(Generic[LF]):
|
||||
|
||||
# pre-condition: self._time is set
|
||||
time_now = (
|
||||
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||
now
|
||||
if now is not MISSING
|
||||
else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||
).timetz()
|
||||
for idx, time in enumerate(self._time):
|
||||
if time >= time_now:
|
||||
@ -601,16 +626,16 @@ class Loop(Generic[LF]):
|
||||
return [inner]
|
||||
if not isinstance(time, Sequence):
|
||||
raise TypeError(
|
||||
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
|
||||
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
|
||||
)
|
||||
if not time:
|
||||
raise ValueError('time parameter must not be an empty sequence.')
|
||||
raise ValueError("time parameter must not be an empty sequence.")
|
||||
|
||||
ret: List[datetime.time] = []
|
||||
for index, t in enumerate(time):
|
||||
if not isinstance(t, dt):
|
||||
raise TypeError(
|
||||
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
|
||||
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
|
||||
)
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
@ -663,7 +688,7 @@ class Loop(Generic[LF]):
|
||||
hours = hours or 0
|
||||
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
|
||||
if sleep < 0:
|
||||
raise ValueError('Total number of seconds cannot be less than zero.')
|
||||
raise ValueError("Total number of seconds cannot be less than zero.")
|
||||
|
||||
self._sleep = sleep
|
||||
self._seconds = float(seconds)
|
||||
@ -672,7 +697,7 @@ class Loop(Generic[LF]):
|
||||
self._time: List[datetime.time] = MISSING
|
||||
else:
|
||||
if any((seconds, minutes, hours)):
|
||||
raise TypeError('Cannot mix explicit time with relative time')
|
||||
raise TypeError("Cannot mix explicit time with relative time")
|
||||
self._time = self._get_time_parameter(time)
|
||||
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
||||
|
||||
|
@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union
|
||||
import os
|
||||
import io
|
||||
|
||||
__all__ = (
|
||||
'File',
|
||||
)
|
||||
__all__ = ("File",)
|
||||
|
||||
|
||||
class File:
|
||||
@ -64,7 +62,7 @@ class File:
|
||||
Whether the attachment is a spoiler.
|
||||
"""
|
||||
|
||||
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer')
|
||||
__slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
fp: io.BufferedIOBase
|
||||
@ -80,12 +78,12 @@ class File:
|
||||
):
|
||||
if isinstance(fp, io.IOBase):
|
||||
if not (fp.seekable() and fp.readable()):
|
||||
raise ValueError(f'File buffer {fp!r} must be seekable and readable')
|
||||
raise ValueError(f"File buffer {fp!r} must be seekable and readable")
|
||||
self.fp = fp
|
||||
self._original_pos = fp.tell()
|
||||
self._owner = False
|
||||
else:
|
||||
self.fp = open(fp, 'rb')
|
||||
self.fp = open(fp, "rb")
|
||||
self._original_pos = 0
|
||||
self._owner = True
|
||||
|
||||
@ -100,14 +98,20 @@ class File:
|
||||
if isinstance(fp, str):
|
||||
_, self.filename = os.path.split(fp)
|
||||
else:
|
||||
self.filename = getattr(fp, 'name', None)
|
||||
self.filename = getattr(fp, "name", None)
|
||||
else:
|
||||
self.filename = filename
|
||||
|
||||
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
|
||||
self.filename = 'SPOILER_' + self.filename
|
||||
if (
|
||||
spoiler
|
||||
and self.filename is not None
|
||||
and not self.filename.startswith("SPOILER_")
|
||||
):
|
||||
self.filename = "SPOILER_" + self.filename
|
||||
|
||||
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_'))
|
||||
self.spoiler = spoiler or (
|
||||
self.filename is not None and self.filename.startswith("SPOILER_")
|
||||
)
|
||||
|
||||
def reset(self, *, seek: Union[int, bool] = True) -> None:
|
||||
# The `seek` parameter is needed because
|
||||
|
@ -24,21 +24,34 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from .enums import UserFlags
|
||||
|
||||
__all__ = (
|
||||
'SystemChannelFlags',
|
||||
'MessageFlags',
|
||||
'PublicUserFlags',
|
||||
'Intents',
|
||||
'MemberCacheFlags',
|
||||
'ApplicationFlags',
|
||||
"SystemChannelFlags",
|
||||
"MessageFlags",
|
||||
"PublicUserFlags",
|
||||
"Intents",
|
||||
"MemberCacheFlags",
|
||||
"ApplicationFlags",
|
||||
)
|
||||
|
||||
FV = TypeVar('FV', bound='flag_value')
|
||||
BF = TypeVar('BF', bound='BaseFlags')
|
||||
FV = TypeVar("FV", bound="flag_value")
|
||||
BF = TypeVar("BF", bound="BaseFlags")
|
||||
|
||||
|
||||
class flag_value:
|
||||
@ -63,7 +76,7 @@ class flag_value:
|
||||
instance._set_flag(self.flag, value)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<flag_value flag={self.flag!r}>'
|
||||
return f"<flag_value flag={self.flag!r}>"
|
||||
|
||||
|
||||
class alias_flag_value(flag_value):
|
||||
@ -98,13 +111,13 @@ class BaseFlags:
|
||||
|
||||
value: int
|
||||
|
||||
__slots__ = ('value',)
|
||||
__slots__ = ("value",)
|
||||
|
||||
def __init__(self, **kwargs: bool):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid flag name.')
|
||||
raise TypeError(f"{key!r} is not a valid flag name.")
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
@ -123,7 +136,7 @@ class BaseFlags:
|
||||
return hash(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} value={self.value}>'
|
||||
return f"<{self.__class__.__name__} value={self.value}>"
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, bool]]:
|
||||
for name, value in self.__class__.__dict__.items():
|
||||
@ -142,7 +155,9 @@ class BaseFlags:
|
||||
elif toggle is False:
|
||||
self.value &= ~o
|
||||
else:
|
||||
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.')
|
||||
raise TypeError(
|
||||
f"Value to set for {self.__class__.__name__} must be a bool."
|
||||
)
|
||||
|
||||
|
||||
@fill_with_flags(inverted=True)
|
||||
@ -196,7 +211,7 @@ class SystemChannelFlags(BaseFlags):
|
||||
elif toggle is False:
|
||||
self.value |= o
|
||||
else:
|
||||
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
|
||||
raise TypeError("Value to set for SystemChannelFlags must be a bool.")
|
||||
|
||||
@flag_value
|
||||
def join_notifications(self):
|
||||
@ -412,7 +427,11 @@ class PublicUserFlags(BaseFlags):
|
||||
|
||||
def all(self) -> List[UserFlags]:
|
||||
"""List[:class:`UserFlags`]: Returns all public flags the user has."""
|
||||
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
|
||||
return [
|
||||
public_flag
|
||||
for public_flag in UserFlags
|
||||
if self._has_flag(public_flag.value)
|
||||
]
|
||||
|
||||
|
||||
@fill_with_flags()
|
||||
@ -461,7 +480,7 @@ class Intents(BaseFlags):
|
||||
self.value = self.DEFAULT_VALUE
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid flag name.')
|
||||
raise TypeError(f"{key!r} is not a valid flag name.")
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
@ -907,7 +926,7 @@ class MemberCacheFlags(BaseFlags):
|
||||
self.value = (1 << bits) - 1
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid flag name.')
|
||||
raise TypeError(f"{key!r} is not a valid flag name.")
|
||||
setattr(self, key, value)
|
||||
|
||||
@classmethod
|
||||
@ -977,10 +996,10 @@ class MemberCacheFlags(BaseFlags):
|
||||
|
||||
def _verify_intents(self, intents: Intents):
|
||||
if self.voice and not intents.voice_states:
|
||||
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states')
|
||||
raise ValueError("MemberCacheFlags.voice requires Intents.voice_states")
|
||||
|
||||
if self.joined and not intents.members:
|
||||
raise ValueError('MemberCacheFlags.joined requires Intents.members')
|
||||
raise ValueError("MemberCacheFlags.joined requires Intents.members")
|
||||
|
||||
@property
|
||||
def _voice_only(self):
|
||||
|
@ -43,25 +43,31 @@ from .errors import ConnectionClosed, InvalidArgument
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'DiscordWebSocket',
|
||||
'KeepAliveHandler',
|
||||
'VoiceKeepAliveHandler',
|
||||
'DiscordVoiceWebSocket',
|
||||
'ReconnectWebSocket',
|
||||
"DiscordWebSocket",
|
||||
"KeepAliveHandler",
|
||||
"VoiceKeepAliveHandler",
|
||||
"DiscordVoiceWebSocket",
|
||||
"ReconnectWebSocket",
|
||||
)
|
||||
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to safely reconnect the websocket."""
|
||||
|
||||
def __init__(self, shard_id, *, resume=True):
|
||||
self.shard_id = shard_id
|
||||
self.resume = resume
|
||||
self.op = 'RESUME' if resume else 'IDENTIFY'
|
||||
self.op = "RESUME" if resume else "IDENTIFY"
|
||||
|
||||
|
||||
class WebSocketClosure(Exception):
|
||||
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
|
||||
|
||||
pass
|
||||
|
||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||
|
||||
EventListener = namedtuple("EventListener", "predicate event result future")
|
||||
|
||||
|
||||
class GatewayRatelimiter:
|
||||
def __init__(self, count=110, per=60.0):
|
||||
@ -101,48 +107,57 @@ class GatewayRatelimiter:
|
||||
async with self.lock:
|
||||
delta = self.get_delay()
|
||||
if delta:
|
||||
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
|
||||
_log.warning(
|
||||
"WebSocket in shard ID %s is ratelimited, waiting %.2f seconds",
|
||||
self.shard_id,
|
||||
delta,
|
||||
)
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
|
||||
class KeepAliveHandler(threading.Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
ws = kwargs.pop('ws', None)
|
||||
interval = kwargs.pop('interval', None)
|
||||
shard_id = kwargs.pop('shard_id', None)
|
||||
ws = kwargs.pop("ws", None)
|
||||
interval = kwargs.pop("interval", None)
|
||||
shard_id = kwargs.pop("shard_id", None)
|
||||
threading.Thread.__init__(self, *args, **kwargs)
|
||||
self.ws = ws
|
||||
self._main_thread_id = ws.thread_id
|
||||
self.interval = interval
|
||||
self.daemon = True
|
||||
self.shard_id = shard_id
|
||||
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
|
||||
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
|
||||
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
|
||||
self.msg = "Keeping shard ID %s websocket alive with sequence %s."
|
||||
self.block_msg = "Shard ID %s heartbeat blocked for more than %s seconds."
|
||||
self.behind_msg = "Can't keep up, shard ID %s websocket is %.1fs behind."
|
||||
self._stop_ev = threading.Event()
|
||||
self._last_ack = time.perf_counter()
|
||||
self._last_send = time.perf_counter()
|
||||
self._last_recv = time.perf_counter()
|
||||
self.latency = float('inf')
|
||||
self.latency = float("inf")
|
||||
self.heartbeat_timeout = ws._max_heartbeat_timeout
|
||||
|
||||
def run(self):
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
_log.warning(
|
||||
"Shard ID %s has stopped responding to the gateway. Closing and restarting.",
|
||||
self.shard_id,
|
||||
)
|
||||
coro = self.ws.close(4000)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
|
||||
try:
|
||||
f.result()
|
||||
except Exception:
|
||||
_log.exception('An error occurred while stopping the gateway. Ignoring.')
|
||||
_log.exception(
|
||||
"An error occurred while stopping the gateway. Ignoring."
|
||||
)
|
||||
finally:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
data = self.get_payload()
|
||||
_log.debug(self.msg, self.shard_id, data['d'])
|
||||
_log.debug(self.msg, self.shard_id, data["d"])
|
||||
coro = self.ws.send_heartbeat(data)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
try:
|
||||
@ -159,8 +174,8 @@ class KeepAliveHandler(threading.Thread):
|
||||
except KeyError:
|
||||
msg = self.block_msg
|
||||
else:
|
||||
stack = ''.join(traceback.format_stack(frame))
|
||||
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
|
||||
stack = "".join(traceback.format_stack(frame))
|
||||
msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}"
|
||||
_log.warning(msg, self.shard_id, total)
|
||||
|
||||
except Exception:
|
||||
@ -169,10 +184,7 @@ class KeepAliveHandler(threading.Thread):
|
||||
self._last_send = time.perf_counter()
|
||||
|
||||
def get_payload(self):
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': self.ws.sequence
|
||||
}
|
||||
return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence}
|
||||
|
||||
def stop(self):
|
||||
self._stop_ev.set()
|
||||
@ -187,19 +199,17 @@ class KeepAliveHandler(threading.Thread):
|
||||
if self.latency > 10:
|
||||
_log.warning(self.behind_msg, self.shard_id, self.latency)
|
||||
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.recent_ack_latencies = deque(maxlen=20)
|
||||
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
|
||||
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
|
||||
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
|
||||
self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s."
|
||||
self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds"
|
||||
self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind"
|
||||
|
||||
def get_payload(self):
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': int(time.time() * 1000)
|
||||
}
|
||||
return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)}
|
||||
|
||||
def ack(self):
|
||||
ack_time = time.perf_counter()
|
||||
@ -208,10 +218,12 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
self.latency = ack_time - self._last_send
|
||||
self.recent_ack_latencies.append(self.latency)
|
||||
|
||||
|
||||
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
|
||||
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
|
||||
async def close(self, *, code: int = 4000, message: bytes = b"") -> bool:
|
||||
return await super().close(code=code, message=message)
|
||||
|
||||
|
||||
class DiscordWebSocket:
|
||||
"""Implements a WebSocket for Discord's gateway v6.
|
||||
|
||||
@ -252,19 +264,19 @@ class DiscordWebSocket:
|
||||
The authentication token for discord.
|
||||
"""
|
||||
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
PRESENCE = 3
|
||||
VOICE_STATE = 4
|
||||
VOICE_PING = 5
|
||||
RESUME = 6
|
||||
RECONNECT = 7
|
||||
REQUEST_MEMBERS = 8
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
PRESENCE = 3
|
||||
VOICE_STATE = 4
|
||||
VOICE_PING = 5
|
||||
RESUME = 6
|
||||
RECONNECT = 7
|
||||
REQUEST_MEMBERS = 8
|
||||
INVALIDATE_SESSION = 9
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
GUILD_SYNC = 12
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
GUILD_SYNC = 12
|
||||
|
||||
def __init__(self, socket, *, loop):
|
||||
self.socket = socket
|
||||
@ -294,13 +306,23 @@ class DiscordWebSocket:
|
||||
return self._rate_limiter.is_ratelimited()
|
||||
|
||||
def debug_log_receive(self, data, /):
|
||||
self._dispatch('socket_raw_receive', data)
|
||||
self._dispatch("socket_raw_receive", data)
|
||||
|
||||
def log_receive(self, _, /):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||
async def from_client(
|
||||
cls,
|
||||
client,
|
||||
*,
|
||||
initial=False,
|
||||
gateway=None,
|
||||
shard_id=None,
|
||||
session=None,
|
||||
sequence=None,
|
||||
resume=False,
|
||||
):
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@ -330,7 +352,7 @@ class DiscordWebSocket:
|
||||
|
||||
client._connection._update_references(ws)
|
||||
|
||||
_log.debug('Created websocket connected to %s', gateway)
|
||||
_log.debug("Created websocket connected to %s", gateway)
|
||||
|
||||
# poll event for OP Hello
|
||||
await ws.poll_event()
|
||||
@ -363,83 +385,87 @@ class DiscordWebSocket:
|
||||
"""
|
||||
|
||||
future = self.loop.create_future()
|
||||
entry = EventListener(event=event, predicate=predicate, result=result, future=future)
|
||||
entry = EventListener(
|
||||
event=event, predicate=predicate, result=result, future=future
|
||||
)
|
||||
self._dispatch_listeners.append(entry)
|
||||
return future
|
||||
|
||||
async def identify(self):
|
||||
"""Sends the IDENTIFY packet."""
|
||||
payload = {
|
||||
'op': self.IDENTIFY,
|
||||
'd': {
|
||||
'token': self.token,
|
||||
'properties': {
|
||||
'$os': sys.platform,
|
||||
'$browser': 'discord.py',
|
||||
'$device': 'discord.py',
|
||||
'$referrer': '',
|
||||
'$referring_domain': ''
|
||||
"op": self.IDENTIFY,
|
||||
"d": {
|
||||
"token": self.token,
|
||||
"properties": {
|
||||
"$os": sys.platform,
|
||||
"$browser": "discord.py",
|
||||
"$device": "discord.py",
|
||||
"$referrer": "",
|
||||
"$referring_domain": "",
|
||||
},
|
||||
'compress': True,
|
||||
'large_threshold': 250,
|
||||
'v': 3
|
||||
}
|
||||
"compress": True,
|
||||
"large_threshold": 250,
|
||||
"v": 3,
|
||||
},
|
||||
}
|
||||
|
||||
if self.shard_id is not None and self.shard_count is not None:
|
||||
payload['d']['shard'] = [self.shard_id, self.shard_count]
|
||||
payload["d"]["shard"] = [self.shard_id, self.shard_count]
|
||||
|
||||
state = self._connection
|
||||
if state._activity is not None or state._status is not None:
|
||||
payload['d']['presence'] = {
|
||||
'status': state._status,
|
||||
'game': state._activity,
|
||||
'since': 0,
|
||||
'afk': False
|
||||
payload["d"]["presence"] = {
|
||||
"status": state._status,
|
||||
"game": state._activity,
|
||||
"since": 0,
|
||||
"afk": False,
|
||||
}
|
||||
|
||||
if state._intents is not None:
|
||||
payload['d']['intents'] = state._intents.value
|
||||
payload["d"]["intents"] = state._intents.value
|
||||
|
||||
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
|
||||
await self.call_hooks(
|
||||
"before_identify", self.shard_id, initial=self._initial_identify
|
||||
)
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
_log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id)
|
||||
|
||||
async def resume(self):
|
||||
"""Sends the RESUME packet."""
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
'd': {
|
||||
'seq': self.sequence,
|
||||
'session_id': self.session_id,
|
||||
'token': self.token
|
||||
}
|
||||
"op": self.RESUME,
|
||||
"d": {
|
||||
"seq": self.sequence,
|
||||
"session_id": self.session_id,
|
||||
"token": self.token,
|
||||
},
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
_log.info("Shard ID %s has sent the RESUME payload.", self.shard_id)
|
||||
|
||||
async def received_message(self, msg, /):
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
|
||||
if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff":
|
||||
return
|
||||
msg = self._zlib.decompress(self._buffer)
|
||||
msg = msg.decode('utf-8')
|
||||
msg = msg.decode("utf-8")
|
||||
self._buffer = bytearray()
|
||||
|
||||
self.log_receive(msg)
|
||||
msg = utils._from_json(msg)
|
||||
|
||||
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
|
||||
event = msg.get('t')
|
||||
_log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg)
|
||||
event = msg.get("t")
|
||||
if event:
|
||||
self._dispatch('socket_event_type', event)
|
||||
self._dispatch("socket_event_type", event)
|
||||
|
||||
op = msg.get('op')
|
||||
data = msg.get('d')
|
||||
seq = msg.get('s')
|
||||
op = msg.get("op")
|
||||
data = msg.get("d")
|
||||
seq = msg.get("s")
|
||||
if seq is not None:
|
||||
self.sequence = seq
|
||||
|
||||
@ -451,7 +477,7 @@ class DiscordWebSocket:
|
||||
# "reconnect" can only be handled by the Client
|
||||
# so we terminate our connection and raise an
|
||||
# internal exception signalling to reconnect.
|
||||
_log.debug('Received RECONNECT opcode.')
|
||||
_log.debug("Received RECONNECT opcode.")
|
||||
await self.close()
|
||||
raise ReconnectWebSocket(self.shard_id)
|
||||
|
||||
@ -467,8 +493,10 @@ class DiscordWebSocket:
|
||||
return
|
||||
|
||||
if op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id)
|
||||
interval = data["heartbeat_interval"] / 1000.0
|
||||
self._keep_alive = KeepAliveHandler(
|
||||
ws=self, interval=interval, shard_id=self.shard_id
|
||||
)
|
||||
# send a heartbeat immediately
|
||||
await self.send_as_json(self._keep_alive.get_payload())
|
||||
self._keep_alive.start()
|
||||
@ -481,33 +509,41 @@ class DiscordWebSocket:
|
||||
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
|
||||
_log.info("Shard ID %s session has been invalidated.", self.shard_id)
|
||||
await self.close(code=1000)
|
||||
raise ReconnectWebSocket(self.shard_id, resume=False)
|
||||
|
||||
_log.warning('Unknown OP code %s.', op)
|
||||
_log.warning("Unknown OP code %s.", op)
|
||||
return
|
||||
|
||||
if event == 'READY':
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
self.sequence = msg['s']
|
||||
self.session_id = data['session_id']
|
||||
if event == "READY":
|
||||
self._trace = trace = data.get("_trace", [])
|
||||
self.sequence = msg["s"]
|
||||
self.session_id = data["session_id"]
|
||||
# pass back shard ID to ready handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
self.shard_id, ', '.join(trace), self.session_id)
|
||||
data["__shard_id__"] = self.shard_id
|
||||
_log.info(
|
||||
"Shard ID %s has connected to Gateway: %s (Session ID: %s).",
|
||||
self.shard_id,
|
||||
", ".join(trace),
|
||||
self.session_id,
|
||||
)
|
||||
|
||||
elif event == 'RESUMED':
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
elif event == "RESUMED":
|
||||
self._trace = trace = data.get("_trace", [])
|
||||
# pass back the shard ID to the resumed handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
|
||||
self.shard_id, self.session_id, ', '.join(trace))
|
||||
data["__shard_id__"] = self.shard_id
|
||||
_log.info(
|
||||
"Shard ID %s has successfully RESUMED session %s under trace %s.",
|
||||
self.shard_id,
|
||||
self.session_id,
|
||||
", ".join(trace),
|
||||
)
|
||||
|
||||
try:
|
||||
func = self._discord_parsers[event]
|
||||
except KeyError:
|
||||
_log.debug('Unknown event %s.', event)
|
||||
_log.debug("Unknown event %s.", event)
|
||||
else:
|
||||
func(data)
|
||||
|
||||
@ -540,7 +576,7 @@ class DiscordWebSocket:
|
||||
def latency(self):
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
return float("inf") if heartbeat is None else heartbeat.latency
|
||||
|
||||
def _can_handle_close(self):
|
||||
code = self._close_code or self.socket.close_code
|
||||
@ -561,10 +597,14 @@ class DiscordWebSocket:
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
await self.received_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug("Received %s", msg)
|
||||
raise msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
|
||||
_log.debug('Received %s', msg)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSED,
|
||||
aiohttp.WSMsgType.CLOSING,
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
):
|
||||
_log.debug("Received %s", msg)
|
||||
raise WebSocketClosure
|
||||
except (asyncio.TimeoutError, WebSocketClosure) as e:
|
||||
# Ensure the keep alive handler is closed
|
||||
@ -573,20 +613,22 @@ class DiscordWebSocket:
|
||||
self._keep_alive = None
|
||||
|
||||
if isinstance(e, asyncio.TimeoutError):
|
||||
_log.info('Timed out receiving packet. Attempting a reconnect.')
|
||||
_log.info("Timed out receiving packet. Attempting a reconnect.")
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
|
||||
code = self._close_code or self.socket.close_code
|
||||
if self._can_handle_close():
|
||||
_log.info('Websocket closed with %s, attempting a reconnect.', code)
|
||||
_log.info("Websocket closed with %s, attempting a reconnect.", code)
|
||||
raise ReconnectWebSocket(self.shard_id) from None
|
||||
else:
|
||||
_log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
|
||||
_log.info("Websocket closed with %s, cannot reconnect.", code)
|
||||
raise ConnectionClosed(
|
||||
self.socket, shard_id=self.shard_id, code=code
|
||||
) from None
|
||||
|
||||
async def debug_send(self, data, /):
|
||||
await self._rate_limiter.block()
|
||||
self._dispatch('socket_raw_send', data)
|
||||
self._dispatch("socket_raw_send", data)
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send(self, data, /):
|
||||
@ -611,62 +653,59 @@ class DiscordWebSocket:
|
||||
async def change_presence(self, *, activity=None, status=None, since=0.0):
|
||||
if activity is not None:
|
||||
if not isinstance(activity, BaseActivity):
|
||||
raise InvalidArgument('activity must derive from BaseActivity.')
|
||||
raise InvalidArgument("activity must derive from BaseActivity.")
|
||||
activity = [activity.to_dict()]
|
||||
else:
|
||||
activity = []
|
||||
|
||||
if status == 'idle':
|
||||
if status == "idle":
|
||||
since = int(time.time() * 1000)
|
||||
|
||||
payload = {
|
||||
'op': self.PRESENCE,
|
||||
'd': {
|
||||
'activities': activity,
|
||||
'afk': False,
|
||||
'since': since,
|
||||
'status': status
|
||||
}
|
||||
"op": self.PRESENCE,
|
||||
"d": {
|
||||
"activities": activity,
|
||||
"afk": False,
|
||||
"since": since,
|
||||
"status": status,
|
||||
},
|
||||
}
|
||||
|
||||
sent = utils._to_json(payload)
|
||||
_log.debug('Sending "%s" to change status', sent)
|
||||
await self.send(sent)
|
||||
|
||||
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
|
||||
async def request_chunks(
|
||||
self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None
|
||||
):
|
||||
payload = {
|
||||
'op': self.REQUEST_MEMBERS,
|
||||
'd': {
|
||||
'guild_id': guild_id,
|
||||
'presences': presences,
|
||||
'limit': limit
|
||||
}
|
||||
"op": self.REQUEST_MEMBERS,
|
||||
"d": {"guild_id": guild_id, "presences": presences, "limit": limit},
|
||||
}
|
||||
|
||||
if nonce:
|
||||
payload['d']['nonce'] = nonce
|
||||
payload["d"]["nonce"] = nonce
|
||||
|
||||
if user_ids:
|
||||
payload['d']['user_ids'] = user_ids
|
||||
payload["d"]["user_ids"] = user_ids
|
||||
|
||||
if query is not None:
|
||||
payload['d']['query'] = query
|
||||
|
||||
payload["d"]["query"] = query
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||
payload = {
|
||||
'op': self.VOICE_STATE,
|
||||
'd': {
|
||||
'guild_id': guild_id,
|
||||
'channel_id': channel_id,
|
||||
'self_mute': self_mute,
|
||||
'self_deaf': self_deaf
|
||||
}
|
||||
"op": self.VOICE_STATE,
|
||||
"d": {
|
||||
"guild_id": guild_id,
|
||||
"channel_id": channel_id,
|
||||
"self_mute": self_mute,
|
||||
"self_deaf": self_deaf,
|
||||
},
|
||||
}
|
||||
|
||||
_log.debug('Updating our voice state to %s.', payload)
|
||||
_log.debug("Updating our voice state to %s.", payload)
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def close(self, code=4000):
|
||||
@ -677,6 +716,7 @@ class DiscordWebSocket:
|
||||
self._close_code = code
|
||||
await self.socket.close(code=code)
|
||||
|
||||
|
||||
class DiscordVoiceWebSocket:
|
||||
"""Implements the websocket protocol for handling voice connections.
|
||||
|
||||
@ -708,18 +748,18 @@ class DiscordVoiceWebSocket:
|
||||
Receive only. Indicates a user has disconnected from voice.
|
||||
"""
|
||||
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
|
||||
def __init__(self, socket, loop, *, hook=None):
|
||||
self.ws = socket
|
||||
@ -734,7 +774,7 @@ class DiscordVoiceWebSocket:
|
||||
pass
|
||||
|
||||
async def send_as_json(self, data):
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
_log.debug("Sending voice websocket frame: %s.", data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
@ -742,32 +782,32 @@ class DiscordVoiceWebSocket:
|
||||
async def resume(self):
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
'd': {
|
||||
'token': state.token,
|
||||
'server_id': str(state.server_id),
|
||||
'session_id': state.session_id
|
||||
}
|
||||
"op": self.RESUME,
|
||||
"d": {
|
||||
"token": state.token,
|
||||
"server_id": str(state.server_id),
|
||||
"session_id": state.session_id,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def identify(self):
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.IDENTIFY,
|
||||
'd': {
|
||||
'server_id': str(state.server_id),
|
||||
'user_id': str(state.user.id),
|
||||
'session_id': state.session_id,
|
||||
'token': state.token
|
||||
}
|
||||
"op": self.IDENTIFY,
|
||||
"d": {
|
||||
"server_id": str(state.server_id),
|
||||
"user_id": str(state.user.id),
|
||||
"session_id": state.session_id,
|
||||
"token": state.token,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, resume=False, hook=None):
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
gateway = "wss://" + client.endpoint + "/?v=4"
|
||||
http = client._state.http
|
||||
socket = await http.ws_connect(gateway, compress=15)
|
||||
ws = cls(socket, loop=client.loop, hook=hook)
|
||||
@ -785,109 +825,101 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
async def select_protocol(self, ip, port, mode):
|
||||
payload = {
|
||||
'op': self.SELECT_PROTOCOL,
|
||||
'd': {
|
||||
'protocol': 'udp',
|
||||
'data': {
|
||||
'address': ip,
|
||||
'port': port,
|
||||
'mode': mode
|
||||
}
|
||||
}
|
||||
"op": self.SELECT_PROTOCOL,
|
||||
"d": {
|
||||
"protocol": "udp",
|
||||
"data": {"address": ip, "port": port, "mode": mode},
|
||||
},
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def client_connect(self):
|
||||
payload = {
|
||||
'op': self.CLIENT_CONNECT,
|
||||
'd': {
|
||||
'audio_ssrc': self._connection.ssrc
|
||||
}
|
||||
"op": self.CLIENT_CONNECT,
|
||||
"d": {"audio_ssrc": self._connection.ssrc},
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def speak(self, state=SpeakingState.voice):
|
||||
payload = {
|
||||
'op': self.SPEAKING,
|
||||
'd': {
|
||||
'speaking': int(state),
|
||||
'delay': 0
|
||||
}
|
||||
}
|
||||
payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg):
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
data = msg.get('d')
|
||||
_log.debug("Voice websocket frame received: %s", msg)
|
||||
op = msg["op"]
|
||||
data = msg.get("d")
|
||||
|
||||
if op == self.READY:
|
||||
await self.initial_connection(data)
|
||||
elif op == self.HEARTBEAT_ACK:
|
||||
self._keep_alive.ack()
|
||||
elif op == self.RESUMED:
|
||||
_log.info('Voice RESUME succeeded.')
|
||||
_log.info("Voice RESUME succeeded.")
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
self._connection.mode = data["mode"]
|
||||
await self.load_secret_key(data)
|
||||
elif op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
||||
interval = data["heartbeat_interval"] / 1000.0
|
||||
self._keep_alive = VoiceKeepAliveHandler(
|
||||
ws=self, interval=min(interval, 5.0)
|
||||
)
|
||||
self._keep_alive.start()
|
||||
|
||||
await self._hook(self, msg)
|
||||
|
||||
async def initial_connection(self, data):
|
||||
state = self._connection
|
||||
state.ssrc = data['ssrc']
|
||||
state.voice_port = data['port']
|
||||
state.endpoint_ip = data['ip']
|
||||
state.ssrc = data["ssrc"]
|
||||
state.voice_port = data["port"]
|
||||
state.endpoint_ip = data["ip"]
|
||||
|
||||
packet = bytearray(70)
|
||||
struct.pack_into('>H', packet, 0, 1) # 1 = Send
|
||||
struct.pack_into('>H', packet, 2, 70) # 70 = Length
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
struct.pack_into(">H", packet, 0, 1) # 1 = Send
|
||||
struct.pack_into(">H", packet, 2, 70) # 70 = Length
|
||||
struct.pack_into(">I", packet, 4, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = await self.loop.sock_recv(state.socket, 70)
|
||||
_log.debug('received packet in initial_connection: %s', recv)
|
||||
_log.debug("received packet in initial_connection: %s", recv)
|
||||
|
||||
# the ip is ascii starting at the 4th byte and ending at the first null
|
||||
ip_start = 4
|
||||
ip_end = recv.index(0, ip_start)
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
state.ip = recv[ip_start:ip_end].decode("ascii")
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0]
|
||||
_log.debug("detected ip: %s port: %s", state.ip, state.port)
|
||||
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
modes = [
|
||||
mode for mode in data["modes"] if mode in self._connection.supported_modes
|
||||
]
|
||||
_log.debug("received supported encryption modes: %s", ", ".join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
_log.info('selected the voice protocol for use (%s)', mode)
|
||||
_log.info("selected the voice protocol for use (%s)", mode)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
return float("inf") if heartbeat is None else heartbeat.latency
|
||||
|
||||
@property
|
||||
def average_latency(self):
|
||||
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
|
||||
heartbeat = self._keep_alive
|
||||
if heartbeat is None or not heartbeat.recent_ack_latencies:
|
||||
return float('inf')
|
||||
return float("inf")
|
||||
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
async def load_secret_key(self, data):
|
||||
_log.info('received secret key for voice connection')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key')
|
||||
_log.info("received secret key for voice connection")
|
||||
self.secret_key = self._connection.secret_key = data.get("secret_key")
|
||||
await self.speak()
|
||||
await self.speak(False)
|
||||
|
||||
@ -897,10 +929,14 @@ class DiscordVoiceWebSocket:
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug("Received %s", msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
_log.debug('Received %s', msg)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSED,
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.CLOSING,
|
||||
):
|
||||
_log.debug("Received %s", msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code=1000):
|
||||
|
641
discord/guild.py
641
discord/guild.py
File diff suppressed because it is too large
Load Diff
1451
discord/http.py
1451
discord/http.py
File diff suppressed because it is too large
Load Diff
@ -32,11 +32,11 @@ from .errors import InvalidArgument
|
||||
from .enums import try_enum, ExpireBehaviour
|
||||
|
||||
__all__ = (
|
||||
'IntegrationAccount',
|
||||
'IntegrationApplication',
|
||||
'Integration',
|
||||
'StreamIntegration',
|
||||
'BotIntegration',
|
||||
"IntegrationAccount",
|
||||
"IntegrationApplication",
|
||||
"Integration",
|
||||
"StreamIntegration",
|
||||
"BotIntegration",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -65,14 +65,14 @@ class IntegrationAccount:
|
||||
The account name.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'name')
|
||||
__slots__ = ("id", "name")
|
||||
|
||||
def __init__(self, data: IntegrationAccountPayload) -> None:
|
||||
self.id: str = data['id']
|
||||
self.name: str = data['name']
|
||||
self.id: str = data["id"]
|
||||
self.name: str = data["name"]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<IntegrationAccount id={self.id} name={self.name!r}>'
|
||||
return f"<IntegrationAccount id={self.id} name={self.name!r}>"
|
||||
|
||||
|
||||
class Integration:
|
||||
@ -99,14 +99,14 @@ class Integration:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'guild',
|
||||
'id',
|
||||
'_state',
|
||||
'type',
|
||||
'name',
|
||||
'account',
|
||||
'user',
|
||||
'enabled',
|
||||
"guild",
|
||||
"id",
|
||||
"_state",
|
||||
"type",
|
||||
"name",
|
||||
"account",
|
||||
"user",
|
||||
"enabled",
|
||||
)
|
||||
|
||||
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
|
||||
@ -118,14 +118,14 @@ class Integration:
|
||||
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
|
||||
|
||||
def _from_data(self, data: IntegrationPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.type: IntegrationType = data['type']
|
||||
self.name: str = data['name']
|
||||
self.account: IntegrationAccount = IntegrationAccount(data['account'])
|
||||
self.id: int = int(data["id"])
|
||||
self.type: IntegrationType = data["type"]
|
||||
self.name: str = data["name"]
|
||||
self.account: IntegrationAccount = IntegrationAccount(data["account"])
|
||||
|
||||
user = data.get('user')
|
||||
user = data.get("user")
|
||||
self.user = User(state=self._state, data=user) if user else None
|
||||
self.enabled: bool = data['enabled']
|
||||
self.enabled: bool = data["enabled"]
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
@ -186,26 +186,28 @@ class StreamIntegration(Integration):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'revoked',
|
||||
'expire_behaviour',
|
||||
'expire_grace_period',
|
||||
'synced_at',
|
||||
'_role_id',
|
||||
'syncing',
|
||||
'enable_emoticons',
|
||||
'subscriber_count',
|
||||
"revoked",
|
||||
"expire_behaviour",
|
||||
"expire_grace_period",
|
||||
"synced_at",
|
||||
"_role_id",
|
||||
"syncing",
|
||||
"enable_emoticons",
|
||||
"subscriber_count",
|
||||
)
|
||||
|
||||
def _from_data(self, data: StreamIntegrationPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.revoked: bool = data['revoked']
|
||||
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior'])
|
||||
self.expire_grace_period: int = data['expire_grace_period']
|
||||
self.synced_at: datetime.datetime = parse_time(data['synced_at'])
|
||||
self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id')
|
||||
self.syncing: bool = data['syncing']
|
||||
self.enable_emoticons: bool = data['enable_emoticons']
|
||||
self.subscriber_count: int = data['subscriber_count']
|
||||
self.revoked: bool = data["revoked"]
|
||||
self.expire_behaviour: ExpireBehaviour = try_enum(
|
||||
ExpireBehaviour, data["expire_behavior"]
|
||||
)
|
||||
self.expire_grace_period: int = data["expire_grace_period"]
|
||||
self.synced_at: datetime.datetime = parse_time(data["synced_at"])
|
||||
self._role_id: Optional[int] = _get_as_snowflake(data, "role_id")
|
||||
self.syncing: bool = data["syncing"]
|
||||
self.enable_emoticons: bool = data["enable_emoticons"]
|
||||
self.subscriber_count: int = data["subscriber_count"]
|
||||
|
||||
@property
|
||||
def expire_behavior(self) -> ExpireBehaviour:
|
||||
@ -252,15 +254,17 @@ class StreamIntegration(Integration):
|
||||
payload: Dict[str, Any] = {}
|
||||
if expire_behaviour is not MISSING:
|
||||
if not isinstance(expire_behaviour, ExpireBehaviour):
|
||||
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour')
|
||||
raise InvalidArgument(
|
||||
"expire_behaviour field must be of type ExpireBehaviour"
|
||||
)
|
||||
|
||||
payload['expire_behavior'] = expire_behaviour.value
|
||||
payload["expire_behavior"] = expire_behaviour.value
|
||||
|
||||
if expire_grace_period is not MISSING:
|
||||
payload['expire_grace_period'] = expire_grace_period
|
||||
payload["expire_grace_period"] = expire_grace_period
|
||||
|
||||
if enable_emoticons is not MISSING:
|
||||
payload['enable_emoticons'] = enable_emoticons
|
||||
payload["enable_emoticons"] = enable_emoticons
|
||||
|
||||
# This endpoint is undocumented.
|
||||
# Unsure if it returns the data or not as a result
|
||||
@ -307,21 +311,21 @@ class IntegrationApplication:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'id',
|
||||
'name',
|
||||
'icon',
|
||||
'description',
|
||||
'summary',
|
||||
'user',
|
||||
"id",
|
||||
"name",
|
||||
"icon",
|
||||
"description",
|
||||
"summary",
|
||||
"user",
|
||||
)
|
||||
|
||||
def __init__(self, *, data: IntegrationApplicationPayload, state):
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.icon: Optional[str] = data['icon']
|
||||
self.description: str = data['description']
|
||||
self.summary: str = data['summary']
|
||||
user = data.get('bot')
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self.icon: Optional[str] = data["icon"]
|
||||
self.description: str = data["description"]
|
||||
self.summary: str = data["summary"]
|
||||
user = data.get("bot")
|
||||
self.user: Optional[User] = User(state=state, data=user) if user else None
|
||||
|
||||
|
||||
@ -350,17 +354,19 @@ class BotIntegration(Integration):
|
||||
The application tied to this integration.
|
||||
"""
|
||||
|
||||
__slots__ = ('application',)
|
||||
__slots__ = ("application",)
|
||||
|
||||
def _from_data(self, data: BotIntegrationPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.application = IntegrationApplication(data=data['application'], state=self._state)
|
||||
self.application = IntegrationApplication(
|
||||
data=data["application"], state=self._state
|
||||
)
|
||||
|
||||
|
||||
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
|
||||
if value == 'discord':
|
||||
if value == "discord":
|
||||
return BotIntegration, value
|
||||
elif value in ('twitch', 'youtube'):
|
||||
elif value in ("twitch", "youtube"):
|
||||
return StreamIntegration, value
|
||||
else:
|
||||
return Integration, value
|
||||
|
@ -41,9 +41,9 @@ from .permissions import Permissions
|
||||
from .webhook.async_ import async_context, Webhook, handle_message_parameters
|
||||
|
||||
__all__ = (
|
||||
'Interaction',
|
||||
'InteractionMessage',
|
||||
'InteractionResponse',
|
||||
"Interaction",
|
||||
"InteractionMessage",
|
||||
"InteractionResponse",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -58,11 +58,24 @@ if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
from .embeds import Embed
|
||||
from .ui.view import View
|
||||
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
|
||||
from .channel import (
|
||||
VoiceChannel,
|
||||
StageChannel,
|
||||
TextChannel,
|
||||
CategoryChannel,
|
||||
StoreChannel,
|
||||
PartialMessageable,
|
||||
)
|
||||
from .threads import Thread
|
||||
|
||||
InteractionChannel = Union[
|
||||
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
|
||||
VoiceChannel,
|
||||
StageChannel,
|
||||
TextChannel,
|
||||
CategoryChannel,
|
||||
StoreChannel,
|
||||
Thread,
|
||||
PartialMessageable,
|
||||
]
|
||||
|
||||
MISSING: Any = utils.MISSING
|
||||
@ -100,23 +113,23 @@ class Interaction:
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'id',
|
||||
'type',
|
||||
'guild_id',
|
||||
'channel_id',
|
||||
'data',
|
||||
'application_id',
|
||||
'message',
|
||||
'user',
|
||||
'token',
|
||||
'version',
|
||||
'_permissions',
|
||||
'_state',
|
||||
'_session',
|
||||
'_original_message',
|
||||
'_cs_response',
|
||||
'_cs_followup',
|
||||
'_cs_channel',
|
||||
"id",
|
||||
"type",
|
||||
"guild_id",
|
||||
"channel_id",
|
||||
"data",
|
||||
"application_id",
|
||||
"message",
|
||||
"user",
|
||||
"token",
|
||||
"version",
|
||||
"_permissions",
|
||||
"_state",
|
||||
"_session",
|
||||
"_original_message",
|
||||
"_cs_response",
|
||||
"_cs_followup",
|
||||
"_cs_channel",
|
||||
)
|
||||
|
||||
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
|
||||
@ -126,18 +139,18 @@ class Interaction:
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: InteractionPayload):
|
||||
self.id: int = int(data['id'])
|
||||
self.type: InteractionType = try_enum(InteractionType, data['type'])
|
||||
self.data: Optional[InteractionData] = data.get('data')
|
||||
self.token: str = data['token']
|
||||
self.version: int = data['version']
|
||||
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.application_id: int = int(data['application_id'])
|
||||
self.id: int = int(data["id"])
|
||||
self.type: InteractionType = try_enum(InteractionType, data["type"])
|
||||
self.data: Optional[InteractionData] = data.get("data")
|
||||
self.token: str = data["token"]
|
||||
self.version: int = data["version"]
|
||||
self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id")
|
||||
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
|
||||
self.application_id: int = int(data["application_id"])
|
||||
|
||||
self.message: Optional[Message]
|
||||
try:
|
||||
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
|
||||
self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore
|
||||
except KeyError:
|
||||
self.message = None
|
||||
|
||||
@ -148,15 +161,15 @@ class Interaction:
|
||||
if self.guild_id:
|
||||
guild = self.guild or Object(id=self.guild_id)
|
||||
try:
|
||||
member = data['member'] # type: ignore
|
||||
member = data["member"] # type: ignore
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
|
||||
self._permissions = int(member.get('permissions', 0))
|
||||
self._permissions = int(member.get("permissions", 0))
|
||||
else:
|
||||
try:
|
||||
self.user = User(state=self._state, data=data['user'])
|
||||
self.user = User(state=self._state, data=data["user"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@ -165,7 +178,7 @@ class Interaction:
|
||||
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
|
||||
return self._state and self._state._get_guild(self.guild_id)
|
||||
|
||||
@utils.cached_slot_property('_cs_channel')
|
||||
@utils.cached_slot_property("_cs_channel")
|
||||
def channel(self) -> Optional[InteractionChannel]:
|
||||
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
|
||||
|
||||
@ -176,8 +189,14 @@ class Interaction:
|
||||
channel = guild and guild._resolve_channel(self.channel_id)
|
||||
if channel is None:
|
||||
if self.channel_id is not None:
|
||||
type = ChannelType.text if self.guild_id is not None else ChannelType.private
|
||||
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
|
||||
type = (
|
||||
ChannelType.text
|
||||
if self.guild_id is not None
|
||||
else ChannelType.private
|
||||
)
|
||||
return PartialMessageable(
|
||||
state=self._state, id=self.channel_id, type=type
|
||||
)
|
||||
return None
|
||||
return channel
|
||||
|
||||
@ -189,7 +208,7 @@ class Interaction:
|
||||
"""
|
||||
return Permissions(self._permissions)
|
||||
|
||||
@utils.cached_slot_property('_cs_response')
|
||||
@utils.cached_slot_property("_cs_response")
|
||||
def response(self) -> InteractionResponse:
|
||||
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
|
||||
|
||||
@ -198,13 +217,13 @@ class Interaction:
|
||||
"""
|
||||
return InteractionResponse(self)
|
||||
|
||||
@utils.cached_slot_property('_cs_followup')
|
||||
@utils.cached_slot_property("_cs_followup")
|
||||
def followup(self) -> Webhook:
|
||||
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
|
||||
payload = {
|
||||
'id': self.application_id,
|
||||
'type': 3,
|
||||
'token': self.token,
|
||||
"id": self.application_id,
|
||||
"type": 3,
|
||||
"token": self.token,
|
||||
}
|
||||
return Webhook.from_state(data=payload, state=self._state)
|
||||
|
||||
@ -238,7 +257,7 @@ class Interaction:
|
||||
# TODO: fix later to not raise?
|
||||
channel = self.channel
|
||||
if channel is None:
|
||||
raise ClientException('Channel for message could not be resolved')
|
||||
raise ClientException("Channel for message could not be resolved")
|
||||
|
||||
adapter = async_context.get()
|
||||
data = await adapter.get_original_interaction_response(
|
||||
@ -369,8 +388,8 @@ class InteractionResponse:
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'_responded',
|
||||
'_parent',
|
||||
"_responded",
|
||||
"_parent",
|
||||
)
|
||||
|
||||
def __init__(self, parent: Interaction):
|
||||
@ -416,12 +435,16 @@ class InteractionResponse:
|
||||
elif parent.type is InteractionType.application_command:
|
||||
defer_type = InteractionResponseType.deferred_channel_message.value
|
||||
if ephemeral:
|
||||
data = {'flags': 64}
|
||||
data = {"flags": 64}
|
||||
|
||||
if defer_type:
|
||||
adapter = async_context.get()
|
||||
await adapter.create_interaction_response(
|
||||
parent.id, parent.token, session=parent._session, type=defer_type, data=data
|
||||
parent.id,
|
||||
parent.token,
|
||||
session=parent._session,
|
||||
type=defer_type,
|
||||
data=data,
|
||||
)
|
||||
self._responded = True
|
||||
|
||||
@ -446,7 +469,10 @@ class InteractionResponse:
|
||||
if parent.type is InteractionType.ping:
|
||||
adapter = async_context.get()
|
||||
await adapter.create_interaction_response(
|
||||
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
|
||||
parent.id,
|
||||
parent.token,
|
||||
session=parent._session,
|
||||
type=InteractionResponseType.pong.value,
|
||||
)
|
||||
self._responded = True
|
||||
|
||||
@ -498,28 +524,28 @@ class InteractionResponse:
|
||||
raise InteractionResponded(self._parent)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
'tts': tts,
|
||||
"tts": tts,
|
||||
}
|
||||
|
||||
if embed is not MISSING and embeds is not MISSING:
|
||||
raise TypeError('cannot mix embed and embeds keyword arguments')
|
||||
raise TypeError("cannot mix embed and embeds keyword arguments")
|
||||
|
||||
if embed is not MISSING:
|
||||
embeds = [embed]
|
||||
|
||||
if embeds:
|
||||
if len(embeds) > 10:
|
||||
raise ValueError('embeds cannot exceed maximum of 10 elements')
|
||||
payload['embeds'] = [e.to_dict() for e in embeds]
|
||||
raise ValueError("embeds cannot exceed maximum of 10 elements")
|
||||
payload["embeds"] = [e.to_dict() for e in embeds]
|
||||
|
||||
if content is not None:
|
||||
payload['content'] = str(content)
|
||||
payload["content"] = str(content)
|
||||
|
||||
if ephemeral:
|
||||
payload['flags'] = 64
|
||||
payload["flags"] = 64
|
||||
|
||||
if view is not MISSING:
|
||||
payload['components'] = view.to_components()
|
||||
payload["components"] = view.to_components()
|
||||
|
||||
parent = self._parent
|
||||
adapter = async_context.get()
|
||||
@ -591,12 +617,12 @@ class InteractionResponse:
|
||||
payload = {}
|
||||
if content is not MISSING:
|
||||
if content is None:
|
||||
payload['content'] = None
|
||||
payload["content"] = None
|
||||
else:
|
||||
payload['content'] = str(content)
|
||||
payload["content"] = str(content)
|
||||
|
||||
if embed is not MISSING and embeds is not MISSING:
|
||||
raise TypeError('cannot mix both embed and embeds keyword arguments')
|
||||
raise TypeError("cannot mix both embed and embeds keyword arguments")
|
||||
|
||||
if embed is not MISSING:
|
||||
if embed is None:
|
||||
@ -605,17 +631,17 @@ class InteractionResponse:
|
||||
embeds = [embed]
|
||||
|
||||
if embeds is not MISSING:
|
||||
payload['embeds'] = [e.to_dict() for e in embeds]
|
||||
payload["embeds"] = [e.to_dict() for e in embeds]
|
||||
|
||||
if attachments is not MISSING:
|
||||
payload['attachments'] = [a.to_dict() for a in attachments]
|
||||
payload["attachments"] = [a.to_dict() for a in attachments]
|
||||
|
||||
if view is not MISSING:
|
||||
state.prevent_view_updates_for(message_id)
|
||||
if view is None:
|
||||
payload['components'] = []
|
||||
payload["components"] = []
|
||||
else:
|
||||
payload['components'] = view.to_components()
|
||||
payload["components"] = view.to_components()
|
||||
|
||||
adapter = async_context.get()
|
||||
await adapter.create_interaction_response(
|
||||
@ -633,7 +659,7 @@ class InteractionResponse:
|
||||
|
||||
|
||||
class _InteractionMessageState:
|
||||
__slots__ = ('_parent', '_interaction')
|
||||
__slots__ = ("_parent", "_interaction")
|
||||
|
||||
def __init__(self, interaction: Interaction, parent: ConnectionState):
|
||||
self._interaction: Interaction = interaction
|
||||
|
@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum
|
||||
from .appinfo import PartialAppInfo
|
||||
|
||||
__all__ = (
|
||||
'PartialInviteChannel',
|
||||
'PartialInviteGuild',
|
||||
'Invite',
|
||||
"PartialInviteChannel",
|
||||
"PartialInviteGuild",
|
||||
"Invite",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -52,8 +52,8 @@ if TYPE_CHECKING:
|
||||
from .abc import GuildChannel
|
||||
from .user import User
|
||||
|
||||
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
|
||||
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object]
|
||||
InviteGuildType = Union[Guild, "PartialInviteGuild", Object]
|
||||
InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object]
|
||||
|
||||
import datetime
|
||||
|
||||
@ -92,23 +92,25 @@ class PartialInviteChannel:
|
||||
The partial channel's type.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'name', 'type')
|
||||
__slots__ = ("id", "name", "type")
|
||||
|
||||
def __init__(self, data: InviteChannelPayload):
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.type: ChannelType = try_enum(ChannelType, data['type'])
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self.type: ChannelType = try_enum(ChannelType, data["type"])
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>'
|
||||
return (
|
||||
f"<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>"
|
||||
)
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: The string that allows you to mention the channel."""
|
||||
return f'<#{self.id}>'
|
||||
return f"<#{self.id}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
@ -154,26 +156,38 @@ class PartialInviteGuild:
|
||||
The partial guild's description.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description')
|
||||
__slots__ = (
|
||||
"_state",
|
||||
"features",
|
||||
"_icon",
|
||||
"_banner",
|
||||
"id",
|
||||
"name",
|
||||
"_splash",
|
||||
"verification_level",
|
||||
"description",
|
||||
)
|
||||
|
||||
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = id
|
||||
self.name: str = data['name']
|
||||
self.features: List[str] = data.get('features', [])
|
||||
self._icon: Optional[str] = data.get('icon')
|
||||
self._banner: Optional[str] = data.get('banner')
|
||||
self._splash: Optional[str] = data.get('splash')
|
||||
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
|
||||
self.description: Optional[str] = data.get('description')
|
||||
self.name: str = data["name"]
|
||||
self.features: List[str] = data.get("features", [])
|
||||
self._icon: Optional[str] = data.get("icon")
|
||||
self._banner: Optional[str] = data.get("banner")
|
||||
self._splash: Optional[str] = data.get("splash")
|
||||
self.verification_level: VerificationLevel = try_enum(
|
||||
VerificationLevel, data.get("verification_level")
|
||||
)
|
||||
self.description: Optional[str] = data.get("description")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} '
|
||||
f'description={self.description!r}>'
|
||||
f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} "
|
||||
f"description={self.description!r}>"
|
||||
)
|
||||
|
||||
@property
|
||||
@ -193,17 +207,21 @@ class PartialInviteGuild:
|
||||
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
|
||||
if self._banner is None:
|
||||
return None
|
||||
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners')
|
||||
return Asset._from_guild_image(
|
||||
self._state, self.id, self._banner, path="banners"
|
||||
)
|
||||
|
||||
@property
|
||||
def splash(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
|
||||
if self._splash is None:
|
||||
return None
|
||||
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
|
||||
return Asset._from_guild_image(
|
||||
self._state, self.id, self._splash, path="splashes"
|
||||
)
|
||||
|
||||
|
||||
I = TypeVar('I', bound='Invite')
|
||||
I = TypeVar("I", bound="Invite")
|
||||
|
||||
|
||||
class Invite(Hashable):
|
||||
@ -308,26 +326,26 @@ class Invite(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'max_age',
|
||||
'code',
|
||||
'guild',
|
||||
'revoked',
|
||||
'created_at',
|
||||
'uses',
|
||||
'temporary',
|
||||
'max_uses',
|
||||
'inviter',
|
||||
'channel',
|
||||
'target_user',
|
||||
'target_type',
|
||||
'_state',
|
||||
'approximate_member_count',
|
||||
'approximate_presence_count',
|
||||
'target_application',
|
||||
'expires_at',
|
||||
"max_age",
|
||||
"code",
|
||||
"guild",
|
||||
"revoked",
|
||||
"created_at",
|
||||
"uses",
|
||||
"temporary",
|
||||
"max_uses",
|
||||
"inviter",
|
||||
"channel",
|
||||
"target_user",
|
||||
"target_type",
|
||||
"_state",
|
||||
"approximate_member_count",
|
||||
"approximate_presence_count",
|
||||
"target_application",
|
||||
"expires_at",
|
||||
)
|
||||
|
||||
BASE = 'https://discord.gg'
|
||||
BASE = "https://discord.gg"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -338,45 +356,67 @@ class Invite(Hashable):
|
||||
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.max_age: Optional[int] = data.get('max_age')
|
||||
self.code: str = data['code']
|
||||
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
|
||||
self.revoked: Optional[bool] = data.get('revoked')
|
||||
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
|
||||
self.temporary: Optional[bool] = data.get('temporary')
|
||||
self.uses: Optional[int] = data.get('uses')
|
||||
self.max_uses: Optional[int] = data.get('max_uses')
|
||||
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
|
||||
self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
|
||||
self.max_age: Optional[int] = data.get("max_age")
|
||||
self.code: str = data["code"]
|
||||
self.guild: Optional[InviteGuildType] = self._resolve_guild(
|
||||
data.get("guild"), guild
|
||||
)
|
||||
self.revoked: Optional[bool] = data.get("revoked")
|
||||
self.created_at: Optional[datetime.datetime] = parse_time(
|
||||
data.get("created_at")
|
||||
)
|
||||
self.temporary: Optional[bool] = data.get("temporary")
|
||||
self.uses: Optional[int] = data.get("uses")
|
||||
self.max_uses: Optional[int] = data.get("max_uses")
|
||||
self.approximate_presence_count: Optional[int] = data.get(
|
||||
"approximate_presence_count"
|
||||
)
|
||||
self.approximate_member_count: Optional[int] = data.get(
|
||||
"approximate_member_count"
|
||||
)
|
||||
|
||||
expires_at = data.get('expires_at', None)
|
||||
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
|
||||
expires_at = data.get("expires_at", None)
|
||||
self.expires_at: Optional[datetime.datetime] = (
|
||||
parse_time(expires_at) if expires_at else None
|
||||
)
|
||||
|
||||
inviter_data = data.get('inviter')
|
||||
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
|
||||
inviter_data = data.get("inviter")
|
||||
self.inviter: Optional[User] = (
|
||||
None if inviter_data is None else self._state.create_user(inviter_data)
|
||||
)
|
||||
|
||||
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
|
||||
self.channel: Optional[InviteChannelType] = self._resolve_channel(
|
||||
data.get("channel"), channel
|
||||
)
|
||||
|
||||
target_user_data = data.get('target_user')
|
||||
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
|
||||
target_user_data = data.get("target_user")
|
||||
self.target_user: Optional[User] = (
|
||||
None
|
||||
if target_user_data is None
|
||||
else self._state.create_user(target_user_data)
|
||||
)
|
||||
|
||||
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
|
||||
self.target_type: InviteTarget = try_enum(
|
||||
InviteTarget, data.get("target_type", 0)
|
||||
)
|
||||
|
||||
application = data.get('target_application')
|
||||
application = data.get("target_application")
|
||||
self.target_application: Optional[PartialAppInfo] = (
|
||||
PartialAppInfo(data=application, state=state) if application else None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I:
|
||||
def from_incomplete(
|
||||
cls: Type[I], *, state: ConnectionState, data: InvitePayload
|
||||
) -> I:
|
||||
guild: Optional[Union[Guild, PartialInviteGuild]]
|
||||
try:
|
||||
guild_data = data['guild']
|
||||
guild_data = data["guild"]
|
||||
except KeyError:
|
||||
# If we're here, then this is a group DM
|
||||
guild = None
|
||||
else:
|
||||
guild_id = int(guild_data['id'])
|
||||
guild_id = int(guild_data["id"])
|
||||
guild = state._get_guild(guild_id)
|
||||
if guild is None:
|
||||
# If it's not cached, then it has to be a partial guild
|
||||
@ -384,7 +424,9 @@ class Invite(Hashable):
|
||||
|
||||
# As far as I know, invites always need a channel
|
||||
# So this should never raise.
|
||||
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel'])
|
||||
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(
|
||||
data["channel"]
|
||||
)
|
||||
if guild is not None and not isinstance(guild, PartialInviteGuild):
|
||||
# Upgrade the partial data if applicable
|
||||
channel = guild.get_channel(channel.id) or channel
|
||||
@ -392,10 +434,12 @@ class Invite(Hashable):
|
||||
return cls(state=state, data=data, guild=guild, channel=channel)
|
||||
|
||||
@classmethod
|
||||
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
|
||||
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
|
||||
def from_gateway(
|
||||
cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload
|
||||
) -> I:
|
||||
guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
|
||||
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
|
||||
channel_id = int(data['channel_id'])
|
||||
channel_id = int(data["channel_id"])
|
||||
if guild is not None:
|
||||
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
|
||||
else:
|
||||
@ -415,7 +459,7 @@ class Invite(Hashable):
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
guild_id = int(data['id'])
|
||||
guild_id = int(data["id"])
|
||||
return PartialInviteGuild(self._state, data, guild_id)
|
||||
|
||||
def _resolve_channel(
|
||||
@ -439,9 +483,9 @@ class Invite(Hashable):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Invite code={self.code!r} guild={self.guild!r} '
|
||||
f'online={self.approximate_presence_count} '
|
||||
f'members={self.approximate_member_count}>'
|
||||
f"<Invite code={self.code!r} guild={self.guild!r} "
|
||||
f"online={self.approximate_presence_count} "
|
||||
f"members={self.approximate_member_count}>"
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -455,7 +499,7 @@ class Invite(Hashable):
|
||||
@property
|
||||
def url(self) -> str:
|
||||
""":class:`str`: A property that retrieves the invite URL."""
|
||||
return self.BASE + '/' + self.code
|
||||
return self.BASE + "/" + self.code
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None):
|
||||
"""|coro|
|
||||
|
@ -26,7 +26,17 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator
|
||||
from typing import (
|
||||
Awaitable,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Optional,
|
||||
Any,
|
||||
Callable,
|
||||
Union,
|
||||
List,
|
||||
AsyncIterator,
|
||||
)
|
||||
|
||||
from .errors import NoMoreItems
|
||||
from .utils import snowflake_time, time_snowflake, maybe_coroutine
|
||||
@ -34,11 +44,11 @@ from .object import Object
|
||||
from .audit_logs import AuditLogEntry
|
||||
|
||||
__all__ = (
|
||||
'ReactionIterator',
|
||||
'HistoryIterator',
|
||||
'AuditLogIterator',
|
||||
'GuildIterator',
|
||||
'MemberIterator',
|
||||
"ReactionIterator",
|
||||
"HistoryIterator",
|
||||
"AuditLogIterator",
|
||||
"GuildIterator",
|
||||
"MemberIterator",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -67,8 +77,8 @@ if TYPE_CHECKING:
|
||||
from .threads import Thread
|
||||
from .abc import Snowflake
|
||||
|
||||
T = TypeVar('T')
|
||||
OT = TypeVar('OT')
|
||||
T = TypeVar("T")
|
||||
OT = TypeVar("OT")
|
||||
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
|
||||
|
||||
OLDEST_OBJECT = Object(id=0)
|
||||
@ -83,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
|
||||
def predicate(elem: T):
|
||||
for attr, val in attrs.items():
|
||||
nested = attr.split('__')
|
||||
nested = attr.split("__")
|
||||
obj = elem
|
||||
for attribute in nested:
|
||||
obj = getattr(obj, attribute)
|
||||
@ -107,7 +117,7 @@ class _AsyncIterator(AsyncIterator[T]):
|
||||
|
||||
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
|
||||
if max_size <= 0:
|
||||
raise ValueError('async iterator chunk sizes must be greater than 0.')
|
||||
raise ValueError("async iterator chunk sizes must be greater than 0.")
|
||||
return _ChunkedAsyncIterator(self, max_size)
|
||||
|
||||
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
|
||||
@ -182,7 +192,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
|
||||
return item
|
||||
|
||||
|
||||
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
|
||||
def __init__(self, message, emoji, limit=100, after=None):
|
||||
self.message = message
|
||||
self.limit = limit
|
||||
@ -218,14 +228,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
|
||||
if data:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[-1]['id']))
|
||||
self.after = Object(id=int(data[-1]["id"]))
|
||||
|
||||
if self.guild is None or isinstance(self.guild, Object):
|
||||
for element in reversed(data):
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
else:
|
||||
for element in reversed(data):
|
||||
member_id = int(element['id'])
|
||||
member_id = int(element["id"])
|
||||
member = self.guild.get_member(member_id)
|
||||
if member is not None:
|
||||
await self.users.put(member)
|
||||
@ -233,7 +243,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
|
||||
await self.users.put(User(state=self.state, data=element))
|
||||
|
||||
|
||||
class HistoryIterator(_AsyncIterator['Message']):
|
||||
class HistoryIterator(_AsyncIterator["Message"]):
|
||||
"""Iterator for receiving a channel's message history.
|
||||
|
||||
The messages endpoint has two behaviours we care about here:
|
||||
@ -267,7 +277,15 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
``True`` if `after` is specified, otherwise ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None):
|
||||
def __init__(
|
||||
self,
|
||||
messageable,
|
||||
limit,
|
||||
before=None,
|
||||
after=None,
|
||||
around=None,
|
||||
oldest_first=None,
|
||||
):
|
||||
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
@ -295,28 +313,30 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
|
||||
if self.around:
|
||||
if self.limit is None:
|
||||
raise ValueError('history does not support around with limit=None')
|
||||
raise ValueError("history does not support around with limit=None")
|
||||
if self.limit > 101:
|
||||
raise ValueError("history max limit 101 when specifying around parameter")
|
||||
raise ValueError(
|
||||
"history max limit 101 when specifying around parameter"
|
||||
)
|
||||
elif self.limit == 101:
|
||||
self.limit = 100 # Thanks discord
|
||||
|
||||
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
|
||||
if self.before and self.after:
|
||||
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
|
||||
self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
|
||||
elif self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
self._filter = lambda m: int(m["id"]) < self.before.id
|
||||
elif self.after:
|
||||
self._filter = lambda m: self.after.id < int(m['id'])
|
||||
self._filter = lambda m: self.after.id < int(m["id"])
|
||||
else:
|
||||
if self.reverse:
|
||||
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
|
||||
if self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
self._filter = lambda m: int(m["id"]) < self.before.id
|
||||
else:
|
||||
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
|
||||
if self.after and self.after != OLDEST_OBJECT:
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
self._filter = lambda m: int(m["id"]) > self.after.id
|
||||
|
||||
async def next(self) -> Message:
|
||||
if self.messages.empty():
|
||||
@ -337,7 +357,7 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
return r > 0
|
||||
|
||||
async def fill_messages(self):
|
||||
if not hasattr(self, 'channel'):
|
||||
if not hasattr(self, "channel"):
|
||||
# do the required set up
|
||||
channel = await self.messageable._get_channel()
|
||||
self.channel = channel
|
||||
@ -354,7 +374,9 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
|
||||
channel = self.channel
|
||||
for element in data:
|
||||
await self.messages.put(self.state.create_message(channel=channel, data=element))
|
||||
await self.messages.put(
|
||||
self.state.create_message(channel=channel, data=element)
|
||||
)
|
||||
|
||||
async def _retrieve_messages(self, retrieve) -> List[Message]:
|
||||
"""Retrieve messages and update next parameters."""
|
||||
@ -363,35 +385,50 @@ class HistoryIterator(_AsyncIterator['Message']):
|
||||
async def _retrieve_messages_before_strategy(self, retrieve):
|
||||
"""Retrieve messages using before parameter."""
|
||||
before = self.before.id if self.before else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before)
|
||||
data: List[MessagePayload] = await self.logs_from(
|
||||
self.channel.id, retrieve, before=before
|
||||
)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(data[-1]['id']))
|
||||
self.before = Object(id=int(data[-1]["id"]))
|
||||
return data
|
||||
|
||||
async def _retrieve_messages_after_strategy(self, retrieve):
|
||||
"""Retrieve messages using after parameter."""
|
||||
after = self.after.id if self.after else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after)
|
||||
data: List[MessagePayload] = await self.logs_from(
|
||||
self.channel.id, retrieve, after=after
|
||||
)
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
self.after = Object(id=int(data[0]["id"]))
|
||||
return data
|
||||
|
||||
async def _retrieve_messages_around_strategy(self, retrieve):
|
||||
"""Retrieve messages using around parameter."""
|
||||
if self.around:
|
||||
around = self.around.id if self.around else None
|
||||
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around)
|
||||
data: List[MessagePayload] = await self.logs_from(
|
||||
self.channel.id, retrieve, around=around
|
||||
)
|
||||
self.around = None
|
||||
return data
|
||||
return []
|
||||
|
||||
|
||||
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
|
||||
class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
|
||||
def __init__(
|
||||
self,
|
||||
guild,
|
||||
limit=None,
|
||||
before=None,
|
||||
after=None,
|
||||
oldest_first=None,
|
||||
user_id=None,
|
||||
action_type=None,
|
||||
):
|
||||
if isinstance(before, datetime.datetime):
|
||||
before = Object(id=time_snowflake(before, high=False))
|
||||
if isinstance(after, datetime.datetime):
|
||||
@ -420,36 +457,44 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
if self.reverse:
|
||||
self._strategy = self._after_strategy
|
||||
if self.before:
|
||||
self._filter = lambda m: int(m['id']) < self.before.id
|
||||
self._filter = lambda m: int(m["id"]) < self.before.id
|
||||
else:
|
||||
self._strategy = self._before_strategy
|
||||
if self.after and self.after != OLDEST_OBJECT:
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
self._filter = lambda m: int(m["id"]) > self.after.id
|
||||
|
||||
async def _before_strategy(self, retrieve):
|
||||
before = self.before.id if self.before else None
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
|
||||
self.guild.id,
|
||||
limit=retrieve,
|
||||
user_id=self.user_id,
|
||||
action_type=self.action_type,
|
||||
before=before,
|
||||
)
|
||||
|
||||
entries = data.get('audit_log_entries', [])
|
||||
entries = data.get("audit_log_entries", [])
|
||||
if len(data) and entries:
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(entries[-1]['id']))
|
||||
return data.get('users', []), entries
|
||||
self.before = Object(id=int(entries[-1]["id"]))
|
||||
return data.get("users", []), entries
|
||||
|
||||
async def _after_strategy(self, retrieve):
|
||||
after = self.after.id if self.after else None
|
||||
data: AuditLogPayload = await self.request(
|
||||
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
|
||||
self.guild.id,
|
||||
limit=retrieve,
|
||||
user_id=self.user_id,
|
||||
action_type=self.action_type,
|
||||
after=after,
|
||||
)
|
||||
entries = data.get('audit_log_entries', [])
|
||||
entries = data.get("audit_log_entries", [])
|
||||
if len(data) and entries:
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(entries[0]['id']))
|
||||
return data.get('users', []), entries
|
||||
self.after = Object(id=int(entries[0]["id"]))
|
||||
return data.get("users", []), entries
|
||||
|
||||
async def next(self) -> AuditLogEntry:
|
||||
if self.entries.empty():
|
||||
@ -488,13 +533,15 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
|
||||
|
||||
for element in data:
|
||||
# TODO: remove this if statement later
|
||||
if element['action_type'] is None:
|
||||
if element["action_type"] is None:
|
||||
continue
|
||||
|
||||
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
|
||||
await self.entries.put(
|
||||
AuditLogEntry(data=element, users=self._users, guild=self.guild)
|
||||
)
|
||||
|
||||
|
||||
class GuildIterator(_AsyncIterator['Guild']):
|
||||
class GuildIterator(_AsyncIterator["Guild"]):
|
||||
"""Iterator for receiving the client's guilds.
|
||||
|
||||
The guilds endpoint has the same two behaviours as described
|
||||
@ -543,7 +590,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
|
||||
if self.before and self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
|
||||
self._filter = lambda m: int(m['id']) > self.after.id
|
||||
self._filter = lambda m: int(m["id"]) > self.after.id
|
||||
elif self.after:
|
||||
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
|
||||
else:
|
||||
@ -595,7 +642,7 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.before = Object(id=int(data[-1]['id']))
|
||||
self.before = Object(id=int(data[-1]["id"]))
|
||||
return data
|
||||
|
||||
async def _retrieve_guilds_after_strategy(self, retrieve):
|
||||
@ -605,11 +652,11 @@ class GuildIterator(_AsyncIterator['Guild']):
|
||||
if len(data):
|
||||
if self.limit is not None:
|
||||
self.limit -= retrieve
|
||||
self.after = Object(id=int(data[0]['id']))
|
||||
self.after = Object(id=int(data[0]["id"]))
|
||||
return data
|
||||
|
||||
|
||||
class MemberIterator(_AsyncIterator['Member']):
|
||||
class MemberIterator(_AsyncIterator["Member"]):
|
||||
def __init__(self, guild, limit=1000, after=None):
|
||||
|
||||
if isinstance(after, datetime.datetime):
|
||||
@ -652,7 +699,7 @@ class MemberIterator(_AsyncIterator['Member']):
|
||||
if len(data) < 1000:
|
||||
self.limit = 0 # terminate loop
|
||||
|
||||
self.after = Object(id=int(data[-1]['user']['id']))
|
||||
self.after = Object(id=int(data[-1]["user"]["id"]))
|
||||
|
||||
for element in reversed(data):
|
||||
await self.members.put(self.create_member(element))
|
||||
@ -663,7 +710,7 @@ class MemberIterator(_AsyncIterator['Member']):
|
||||
return Member(data=data, guild=self.guild, state=self.state)
|
||||
|
||||
|
||||
class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
class ArchivedThreadIterator(_AsyncIterator["Thread"]):
|
||||
def __init__(
|
||||
self,
|
||||
channel_id: int,
|
||||
@ -681,7 +728,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
self.http = guild._state.http
|
||||
|
||||
if joined and not private:
|
||||
raise ValueError('Cannot iterate over joined public archived threads')
|
||||
raise ValueError("Cannot iterate over joined public archived threads")
|
||||
|
||||
self.before: Optional[str]
|
||||
if before is None:
|
||||
@ -721,11 +768,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
|
||||
@staticmethod
|
||||
def get_archive_timestamp(data: ThreadPayload) -> str:
|
||||
return data['thread_metadata']['archive_timestamp']
|
||||
return data["thread_metadata"]["archive_timestamp"]
|
||||
|
||||
@staticmethod
|
||||
def get_thread_id(data: ThreadPayload) -> str:
|
||||
return data['id'] # type: ignore
|
||||
return data["id"] # type: ignore
|
||||
|
||||
async def fill_queue(self) -> None:
|
||||
if not self.has_more:
|
||||
@ -735,11 +782,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
|
||||
|
||||
# This stuff is obviously WIP because 'members' is always empty
|
||||
threads: List[ThreadPayload] = data.get('threads', [])
|
||||
threads: List[ThreadPayload] = data.get("threads", [])
|
||||
for d in reversed(threads):
|
||||
self.queue.put_nowait(self.create_thread(d))
|
||||
|
||||
self.has_more = data.get('has_more', False)
|
||||
self.has_more = data.get("has_more", False)
|
||||
if self.limit is not None:
|
||||
self.limit -= len(threads)
|
||||
if self.limit <= 0:
|
||||
@ -750,4 +797,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
|
||||
|
||||
def create_thread(self, data: ThreadPayload) -> Thread:
|
||||
from .threads import Thread
|
||||
|
||||
return Thread(guild=self.guild, state=self.guild._state, data=data)
|
||||
|
@ -29,7 +29,19 @@ import inspect
|
||||
import itertools
|
||||
import sys
|
||||
from operator import attrgetter
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import discord.abc
|
||||
|
||||
@ -44,8 +56,8 @@ from .colour import Colour
|
||||
from .object import Object
|
||||
|
||||
__all__ = (
|
||||
'VoiceState',
|
||||
'Member',
|
||||
"VoiceState",
|
||||
"Member",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -113,52 +125,58 @@ class VoiceState:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'session_id',
|
||||
'deaf',
|
||||
'mute',
|
||||
'self_mute',
|
||||
'self_stream',
|
||||
'self_video',
|
||||
'self_deaf',
|
||||
'afk',
|
||||
'channel',
|
||||
'requested_to_speak_at',
|
||||
'suppress',
|
||||
"session_id",
|
||||
"deaf",
|
||||
"mute",
|
||||
"self_mute",
|
||||
"self_stream",
|
||||
"self_video",
|
||||
"self_deaf",
|
||||
"afk",
|
||||
"channel",
|
||||
"requested_to_speak_at",
|
||||
"suppress",
|
||||
)
|
||||
|
||||
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
|
||||
self.session_id: str = data.get('session_id')
|
||||
def __init__(
|
||||
self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None
|
||||
):
|
||||
self.session_id: str = data.get("session_id")
|
||||
self._update(data, channel)
|
||||
|
||||
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
|
||||
self.self_mute: bool = data.get('self_mute', False)
|
||||
self.self_deaf: bool = data.get('self_deaf', False)
|
||||
self.self_stream: bool = data.get('self_stream', False)
|
||||
self.self_video: bool = data.get('self_video', False)
|
||||
self.afk: bool = data.get('suppress', False)
|
||||
self.mute: bool = data.get('mute', False)
|
||||
self.deaf: bool = data.get('deaf', False)
|
||||
self.suppress: bool = data.get('suppress', False)
|
||||
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp'))
|
||||
self.self_mute: bool = data.get("self_mute", False)
|
||||
self.self_deaf: bool = data.get("self_deaf", False)
|
||||
self.self_stream: bool = data.get("self_stream", False)
|
||||
self.self_video: bool = data.get("self_video", False)
|
||||
self.afk: bool = data.get("suppress", False)
|
||||
self.mute: bool = data.get("mute", False)
|
||||
self.deaf: bool = data.get("deaf", False)
|
||||
self.suppress: bool = data.get("suppress", False)
|
||||
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(
|
||||
data.get("request_to_speak_timestamp")
|
||||
)
|
||||
self.channel: Optional[VocalGuildChannel] = channel
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = [
|
||||
('self_mute', self.self_mute),
|
||||
('self_deaf', self.self_deaf),
|
||||
('self_stream', self.self_stream),
|
||||
('suppress', self.suppress),
|
||||
('requested_to_speak_at', self.requested_to_speak_at),
|
||||
('channel', self.channel),
|
||||
("self_mute", self.self_mute),
|
||||
("self_deaf", self.self_deaf),
|
||||
("self_stream", self.self_stream),
|
||||
("suppress", self.suppress),
|
||||
("requested_to_speak_at", self.requested_to_speak_at),
|
||||
("channel", self.channel),
|
||||
]
|
||||
inner = ' '.join('%s=%r' % t for t in attrs)
|
||||
return f'<{self.__class__.__name__} {inner}>'
|
||||
inner = " ".join("%s=%r" % t for t in attrs)
|
||||
return f"<{self.__class__.__name__} {inner}>"
|
||||
|
||||
|
||||
def flatten_user(cls):
|
||||
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
|
||||
for attr, value in itertools.chain(
|
||||
BaseUser.__dict__.items(), User.__dict__.items()
|
||||
):
|
||||
# ignore private/special methods
|
||||
if attr.startswith('_'):
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
|
||||
# don't override what we already have
|
||||
@ -167,9 +185,11 @@ def flatten_user(cls):
|
||||
|
||||
# if it's a slotted attribute or a property, redirect it
|
||||
# slotted members are implemented as member_descriptors in Type.__dict__
|
||||
if not hasattr(value, '__annotations__'):
|
||||
getter = attrgetter('_user.' + attr)
|
||||
setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`'))
|
||||
if not hasattr(value, "__annotations__"):
|
||||
getter = attrgetter("_user." + attr)
|
||||
setattr(
|
||||
cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`")
|
||||
)
|
||||
else:
|
||||
# Technically, this can also use attrgetter
|
||||
# However I'm not sure how I feel about "functions" returning properties
|
||||
@ -197,7 +217,7 @@ def flatten_user(cls):
|
||||
return cls
|
||||
|
||||
|
||||
M = TypeVar('M', bound='Member')
|
||||
M = TypeVar("M", bound="Member")
|
||||
|
||||
|
||||
@flatten_user
|
||||
@ -258,17 +278,17 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_roles',
|
||||
'joined_at',
|
||||
'premium_since',
|
||||
'activities',
|
||||
'guild',
|
||||
'pending',
|
||||
'nick',
|
||||
'_client_status',
|
||||
'_user',
|
||||
'_state',
|
||||
'_avatar',
|
||||
"_roles",
|
||||
"joined_at",
|
||||
"premium_since",
|
||||
"activities",
|
||||
"guild",
|
||||
"pending",
|
||||
"nick",
|
||||
"_client_status",
|
||||
"_user",
|
||||
"_state",
|
||||
"_avatar",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -288,18 +308,24 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
accent_color: Optional[Colour]
|
||||
accent_colour: Optional[Colour]
|
||||
|
||||
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
|
||||
def __init__(
|
||||
self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self._user: User = state.store_user(data['user'])
|
||||
self._user: User = state.store_user(data["user"])
|
||||
self.guild: Guild = guild
|
||||
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
|
||||
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
|
||||
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
|
||||
self._client_status: Dict[Optional[str], str] = {None: 'offline'}
|
||||
self.joined_at: Optional[datetime.datetime] = utils.parse_time(
|
||||
data.get("joined_at")
|
||||
)
|
||||
self.premium_since: Optional[datetime.datetime] = utils.parse_time(
|
||||
data.get("premium_since")
|
||||
)
|
||||
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"]))
|
||||
self._client_status: Dict[Optional[str], str] = {None: "offline"}
|
||||
self.activities: Tuple[ActivityTypes, ...] = tuple()
|
||||
self.nick: Optional[str] = data.get('nick', None)
|
||||
self.pending: bool = data.get('pending', False)
|
||||
self._avatar: Optional[str] = data.get('avatar')
|
||||
self.nick: Optional[str] = data.get("nick", None)
|
||||
self.pending: bool = data.get("pending", False)
|
||||
self._avatar: Optional[str] = data.get("avatar")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._user)
|
||||
@ -309,8 +335,8 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
|
||||
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
|
||||
f"<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}"
|
||||
f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
@ -325,25 +351,31 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
@classmethod
|
||||
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
|
||||
author = message.author
|
||||
data['user'] = author._to_minimal_user_json() # type: ignore
|
||||
data["user"] = author._to_minimal_user_json() # type: ignore
|
||||
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
|
||||
|
||||
def _update_from_message(self, data: MemberPayload) -> None:
|
||||
self.joined_at = utils.parse_time(data.get('joined_at'))
|
||||
self.premium_since = utils.parse_time(data.get('premium_since'))
|
||||
self._roles = utils.SnowflakeList(map(int, data['roles']))
|
||||
self.nick = data.get('nick', None)
|
||||
self.pending = data.get('pending', False)
|
||||
self.joined_at = utils.parse_time(data.get("joined_at"))
|
||||
self.premium_since = utils.parse_time(data.get("premium_since"))
|
||||
self._roles = utils.SnowflakeList(map(int, data["roles"]))
|
||||
self.nick = data.get("nick", None)
|
||||
self.pending = data.get("pending", False)
|
||||
|
||||
@classmethod
|
||||
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]:
|
||||
def _try_upgrade(
|
||||
cls: Type[M],
|
||||
*,
|
||||
data: UserWithMemberPayload,
|
||||
guild: Guild,
|
||||
state: ConnectionState,
|
||||
) -> Union[User, M]:
|
||||
# A User object with a 'member' key
|
||||
try:
|
||||
member_data = data.pop('member')
|
||||
member_data = data.pop("member")
|
||||
except KeyError:
|
||||
return state.create_user(data)
|
||||
else:
|
||||
member_data['user'] = data # type: ignore
|
||||
member_data["user"] = data # type: ignore
|
||||
return cls(data=member_data, guild=guild, state=state) # type: ignore
|
||||
|
||||
@classmethod
|
||||
@ -374,25 +406,27 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
# the nickname change is optional,
|
||||
# if it isn't in the payload then it didn't change
|
||||
try:
|
||||
self.nick = data['nick']
|
||||
self.nick = data["nick"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.pending = data['pending']
|
||||
self.pending = data["pending"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
self.premium_since = utils.parse_time(data.get('premium_since'))
|
||||
self._roles = utils.SnowflakeList(map(int, data['roles']))
|
||||
self._avatar = data.get('avatar')
|
||||
self.premium_since = utils.parse_time(data.get("premium_since"))
|
||||
self._roles = utils.SnowflakeList(map(int, data["roles"]))
|
||||
self._avatar = data.get("avatar")
|
||||
|
||||
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
|
||||
self.activities = tuple(map(create_activity, data['activities']))
|
||||
def _presence_update(
|
||||
self, data: PartialPresenceUpdate, user: UserPayload
|
||||
) -> Optional[Tuple[User, User]]:
|
||||
self.activities = tuple(map(create_activity, data["activities"]))
|
||||
self._client_status = {
|
||||
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore
|
||||
sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore
|
||||
}
|
||||
self._client_status[None] = sys.intern(data['status'])
|
||||
self._client_status[None] = sys.intern(data["status"])
|
||||
|
||||
if len(user) > 1:
|
||||
return self._update_inner_user(user)
|
||||
@ -402,7 +436,12 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
u = self._user
|
||||
original = (u.name, u._avatar, u.discriminator, u._public_flags)
|
||||
# These keys seem to always be available
|
||||
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
|
||||
modified = (
|
||||
user["username"],
|
||||
user["avatar"],
|
||||
user["discriminator"],
|
||||
user.get("public_flags", 0),
|
||||
)
|
||||
if original != modified:
|
||||
to_return = User._copy(self._user)
|
||||
u.name, u._avatar, u.discriminator, u._public_flags = modified
|
||||
@ -430,21 +469,21 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
@property
|
||||
def mobile_status(self) -> Status:
|
||||
""":class:`Status`: The member's status on a mobile device, if applicable."""
|
||||
return try_enum(Status, self._client_status.get('mobile', 'offline'))
|
||||
return try_enum(Status, self._client_status.get("mobile", "offline"))
|
||||
|
||||
@property
|
||||
def desktop_status(self) -> Status:
|
||||
""":class:`Status`: The member's status on the desktop client, if applicable."""
|
||||
return try_enum(Status, self._client_status.get('desktop', 'offline'))
|
||||
return try_enum(Status, self._client_status.get("desktop", "offline"))
|
||||
|
||||
@property
|
||||
def web_status(self) -> Status:
|
||||
""":class:`Status`: The member's status on the web client, if applicable."""
|
||||
return try_enum(Status, self._client_status.get('web', 'offline'))
|
||||
return try_enum(Status, self._client_status.get("web", "offline"))
|
||||
|
||||
def is_on_mobile(self) -> bool:
|
||||
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
|
||||
return 'mobile' in self._client_status
|
||||
return "mobile" in self._client_status
|
||||
|
||||
@property
|
||||
def colour(self) -> Colour:
|
||||
@ -497,8 +536,8 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: Returns a string that allows you to mention the member."""
|
||||
if self.nick:
|
||||
return f'<@!{self._user.id}>'
|
||||
return f'<@{self._user.id}>'
|
||||
return f"<@!{self._user.id}>"
|
||||
return f"<@{self._user.id}>"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
@ -531,7 +570,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
if self._avatar is None:
|
||||
return None
|
||||
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
|
||||
return Asset._from_guild_avatar(
|
||||
self._state, self.guild.id, self.id, self._avatar
|
||||
)
|
||||
|
||||
@property
|
||||
def activity(self) -> Optional[ActivityTypes]:
|
||||
@ -625,7 +666,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
|
||||
Bans this member. Equivalent to :meth:`Guild.ban`.
|
||||
"""
|
||||
await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days)
|
||||
await self.guild.ban(
|
||||
self, reason=reason, delete_message_days=delete_message_days
|
||||
)
|
||||
|
||||
async def unban(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
@ -720,39 +763,41 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
if nick is not MISSING:
|
||||
nick = nick or ''
|
||||
nick = nick or ""
|
||||
if me:
|
||||
await http.change_my_nickname(guild_id, nick, reason=reason)
|
||||
else:
|
||||
payload['nick'] = nick
|
||||
payload["nick"] = nick
|
||||
|
||||
if deafen is not MISSING:
|
||||
payload['deaf'] = deafen
|
||||
payload["deaf"] = deafen
|
||||
|
||||
if mute is not MISSING:
|
||||
payload['mute'] = mute
|
||||
payload["mute"] = mute
|
||||
|
||||
if suppress is not MISSING:
|
||||
voice_state_payload = {
|
||||
'channel_id': self.voice.channel.id,
|
||||
'suppress': suppress,
|
||||
"channel_id": self.voice.channel.id,
|
||||
"suppress": suppress,
|
||||
}
|
||||
|
||||
if suppress or self.bot:
|
||||
voice_state_payload['request_to_speak_timestamp'] = None
|
||||
voice_state_payload["request_to_speak_timestamp"] = None
|
||||
|
||||
if me:
|
||||
await http.edit_my_voice_state(guild_id, voice_state_payload)
|
||||
else:
|
||||
if not suppress:
|
||||
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat()
|
||||
voice_state_payload[
|
||||
"request_to_speak_timestamp"
|
||||
] = datetime.datetime.utcnow().isoformat()
|
||||
await http.edit_voice_state(guild_id, self.id, voice_state_payload)
|
||||
|
||||
if voice_channel is not MISSING:
|
||||
payload['channel_id'] = voice_channel and voice_channel.id
|
||||
payload["channel_id"] = voice_channel and voice_channel.id
|
||||
|
||||
if roles is not MISSING:
|
||||
payload['roles'] = tuple(r.id for r in roles)
|
||||
payload["roles"] = tuple(r.id for r in roles)
|
||||
|
||||
if payload:
|
||||
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
|
||||
@ -780,17 +825,19 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
The operation failed.
|
||||
"""
|
||||
payload = {
|
||||
'channel_id': self.voice.channel.id,
|
||||
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),
|
||||
"channel_id": self.voice.channel.id,
|
||||
"request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
if self._state.self_id != self.id:
|
||||
payload['suppress'] = False
|
||||
payload["suppress"] = False
|
||||
await self._state.http.edit_voice_state(self.guild.id, self.id, payload)
|
||||
else:
|
||||
await self._state.http.edit_my_voice_state(self.guild.id, payload)
|
||||
|
||||
async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None:
|
||||
async def move_to(
|
||||
self, channel: VocalGuildChannel, *, reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Moves a member to a new voice channel (they must be connected first).
|
||||
@ -813,7 +860,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
await self.edit(voice_channel=channel, reason=reason)
|
||||
|
||||
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
|
||||
async def add_roles(
|
||||
self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True
|
||||
) -> None:
|
||||
r"""|coro|
|
||||
|
||||
Gives the member a number of :class:`Role`\s.
|
||||
@ -843,7 +892,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
"""
|
||||
|
||||
if not atomic:
|
||||
new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s)
|
||||
new_roles = utils._unique(
|
||||
Object(id=r.id) for s in (self.roles[1:], roles) for r in s
|
||||
)
|
||||
await self.edit(roles=new_roles, reason=reason)
|
||||
else:
|
||||
req = self._state.http.add_role
|
||||
@ -852,7 +903,9 @@ class Member(discord.abc.Messageable, _UserTag):
|
||||
for role in roles:
|
||||
await req(guild_id, user_id, role.id, reason=reason)
|
||||
|
||||
async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
|
||||
async def remove_roles(
|
||||
self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True
|
||||
) -> None:
|
||||
r"""|coro|
|
||||
|
||||
Removes :class:`Role`\s from this member.
|
||||
|
@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
from __future__ import annotations
|
||||
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
|
||||
|
||||
__all__ = (
|
||||
'AllowedMentions',
|
||||
)
|
||||
__all__ = ("AllowedMentions",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.message import AllowedMentions as AllowedMentionsPayload
|
||||
@ -36,7 +34,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class _FakeBool:
|
||||
def __repr__(self):
|
||||
return 'True'
|
||||
return "True"
|
||||
|
||||
def __eq__(self, other):
|
||||
return other is True
|
||||
@ -47,7 +45,7 @@ class _FakeBool:
|
||||
|
||||
default: Any = _FakeBool()
|
||||
|
||||
A = TypeVar('A', bound='AllowedMentions')
|
||||
A = TypeVar("A", bound="AllowedMentions")
|
||||
|
||||
|
||||
class AllowedMentions:
|
||||
@ -80,7 +78,7 @@ class AllowedMentions:
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
|
||||
__slots__ = ('everyone', 'users', 'roles', 'replied_user')
|
||||
__slots__ = ("everyone", "users", "roles", "replied_user")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -116,22 +114,22 @@ class AllowedMentions:
|
||||
data = {}
|
||||
|
||||
if self.everyone:
|
||||
parse.append('everyone')
|
||||
parse.append("everyone")
|
||||
|
||||
if self.users == True:
|
||||
parse.append('users')
|
||||
parse.append("users")
|
||||
elif self.users != False:
|
||||
data['users'] = [x.id for x in self.users]
|
||||
data["users"] = [x.id for x in self.users]
|
||||
|
||||
if self.roles == True:
|
||||
parse.append('roles')
|
||||
parse.append("roles")
|
||||
elif self.roles != False:
|
||||
data['roles'] = [x.id for x in self.roles]
|
||||
data["roles"] = [x.id for x in self.roles]
|
||||
|
||||
if self.replied_user:
|
||||
data['replied_user'] = True
|
||||
data["replied_user"] = True
|
||||
|
||||
data['parse'] = parse
|
||||
data["parse"] = parse
|
||||
return data # type: ignore
|
||||
|
||||
def merge(self, other: AllowedMentions) -> AllowedMentions:
|
||||
@ -141,11 +139,15 @@ class AllowedMentions:
|
||||
everyone = self.everyone if other.everyone is default else other.everyone
|
||||
users = self.users if other.users is default else other.users
|
||||
roles = self.roles if other.roles is default else other.roles
|
||||
replied_user = self.replied_user if other.replied_user is default else other.replied_user
|
||||
return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user)
|
||||
replied_user = (
|
||||
self.replied_user if other.replied_user is default else other.replied_user
|
||||
)
|
||||
return AllowedMentions(
|
||||
everyone=everyone, roles=roles, users=users, replied_user=replied_user
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(everyone={self.everyone}, '
|
||||
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})'
|
||||
f"{self.__class__.__name__}(everyone={self.everyone}, "
|
||||
f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})"
|
||||
)
|
||||
|
@ -29,7 +29,21 @@ import datetime
|
||||
import re
|
||||
import io
|
||||
from os import PathLike
|
||||
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
|
||||
from typing import (
|
||||
Dict,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
List,
|
||||
Optional,
|
||||
Any,
|
||||
Callable,
|
||||
Tuple,
|
||||
ClassVar,
|
||||
Optional,
|
||||
overload,
|
||||
TypeVar,
|
||||
Type,
|
||||
)
|
||||
|
||||
from . import utils
|
||||
from .reaction import Reaction
|
||||
@ -76,15 +90,15 @@ if TYPE_CHECKING:
|
||||
from .role import Role
|
||||
from .ui.view import View
|
||||
|
||||
MR = TypeVar('MR', bound='MessageReference')
|
||||
MR = TypeVar("MR", bound="MessageReference")
|
||||
EmojiInputType = Union[Emoji, PartialEmoji, str]
|
||||
|
||||
__all__ = (
|
||||
'Attachment',
|
||||
'Message',
|
||||
'PartialMessage',
|
||||
'MessageReference',
|
||||
'DeletedReferencedMessage',
|
||||
"Attachment",
|
||||
"Message",
|
||||
"PartialMessage",
|
||||
"MessageReference",
|
||||
"DeletedReferencedMessage",
|
||||
)
|
||||
|
||||
|
||||
@ -93,15 +107,17 @@ def convert_emoji_reaction(emoji):
|
||||
emoji = emoji.emoji
|
||||
|
||||
if isinstance(emoji, Emoji):
|
||||
return f'{emoji.name}:{emoji.id}'
|
||||
return f"{emoji.name}:{emoji.id}"
|
||||
if isinstance(emoji, PartialEmoji):
|
||||
return emoji._as_reaction()
|
||||
if isinstance(emoji, str):
|
||||
# Reactions can be in :name:id format, but not <:name:id>.
|
||||
# No existing emojis have <> in them, so this should be okay.
|
||||
return emoji.strip('<>')
|
||||
return emoji.strip("<>")
|
||||
|
||||
raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.')
|
||||
raise InvalidArgument(
|
||||
f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}."
|
||||
)
|
||||
|
||||
|
||||
class Attachment(Hashable):
|
||||
@ -157,28 +173,38 @@ class Attachment(Hashable):
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type')
|
||||
__slots__ = (
|
||||
"id",
|
||||
"size",
|
||||
"height",
|
||||
"width",
|
||||
"filename",
|
||||
"url",
|
||||
"proxy_url",
|
||||
"_http",
|
||||
"content_type",
|
||||
)
|
||||
|
||||
def __init__(self, *, data: AttachmentPayload, state: ConnectionState):
|
||||
self.id: int = int(data['id'])
|
||||
self.size: int = data['size']
|
||||
self.height: Optional[int] = data.get('height')
|
||||
self.width: Optional[int] = data.get('width')
|
||||
self.filename: str = data['filename']
|
||||
self.url: str = data.get('url')
|
||||
self.proxy_url: str = data.get('proxy_url')
|
||||
self.id: int = int(data["id"])
|
||||
self.size: int = data["size"]
|
||||
self.height: Optional[int] = data.get("height")
|
||||
self.width: Optional[int] = data.get("width")
|
||||
self.filename: str = data["filename"]
|
||||
self.url: str = data.get("url")
|
||||
self.proxy_url: str = data.get("proxy_url")
|
||||
self._http = state.http
|
||||
self.content_type: Optional[str] = data.get('content_type')
|
||||
self.content_type: Optional[str] = data.get("content_type")
|
||||
|
||||
def is_spoiler(self) -> bool:
|
||||
""":class:`bool`: Whether this attachment contains a spoiler."""
|
||||
return self.filename.startswith('SPOILER_')
|
||||
return self.filename.startswith("SPOILER_")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>'
|
||||
return f"<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.url or ''
|
||||
return self.url or ""
|
||||
|
||||
async def save(
|
||||
self,
|
||||
@ -227,7 +253,7 @@ class Attachment(Hashable):
|
||||
fp.seek(0)
|
||||
return written
|
||||
else:
|
||||
with open(fp, 'wb') as f:
|
||||
with open(fp, "wb") as f:
|
||||
return f.write(data)
|
||||
|
||||
async def read(self, *, use_cached: bool = False) -> bytes:
|
||||
@ -309,19 +335,19 @@ class Attachment(Hashable):
|
||||
|
||||
def to_dict(self) -> AttachmentPayload:
|
||||
result: AttachmentPayload = {
|
||||
'filename': self.filename,
|
||||
'id': self.id,
|
||||
'proxy_url': self.proxy_url,
|
||||
'size': self.size,
|
||||
'url': self.url,
|
||||
'spoiler': self.is_spoiler(),
|
||||
"filename": self.filename,
|
||||
"id": self.id,
|
||||
"proxy_url": self.proxy_url,
|
||||
"size": self.size,
|
||||
"url": self.url,
|
||||
"spoiler": self.is_spoiler(),
|
||||
}
|
||||
if self.height:
|
||||
result['height'] = self.height
|
||||
result["height"] = self.height
|
||||
if self.width:
|
||||
result['width'] = self.width
|
||||
result["width"] = self.width
|
||||
if self.content_type:
|
||||
result['content_type'] = self.content_type
|
||||
result["content_type"] = self.content_type
|
||||
return result
|
||||
|
||||
|
||||
@ -335,7 +361,7 @@ class DeletedReferencedMessage:
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
|
||||
__slots__ = ('_parent',)
|
||||
__slots__ = ("_parent",)
|
||||
|
||||
def __init__(self, parent: MessageReference):
|
||||
self._parent: MessageReference = parent
|
||||
@ -347,7 +373,7 @@ class DeletedReferencedMessage:
|
||||
def id(self) -> int:
|
||||
""":class:`int`: The message ID of the deleted referenced message."""
|
||||
# the parent's message id won't be None here
|
||||
return self._parent.message_id # type: ignore
|
||||
return self._parent.message_id # type: ignore
|
||||
|
||||
@property
|
||||
def channel_id(self) -> int:
|
||||
@ -394,9 +420,23 @@ class MessageReference:
|
||||
.. versionadded:: 1.6
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state')
|
||||
__slots__ = (
|
||||
"message_id",
|
||||
"channel_id",
|
||||
"guild_id",
|
||||
"fail_if_not_exists",
|
||||
"resolved",
|
||||
"_state",
|
||||
)
|
||||
|
||||
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message_id: int,
|
||||
channel_id: int,
|
||||
guild_id: Optional[int] = None,
|
||||
fail_if_not_exists: bool = True,
|
||||
):
|
||||
self._state: Optional[ConnectionState] = None
|
||||
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
|
||||
self.message_id: Optional[int] = message_id
|
||||
@ -405,18 +445,22 @@ class MessageReference:
|
||||
self.fail_if_not_exists: bool = fail_if_not_exists
|
||||
|
||||
@classmethod
|
||||
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
|
||||
def with_state(
|
||||
cls: Type[MR], state: ConnectionState, data: MessageReferencePayload
|
||||
) -> MR:
|
||||
self = cls.__new__(cls)
|
||||
self.message_id = utils._get_as_snowflake(data, 'message_id')
|
||||
self.channel_id = int(data.pop('channel_id'))
|
||||
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.fail_if_not_exists = data.get('fail_if_not_exists', True)
|
||||
self.message_id = utils._get_as_snowflake(data, "message_id")
|
||||
self.channel_id = int(data.pop("channel_id"))
|
||||
self.guild_id = utils._get_as_snowflake(data, "guild_id")
|
||||
self.fail_if_not_exists = data.get("fail_if_not_exists", True)
|
||||
self._state = state
|
||||
self.resolved = None
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR:
|
||||
def from_message(
|
||||
cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True
|
||||
) -> MR:
|
||||
"""Creates a :class:`MessageReference` from an existing :class:`~discord.Message`.
|
||||
|
||||
.. versionadded:: 1.6
|
||||
@ -439,7 +483,7 @@ class MessageReference:
|
||||
self = cls(
|
||||
message_id=message.id,
|
||||
channel_id=message.channel.id,
|
||||
guild_id=getattr(message.guild, 'id', None),
|
||||
guild_id=getattr(message.guild, "id", None),
|
||||
fail_if_not_exists=fail_if_not_exists,
|
||||
)
|
||||
self._state = message._state
|
||||
@ -456,36 +500,38 @@ class MessageReference:
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
guild_id = self.guild_id if self.guild_id is not None else '@me'
|
||||
return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}'
|
||||
guild_id = self.guild_id if self.guild_id is not None else "@me"
|
||||
return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
|
||||
return f"<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>"
|
||||
|
||||
def to_dict(self) -> MessageReferencePayload:
|
||||
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
|
||||
result['channel_id'] = self.channel_id
|
||||
result: MessageReferencePayload = (
|
||||
{"message_id": self.message_id} if self.message_id is not None else {}
|
||||
)
|
||||
result["channel_id"] = self.channel_id
|
||||
if self.guild_id is not None:
|
||||
result['guild_id'] = self.guild_id
|
||||
result["guild_id"] = self.guild_id
|
||||
if self.fail_if_not_exists is not None:
|
||||
result['fail_if_not_exists'] = self.fail_if_not_exists
|
||||
result["fail_if_not_exists"] = self.fail_if_not_exists
|
||||
return result
|
||||
|
||||
to_message_reference_dict = to_dict
|
||||
|
||||
|
||||
def flatten_handlers(cls):
|
||||
prefix = len('_handle_')
|
||||
prefix = len("_handle_")
|
||||
handlers = [
|
||||
(key[prefix:], value)
|
||||
for key, value in cls.__dict__.items()
|
||||
if key.startswith('_handle_') and key != '_handle_member'
|
||||
if key.startswith("_handle_") and key != "_handle_member"
|
||||
]
|
||||
|
||||
# store _handle_member last
|
||||
handlers.append(('member', cls._handle_member))
|
||||
handlers.append(("member", cls._handle_member))
|
||||
cls._HANDLERS = handlers
|
||||
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')]
|
||||
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")]
|
||||
return cls
|
||||
|
||||
|
||||
@ -615,36 +661,36 @@ class Message(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'_edited_timestamp',
|
||||
'_cs_channel_mentions',
|
||||
'_cs_raw_mentions',
|
||||
'_cs_clean_content',
|
||||
'_cs_raw_channel_mentions',
|
||||
'_cs_raw_role_mentions',
|
||||
'_cs_system_content',
|
||||
'tts',
|
||||
'content',
|
||||
'channel',
|
||||
'webhook_id',
|
||||
'mention_everyone',
|
||||
'embeds',
|
||||
'id',
|
||||
'mentions',
|
||||
'author',
|
||||
'attachments',
|
||||
'nonce',
|
||||
'pinned',
|
||||
'role_mentions',
|
||||
'type',
|
||||
'flags',
|
||||
'reactions',
|
||||
'reference',
|
||||
'application',
|
||||
'activity',
|
||||
'stickers',
|
||||
'components',
|
||||
'guild',
|
||||
"_state",
|
||||
"_edited_timestamp",
|
||||
"_cs_channel_mentions",
|
||||
"_cs_raw_mentions",
|
||||
"_cs_clean_content",
|
||||
"_cs_raw_channel_mentions",
|
||||
"_cs_raw_role_mentions",
|
||||
"_cs_system_content",
|
||||
"tts",
|
||||
"content",
|
||||
"channel",
|
||||
"webhook_id",
|
||||
"mention_everyone",
|
||||
"embeds",
|
||||
"id",
|
||||
"mentions",
|
||||
"author",
|
||||
"attachments",
|
||||
"nonce",
|
||||
"pinned",
|
||||
"role_mentions",
|
||||
"type",
|
||||
"flags",
|
||||
"reactions",
|
||||
"reference",
|
||||
"application",
|
||||
"activity",
|
||||
"stickers",
|
||||
"components",
|
||||
"guild",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -664,39 +710,49 @@ class Message(Hashable):
|
||||
data: MessagePayload,
|
||||
):
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id')
|
||||
self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])]
|
||||
self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']]
|
||||
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']]
|
||||
self.application: Optional[MessageApplicationPayload] = data.get('application')
|
||||
self.activity: Optional[MessageActivityPayload] = data.get('activity')
|
||||
self.id: int = int(data["id"])
|
||||
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id")
|
||||
self.reactions: List[Reaction] = [
|
||||
Reaction(message=self, data=d) for d in data.get("reactions", [])
|
||||
]
|
||||
self.attachments: List[Attachment] = [
|
||||
Attachment(data=a, state=self._state) for a in data["attachments"]
|
||||
]
|
||||
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]]
|
||||
self.application: Optional[MessageApplicationPayload] = data.get("application")
|
||||
self.activity: Optional[MessageActivityPayload] = data.get("activity")
|
||||
self.channel: MessageableChannel = channel
|
||||
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp'])
|
||||
self.type: MessageType = try_enum(MessageType, data['type'])
|
||||
self.pinned: bool = data['pinned']
|
||||
self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0))
|
||||
self.mention_everyone: bool = data['mention_everyone']
|
||||
self.tts: bool = data['tts']
|
||||
self.content: str = data['content']
|
||||
self.nonce: Optional[Union[int, str]] = data.get('nonce')
|
||||
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
|
||||
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
|
||||
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(
|
||||
data["edited_timestamp"]
|
||||
)
|
||||
self.type: MessageType = try_enum(MessageType, data["type"])
|
||||
self.pinned: bool = data["pinned"]
|
||||
self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0))
|
||||
self.mention_everyone: bool = data["mention_everyone"]
|
||||
self.tts: bool = data["tts"]
|
||||
self.content: str = data["content"]
|
||||
self.nonce: Optional[Union[int, str]] = data.get("nonce")
|
||||
self.stickers: List[StickerItem] = [
|
||||
StickerItem(data=d, state=state) for d in data.get("sticker_items", [])
|
||||
]
|
||||
self.components: List[Component] = [
|
||||
_component_factory(d) for d in data.get("components", [])
|
||||
]
|
||||
|
||||
try:
|
||||
# if the channel doesn't have a guild attribute, we handle that
|
||||
self.guild = channel.guild # type: ignore
|
||||
except AttributeError:
|
||||
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
|
||||
self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id"))
|
||||
|
||||
try:
|
||||
ref = data['message_reference']
|
||||
ref = data["message_reference"]
|
||||
except KeyError:
|
||||
self.reference = None
|
||||
else:
|
||||
self.reference = ref = MessageReference.with_state(state, ref)
|
||||
try:
|
||||
resolved = data['referenced_message']
|
||||
resolved = data["referenced_message"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@ -712,18 +768,15 @@ class Message(Hashable):
|
||||
# the channel will be the correct type here
|
||||
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
|
||||
|
||||
for handler in ('author', 'member', 'mentions', 'mention_roles'):
|
||||
for handler in ("author", "member", "mentions", "mention_roles"):
|
||||
try:
|
||||
getattr(self, f'_handle_{handler}')(data[handler])
|
||||
getattr(self, f"_handle_{handler}")(data[handler])
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
def __repr__(self) -> str:
|
||||
name = self.__class__.__name__
|
||||
return (
|
||||
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
|
||||
)
|
||||
|
||||
return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>"
|
||||
|
||||
def __str__(self) -> Optional[str]:
|
||||
return self.content
|
||||
@ -741,7 +794,7 @@ class Message(Hashable):
|
||||
|
||||
def _add_reaction(self, data, emoji, user_id) -> Reaction:
|
||||
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
|
||||
is_me = data['me'] = user_id == self._state.self_id
|
||||
is_me = data["me"] = user_id == self._state.self_id
|
||||
|
||||
if reaction is None:
|
||||
reaction = Reaction(message=self, data=data, emoji=emoji)
|
||||
@ -753,12 +806,14 @@ class Message(Hashable):
|
||||
|
||||
return reaction
|
||||
|
||||
def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction:
|
||||
def _remove_reaction(
|
||||
self, data: ReactionPayload, emoji: EmojiInputType, user_id: int
|
||||
) -> Reaction:
|
||||
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
|
||||
|
||||
if reaction is None:
|
||||
# already removed?
|
||||
raise ValueError('Emoji already removed?')
|
||||
raise ValueError("Emoji already removed?")
|
||||
|
||||
# if reaction isn't in the list, we crash. This means discord
|
||||
# sent bad data, or we stored improperly
|
||||
@ -872,7 +927,7 @@ class Message(Hashable):
|
||||
return
|
||||
|
||||
for mention in filter(None, mentions):
|
||||
id_search = int(mention['id'])
|
||||
id_search = int(mention["id"])
|
||||
member = guild.get_member(id_search)
|
||||
if member is not None:
|
||||
r.append(member)
|
||||
@ -890,11 +945,13 @@ class Message(Hashable):
|
||||
def _handle_components(self, components: List[ComponentPayload]):
|
||||
self.components = [_component_factory(d) for d in components]
|
||||
|
||||
def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None:
|
||||
def _rebind_cached_references(
|
||||
self, new_guild: Guild, new_channel: Union[TextChannel, Thread]
|
||||
) -> None:
|
||||
self.guild = new_guild
|
||||
self.channel = new_channel
|
||||
|
||||
@utils.cached_slot_property('_cs_raw_mentions')
|
||||
@utils.cached_slot_property("_cs_raw_mentions")
|
||||
def raw_mentions(self) -> List[int]:
|
||||
"""List[:class:`int`]: A property that returns an array of user IDs matched with
|
||||
the syntax of ``<@user_id>`` in the message content.
|
||||
@ -902,30 +959,30 @@ class Message(Hashable):
|
||||
This allows you to receive the user IDs of mentioned users
|
||||
even in a private message context.
|
||||
"""
|
||||
return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)]
|
||||
return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)]
|
||||
|
||||
@utils.cached_slot_property('_cs_raw_channel_mentions')
|
||||
@utils.cached_slot_property("_cs_raw_channel_mentions")
|
||||
def raw_channel_mentions(self) -> List[int]:
|
||||
"""List[:class:`int`]: A property that returns an array of channel IDs matched with
|
||||
the syntax of ``<#channel_id>`` in the message content.
|
||||
"""
|
||||
return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)]
|
||||
return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)]
|
||||
|
||||
@utils.cached_slot_property('_cs_raw_role_mentions')
|
||||
@utils.cached_slot_property("_cs_raw_role_mentions")
|
||||
def raw_role_mentions(self) -> List[int]:
|
||||
"""List[:class:`int`]: A property that returns an array of role IDs matched with
|
||||
the syntax of ``<@&role_id>`` in the message content.
|
||||
"""
|
||||
return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)]
|
||||
return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)]
|
||||
|
||||
@utils.cached_slot_property('_cs_channel_mentions')
|
||||
@utils.cached_slot_property("_cs_channel_mentions")
|
||||
def channel_mentions(self) -> List[GuildChannel]:
|
||||
if self.guild is None:
|
||||
return []
|
||||
it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions))
|
||||
return utils._unique(it)
|
||||
|
||||
@utils.cached_slot_property('_cs_clean_content')
|
||||
@utils.cached_slot_property("_cs_clean_content")
|
||||
def clean_content(self) -> str:
|
||||
""":class:`str`: A property that returns the content in a "cleaned up"
|
||||
manner. This basically means that mentions are transformed
|
||||
@ -972,9 +1029,9 @@ class Message(Hashable):
|
||||
# fmt: on
|
||||
|
||||
def repl(obj):
|
||||
return transformations.get(re.escape(obj.group(0)), '')
|
||||
return transformations.get(re.escape(obj.group(0)), "")
|
||||
|
||||
pattern = re.compile('|'.join(transformations.keys()))
|
||||
pattern = re.compile("|".join(transformations.keys()))
|
||||
result = pattern.sub(repl, self.content)
|
||||
return escape_mentions(result)
|
||||
|
||||
@ -991,8 +1048,8 @@ class Message(Hashable):
|
||||
@property
|
||||
def jump_url(self) -> str:
|
||||
""":class:`str`: Returns a URL that allows the client to jump to this message."""
|
||||
guild_id = getattr(self.guild, 'id', '@me')
|
||||
return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}'
|
||||
guild_id = getattr(self.guild, "id", "@me")
|
||||
return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}"
|
||||
|
||||
def is_system(self) -> bool:
|
||||
""":class:`bool`: Whether the message is a system message.
|
||||
@ -1009,7 +1066,7 @@ class Message(Hashable):
|
||||
MessageType.thread_starter_message,
|
||||
)
|
||||
|
||||
@utils.cached_slot_property('_cs_system_content')
|
||||
@utils.cached_slot_property("_cs_system_content")
|
||||
def system_content(self):
|
||||
r""":class:`str`: A property that returns the content that is rendered
|
||||
regardless of the :attr:`Message.type`.
|
||||
@ -1024,24 +1081,26 @@ class Message(Hashable):
|
||||
|
||||
if self.type is MessageType.recipient_add:
|
||||
if self.channel.type is ChannelType.group:
|
||||
return f'{self.author.name} added {self.mentions[0].name} to the group.'
|
||||
return f"{self.author.name} added {self.mentions[0].name} to the group."
|
||||
else:
|
||||
return f'{self.author.name} added {self.mentions[0].name} to the thread.'
|
||||
return (
|
||||
f"{self.author.name} added {self.mentions[0].name} to the thread."
|
||||
)
|
||||
|
||||
if self.type is MessageType.recipient_remove:
|
||||
if self.channel.type is ChannelType.group:
|
||||
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
|
||||
return f"{self.author.name} removed {self.mentions[0].name} from the group."
|
||||
else:
|
||||
return f'{self.author.name} removed {self.mentions[0].name} from the thread.'
|
||||
return f"{self.author.name} removed {self.mentions[0].name} from the thread."
|
||||
|
||||
if self.type is MessageType.channel_name_change:
|
||||
return f'{self.author.name} changed the channel name: **{self.content}**'
|
||||
return f"{self.author.name} changed the channel name: **{self.content}**"
|
||||
|
||||
if self.type is MessageType.channel_icon_change:
|
||||
return f'{self.author.name} changed the channel icon.'
|
||||
return f"{self.author.name} changed the channel icon."
|
||||
|
||||
if self.type is MessageType.pins_add:
|
||||
return f'{self.author.name} pinned a message to this channel.'
|
||||
return f"{self.author.name} pinned a message to this channel."
|
||||
|
||||
if self.type is MessageType.new_member:
|
||||
formats = [
|
||||
@ -1065,64 +1124,66 @@ class Message(Hashable):
|
||||
|
||||
if self.type is MessageType.premium_guild_subscription:
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server!'
|
||||
return f"{self.author.name} just boosted the server!"
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times!'
|
||||
return f"{self.author.name} just boosted the server **{self.content}** times!"
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_1:
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
|
||||
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**"
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**'
|
||||
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**"
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_2:
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
|
||||
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**"
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**'
|
||||
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**"
|
||||
|
||||
if self.type is MessageType.premium_guild_tier_3:
|
||||
if not self.content:
|
||||
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
|
||||
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**"
|
||||
else:
|
||||
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**'
|
||||
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**"
|
||||
|
||||
if self.type is MessageType.channel_follow_add:
|
||||
return f'{self.author.name} has added {self.content} to this channel'
|
||||
return f"{self.author.name} has added {self.content} to this channel"
|
||||
|
||||
if self.type is MessageType.guild_stream:
|
||||
# the author will be a Member
|
||||
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
|
||||
return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_discovery_disqualified:
|
||||
return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.'
|
||||
return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details."
|
||||
|
||||
if self.type is MessageType.guild_discovery_requalified:
|
||||
return 'This server is eligible for Server Discovery again and has been automatically relisted!'
|
||||
return "This server is eligible for Server Discovery again and has been automatically relisted!"
|
||||
|
||||
if self.type is MessageType.guild_discovery_grace_period_initial_warning:
|
||||
return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.'
|
||||
return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery."
|
||||
|
||||
if self.type is MessageType.guild_discovery_grace_period_final_warning:
|
||||
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.'
|
||||
return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery."
|
||||
|
||||
if self.type is MessageType.thread_created:
|
||||
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.'
|
||||
return f"{self.author.name} started a thread: **{self.content}**. See all **threads**."
|
||||
|
||||
if self.type is MessageType.reply:
|
||||
return self.content
|
||||
|
||||
if self.type is MessageType.thread_starter_message:
|
||||
if self.reference is None or self.reference.resolved is None:
|
||||
return 'Sorry, we couldn\'t load the first message in this thread'
|
||||
return "Sorry, we couldn't load the first message in this thread"
|
||||
|
||||
# the resolved message for the reference will be a Message
|
||||
return self.reference.resolved.content # type: ignore
|
||||
|
||||
if self.type is MessageType.guild_invite_reminder:
|
||||
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
|
||||
return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!"
|
||||
|
||||
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
|
||||
async def delete(
|
||||
self, *, delay: Optional[float] = None, silent: bool = False
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Deletes the message.
|
||||
@ -1271,45 +1332,52 @@ class Message(Hashable):
|
||||
payload: Dict[str, Any] = {}
|
||||
if content is not MISSING:
|
||||
if content is not None:
|
||||
payload['content'] = str(content)
|
||||
payload["content"] = str(content)
|
||||
else:
|
||||
payload['content'] = None
|
||||
payload["content"] = None
|
||||
|
||||
if embed is not MISSING and embeds is not MISSING:
|
||||
raise InvalidArgument('cannot pass both embed and embeds parameter to edit()')
|
||||
raise InvalidArgument(
|
||||
"cannot pass both embed and embeds parameter to edit()"
|
||||
)
|
||||
|
||||
if embed is not MISSING:
|
||||
if embed is None:
|
||||
payload['embeds'] = []
|
||||
payload["embeds"] = []
|
||||
else:
|
||||
payload['embeds'] = [embed.to_dict()]
|
||||
payload["embeds"] = [embed.to_dict()]
|
||||
elif embeds is not MISSING:
|
||||
payload['embeds'] = [e.to_dict() for e in embeds]
|
||||
payload["embeds"] = [e.to_dict() for e in embeds]
|
||||
|
||||
if suppress is not MISSING:
|
||||
flags = MessageFlags._from_value(self.flags.value)
|
||||
flags.suppress_embeds = suppress
|
||||
payload['flags'] = flags.value
|
||||
payload["flags"] = flags.value
|
||||
|
||||
if allowed_mentions is MISSING:
|
||||
if self._state.allowed_mentions is not None and self.author.id == self._state.self_id:
|
||||
payload['allowed_mentions'] = self._state.allowed_mentions.to_dict()
|
||||
if (
|
||||
self._state.allowed_mentions is not None
|
||||
and self.author.id == self._state.self_id
|
||||
):
|
||||
payload["allowed_mentions"] = self._state.allowed_mentions.to_dict()
|
||||
else:
|
||||
if allowed_mentions is not None:
|
||||
if self._state.allowed_mentions is not None:
|
||||
payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
|
||||
payload["allowed_mentions"] = self._state.allowed_mentions.merge(
|
||||
allowed_mentions
|
||||
).to_dict()
|
||||
else:
|
||||
payload['allowed_mentions'] = allowed_mentions.to_dict()
|
||||
payload["allowed_mentions"] = allowed_mentions.to_dict()
|
||||
|
||||
if attachments is not MISSING:
|
||||
payload['attachments'] = [a.to_dict() for a in attachments]
|
||||
payload["attachments"] = [a.to_dict() for a in attachments]
|
||||
|
||||
if view is not MISSING:
|
||||
self._state.prevent_view_updates_for(self.id)
|
||||
if view:
|
||||
payload['components'] = view.to_components()
|
||||
payload["components"] = view.to_components()
|
||||
else:
|
||||
payload['components'] = []
|
||||
payload["components"] = []
|
||||
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
|
||||
message = Message(state=self._state, channel=self.channel, data=data)
|
||||
@ -1430,7 +1498,9 @@ class Message(Hashable):
|
||||
emoji = convert_emoji_reaction(emoji)
|
||||
await self._state.http.add_reaction(self.channel.id, self.id, emoji)
|
||||
|
||||
async def remove_reaction(self, emoji: Union[EmojiInputType, Reaction], member: Snowflake) -> None:
|
||||
async def remove_reaction(
|
||||
self, emoji: Union[EmojiInputType, Reaction], member: Snowflake
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Remove a reaction by the member from the message.
|
||||
@ -1467,7 +1537,9 @@ class Message(Hashable):
|
||||
if member.id == self._state.self_id:
|
||||
await self._state.http.remove_own_reaction(self.channel.id, self.id, emoji)
|
||||
else:
|
||||
await self._state.http.remove_reaction(self.channel.id, self.id, emoji, member.id)
|
||||
await self._state.http.remove_reaction(
|
||||
self.channel.id, self.id, emoji, member.id
|
||||
)
|
||||
|
||||
async def clear_reaction(self, emoji: Union[EmojiInputType, Reaction]) -> None:
|
||||
"""|coro|
|
||||
@ -1516,7 +1588,9 @@ class Message(Hashable):
|
||||
"""
|
||||
await self._state.http.clear_reactions(self.channel.id, self.id)
|
||||
|
||||
async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread:
|
||||
async def create_thread(
|
||||
self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING
|
||||
) -> Thread:
|
||||
"""|coro|
|
||||
|
||||
Creates a public thread from this message.
|
||||
@ -1551,14 +1625,17 @@ class Message(Hashable):
|
||||
The created thread.
|
||||
"""
|
||||
if self.guild is None:
|
||||
raise InvalidArgument('This message does not have guild info attached.')
|
||||
raise InvalidArgument("This message does not have guild info attached.")
|
||||
|
||||
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
|
||||
default_auto_archive_duration: ThreadArchiveDuration = getattr(
|
||||
self.channel, "default_auto_archive_duration", 1440
|
||||
)
|
||||
data = await self._state.http.start_thread_with_message(
|
||||
self.channel.id,
|
||||
self.id,
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration or default_auto_archive_duration,
|
||||
auto_archive_duration=auto_archive_duration
|
||||
or default_auto_archive_duration,
|
||||
)
|
||||
return Thread(guild=self.guild, state=self._state, data=data)
|
||||
|
||||
@ -1607,16 +1684,18 @@ class Message(Hashable):
|
||||
The reference to this message.
|
||||
"""
|
||||
|
||||
return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists)
|
||||
return MessageReference.from_message(
|
||||
self, fail_if_not_exists=fail_if_not_exists
|
||||
)
|
||||
|
||||
def to_message_reference_dict(self) -> MessageReferencePayload:
|
||||
data: MessageReferencePayload = {
|
||||
'message_id': self.id,
|
||||
'channel_id': self.channel.id,
|
||||
"message_id": self.id,
|
||||
"channel_id": self.channel.id,
|
||||
}
|
||||
|
||||
if self.guild is not None:
|
||||
data['guild_id'] = self.guild.id
|
||||
data["guild_id"] = self.guild.id
|
||||
|
||||
return data
|
||||
|
||||
@ -1662,7 +1741,7 @@ class PartialMessage(Hashable):
|
||||
The message ID.
|
||||
"""
|
||||
|
||||
__slots__ = ('channel', 'id', '_cs_guild', '_state')
|
||||
__slots__ = ("channel", "id", "_cs_guild", "_state")
|
||||
|
||||
jump_url: str = Message.jump_url # type: ignore
|
||||
delete = Message.delete
|
||||
@ -1686,7 +1765,9 @@ class PartialMessage(Hashable):
|
||||
ChannelType.public_thread,
|
||||
ChannelType.private_thread,
|
||||
):
|
||||
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}')
|
||||
raise TypeError(
|
||||
f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}"
|
||||
)
|
||||
|
||||
self.channel: PartialMessageableChannel = channel
|
||||
self._state: ConnectionState = channel._state
|
||||
@ -1702,17 +1783,17 @@ class PartialMessage(Hashable):
|
||||
pinned = property(None, lambda x, y: None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<PartialMessage id={self.id} channel={self.channel!r}>'
|
||||
return f"<PartialMessage id={self.id} channel={self.channel!r}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: The partial message's creation time in UTC."""
|
||||
return utils.snowflake_time(self.id)
|
||||
|
||||
@utils.cached_slot_property('_cs_guild')
|
||||
@utils.cached_slot_property("_cs_guild")
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable."""
|
||||
return getattr(self.channel, 'guild', None)
|
||||
return getattr(self.channel, "guild", None)
|
||||
|
||||
async def fetch(self) -> Message:
|
||||
"""|coro|
|
||||
@ -1794,58 +1875,62 @@ class PartialMessage(Hashable):
|
||||
"""
|
||||
|
||||
try:
|
||||
content = fields['content']
|
||||
content = fields["content"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if content is not None:
|
||||
fields['content'] = str(content)
|
||||
fields["content"] = str(content)
|
||||
|
||||
try:
|
||||
embed = fields['embed']
|
||||
embed = fields["embed"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if embed is not None:
|
||||
fields['embed'] = embed.to_dict()
|
||||
fields["embed"] = embed.to_dict()
|
||||
|
||||
try:
|
||||
suppress: bool = fields.pop('suppress')
|
||||
suppress: bool = fields.pop("suppress")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
flags = MessageFlags._from_value(0)
|
||||
flags.suppress_embeds = suppress
|
||||
fields['flags'] = flags.value
|
||||
fields["flags"] = flags.value
|
||||
|
||||
delete_after = fields.pop('delete_after', None)
|
||||
delete_after = fields.pop("delete_after", None)
|
||||
|
||||
try:
|
||||
allowed_mentions = fields.pop('allowed_mentions')
|
||||
allowed_mentions = fields.pop("allowed_mentions")
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if allowed_mentions is not None:
|
||||
if self._state.allowed_mentions is not None:
|
||||
allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
|
||||
allowed_mentions = self._state.allowed_mentions.merge(
|
||||
allowed_mentions
|
||||
).to_dict()
|
||||
else:
|
||||
allowed_mentions = allowed_mentions.to_dict()
|
||||
fields['allowed_mentions'] = allowed_mentions
|
||||
fields["allowed_mentions"] = allowed_mentions
|
||||
|
||||
try:
|
||||
view = fields.pop('view')
|
||||
view = fields.pop("view")
|
||||
except KeyError:
|
||||
# To check for the view afterwards
|
||||
view = None
|
||||
else:
|
||||
self._state.prevent_view_updates_for(self.id)
|
||||
if view:
|
||||
fields['components'] = view.to_components()
|
||||
fields["components"] = view.to_components()
|
||||
else:
|
||||
fields['components'] = []
|
||||
fields["components"] = []
|
||||
|
||||
if fields:
|
||||
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)
|
||||
data = await self._state.http.edit_message(
|
||||
self.channel.id, self.id, **fields
|
||||
)
|
||||
|
||||
if delete_after is not None:
|
||||
await self.delete(delay=delete_after)
|
||||
|
@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
__all__ = (
|
||||
'EqualityComparable',
|
||||
'Hashable',
|
||||
"EqualityComparable",
|
||||
"Hashable",
|
||||
)
|
||||
|
||||
|
||||
class EqualityComparable:
|
||||
__slots__ = ()
|
||||
|
||||
@ -40,6 +41,7 @@ class EqualityComparable:
|
||||
return other.id != self.id
|
||||
return True
|
||||
|
||||
|
||||
class Hashable(EqualityComparable):
|
||||
__slots__ = ()
|
||||
|
||||
|
@ -35,11 +35,11 @@ from typing import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime
|
||||
|
||||
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
|
||||
|
||||
__all__ = (
|
||||
'Object',
|
||||
)
|
||||
__all__ = ("Object",)
|
||||
|
||||
|
||||
class Object(Hashable):
|
||||
"""Represents a generic Discord object.
|
||||
@ -83,12 +83,14 @@ class Object(Hashable):
|
||||
try:
|
||||
id = int(id)
|
||||
except ValueError:
|
||||
raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None
|
||||
raise TypeError(
|
||||
f"id parameter must be convertable to int not {id.__class__!r}"
|
||||
) from None
|
||||
else:
|
||||
self.id = id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Object id={self.id!r}>'
|
||||
return f"<Object id={self.id!r}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
|
@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
|
||||
from .errors import DiscordException
|
||||
|
||||
__all__ = (
|
||||
'OggError',
|
||||
'OggPage',
|
||||
'OggStream',
|
||||
"OggError",
|
||||
"OggPage",
|
||||
"OggStream",
|
||||
)
|
||||
|
||||
|
||||
class OggError(DiscordException):
|
||||
"""An exception that is thrown for Ogg stream parsing errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# https://tools.ietf.org/html/rfc3533
|
||||
# https://tools.ietf.org/html/rfc7845
|
||||
|
||||
|
||||
class OggPage:
|
||||
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
|
||||
_header: ClassVar[struct.Struct] = struct.Struct("<xBQIIIB")
|
||||
if TYPE_CHECKING:
|
||||
flag: int
|
||||
gran_pos: int
|
||||
@ -57,14 +61,20 @@ class OggPage:
|
||||
try:
|
||||
header = stream.read(struct.calcsize(self._header.format))
|
||||
|
||||
self.flag, self.gran_pos, self.serial, \
|
||||
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
|
||||
(
|
||||
self.flag,
|
||||
self.gran_pos,
|
||||
self.serial,
|
||||
self.pagenum,
|
||||
self.crc,
|
||||
self.segnum,
|
||||
) = self._header.unpack(header)
|
||||
|
||||
self.segtable: bytes = stream.read(self.segnum)
|
||||
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
|
||||
bodylen = sum(struct.unpack("B" * self.segnum, self.segtable))
|
||||
self.data: bytes = stream.read(bodylen)
|
||||
except Exception:
|
||||
raise OggError('bad data stream') from None
|
||||
raise OggError("bad data stream") from None
|
||||
|
||||
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
|
||||
packetlen = offset = 0
|
||||
@ -76,7 +86,7 @@ class OggPage:
|
||||
partial = True
|
||||
else:
|
||||
packetlen += seg
|
||||
yield self.data[offset:offset+packetlen], True
|
||||
yield self.data[offset : offset + packetlen], True
|
||||
offset += packetlen
|
||||
packetlen = 0
|
||||
partial = False
|
||||
@ -84,18 +94,19 @@ class OggPage:
|
||||
if partial:
|
||||
yield self.data[offset:], False
|
||||
|
||||
|
||||
class OggStream:
|
||||
def __init__(self, stream: IO[bytes]) -> None:
|
||||
self.stream: IO[bytes] = stream
|
||||
|
||||
def _next_page(self) -> Optional[OggPage]:
|
||||
head = self.stream.read(4)
|
||||
if head == b'OggS':
|
||||
if head == b"OggS":
|
||||
return OggPage(self.stream)
|
||||
elif not head:
|
||||
return None
|
||||
else:
|
||||
raise OggError('invalid header magic')
|
||||
raise OggError("invalid header magic")
|
||||
|
||||
def _iter_pages(self) -> Generator[OggPage, None, None]:
|
||||
page = self._next_page()
|
||||
@ -104,10 +115,10 @@ class OggStream:
|
||||
page = self._next_page()
|
||||
|
||||
def iter_packets(self) -> Generator[bytes, None, None]:
|
||||
partial = b''
|
||||
partial = b""
|
||||
for page in self._iter_pages():
|
||||
for data, complete in page.iter_packets():
|
||||
partial += data
|
||||
if complete:
|
||||
yield partial
|
||||
partial = b''
|
||||
partial = b""
|
||||
|
271
discord/opus.py
271
discord/opus.py
@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
|
||||
from typing import (
|
||||
List,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
TypeVar,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
)
|
||||
|
||||
import array
|
||||
import ctypes
|
||||
@ -38,9 +49,10 @@ import sys
|
||||
from .errors import DiscordException, InvalidArgument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
T = TypeVar('T')
|
||||
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
|
||||
SIGNAL_CTL = Literal['auto', 'voice', 'music']
|
||||
T = TypeVar("T")
|
||||
BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"]
|
||||
SIGNAL_CTL = Literal["auto", "voice", "music"]
|
||||
|
||||
|
||||
class BandCtl(TypedDict):
|
||||
narrow: int
|
||||
@ -49,81 +61,89 @@ class BandCtl(TypedDict):
|
||||
superwide: int
|
||||
full: int
|
||||
|
||||
|
||||
class SignalCtl(TypedDict):
|
||||
auto: int
|
||||
voice: int
|
||||
music: int
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Encoder',
|
||||
'OpusError',
|
||||
'OpusNotLoaded',
|
||||
"Encoder",
|
||||
"OpusError",
|
||||
"OpusNotLoaded",
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
c_int_ptr = ctypes.POINTER(ctypes.c_int)
|
||||
c_int_ptr = ctypes.POINTER(ctypes.c_int)
|
||||
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
|
||||
c_float_ptr = ctypes.POINTER(ctypes.c_float)
|
||||
|
||||
_lib = None
|
||||
|
||||
|
||||
class EncoderStruct(ctypes.Structure):
|
||||
pass
|
||||
|
||||
|
||||
class DecoderStruct(ctypes.Structure):
|
||||
pass
|
||||
|
||||
|
||||
EncoderStructPtr = ctypes.POINTER(EncoderStruct)
|
||||
DecoderStructPtr = ctypes.POINTER(DecoderStruct)
|
||||
|
||||
## Some constants from opus_defines.h
|
||||
# Error codes
|
||||
OK = 0
|
||||
OK = 0
|
||||
BAD_ARG = -1
|
||||
|
||||
# Encoder CTLs
|
||||
APPLICATION_AUDIO = 2049
|
||||
APPLICATION_VOIP = 2048
|
||||
APPLICATION_AUDIO = 2049
|
||||
APPLICATION_VOIP = 2048
|
||||
APPLICATION_LOWDELAY = 2051
|
||||
|
||||
CTL_SET_BITRATE = 4002
|
||||
CTL_SET_BANDWIDTH = 4008
|
||||
CTL_SET_FEC = 4012
|
||||
CTL_SET_PLP = 4014
|
||||
CTL_SET_SIGNAL = 4024
|
||||
CTL_SET_BITRATE = 4002
|
||||
CTL_SET_BANDWIDTH = 4008
|
||||
CTL_SET_FEC = 4012
|
||||
CTL_SET_PLP = 4014
|
||||
CTL_SET_SIGNAL = 4024
|
||||
|
||||
# Decoder CTLs
|
||||
CTL_SET_GAIN = 4034
|
||||
CTL_SET_GAIN = 4034
|
||||
CTL_LAST_PACKET_DURATION = 4039
|
||||
|
||||
band_ctl: BandCtl = {
|
||||
'narrow': 1101,
|
||||
'medium': 1102,
|
||||
'wide': 1103,
|
||||
'superwide': 1104,
|
||||
'full': 1105,
|
||||
"narrow": 1101,
|
||||
"medium": 1102,
|
||||
"wide": 1103,
|
||||
"superwide": 1104,
|
||||
"full": 1105,
|
||||
}
|
||||
|
||||
signal_ctl: SignalCtl = {
|
||||
'auto': -1000,
|
||||
'voice': 3001,
|
||||
'music': 3002,
|
||||
"auto": -1000,
|
||||
"voice": 3001,
|
||||
"music": 3002,
|
||||
}
|
||||
|
||||
|
||||
def _err_lt(result: int, func: Callable, args: List) -> int:
|
||||
if result < OK:
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
_log.info("error has happened in %s", func.__name__)
|
||||
raise OpusError(result)
|
||||
return result
|
||||
|
||||
|
||||
def _err_ne(result: T, func: Callable, args: List) -> T:
|
||||
ret = args[-1]._obj
|
||||
if ret.value != OK:
|
||||
_log.info('error has happened in %s', func.__name__)
|
||||
_log.info("error has happened in %s", func.__name__)
|
||||
raise OpusError(ret.value)
|
||||
return result
|
||||
|
||||
|
||||
# A list of exported functions.
|
||||
# The first argument is obviously the name.
|
||||
# The second one are the types of arguments it takes.
|
||||
@ -131,54 +151,90 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
|
||||
# The fourth is the error handler.
|
||||
exported_functions: List[Tuple[Any, ...]] = [
|
||||
# Generic
|
||||
('opus_get_version_string',
|
||||
None, ctypes.c_char_p, None),
|
||||
('opus_strerror',
|
||||
[ctypes.c_int], ctypes.c_char_p, None),
|
||||
|
||||
("opus_get_version_string", None, ctypes.c_char_p, None),
|
||||
("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None),
|
||||
# Encoder functions
|
||||
('opus_encoder_get_size',
|
||||
[ctypes.c_int], ctypes.c_int, None),
|
||||
('opus_encoder_create',
|
||||
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
|
||||
('opus_encode',
|
||||
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
|
||||
('opus_encode_float',
|
||||
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
|
||||
('opus_encoder_ctl',
|
||||
None, ctypes.c_int32, _err_lt),
|
||||
('opus_encoder_destroy',
|
||||
[EncoderStructPtr], None, None),
|
||||
|
||||
("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None),
|
||||
(
|
||||
"opus_encoder_create",
|
||||
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr],
|
||||
EncoderStructPtr,
|
||||
_err_ne,
|
||||
),
|
||||
(
|
||||
"opus_encode",
|
||||
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
|
||||
ctypes.c_int32,
|
||||
_err_lt,
|
||||
),
|
||||
(
|
||||
"opus_encode_float",
|
||||
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
|
||||
ctypes.c_int32,
|
||||
_err_lt,
|
||||
),
|
||||
("opus_encoder_ctl", None, ctypes.c_int32, _err_lt),
|
||||
("opus_encoder_destroy", [EncoderStructPtr], None, None),
|
||||
# Decoder functions
|
||||
('opus_decoder_get_size',
|
||||
[ctypes.c_int], ctypes.c_int, None),
|
||||
('opus_decoder_create',
|
||||
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
|
||||
('opus_decode',
|
||||
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
|
||||
ctypes.c_int, _err_lt),
|
||||
('opus_decode_float',
|
||||
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
|
||||
ctypes.c_int, _err_lt),
|
||||
('opus_decoder_ctl',
|
||||
None, ctypes.c_int32, _err_lt),
|
||||
('opus_decoder_destroy',
|
||||
[DecoderStructPtr], None, None),
|
||||
('opus_decoder_get_nb_samples',
|
||||
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
|
||||
|
||||
("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None),
|
||||
(
|
||||
"opus_decoder_create",
|
||||
[ctypes.c_int, ctypes.c_int, c_int_ptr],
|
||||
DecoderStructPtr,
|
||||
_err_ne,
|
||||
),
|
||||
(
|
||||
"opus_decode",
|
||||
[
|
||||
DecoderStructPtr,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int32,
|
||||
c_int16_ptr,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
],
|
||||
ctypes.c_int,
|
||||
_err_lt,
|
||||
),
|
||||
(
|
||||
"opus_decode_float",
|
||||
[
|
||||
DecoderStructPtr,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int32,
|
||||
c_float_ptr,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
],
|
||||
ctypes.c_int,
|
||||
_err_lt,
|
||||
),
|
||||
("opus_decoder_ctl", None, ctypes.c_int32, _err_lt),
|
||||
("opus_decoder_destroy", [DecoderStructPtr], None, None),
|
||||
(
|
||||
"opus_decoder_get_nb_samples",
|
||||
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32],
|
||||
ctypes.c_int,
|
||||
_err_lt,
|
||||
),
|
||||
# Packet functions
|
||||
('opus_packet_get_bandwidth',
|
||||
[ctypes.c_char_p], ctypes.c_int, _err_lt),
|
||||
('opus_packet_get_nb_channels',
|
||||
[ctypes.c_char_p], ctypes.c_int, _err_lt),
|
||||
('opus_packet_get_nb_frames',
|
||||
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
|
||||
('opus_packet_get_samples_per_frame',
|
||||
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
|
||||
("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt),
|
||||
("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt),
|
||||
(
|
||||
"opus_packet_get_nb_frames",
|
||||
[ctypes.c_char_p, ctypes.c_int],
|
||||
ctypes.c_int,
|
||||
_err_lt,
|
||||
),
|
||||
(
|
||||
"opus_packet_get_samples_per_frame",
|
||||
[ctypes.c_char_p, ctypes.c_int],
|
||||
ctypes.c_int,
|
||||
_err_lt,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def libopus_loader(name: str) -> Any:
|
||||
# create the library...
|
||||
lib = ctypes.cdll.LoadLibrary(name)
|
||||
@ -203,22 +259,24 @@ def libopus_loader(name: str) -> Any:
|
||||
|
||||
return lib
|
||||
|
||||
|
||||
def _load_default() -> bool:
|
||||
global _lib
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
_basedir = os.path.dirname(os.path.abspath(__file__))
|
||||
_bitness = struct.calcsize('P') * 8
|
||||
_target = 'x64' if _bitness > 32 else 'x86'
|
||||
_filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll')
|
||||
_bitness = struct.calcsize("P") * 8
|
||||
_target = "x64" if _bitness > 32 else "x86"
|
||||
_filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll")
|
||||
_lib = libopus_loader(_filename)
|
||||
else:
|
||||
_lib = libopus_loader(ctypes.util.find_library('opus'))
|
||||
_lib = libopus_loader(ctypes.util.find_library("opus"))
|
||||
except Exception:
|
||||
_lib = None
|
||||
|
||||
return _lib is not None
|
||||
|
||||
|
||||
def load_opus(name: str) -> None:
|
||||
"""Loads the libopus shared library for use with voice.
|
||||
|
||||
@ -257,6 +315,7 @@ def load_opus(name: str) -> None:
|
||||
global _lib
|
||||
_lib = libopus_loader(name)
|
||||
|
||||
|
||||
def is_loaded() -> bool:
|
||||
"""Function to check if opus lib is successfully loaded either
|
||||
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
|
||||
@ -271,6 +330,7 @@ def is_loaded() -> bool:
|
||||
global _lib
|
||||
return _lib is not None
|
||||
|
||||
|
||||
class OpusError(DiscordException):
|
||||
"""An exception that is thrown for libopus related errors.
|
||||
|
||||
@ -282,19 +342,22 @@ class OpusError(DiscordException):
|
||||
|
||||
def __init__(self, code: int):
|
||||
self.code: int = code
|
||||
msg = _lib.opus_strerror(self.code).decode('utf-8')
|
||||
msg = _lib.opus_strerror(self.code).decode("utf-8")
|
||||
_log.info('"%s" has happened', msg)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class OpusNotLoaded(DiscordException):
|
||||
"""An exception that is thrown for when libopus is not loaded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _OpusStruct:
|
||||
SAMPLING_RATE = 48000
|
||||
CHANNELS = 2
|
||||
FRAME_LENGTH = 20 # in milliseconds
|
||||
SAMPLE_SIZE = struct.calcsize('h') * CHANNELS
|
||||
SAMPLE_SIZE = struct.calcsize("h") * CHANNELS
|
||||
SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH)
|
||||
|
||||
FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE
|
||||
@ -304,7 +367,8 @@ class _OpusStruct:
|
||||
if not is_loaded() and not _load_default():
|
||||
raise OpusNotLoaded()
|
||||
|
||||
return _lib.opus_get_version_string().decode('utf-8')
|
||||
return _lib.opus_get_version_string().decode("utf-8")
|
||||
|
||||
|
||||
class Encoder(_OpusStruct):
|
||||
def __init__(self, application: int = APPLICATION_AUDIO):
|
||||
@ -315,18 +379,20 @@ class Encoder(_OpusStruct):
|
||||
self.set_bitrate(128)
|
||||
self.set_fec(True)
|
||||
self.set_expected_packet_loss_percent(0.15)
|
||||
self.set_bandwidth('full')
|
||||
self.set_signal_type('auto')
|
||||
self.set_bandwidth("full")
|
||||
self.set_signal_type("auto")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
if hasattr(self, "_state"):
|
||||
_lib.opus_encoder_destroy(self._state)
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self) -> EncoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
|
||||
return _lib.opus_encoder_create(
|
||||
self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)
|
||||
)
|
||||
|
||||
def set_bitrate(self, kbps: int) -> int:
|
||||
kbps = min(512, max(16, int(kbps)))
|
||||
@ -336,14 +402,18 @@ class Encoder(_OpusStruct):
|
||||
|
||||
def set_bandwidth(self, req: BAND_CTL) -> None:
|
||||
if req not in band_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
|
||||
raise KeyError(
|
||||
f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}'
|
||||
)
|
||||
|
||||
k = band_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
|
||||
|
||||
def set_signal_type(self, req: SIGNAL_CTL) -> None:
|
||||
if req not in signal_ctl:
|
||||
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
|
||||
raise KeyError(
|
||||
f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}'
|
||||
)
|
||||
|
||||
k = signal_ctl[req]
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
|
||||
@ -352,18 +422,19 @@ class Encoder(_OpusStruct):
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
|
||||
|
||||
def set_expected_packet_loss_percent(self, percentage: float) -> None:
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
|
||||
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
|
||||
|
||||
def encode(self, pcm: bytes, frame_size: int) -> bytes:
|
||||
max_data_bytes = len(pcm)
|
||||
# bytes can be used to reference pointer
|
||||
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
|
||||
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
|
||||
data = (ctypes.c_char * max_data_bytes)()
|
||||
|
||||
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
|
||||
|
||||
# array can be initialized with bytes but mypy doesn't know
|
||||
return array.array('b', data[:ret]).tobytes() # type: ignore
|
||||
return array.array("b", data[:ret]).tobytes() # type: ignore
|
||||
|
||||
|
||||
class Decoder(_OpusStruct):
|
||||
def __init__(self):
|
||||
@ -372,14 +443,16 @@ class Decoder(_OpusStruct):
|
||||
self._state: DecoderStruct = self._create_state()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, '_state'):
|
||||
if hasattr(self, "_state"):
|
||||
_lib.opus_decoder_destroy(self._state)
|
||||
# This is a destructor, so it's okay to assign None
|
||||
self._state = None # type: ignore
|
||||
self._state = None # type: ignore
|
||||
|
||||
def _create_state(self) -> DecoderStruct:
|
||||
ret = ctypes.c_int()
|
||||
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
|
||||
return _lib.opus_decoder_create(
|
||||
self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def packet_get_nb_frames(data: bytes) -> int:
|
||||
@ -411,12 +484,12 @@ class Decoder(_OpusStruct):
|
||||
def set_gain(self, dB: float) -> int:
|
||||
"""Sets the decoder gain in dB, from -128 to 128."""
|
||||
|
||||
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
|
||||
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
|
||||
return self._set_gain(dB_Q8)
|
||||
|
||||
def set_volume(self, mult: float) -> int:
|
||||
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
|
||||
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
|
||||
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
|
||||
|
||||
def _get_last_packet_duration(self) -> int:
|
||||
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
|
||||
@ -428,14 +501,16 @@ class Decoder(_OpusStruct):
|
||||
@overload
|
||||
def decode(self, data: bytes, *, fec: bool) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
|
||||
...
|
||||
|
||||
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
|
||||
if data is None and fec:
|
||||
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
|
||||
raise InvalidArgument(
|
||||
"Invalid arguments: FEC cannot be used with null data"
|
||||
)
|
||||
|
||||
if data is None:
|
||||
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME
|
||||
@ -449,6 +524,8 @@ class Decoder(_OpusStruct):
|
||||
pcm = (ctypes.c_int16 * (frame_size * channel_count))()
|
||||
pcm_ptr = ctypes.cast(pcm, c_int16_ptr)
|
||||
|
||||
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
|
||||
ret = _lib.opus_decode(
|
||||
self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec
|
||||
)
|
||||
|
||||
return array.array('h', pcm[:ret * channel_count]).tobytes()
|
||||
return array.array("h", pcm[: ret * channel_count]).tobytes()
|
||||
|
@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin
|
||||
from .errors import InvalidArgument
|
||||
from . import utils
|
||||
|
||||
__all__ = (
|
||||
'PartialEmoji',
|
||||
)
|
||||
__all__ = ("PartialEmoji",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .state import ConnectionState
|
||||
from datetime import datetime
|
||||
from .types.message import PartialEmoji as PartialEmojiPayload
|
||||
|
||||
|
||||
class _EmojiTag:
|
||||
__slots__ = ()
|
||||
|
||||
@ -49,7 +48,7 @@ class _EmojiTag:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
PE = TypeVar('PE', bound='PartialEmoji')
|
||||
PE = TypeVar("PE", bound="PartialEmoji")
|
||||
|
||||
|
||||
class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
@ -90,9 +89,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
The ID of the custom emoji, if applicable.
|
||||
"""
|
||||
|
||||
__slots__ = ('animated', 'name', 'id', '_state')
|
||||
__slots__ = ("animated", "name", "id", "_state")
|
||||
|
||||
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
|
||||
_CUSTOM_EMOJI_RE = re.compile(
|
||||
r"<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?"
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
id: Optional[int]
|
||||
@ -104,11 +105,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
self._state: Optional[ConnectionState] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE:
|
||||
def from_dict(
|
||||
cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]
|
||||
) -> PE:
|
||||
return cls(
|
||||
animated=data.get('animated', False),
|
||||
id=utils._get_as_snowflake(data, 'id'),
|
||||
name=data.get('name') or '',
|
||||
animated=data.get("animated", False),
|
||||
id=utils._get_as_snowflake(data, "id"),
|
||||
name=data.get("name") or "",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -139,19 +142,19 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
match = cls._CUSTOM_EMOJI_RE.match(value)
|
||||
if match is not None:
|
||||
groups = match.groupdict()
|
||||
animated = bool(groups['animated'])
|
||||
emoji_id = int(groups['id'])
|
||||
name = groups['name']
|
||||
animated = bool(groups["animated"])
|
||||
emoji_id = int(groups["id"])
|
||||
name = groups["name"]
|
||||
return cls(name=name, animated=animated, id=emoji_id)
|
||||
|
||||
return cls(name=value, id=None, animated=False)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
o: Dict[str, Any] = {'name': self.name}
|
||||
o: Dict[str, Any] = {"name": self.name}
|
||||
if self.id:
|
||||
o['id'] = self.id
|
||||
o["id"] = self.id
|
||||
if self.animated:
|
||||
o['animated'] = self.animated
|
||||
o["animated"] = self.animated
|
||||
return o
|
||||
|
||||
def _to_partial(self) -> PartialEmoji:
|
||||
@ -159,7 +162,12 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
|
||||
@classmethod
|
||||
def with_state(
|
||||
cls: Type[PE], state: ConnectionState, *, name: str, animated: bool = False, id: Optional[int] = None
|
||||
cls: Type[PE],
|
||||
state: ConnectionState,
|
||||
*,
|
||||
name: str,
|
||||
animated: bool = False,
|
||||
id: Optional[int] = None,
|
||||
) -> PE:
|
||||
self = cls(name=name, animated=animated, id=id)
|
||||
self._state = state
|
||||
@ -169,11 +177,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
if self.id is None:
|
||||
return self.name
|
||||
if self.animated:
|
||||
return f'<a:{self.name}:{self.id}>'
|
||||
return f'<:{self.name}:{self.id}>'
|
||||
return f"<a:{self.name}:{self.id}>"
|
||||
return f"<:{self.name}:{self.id}>"
|
||||
|
||||
def __repr__(self):
|
||||
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
|
||||
return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if self.is_unicode_emoji():
|
||||
@ -200,7 +208,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
def _as_reaction(self) -> str:
|
||||
if self.id is None:
|
||||
return self.name
|
||||
return f'{self.name}:{self.id}'
|
||||
return f"{self.name}:{self.id}"
|
||||
|
||||
@property
|
||||
def created_at(self) -> Optional[datetime]:
|
||||
@ -220,13 +228,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
|
||||
If this isn't a custom emoji then an empty string is returned
|
||||
"""
|
||||
if self.is_unicode_emoji():
|
||||
return ''
|
||||
return ""
|
||||
|
||||
fmt = 'gif' if self.animated else 'png'
|
||||
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
|
||||
fmt = "gif" if self.animated else "png"
|
||||
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
if self.is_unicode_emoji():
|
||||
raise InvalidArgument('PartialEmoji is not a custom emoji')
|
||||
raise InvalidArgument("PartialEmoji is not a custom emoji")
|
||||
|
||||
return await super().read()
|
||||
|
@ -24,12 +24,24 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, Tuple, Type, TypeVar, Optional
|
||||
from typing import (
|
||||
Callable,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Optional,
|
||||
)
|
||||
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
|
||||
|
||||
__all__ = (
|
||||
'Permissions',
|
||||
'PermissionOverwrite',
|
||||
"Permissions",
|
||||
"PermissionOverwrite",
|
||||
)
|
||||
|
||||
# A permission alias works like a regular flag but is marked
|
||||
@ -38,7 +50,9 @@ class permission_alias(alias_flag_value):
|
||||
alias: str
|
||||
|
||||
|
||||
def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permission_alias]:
|
||||
def make_permission_alias(
|
||||
alias: str,
|
||||
) -> Callable[[Callable[[Any], int]], permission_alias]:
|
||||
def decorator(func: Callable[[Any], int]) -> permission_alias:
|
||||
ret = permission_alias(func)
|
||||
ret.alias = alias
|
||||
@ -46,7 +60,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis
|
||||
|
||||
return decorator
|
||||
|
||||
P = TypeVar('P', bound='Permissions')
|
||||
|
||||
P = TypeVar("P", bound="Permissions")
|
||||
|
||||
|
||||
@fill_with_flags()
|
||||
class Permissions(BaseFlags):
|
||||
@ -101,12 +117,14 @@ class Permissions(BaseFlags):
|
||||
|
||||
def __init__(self, permissions: int = 0, **kwargs: bool):
|
||||
if not isinstance(permissions, int):
|
||||
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.')
|
||||
raise TypeError(
|
||||
f"Expected int parameter, received {permissions.__class__.__name__} instead."
|
||||
)
|
||||
|
||||
self.value = permissions
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_FLAGS:
|
||||
raise TypeError(f'{key!r} is not a valid permission name.')
|
||||
raise TypeError(f"{key!r} is not a valid permission name.")
|
||||
setattr(self, key, value)
|
||||
|
||||
def is_subset(self, other: Permissions) -> bool:
|
||||
@ -114,14 +132,18 @@ class Permissions(BaseFlags):
|
||||
if isinstance(other, Permissions):
|
||||
return (self.value & other.value) == self.value
|
||||
else:
|
||||
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
|
||||
raise TypeError(
|
||||
f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}"
|
||||
)
|
||||
|
||||
def is_superset(self, other: Permissions) -> bool:
|
||||
"""Returns ``True`` if self has the same or more permissions as other."""
|
||||
if isinstance(other, Permissions):
|
||||
return (self.value | other.value) == self.value
|
||||
else:
|
||||
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
|
||||
raise TypeError(
|
||||
f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}"
|
||||
)
|
||||
|
||||
def is_strict_subset(self, other: Permissions) -> bool:
|
||||
"""Returns ``True`` if the permissions on other are a strict subset of those on self."""
|
||||
@ -336,7 +358,7 @@ class Permissions(BaseFlags):
|
||||
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
|
||||
return 1 << 10
|
||||
|
||||
@make_permission_alias('read_messages')
|
||||
@make_permission_alias("read_messages")
|
||||
def view_channel(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`read_messages`.
|
||||
|
||||
@ -389,7 +411,7 @@ class Permissions(BaseFlags):
|
||||
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
|
||||
return 1 << 18
|
||||
|
||||
@make_permission_alias('external_emojis')
|
||||
@make_permission_alias("external_emojis")
|
||||
def use_external_emojis(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`external_emojis`.
|
||||
|
||||
@ -453,7 +475,7 @@ class Permissions(BaseFlags):
|
||||
"""
|
||||
return 1 << 28
|
||||
|
||||
@make_permission_alias('manage_roles')
|
||||
@make_permission_alias("manage_roles")
|
||||
def manage_permissions(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`manage_roles`.
|
||||
|
||||
@ -471,7 +493,7 @@ class Permissions(BaseFlags):
|
||||
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
|
||||
return 1 << 30
|
||||
|
||||
@make_permission_alias('manage_emojis')
|
||||
@make_permission_alias("manage_emojis")
|
||||
def manage_emojis_and_stickers(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`manage_emojis`.
|
||||
|
||||
@ -535,7 +557,7 @@ class Permissions(BaseFlags):
|
||||
"""
|
||||
return 1 << 37
|
||||
|
||||
@make_permission_alias('external_stickers')
|
||||
@make_permission_alias("external_stickers")
|
||||
def use_external_stickers(self) -> int:
|
||||
""":class:`bool`: An alias for :attr:`external_stickers`.
|
||||
|
||||
@ -551,7 +573,9 @@ class Permissions(BaseFlags):
|
||||
"""
|
||||
return 1 << 38
|
||||
|
||||
PO = TypeVar('PO', bound='PermissionOverwrite')
|
||||
|
||||
PO = TypeVar("PO", bound="PermissionOverwrite")
|
||||
|
||||
|
||||
def _augment_from_permissions(cls):
|
||||
cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
|
||||
@ -614,7 +638,7 @@ class PermissionOverwrite:
|
||||
Set the value of permissions by their name.
|
||||
"""
|
||||
|
||||
__slots__ = ('_values',)
|
||||
__slots__ = ("_values",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
VALID_NAMES: ClassVar[Set[str]]
|
||||
@ -670,7 +694,7 @@ class PermissionOverwrite:
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key not in self.VALID_NAMES:
|
||||
raise ValueError(f'no permission called {key}.')
|
||||
raise ValueError(f"no permission called {key}.")
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
@ -679,7 +703,9 @@ class PermissionOverwrite:
|
||||
|
||||
def _set(self, key: str, value: Optional[bool]) -> None:
|
||||
if value not in (True, None, False):
|
||||
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}')
|
||||
raise TypeError(
|
||||
f"Expected bool or NoneType, received {value.__class__.__name__}"
|
||||
)
|
||||
|
||||
if value is None:
|
||||
self._values.pop(key, None)
|
||||
|
@ -36,7 +36,18 @@ import sys
|
||||
import re
|
||||
import io
|
||||
|
||||
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
IO,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .errors import ClientException
|
||||
from .opus import Encoder as OpusEncoder
|
||||
@ -47,27 +58,28 @@ if TYPE_CHECKING:
|
||||
from .voice_client import VoiceClient
|
||||
|
||||
|
||||
AT = TypeVar('AT', bound='AudioSource')
|
||||
FT = TypeVar('FT', bound='FFmpegOpusAudio')
|
||||
AT = TypeVar("AT", bound="AudioSource")
|
||||
FT = TypeVar("FT", bound="FFmpegOpusAudio")
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
'AudioSource',
|
||||
'PCMAudio',
|
||||
'FFmpegAudio',
|
||||
'FFmpegPCMAudio',
|
||||
'FFmpegOpusAudio',
|
||||
'PCMVolumeTransformer',
|
||||
"AudioSource",
|
||||
"PCMAudio",
|
||||
"FFmpegAudio",
|
||||
"FFmpegPCMAudio",
|
||||
"FFmpegOpusAudio",
|
||||
"PCMVolumeTransformer",
|
||||
)
|
||||
|
||||
CREATE_NO_WINDOW: int
|
||||
|
||||
if sys.platform != 'win32':
|
||||
if sys.platform != "win32":
|
||||
CREATE_NO_WINDOW = 0
|
||||
else:
|
||||
CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
|
||||
class AudioSource:
|
||||
"""Represents an audio stream.
|
||||
|
||||
@ -114,6 +126,7 @@ class AudioSource:
|
||||
def __del__(self) -> None:
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class PCMAudio(AudioSource):
|
||||
"""Represents raw 16-bit 48KHz stereo PCM audio source.
|
||||
|
||||
@ -122,15 +135,17 @@ class PCMAudio(AudioSource):
|
||||
stream: :term:`py:file object`
|
||||
A file-like object that reads byte data representing raw PCM.
|
||||
"""
|
||||
|
||||
def __init__(self, stream: io.BufferedIOBase) -> None:
|
||||
self.stream: io.BufferedIOBase = stream
|
||||
|
||||
def read(self) -> bytes:
|
||||
ret = self.stream.read(OpusEncoder.FRAME_SIZE)
|
||||
if len(ret) != OpusEncoder.FRAME_SIZE:
|
||||
return b''
|
||||
return b""
|
||||
return ret
|
||||
|
||||
|
||||
class FFmpegAudio(AudioSource):
|
||||
"""Represents an FFmpeg (or AVConv) based AudioSource.
|
||||
|
||||
@ -140,13 +155,22 @@ class FFmpegAudio(AudioSource):
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
|
||||
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
|
||||
def __init__(
|
||||
self,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
executable: str = "ffmpeg",
|
||||
args: Any,
|
||||
**subprocess_kwargs: Any,
|
||||
):
|
||||
piping = subprocess_kwargs.get("stdin") == subprocess.PIPE
|
||||
if piping and isinstance(source, str):
|
||||
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
|
||||
raise TypeError(
|
||||
"parameter conflict: 'source' parameter cannot be a string when piping to stdin"
|
||||
)
|
||||
|
||||
args = [executable, *args]
|
||||
kwargs = {'stdout': subprocess.PIPE}
|
||||
kwargs = {"stdout": subprocess.PIPE}
|
||||
kwargs.update(subprocess_kwargs)
|
||||
|
||||
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
|
||||
@ -155,20 +179,26 @@ class FFmpegAudio(AudioSource):
|
||||
self._pipe_thread: Optional[threading.Thread] = None
|
||||
|
||||
if piping:
|
||||
n = f'popen-stdin-writer:{id(self):#x}'
|
||||
n = f"popen-stdin-writer:{id(self):#x}"
|
||||
self._stdin = self._process.stdin
|
||||
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
|
||||
self._pipe_thread = threading.Thread(
|
||||
target=self._pipe_writer, args=(source,), daemon=True, name=n
|
||||
)
|
||||
self._pipe_thread.start()
|
||||
|
||||
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
|
||||
process = None
|
||||
try:
|
||||
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
|
||||
process = subprocess.Popen(
|
||||
args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs
|
||||
)
|
||||
except FileNotFoundError:
|
||||
executable = args.partition(' ')[0] if isinstance(args, str) else args[0]
|
||||
raise ClientException(executable + ' was not found.') from None
|
||||
executable = args.partition(" ")[0] if isinstance(args, str) else args[0]
|
||||
raise ClientException(executable + " was not found.") from None
|
||||
except subprocess.SubprocessError as exc:
|
||||
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc
|
||||
raise ClientException(
|
||||
f"Popen failed: {exc.__class__.__name__}: {exc}"
|
||||
) from exc
|
||||
else:
|
||||
return process
|
||||
|
||||
@ -177,20 +207,32 @@ class FFmpegAudio(AudioSource):
|
||||
if proc is MISSING:
|
||||
return
|
||||
|
||||
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
|
||||
_log.info("Preparing to terminate ffmpeg process %s.", proc.pid)
|
||||
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
|
||||
_log.exception(
|
||||
"Ignoring error attempting to kill ffmpeg process %s", proc.pid
|
||||
)
|
||||
|
||||
if proc.poll() is None:
|
||||
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
|
||||
_log.info(
|
||||
"ffmpeg process %s has not terminated. Waiting to terminate...",
|
||||
proc.pid,
|
||||
)
|
||||
proc.communicate()
|
||||
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
|
||||
_log.info(
|
||||
"ffmpeg process %s should have terminated with a return code of %s.",
|
||||
proc.pid,
|
||||
proc.returncode,
|
||||
)
|
||||
else:
|
||||
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
|
||||
|
||||
_log.info(
|
||||
"ffmpeg process %s successfully terminated with return code of %s.",
|
||||
proc.pid,
|
||||
proc.returncode,
|
||||
)
|
||||
|
||||
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
|
||||
while self._process:
|
||||
@ -202,7 +244,11 @@ class FFmpegAudio(AudioSource):
|
||||
try:
|
||||
self._stdin.write(data)
|
||||
except Exception:
|
||||
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
|
||||
_log.debug(
|
||||
"Write error for %s, this is probably not a problem",
|
||||
self,
|
||||
exc_info=True,
|
||||
)
|
||||
# at this point the source data is either exhausted or the process is fubar
|
||||
self._process.terminate()
|
||||
return
|
||||
@ -211,6 +257,7 @@ class FFmpegAudio(AudioSource):
|
||||
self._kill_process()
|
||||
self._process = self._stdout = self._stdin = MISSING
|
||||
|
||||
|
||||
class FFmpegPCMAudio(FFmpegAudio):
|
||||
"""An audio source from FFmpeg (or AVConv).
|
||||
|
||||
@ -250,38 +297,42 @@ class FFmpegPCMAudio(FFmpegAudio):
|
||||
self,
|
||||
source: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
executable: str = 'ffmpeg',
|
||||
executable: str = "ffmpeg",
|
||||
pipe: bool = False,
|
||||
stderr: Optional[IO[str]] = None,
|
||||
before_options: Optional[str] = None,
|
||||
options: Optional[str] = None
|
||||
options: Optional[str] = None,
|
||||
) -> None:
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {
|
||||
"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL,
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
|
||||
args.append('-i')
|
||||
args.append('-' if pipe else source)
|
||||
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
|
||||
args.append("-i")
|
||||
args.append("-" if pipe else source)
|
||||
args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"))
|
||||
|
||||
if isinstance(options, str):
|
||||
args.extend(shlex.split(options))
|
||||
|
||||
args.append('pipe:1')
|
||||
args.append("pipe:1")
|
||||
|
||||
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
|
||||
|
||||
def read(self) -> bytes:
|
||||
ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
|
||||
if len(ret) != OpusEncoder.FRAME_SIZE:
|
||||
return b''
|
||||
return b""
|
||||
return ret
|
||||
|
||||
def is_opus(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FFmpegOpusAudio(FFmpegAudio):
|
||||
"""An audio source from FFmpeg (or AVConv).
|
||||
|
||||
@ -349,7 +400,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
*,
|
||||
bitrate: int = 128,
|
||||
codec: Optional[str] = None,
|
||||
executable: str = 'ffmpeg',
|
||||
executable: str = "ffmpeg",
|
||||
pipe=False,
|
||||
stderr=None,
|
||||
before_options=None,
|
||||
@ -357,28 +408,42 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
) -> None:
|
||||
|
||||
args = []
|
||||
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
|
||||
subprocess_kwargs = {
|
||||
"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL,
|
||||
"stderr": stderr,
|
||||
}
|
||||
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
|
||||
args.append('-i')
|
||||
args.append('-' if pipe else source)
|
||||
args.append("-i")
|
||||
args.append("-" if pipe else source)
|
||||
|
||||
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
|
||||
codec = "copy" if codec in ("opus", "libopus") else "libopus"
|
||||
|
||||
args.extend(('-map_metadata', '-1',
|
||||
'-f', 'opus',
|
||||
'-c:a', codec,
|
||||
'-ar', '48000',
|
||||
'-ac', '2',
|
||||
'-b:a', f'{bitrate}k',
|
||||
'-loglevel', 'warning'))
|
||||
args.extend(
|
||||
(
|
||||
"-map_metadata",
|
||||
"-1",
|
||||
"-f",
|
||||
"opus",
|
||||
"-c:a",
|
||||
codec,
|
||||
"-ar",
|
||||
"48000",
|
||||
"-ac",
|
||||
"2",
|
||||
"-b:a",
|
||||
f"{bitrate}k",
|
||||
"-loglevel",
|
||||
"warning",
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(options, str):
|
||||
args.extend(shlex.split(options))
|
||||
|
||||
args.append('pipe:1')
|
||||
args.append("pipe:1")
|
||||
|
||||
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
|
||||
self._packet_iter = OggStream(self._stdout).iter_packets()
|
||||
@ -388,7 +453,9 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
cls: Type[FT],
|
||||
source: str,
|
||||
*,
|
||||
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
|
||||
method: Optional[
|
||||
Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> FT:
|
||||
"""|coro|
|
||||
@ -446,7 +513,7 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
An instance of this class.
|
||||
"""
|
||||
|
||||
executable = kwargs.get('executable')
|
||||
executable = kwargs.get("executable")
|
||||
codec, bitrate = await cls.probe(source, method=method, executable=executable)
|
||||
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore
|
||||
|
||||
@ -455,7 +522,9 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
cls,
|
||||
source: str,
|
||||
*,
|
||||
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
|
||||
method: Optional[
|
||||
Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]
|
||||
] = None,
|
||||
executable: Optional[str] = None,
|
||||
) -> Tuple[Optional[str], Optional[int]]:
|
||||
"""|coro|
|
||||
@ -484,12 +553,12 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
A 2-tuple with the codec and bitrate of the input source.
|
||||
"""
|
||||
|
||||
method = method or 'native'
|
||||
executable = executable or 'ffmpeg'
|
||||
method = method or "native"
|
||||
executable = executable or "ffmpeg"
|
||||
probefunc = fallback = None
|
||||
|
||||
if isinstance(method, str):
|
||||
probefunc = getattr(cls, '_probe_codec_' + method, None)
|
||||
probefunc = getattr(cls, "_probe_codec_" + method, None)
|
||||
if probefunc is None:
|
||||
raise AttributeError(f"Invalid probe method {method!r}")
|
||||
|
||||
@ -500,8 +569,10 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
probefunc = method
|
||||
fallback = cls._probe_codec_fallback
|
||||
else:
|
||||
raise TypeError("Expected str or callable for parameter 'probe', " \
|
||||
f"not '{method.__class__.__name__}'")
|
||||
raise TypeError(
|
||||
"Expected str or callable for parameter 'probe', "
|
||||
f"not '{method.__class__.__name__}'"
|
||||
)
|
||||
|
||||
codec = bitrate = None
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -512,7 +583,9 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
_log.exception("Probe '%s' using '%s' failed", method, executable)
|
||||
return # type: ignore
|
||||
|
||||
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
|
||||
_log.exception(
|
||||
"Probe '%s' using '%s' failed, trying fallback", method, executable
|
||||
)
|
||||
try:
|
||||
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
|
||||
except Exception:
|
||||
@ -525,28 +598,51 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
return codec, bitrate
|
||||
|
||||
@staticmethod
|
||||
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
|
||||
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
|
||||
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
|
||||
def _probe_codec_native(
|
||||
source, executable: str = "ffmpeg"
|
||||
) -> Tuple[Optional[str], Optional[int]]:
|
||||
exe = (
|
||||
executable[:2] + "probe"
|
||||
if executable in ("ffmpeg", "avconv")
|
||||
else executable
|
||||
)
|
||||
args = [
|
||||
exe,
|
||||
"-v",
|
||||
"quiet",
|
||||
"-print_format",
|
||||
"json",
|
||||
"-show_streams",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
source,
|
||||
]
|
||||
output = subprocess.check_output(args, timeout=20)
|
||||
codec = bitrate = None
|
||||
|
||||
if output:
|
||||
data = json.loads(output)
|
||||
streamdata = data['streams'][0]
|
||||
streamdata = data["streams"][0]
|
||||
|
||||
codec = streamdata.get('codec_name')
|
||||
bitrate = int(streamdata.get('bit_rate', 0))
|
||||
bitrate = max(round(bitrate/1000), 512)
|
||||
codec = streamdata.get("codec_name")
|
||||
bitrate = int(streamdata.get("bit_rate", 0))
|
||||
bitrate = max(round(bitrate / 1000), 512)
|
||||
|
||||
return codec, bitrate
|
||||
|
||||
@staticmethod
|
||||
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
|
||||
args = [executable, '-hide_banner', '-i', source]
|
||||
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
def _probe_codec_fallback(
|
||||
source, executable: str = "ffmpeg"
|
||||
) -> Tuple[Optional[str], Optional[int]]:
|
||||
args = [executable, "-hide_banner", "-i", source]
|
||||
proc = subprocess.Popen(
|
||||
args,
|
||||
creationflags=CREATE_NO_WINDOW,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
out, _ = proc.communicate(timeout=20)
|
||||
output = out.decode('utf8')
|
||||
output = out.decode("utf8")
|
||||
codec = bitrate = None
|
||||
|
||||
codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output)
|
||||
@ -560,11 +656,12 @@ class FFmpegOpusAudio(FFmpegAudio):
|
||||
return codec, bitrate
|
||||
|
||||
def read(self) -> bytes:
|
||||
return next(self._packet_iter, b'')
|
||||
return next(self._packet_iter, b"")
|
||||
|
||||
def is_opus(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class PCMVolumeTransformer(AudioSource, Generic[AT]):
|
||||
"""Transforms a previous :class:`AudioSource` to have volume controls.
|
||||
|
||||
@ -589,10 +686,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
|
||||
|
||||
def __init__(self, original: AT, volume: float = 1.0):
|
||||
if not isinstance(original, AudioSource):
|
||||
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.')
|
||||
raise TypeError(f"expected AudioSource not {original.__class__.__name__}.")
|
||||
|
||||
if original.is_opus():
|
||||
raise ClientException('AudioSource must not be Opus encoded.')
|
||||
raise ClientException("AudioSource must not be Opus encoded.")
|
||||
|
||||
self.original: AT = original
|
||||
self.volume = volume
|
||||
@ -613,6 +710,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
|
||||
ret = self.original.read()
|
||||
return audioop.mul(ret, 2, min(self._volume, 2.0))
|
||||
|
||||
|
||||
class AudioPlayer(threading.Thread):
|
||||
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
|
||||
|
||||
@ -625,7 +723,7 @@ class AudioPlayer(threading.Thread):
|
||||
|
||||
self._end: threading.Event = threading.Event()
|
||||
self._resumed: threading.Event = threading.Event()
|
||||
self._resumed.set() # we are not paused
|
||||
self._resumed.set() # we are not paused
|
||||
self._current_error: Optional[Exception] = None
|
||||
self._connected: threading.Event = client._connected
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
@ -685,11 +783,11 @@ class AudioPlayer(threading.Thread):
|
||||
try:
|
||||
self.after(error)
|
||||
except Exception as exc:
|
||||
_log.exception('Calling the after function failed.')
|
||||
_log.exception("Calling the after function failed.")
|
||||
exc.__context__ = error
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
elif error:
|
||||
msg = f'Exception in voice thread {self.name}'
|
||||
msg = f"Exception in voice thread {self.name}"
|
||||
_log.exception(msg, exc_info=error)
|
||||
print(msg, file=sys.stderr)
|
||||
traceback.print_exception(type(error), error, error.__traceback__)
|
||||
@ -725,6 +823,8 @@ class AudioPlayer(threading.Thread):
|
||||
|
||||
def _speak(self, speaking: bool) -> None:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.client.ws.speak(speaking), self.client.loop
|
||||
)
|
||||
except Exception as e:
|
||||
_log.info("Speaking call in player failed: %s", e)
|
||||
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||
MessageUpdateEvent,
|
||||
ReactionClearEvent,
|
||||
ReactionClearEmojiEvent,
|
||||
IntegrationDeleteEvent
|
||||
IntegrationDeleteEvent,
|
||||
)
|
||||
from .message import Message
|
||||
from .partial_emoji import PartialEmoji
|
||||
@ -42,20 +42,20 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'RawMessageDeleteEvent',
|
||||
'RawBulkMessageDeleteEvent',
|
||||
'RawMessageUpdateEvent',
|
||||
'RawReactionActionEvent',
|
||||
'RawReactionClearEvent',
|
||||
'RawReactionClearEmojiEvent',
|
||||
'RawIntegrationDeleteEvent',
|
||||
"RawMessageDeleteEvent",
|
||||
"RawBulkMessageDeleteEvent",
|
||||
"RawMessageUpdateEvent",
|
||||
"RawReactionActionEvent",
|
||||
"RawReactionClearEvent",
|
||||
"RawReactionClearEmojiEvent",
|
||||
"RawIntegrationDeleteEvent",
|
||||
)
|
||||
|
||||
|
||||
class _RawReprMixin:
|
||||
def __repr__(self) -> str:
|
||||
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
|
||||
return f'<{self.__class__.__name__} {value}>'
|
||||
value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__)
|
||||
return f"<{self.__class__.__name__} {value}>"
|
||||
|
||||
|
||||
class RawMessageDeleteEvent(_RawReprMixin):
|
||||
@ -73,14 +73,14 @@ class RawMessageDeleteEvent(_RawReprMixin):
|
||||
The cached message, if found in the internal message cache.
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
|
||||
__slots__ = ("message_id", "channel_id", "guild_id", "cached_message")
|
||||
|
||||
def __init__(self, data: MessageDeleteEvent) -> None:
|
||||
self.message_id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.message_id: int = int(data["id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
self.cached_message: Optional[Message] = None
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -100,15 +100,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
|
||||
The cached messages, if found in the internal message cache.
|
||||
"""
|
||||
|
||||
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
|
||||
__slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages")
|
||||
|
||||
def __init__(self, data: BulkMessageDeleteEvent) -> None:
|
||||
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])}
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])}
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
self.cached_messages: List[Message] = []
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -136,16 +136,16 @@ class RawMessageUpdateEvent(_RawReprMixin):
|
||||
it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
|
||||
__slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message")
|
||||
|
||||
def __init__(self, data: MessageUpdateEvent) -> None:
|
||||
self.message_id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.message_id: int = int(data["id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
self.data: MessageUpdateEvent = data
|
||||
self.cached_message: Optional[Message] = None
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -179,19 +179,28 @@ class RawReactionActionEvent(_RawReprMixin):
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
|
||||
'event_type', 'member')
|
||||
__slots__ = (
|
||||
"message_id",
|
||||
"user_id",
|
||||
"channel_id",
|
||||
"guild_id",
|
||||
"emoji",
|
||||
"event_type",
|
||||
"member",
|
||||
)
|
||||
|
||||
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.user_id: int = int(data['user_id'])
|
||||
def __init__(
|
||||
self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str
|
||||
) -> None:
|
||||
self.message_id: int = int(data["message_id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
self.user_id: int = int(data["user_id"])
|
||||
self.emoji: PartialEmoji = emoji
|
||||
self.event_type: str = event_type
|
||||
self.member: Optional[Member] = None
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -209,14 +218,14 @@ class RawReactionClearEvent(_RawReprMixin):
|
||||
The guild ID where the reactions got cleared.
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id')
|
||||
__slots__ = ("message_id", "channel_id", "guild_id")
|
||||
|
||||
def __init__(self, data: ReactionClearEvent) -> None:
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.message_id: int = int(data["message_id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -238,15 +247,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
|
||||
The custom or unicode emoji being removed.
|
||||
"""
|
||||
|
||||
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
|
||||
__slots__ = ("message_id", "channel_id", "guild_id", "emoji")
|
||||
|
||||
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
|
||||
self.emoji: PartialEmoji = emoji
|
||||
self.message_id: int = int(data['message_id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.message_id: int = int(data["message_id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
|
||||
try:
|
||||
self.guild_id: Optional[int] = int(data['guild_id'])
|
||||
self.guild_id: Optional[int] = int(data["guild_id"])
|
||||
except KeyError:
|
||||
self.guild_id: Optional[int] = None
|
||||
|
||||
@ -266,13 +275,13 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
|
||||
The guild ID where the integration got deleted.
|
||||
"""
|
||||
|
||||
__slots__ = ('integration_id', 'application_id', 'guild_id')
|
||||
__slots__ = ("integration_id", "application_id", "guild_id")
|
||||
|
||||
def __init__(self, data: IntegrationDeleteEvent) -> None:
|
||||
self.integration_id: int = int(data['id'])
|
||||
self.guild_id: int = int(data['guild_id'])
|
||||
self.integration_id: int = int(data["id"])
|
||||
self.guild_id: int = int(data["guild_id"])
|
||||
|
||||
try:
|
||||
self.application_id: Optional[int] = int(data['application_id'])
|
||||
self.application_id: Optional[int] = int(data["application_id"])
|
||||
except KeyError:
|
||||
self.application_id: Optional[int] = None
|
||||
|
@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional
|
||||
|
||||
from .iterators import ReactionIterator
|
||||
|
||||
__all__ = (
|
||||
'Reaction',
|
||||
)
|
||||
__all__ = ("Reaction",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.message import Reaction as ReactionPayload
|
||||
@ -38,6 +36,7 @@ if TYPE_CHECKING:
|
||||
from .emoji import Emoji
|
||||
from .abc import Snowflake
|
||||
|
||||
|
||||
class Reaction:
|
||||
"""Represents a reaction to a message.
|
||||
|
||||
@ -75,13 +74,22 @@ class Reaction:
|
||||
message: :class:`Message`
|
||||
Message this reaction is for.
|
||||
"""
|
||||
__slots__ = ('message', 'count', 'emoji', 'me')
|
||||
|
||||
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None):
|
||||
__slots__ = ("message", "count", "emoji", "me")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message: Message,
|
||||
data: ReactionPayload,
|
||||
emoji: Optional[Union[PartialEmoji, Emoji, str]] = None,
|
||||
):
|
||||
self.message: Message = message
|
||||
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji'])
|
||||
self.count: int = data.get('count', 1)
|
||||
self.me: bool = data.get('me')
|
||||
self.emoji: Union[
|
||||
PartialEmoji, Emoji, str
|
||||
] = emoji or message._state.get_reaction_emoji(data["emoji"])
|
||||
self.count: int = data.get("count", 1)
|
||||
self.me: bool = data.get("me")
|
||||
|
||||
# TODO: typeguard
|
||||
def is_custom_emoji(self) -> bool:
|
||||
@ -103,7 +111,7 @@ class Reaction:
|
||||
return str(self.emoji)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>'
|
||||
return f"<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>"
|
||||
|
||||
async def remove(self, user: Snowflake) -> None:
|
||||
"""|coro|
|
||||
@ -155,7 +163,9 @@ class Reaction:
|
||||
"""
|
||||
await self.message.clear_reaction(self.emoji)
|
||||
|
||||
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator:
|
||||
def users(
|
||||
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
|
||||
) -> ReactionIterator:
|
||||
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
|
||||
|
||||
The ``after`` parameter must represent a member
|
||||
@ -201,7 +211,7 @@ class Reaction:
|
||||
"""
|
||||
|
||||
if not isinstance(self.emoji, str):
|
||||
emoji = f'{self.emoji.name}:{self.emoji.id}'
|
||||
emoji = f"{self.emoji.name}:{self.emoji.id}"
|
||||
else:
|
||||
emoji = self.emoji
|
||||
|
||||
|
106
discord/role.py
106
discord/role.py
@ -32,8 +32,8 @@ from .mixins import Hashable
|
||||
from .utils import snowflake_time, _get_as_snowflake, MISSING
|
||||
|
||||
__all__ = (
|
||||
'RoleTags',
|
||||
'Role',
|
||||
"RoleTags",
|
||||
"Role",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -68,19 +68,21 @@ class RoleTags:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'bot_id',
|
||||
'integration_id',
|
||||
'_premium_subscriber',
|
||||
"bot_id",
|
||||
"integration_id",
|
||||
"_premium_subscriber",
|
||||
)
|
||||
|
||||
def __init__(self, data: RoleTagPayload):
|
||||
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id')
|
||||
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id')
|
||||
self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id")
|
||||
self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id")
|
||||
# NOTE: The API returns "null" for this if it's valid, which corresponds to None.
|
||||
# This is different from other fields where "null" means "not there".
|
||||
# So in this case, a value of None is the same as True.
|
||||
# Which means we would need a different sentinel.
|
||||
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING)
|
||||
self._premium_subscriber: Optional[Any] = data.get(
|
||||
"premium_subscriber", MISSING
|
||||
)
|
||||
|
||||
def is_bot_managed(self) -> bool:
|
||||
""":class:`bool`: Whether the role is associated with a bot."""
|
||||
@ -96,12 +98,12 @@ class RoleTags:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} '
|
||||
f'premium_subscriber={self.is_premium_subscriber()}>'
|
||||
f"<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} "
|
||||
f"premium_subscriber={self.is_premium_subscriber()}>"
|
||||
)
|
||||
|
||||
|
||||
R = TypeVar('R', bound='Role')
|
||||
R = TypeVar("R", bound="Role")
|
||||
|
||||
|
||||
class Role(Hashable):
|
||||
@ -181,23 +183,23 @@ class Role(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'id',
|
||||
'name',
|
||||
'_permissions',
|
||||
'_colour',
|
||||
'position',
|
||||
'managed',
|
||||
'mentionable',
|
||||
'hoist',
|
||||
'guild',
|
||||
'tags',
|
||||
'_state',
|
||||
"id",
|
||||
"name",
|
||||
"_permissions",
|
||||
"_colour",
|
||||
"position",
|
||||
"managed",
|
||||
"mentionable",
|
||||
"hoist",
|
||||
"guild",
|
||||
"tags",
|
||||
"_state",
|
||||
)
|
||||
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload):
|
||||
self.guild: Guild = guild
|
||||
self._state: ConnectionState = state
|
||||
self.id: int = int(data['id'])
|
||||
self.id: int = int(data["id"])
|
||||
self._update(data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -207,14 +209,14 @@ class Role(Hashable):
|
||||
return self.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Role id={self.id} name={self.name!r}>'
|
||||
return f"<Role id={self.id} name={self.name!r}>"
|
||||
|
||||
def __lt__(self: R, other: R) -> bool:
|
||||
if not isinstance(other, Role) or not isinstance(self, Role):
|
||||
return NotImplemented
|
||||
|
||||
if self.guild != other.guild:
|
||||
raise RuntimeError('cannot compare roles from two different guilds.')
|
||||
raise RuntimeError("cannot compare roles from two different guilds.")
|
||||
|
||||
# the @everyone role is always the lowest role in hierarchy
|
||||
guild_id = self.guild.id
|
||||
@ -246,17 +248,17 @@ class Role(Hashable):
|
||||
return not r
|
||||
|
||||
def _update(self, data: RolePayload):
|
||||
self.name: str = data['name']
|
||||
self._permissions: int = int(data.get('permissions', 0))
|
||||
self.position: int = data.get('position', 0)
|
||||
self._colour: int = data.get('color', 0)
|
||||
self.hoist: bool = data.get('hoist', False)
|
||||
self.managed: bool = data.get('managed', False)
|
||||
self.mentionable: bool = data.get('mentionable', False)
|
||||
self.name: str = data["name"]
|
||||
self._permissions: int = int(data.get("permissions", 0))
|
||||
self.position: int = data.get("position", 0)
|
||||
self._colour: int = data.get("color", 0)
|
||||
self.hoist: bool = data.get("hoist", False)
|
||||
self.managed: bool = data.get("managed", False)
|
||||
self.mentionable: bool = data.get("mentionable", False)
|
||||
self.tags: Optional[RoleTags]
|
||||
|
||||
try:
|
||||
self.tags = RoleTags(data['tags'])
|
||||
self.tags = RoleTags(data["tags"])
|
||||
except KeyError:
|
||||
self.tags = None
|
||||
|
||||
@ -291,7 +293,11 @@ class Role(Hashable):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
me = self.guild.me
|
||||
return not self.is_default() and not self.managed and (me.top_role > self or me.id == self.guild.owner_id)
|
||||
return (
|
||||
not self.is_default()
|
||||
and not self.managed
|
||||
and (me.top_role > self or me.id == self.guild.owner_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def permissions(self) -> Permissions:
|
||||
@ -316,7 +322,7 @@ class Role(Hashable):
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: Returns a string that allows you to mention a role."""
|
||||
return f'<@&{self.id}>'
|
||||
return f"<@&{self.id}>"
|
||||
|
||||
@property
|
||||
def members(self) -> List[Member]:
|
||||
@ -340,15 +346,23 @@ class Role(Hashable):
|
||||
|
||||
http = self._state.http
|
||||
|
||||
change_range = range(min(self.position, position), max(self.position, position) + 1)
|
||||
roles = [r.id for r in self.guild.roles[1:] if r.position in change_range and r.id != self.id]
|
||||
change_range = range(
|
||||
min(self.position, position), max(self.position, position) + 1
|
||||
)
|
||||
roles = [
|
||||
r.id
|
||||
for r in self.guild.roles[1:]
|
||||
if r.position in change_range and r.id != self.id
|
||||
]
|
||||
|
||||
if self.position > position:
|
||||
roles.insert(0, self.id)
|
||||
else:
|
||||
roles.append(self.id)
|
||||
|
||||
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
payload: List[RolePositionUpdate] = [
|
||||
{"id": z[0], "position": z[1]} for z in zip(roles, change_range)
|
||||
]
|
||||
await http.move_role_position(self.guild.id, payload, reason=reason)
|
||||
|
||||
async def edit(
|
||||
@ -420,23 +434,25 @@ class Role(Hashable):
|
||||
|
||||
if colour is not MISSING:
|
||||
if isinstance(colour, int):
|
||||
payload['color'] = colour
|
||||
payload["color"] = colour
|
||||
else:
|
||||
payload['color'] = colour.value
|
||||
payload["color"] = colour.value
|
||||
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
payload["name"] = name
|
||||
|
||||
if permissions is not MISSING:
|
||||
payload['permissions'] = permissions.value
|
||||
payload["permissions"] = permissions.value
|
||||
|
||||
if hoist is not MISSING:
|
||||
payload['hoist'] = hoist
|
||||
payload["hoist"] = hoist
|
||||
|
||||
if mentionable is not MISSING:
|
||||
payload['mentionable'] = mentionable
|
||||
payload["mentionable"] = mentionable
|
||||
|
||||
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
|
||||
data = await self._state.http.edit_role(
|
||||
self.guild.id, self.id, reason=reason, **payload
|
||||
)
|
||||
return Role(guild=self.guild, data=data, state=self._state)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
|
111
discord/shard.py
111
discord/shard.py
@ -43,18 +43,28 @@ from .errors import (
|
||||
|
||||
from .enums import Status
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Tuple,
|
||||
Type,
|
||||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gateway import DiscordWebSocket
|
||||
from .activity import BaseActivity
|
||||
from .enums import Status
|
||||
|
||||
EI = TypeVar('EI', bound='EventItem')
|
||||
EI = TypeVar("EI", bound="EventItem")
|
||||
|
||||
__all__ = (
|
||||
'AutoShardedClient',
|
||||
'ShardInfo',
|
||||
"AutoShardedClient",
|
||||
"ShardInfo",
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@ -70,11 +80,13 @@ class EventType:
|
||||
|
||||
|
||||
class EventItem:
|
||||
__slots__ = ('type', 'shard', 'error')
|
||||
__slots__ = ("type", "shard", "error")
|
||||
|
||||
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
|
||||
def __init__(
|
||||
self, etype: int, shard: Optional["Shard"], error: Optional[Exception]
|
||||
) -> None:
|
||||
self.type: int = etype
|
||||
self.shard: Optional['Shard'] = shard
|
||||
self.shard: Optional["Shard"] = shard
|
||||
self.error: Optional[Exception] = error
|
||||
|
||||
def __lt__(self: EI, other: EI) -> bool:
|
||||
@ -92,7 +104,12 @@ class EventItem:
|
||||
|
||||
|
||||
class Shard:
|
||||
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
ws: DiscordWebSocket,
|
||||
client: AutoShardedClient,
|
||||
queue_put: Callable[[EventItem], None],
|
||||
) -> None:
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._client: Client = client
|
||||
self._dispatch: Callable[..., None] = client.dispatch
|
||||
@ -129,11 +146,11 @@ class Shard:
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
await self.close()
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
|
||||
async def _handle_disconnect(self, e: Exception) -> None:
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
self._dispatch("disconnect")
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
if not self._reconnect:
|
||||
self._queue_put(EventItem(EventType.close, self, e))
|
||||
return
|
||||
@ -149,14 +166,23 @@ class Shard:
|
||||
|
||||
if isinstance(e, ConnectionClosed):
|
||||
if e.code == 4014:
|
||||
self._queue_put(EventItem(EventType.terminate, self, PrivilegedIntentsRequired(self.id)))
|
||||
self._queue_put(
|
||||
EventItem(
|
||||
EventType.terminate, self, PrivilegedIntentsRequired(self.id)
|
||||
)
|
||||
)
|
||||
return
|
||||
if e.code != 1000:
|
||||
self._queue_put(EventItem(EventType.close, self, e))
|
||||
return
|
||||
|
||||
retry = self._backoff.delay()
|
||||
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
|
||||
_log.error(
|
||||
"Attempting a reconnect for shard ID %s in %.2fs",
|
||||
self.id,
|
||||
retry,
|
||||
exc_info=e,
|
||||
)
|
||||
await asyncio.sleep(retry)
|
||||
self._queue_put(EventItem(EventType.reconnect, self, e))
|
||||
|
||||
@ -179,9 +205,9 @@ class Shard:
|
||||
|
||||
async def reidentify(self, exc: ReconnectWebSocket) -> None:
|
||||
self._cancel_task()
|
||||
self._dispatch('disconnect')
|
||||
self._dispatch('shard_disconnect', self.id)
|
||||
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
self._dispatch("disconnect")
|
||||
self._dispatch("shard_disconnect", self.id)
|
||||
_log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self._client,
|
||||
@ -231,7 +257,7 @@ class ShardInfo:
|
||||
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
|
||||
"""
|
||||
|
||||
__slots__ = ('_parent', 'id', 'shard_count')
|
||||
__slots__ = ("_parent", "id", "shard_count")
|
||||
|
||||
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
|
||||
self._parent: Shard = parent
|
||||
@ -320,16 +346,23 @@ class AutoShardedClient(Client):
|
||||
if TYPE_CHECKING:
|
||||
_connection: AutoShardedConnectionState
|
||||
|
||||
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
|
||||
kwargs.pop('shard_id', None)
|
||||
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
kwargs.pop("shard_id", None)
|
||||
self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None)
|
||||
super().__init__(*args, loop=loop, **kwargs)
|
||||
|
||||
if self.shard_ids is not None:
|
||||
if self.shard_count is None:
|
||||
raise ClientException('When passing manual shard_ids, you must provide a shard_count.')
|
||||
raise ClientException(
|
||||
"When passing manual shard_ids, you must provide a shard_count."
|
||||
)
|
||||
elif not isinstance(self.shard_ids, (list, tuple)):
|
||||
raise ClientException('shard_ids parameter must be a list or a tuple.')
|
||||
raise ClientException("shard_ids parameter must be a list or a tuple.")
|
||||
|
||||
# instead of a single websocket, we have multiple
|
||||
# the key is the shard_id
|
||||
@ -338,7 +371,9 @@ class AutoShardedClient(Client):
|
||||
self._connection._get_client = lambda: self
|
||||
self.__queue = asyncio.PriorityQueue()
|
||||
|
||||
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
|
||||
def _get_websocket(
|
||||
self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None
|
||||
) -> DiscordWebSocket:
|
||||
if shard_id is None:
|
||||
# guild_id won't be None if shard_id is None and shard_count won't be None here
|
||||
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
|
||||
@ -363,7 +398,7 @@ class AutoShardedClient(Client):
|
||||
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
|
||||
"""
|
||||
if not self.__shards:
|
||||
return float('nan')
|
||||
return float("nan")
|
||||
return sum(latency for _, latency in self.latencies) / len(self.__shards)
|
||||
|
||||
@property
|
||||
@ -372,7 +407,9 @@ class AutoShardedClient(Client):
|
||||
|
||||
This returns a list of tuples with elements ``(shard_id, latency)``.
|
||||
"""
|
||||
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
|
||||
return [
|
||||
(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()
|
||||
]
|
||||
|
||||
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
|
||||
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
|
||||
@ -386,14 +423,21 @@ class AutoShardedClient(Client):
|
||||
@property
|
||||
def shards(self) -> Dict[int, ShardInfo]:
|
||||
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
|
||||
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
|
||||
return {
|
||||
shard_id: ShardInfo(parent, self.shard_count)
|
||||
for shard_id, parent in self.__shards.items()
|
||||
}
|
||||
|
||||
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
|
||||
async def launch_shard(
|
||||
self, gateway: str, shard_id: int, *, initial: bool = False
|
||||
) -> None:
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
|
||||
coro = DiscordWebSocket.from_client(
|
||||
self, initial=initial, gateway=gateway, shard_id=shard_id
|
||||
)
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
except Exception:
|
||||
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
_log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
return await self.launch_shard(gateway, shard_id)
|
||||
|
||||
@ -458,7 +502,10 @@ class AutoShardedClient(Client):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
|
||||
to_close = [
|
||||
asyncio.ensure_future(shard.close(), loop=self.loop)
|
||||
for shard in self.__shards.values()
|
||||
]
|
||||
if to_close:
|
||||
await asyncio.wait(to_close)
|
||||
|
||||
@ -503,10 +550,10 @@ class AutoShardedClient(Client):
|
||||
"""
|
||||
|
||||
if status is None:
|
||||
status_value = 'online'
|
||||
status_value = "online"
|
||||
status_enum = Status.online
|
||||
elif status is Status.offline:
|
||||
status_value = 'invisible'
|
||||
status_value = "invisible"
|
||||
status_enum = Status.offline
|
||||
else:
|
||||
status_enum = status
|
||||
|
@ -31,9 +31,7 @@ from .mixins import Hashable
|
||||
from .errors import InvalidArgument
|
||||
from .enums import StagePrivacyLevel, try_enum
|
||||
|
||||
__all__ = (
|
||||
'StageInstance',
|
||||
)
|
||||
__all__ = ("StageInstance",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types.channel import StageInstance as StageInstancePayload
|
||||
@ -82,41 +80,51 @@ class StageInstance(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'id',
|
||||
'guild',
|
||||
'channel_id',
|
||||
'topic',
|
||||
'privacy_level',
|
||||
'discoverable_disabled',
|
||||
'_cs_channel',
|
||||
"_state",
|
||||
"id",
|
||||
"guild",
|
||||
"channel_id",
|
||||
"topic",
|
||||
"privacy_level",
|
||||
"discoverable_disabled",
|
||||
"_cs_channel",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
|
||||
def __init__(
|
||||
self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload
|
||||
) -> None:
|
||||
self._state = state
|
||||
self.guild = guild
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: StageInstancePayload):
|
||||
self.id: int = int(data['id'])
|
||||
self.channel_id: int = int(data['channel_id'])
|
||||
self.topic: str = data['topic']
|
||||
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
|
||||
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
|
||||
self.id: int = int(data["id"])
|
||||
self.channel_id: int = int(data["channel_id"])
|
||||
self.topic: str = data["topic"]
|
||||
self.privacy_level: StagePrivacyLevel = try_enum(
|
||||
StagePrivacyLevel, data["privacy_level"]
|
||||
)
|
||||
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
|
||||
return f"<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>"
|
||||
|
||||
@cached_slot_property('_cs_channel')
|
||||
@cached_slot_property("_cs_channel")
|
||||
def channel(self) -> Optional[StageChannel]:
|
||||
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
|
||||
# the returned channel will always be a StageChannel or None
|
||||
return self._state.get_channel(self.channel_id) # type: ignore
|
||||
return self._state.get_channel(self.channel_id) # type: ignore
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return self.privacy_level is StagePrivacyLevel.public
|
||||
|
||||
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
|
||||
async def edit(
|
||||
self,
|
||||
*,
|
||||
topic: str = MISSING,
|
||||
privacy_level: StagePrivacyLevel = MISSING,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
Edits the stage instance.
|
||||
@ -146,16 +154,20 @@ class StageInstance(Hashable):
|
||||
payload = {}
|
||||
|
||||
if topic is not MISSING:
|
||||
payload['topic'] = topic
|
||||
payload["topic"] = topic
|
||||
|
||||
if privacy_level is not MISSING:
|
||||
if not isinstance(privacy_level, StagePrivacyLevel):
|
||||
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
|
||||
raise InvalidArgument(
|
||||
"privacy_level field must be of type PrivacyLevel"
|
||||
)
|
||||
|
||||
payload['privacy_level'] = privacy_level.value
|
||||
payload["privacy_level"] = privacy_level.value
|
||||
|
||||
if payload:
|
||||
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
|
||||
await self._state.http.edit_stage_instance(
|
||||
self.channel_id, **payload, reason=reason
|
||||
)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
"""|coro|
|
||||
|
845
discord/state.py
845
discord/state.py
File diff suppressed because it is too large
Load Diff
@ -33,11 +33,11 @@ from .errors import InvalidData
|
||||
from .enums import StickerType, StickerFormatType, try_enum
|
||||
|
||||
__all__ = (
|
||||
'StickerPack',
|
||||
'StickerItem',
|
||||
'Sticker',
|
||||
'StandardSticker',
|
||||
'GuildSticker',
|
||||
"StickerPack",
|
||||
"StickerItem",
|
||||
"Sticker",
|
||||
"StandardSticker",
|
||||
"GuildSticker",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -102,15 +102,15 @@ class StickerPack(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'id',
|
||||
'stickers',
|
||||
'name',
|
||||
'sku_id',
|
||||
'cover_sticker_id',
|
||||
'cover_sticker',
|
||||
'description',
|
||||
'_banner',
|
||||
"_state",
|
||||
"id",
|
||||
"stickers",
|
||||
"name",
|
||||
"sku_id",
|
||||
"cover_sticker_id",
|
||||
"cover_sticker",
|
||||
"description",
|
||||
"_banner",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
|
||||
@ -118,15 +118,17 @@ class StickerPack(Hashable):
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: StickerPackPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
stickers = data['stickers']
|
||||
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
|
||||
self.name: str = data['name']
|
||||
self.sku_id: int = int(data['sku_id'])
|
||||
self.cover_sticker_id: int = int(data['cover_sticker_id'])
|
||||
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
|
||||
self.description: str = data['description']
|
||||
self._banner: int = int(data['banner_asset_id'])
|
||||
self.id: int = int(data["id"])
|
||||
stickers = data["stickers"]
|
||||
self.stickers: List[StandardSticker] = [
|
||||
StandardSticker(state=self._state, data=sticker) for sticker in stickers
|
||||
]
|
||||
self.name: str = data["name"]
|
||||
self.sku_id: int = int(data["sku_id"])
|
||||
self.cover_sticker_id: int = int(data["cover_sticker_id"])
|
||||
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
|
||||
self.description: str = data["description"]
|
||||
self._banner: int = int(data["banner_asset_id"])
|
||||
|
||||
@property
|
||||
def banner(self) -> Asset:
|
||||
@ -134,7 +136,7 @@ class StickerPack(Hashable):
|
||||
return Asset._from_sticker_banner(self._state, self._banner)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
|
||||
return f"<StickerPack id={self.id} name={self.name!r} description={self.description!r}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@ -205,17 +207,19 @@ class StickerItem(_StickerTag):
|
||||
The URL for the sticker's image.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'name', 'id', 'format', 'url')
|
||||
__slots__ = ("_state", "name", "id", "format", "url")
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
|
||||
self._state: ConnectionState = state
|
||||
self.name: str = data['name']
|
||||
self.id: int = int(data['id'])
|
||||
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
|
||||
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
|
||||
self.name: str = data["name"]
|
||||
self.id: int = int(data["id"])
|
||||
self.format: StickerFormatType = try_enum(
|
||||
StickerFormatType, data["format_type"]
|
||||
)
|
||||
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
|
||||
return f"<StickerItem id={self.id} name={self.name!r} format={self.format}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@ -236,7 +240,7 @@ class StickerItem(_StickerTag):
|
||||
The retrieved sticker.
|
||||
"""
|
||||
data: StickerPayload = await self._state.http.get_sticker(self.id)
|
||||
cls, _ = _sticker_factory(data['type']) # type: ignore
|
||||
cls, _ = _sticker_factory(data["type"]) # type: ignore
|
||||
return cls(state=self._state, data=data)
|
||||
|
||||
|
||||
@ -275,21 +279,23 @@ class Sticker(_StickerTag):
|
||||
The URL for the sticker's image.
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
|
||||
__slots__ = ("_state", "id", "name", "description", "format", "url")
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
|
||||
self._state: ConnectionState = state
|
||||
self._from_data(data)
|
||||
|
||||
def _from_data(self, data: StickerPayload) -> None:
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self.description: str = data['description']
|
||||
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
|
||||
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self.description: str = data["description"]
|
||||
self.format: StickerFormatType = try_enum(
|
||||
StickerFormatType, data["format_type"]
|
||||
)
|
||||
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Sticker id={self.id} name={self.name!r}>'
|
||||
return f"<Sticker id={self.id} name={self.name!r}>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@ -337,21 +343,23 @@ class StandardSticker(Sticker):
|
||||
The sticker's sort order within its pack.
|
||||
"""
|
||||
|
||||
__slots__ = ('sort_value', 'pack_id', 'type', 'tags')
|
||||
__slots__ = ("sort_value", "pack_id", "type", "tags")
|
||||
|
||||
def _from_data(self, data: StandardStickerPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.sort_value: int = data['sort_value']
|
||||
self.pack_id: int = int(data['pack_id'])
|
||||
self.sort_value: int = data["sort_value"]
|
||||
self.pack_id: int = int(data["pack_id"])
|
||||
self.type: StickerType = StickerType.standard
|
||||
|
||||
try:
|
||||
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
|
||||
self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")]
|
||||
except KeyError:
|
||||
self.tags = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
|
||||
return (
|
||||
f"<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>"
|
||||
)
|
||||
|
||||
async def pack(self) -> StickerPack:
|
||||
"""|coro|
|
||||
@ -370,13 +378,15 @@ class StandardSticker(Sticker):
|
||||
:class:`StickerPack`
|
||||
The retrieved sticker pack.
|
||||
"""
|
||||
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
|
||||
packs = data['sticker_packs']
|
||||
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
|
||||
data: ListPremiumStickerPacksPayload = (
|
||||
await self._state.http.list_premium_sticker_packs()
|
||||
)
|
||||
packs = data["sticker_packs"]
|
||||
pack = find(lambda d: int(d["id"]) == self.pack_id, packs)
|
||||
|
||||
if pack:
|
||||
return StickerPack(state=self._state, data=pack)
|
||||
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
|
||||
raise InvalidData(f"Could not find corresponding sticker pack for {self!r}")
|
||||
|
||||
|
||||
class GuildSticker(Sticker):
|
||||
@ -419,21 +429,21 @@ class GuildSticker(Sticker):
|
||||
The name of a unicode emoji that represents this sticker.
|
||||
"""
|
||||
|
||||
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
|
||||
__slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild")
|
||||
|
||||
def _from_data(self, data: GuildStickerPayload) -> None:
|
||||
super()._from_data(data)
|
||||
self.available: bool = data['available']
|
||||
self.guild_id: int = int(data['guild_id'])
|
||||
user = data.get('user')
|
||||
self.available: bool = data["available"]
|
||||
self.guild_id: int = int(data["guild_id"])
|
||||
user = data.get("user")
|
||||
self.user: Optional[User] = self._state.store_user(user) if user else None
|
||||
self.emoji: str = data['tags']
|
||||
self.emoji: str = data["tags"]
|
||||
self.type: StickerType = StickerType.guild
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
|
||||
return f"<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>"
|
||||
|
||||
@cached_slot_property('_cs_guild')
|
||||
@cached_slot_property("_cs_guild")
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`Guild`]: The guild that this sticker is from.
|
||||
Could be ``None`` if the bot is not in the guild.
|
||||
@ -480,10 +490,10 @@ class GuildSticker(Sticker):
|
||||
payload: EditGuildSticker = {}
|
||||
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
payload["name"] = name
|
||||
|
||||
if description is not MISSING:
|
||||
payload['description'] = description
|
||||
payload["description"] = description
|
||||
|
||||
if emoji is not MISSING:
|
||||
try:
|
||||
@ -491,11 +501,13 @@ class GuildSticker(Sticker):
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
emoji = emoji.replace(' ', '_')
|
||||
emoji = emoji.replace(" ", "_")
|
||||
|
||||
payload['tags'] = emoji
|
||||
payload["tags"] = emoji
|
||||
|
||||
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
|
||||
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(
|
||||
self.guild_id, self.id, payload, reason
|
||||
)
|
||||
return GuildSticker(state=self._state, data=data)
|
||||
|
||||
async def delete(self, *, reason: Optional[str] = None) -> None:
|
||||
@ -521,7 +533,9 @@ class GuildSticker(Sticker):
|
||||
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
|
||||
|
||||
|
||||
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
|
||||
def _sticker_factory(
|
||||
sticker_type: Literal[1, 2]
|
||||
) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
|
||||
value = try_enum(StickerType, sticker_type)
|
||||
if value == StickerType.standard:
|
||||
return StandardSticker, value
|
||||
|
@ -40,8 +40,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'Team',
|
||||
'TeamMember',
|
||||
"Team",
|
||||
"TeamMember",
|
||||
)
|
||||
|
||||
|
||||
@ -62,26 +62,28 @@ class Team:
|
||||
.. versionadded:: 1.3
|
||||
"""
|
||||
|
||||
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members')
|
||||
__slots__ = ("_state", "id", "name", "_icon", "owner_id", "members")
|
||||
|
||||
def __init__(self, state: ConnectionState, data: TeamPayload):
|
||||
self._state: ConnectionState = state
|
||||
|
||||
self.id: int = int(data['id'])
|
||||
self.name: str = data['name']
|
||||
self._icon: Optional[str] = data['icon']
|
||||
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id')
|
||||
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']]
|
||||
self.id: int = int(data["id"])
|
||||
self.name: str = data["name"]
|
||||
self._icon: Optional[str] = data["icon"]
|
||||
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id")
|
||||
self.members: List[TeamMember] = [
|
||||
TeamMember(self, self._state, member) for member in data["members"]
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
|
||||
return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional[Asset]:
|
||||
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any."""
|
||||
if self._icon is None:
|
||||
return None
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path='team')
|
||||
return Asset._from_icon(self._state, self.id, self._icon, path="team")
|
||||
|
||||
@property
|
||||
def owner(self) -> Optional[TeamMember]:
|
||||
@ -130,16 +132,18 @@ class TeamMember(BaseUser):
|
||||
The membership state of the member (e.g. invited or accepted)
|
||||
"""
|
||||
|
||||
__slots__ = ('team', 'membership_state', 'permissions')
|
||||
__slots__ = ("team", "membership_state", "permissions")
|
||||
|
||||
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload):
|
||||
self.team: Team = team
|
||||
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
|
||||
self.permissions: List[str] = data['permissions']
|
||||
super().__init__(state=state, data=data['user'])
|
||||
self.membership_state: TeamMembershipState = try_enum(
|
||||
TeamMembershipState, data["membership_state"]
|
||||
)
|
||||
self.permissions: List[str] = data["permissions"]
|
||||
super().__init__(state=state, data=data["user"])
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
|
||||
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
|
||||
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
|
||||
f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>"
|
||||
)
|
||||
|
@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
|
||||
from .enums import VoiceRegion
|
||||
from .guild import Guild
|
||||
|
||||
__all__ = (
|
||||
'Template',
|
||||
)
|
||||
__all__ = ("Template",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datetime
|
||||
@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper:
|
||||
__slots__ = ()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
raise AttributeError('PartialTemplateState does not support http methods.')
|
||||
raise AttributeError("PartialTemplateState does not support http methods.")
|
||||
|
||||
|
||||
class _PartialTemplateState:
|
||||
@ -84,7 +82,7 @@ class _PartialTemplateState:
|
||||
return []
|
||||
|
||||
def __getattr__(self, attr):
|
||||
raise AttributeError(f'PartialTemplateState does not support {attr!r}.')
|
||||
raise AttributeError(f"PartialTemplateState does not support {attr!r}.")
|
||||
|
||||
|
||||
class Template:
|
||||
@ -118,16 +116,16 @@ class Template:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'code',
|
||||
'uses',
|
||||
'name',
|
||||
'description',
|
||||
'creator',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
'source_guild',
|
||||
'is_dirty',
|
||||
'_state',
|
||||
"code",
|
||||
"uses",
|
||||
"name",
|
||||
"description",
|
||||
"creator",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"source_guild",
|
||||
"is_dirty",
|
||||
"_state",
|
||||
)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
|
||||
@ -135,38 +133,46 @@ class Template:
|
||||
self._store(data)
|
||||
|
||||
def _store(self, data: TemplatePayload) -> None:
|
||||
self.code: str = data['code']
|
||||
self.uses: int = data['usage_count']
|
||||
self.name: str = data['name']
|
||||
self.description: Optional[str] = data['description']
|
||||
creator_data = data.get('creator')
|
||||
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
|
||||
self.code: str = data["code"]
|
||||
self.uses: int = data["usage_count"]
|
||||
self.name: str = data["name"]
|
||||
self.description: Optional[str] = data["description"]
|
||||
creator_data = data.get("creator")
|
||||
self.creator: Optional[User] = (
|
||||
None if creator_data is None else self._state.create_user(creator_data)
|
||||
)
|
||||
|
||||
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
|
||||
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
|
||||
self.created_at: Optional[datetime.datetime] = parse_time(
|
||||
data.get("created_at")
|
||||
)
|
||||
self.updated_at: Optional[datetime.datetime] = parse_time(
|
||||
data.get("updated_at")
|
||||
)
|
||||
|
||||
guild_id = int(data['source_guild_id'])
|
||||
guild_id = int(data["source_guild_id"])
|
||||
guild: Optional[Guild] = self._state._get_guild(guild_id)
|
||||
|
||||
self.source_guild: Guild
|
||||
if guild is None:
|
||||
source_serialised = data['serialized_source_guild']
|
||||
source_serialised['id'] = guild_id
|
||||
source_serialised = data["serialized_source_guild"]
|
||||
source_serialised["id"] = guild_id
|
||||
state = _PartialTemplateState(state=self._state)
|
||||
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
|
||||
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
|
||||
else:
|
||||
self.source_guild = guild
|
||||
|
||||
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
|
||||
self.is_dirty: Optional[bool] = data.get("is_dirty", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}'
|
||||
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
|
||||
f"<Template code={self.code!r} uses={self.uses} name={self.name!r}"
|
||||
f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>"
|
||||
)
|
||||
|
||||
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
|
||||
async def create_guild(
|
||||
self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None
|
||||
) -> Guild:
|
||||
"""|coro|
|
||||
|
||||
Creates a :class:`.Guild` using the template.
|
||||
@ -203,7 +209,9 @@ class Template:
|
||||
region = region or VoiceRegion.us_west
|
||||
region_value = region.value
|
||||
|
||||
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
|
||||
data = await self._state.http.create_from_template(
|
||||
self.code, name, region_value, icon
|
||||
)
|
||||
return Guild(data=data, state=self._state)
|
||||
|
||||
async def sync(self) -> Template:
|
||||
@ -279,11 +287,13 @@ class Template:
|
||||
payload = {}
|
||||
|
||||
if name is not MISSING:
|
||||
payload['name'] = name
|
||||
payload["name"] = name
|
||||
if description is not MISSING:
|
||||
payload['description'] = description
|
||||
payload["description"] = description
|
||||
|
||||
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
|
||||
data = await self._state.http.edit_template(
|
||||
self.source_guild.id, self.code, payload
|
||||
)
|
||||
return Template(state=self._state, data=data)
|
||||
|
||||
async def delete(self) -> None:
|
||||
@ -310,7 +320,7 @@ class Template:
|
||||
@property
|
||||
def url(self) -> str:
|
||||
""":class:`str`: The template url.
|
||||
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return f'https://discord.new/{self.code}'
|
||||
return f"https://discord.new/{self.code}"
|
||||
|
@ -35,8 +35,8 @@ from .errors import ClientException
|
||||
from .utils import MISSING, parse_time, _get_as_snowflake
|
||||
|
||||
__all__ = (
|
||||
'Thread',
|
||||
'ThreadMember',
|
||||
"Thread",
|
||||
"ThreadMember",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -128,25 +128,25 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'guild',
|
||||
'_type',
|
||||
'_state',
|
||||
'_members',
|
||||
'owner_id',
|
||||
'parent_id',
|
||||
'last_message_id',
|
||||
'message_count',
|
||||
'member_count',
|
||||
'slowmode_delay',
|
||||
'me',
|
||||
'locked',
|
||||
'archived',
|
||||
'invitable',
|
||||
'archiver_id',
|
||||
'auto_archive_duration',
|
||||
'archive_timestamp',
|
||||
"name",
|
||||
"id",
|
||||
"guild",
|
||||
"_type",
|
||||
"_state",
|
||||
"_members",
|
||||
"owner_id",
|
||||
"parent_id",
|
||||
"last_message_id",
|
||||
"message_count",
|
||||
"member_count",
|
||||
"slowmode_delay",
|
||||
"me",
|
||||
"locked",
|
||||
"archived",
|
||||
"invitable",
|
||||
"archiver_id",
|
||||
"auto_archive_duration",
|
||||
"archive_timestamp",
|
||||
)
|
||||
|
||||
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
|
||||
@ -160,50 +160,50 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}'
|
||||
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>'
|
||||
f"<Thread id={self.id!r} name={self.name!r} parent={self.parent}"
|
||||
f" owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def _from_data(self, data: ThreadPayload):
|
||||
self.id = int(data['id'])
|
||||
self.parent_id = int(data['parent_id'])
|
||||
self.owner_id = int(data['owner_id'])
|
||||
self.name = data['name']
|
||||
self._type = try_enum(ChannelType, data['type'])
|
||||
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
|
||||
self.slowmode_delay = data.get('rate_limit_per_user', 0)
|
||||
self.message_count = data['message_count']
|
||||
self.member_count = data['member_count']
|
||||
self._unroll_metadata(data['thread_metadata'])
|
||||
self.id = int(data["id"])
|
||||
self.parent_id = int(data["parent_id"])
|
||||
self.owner_id = int(data["owner_id"])
|
||||
self.name = data["name"]
|
||||
self._type = try_enum(ChannelType, data["type"])
|
||||
self.last_message_id = _get_as_snowflake(data, "last_message_id")
|
||||
self.slowmode_delay = data.get("rate_limit_per_user", 0)
|
||||
self.message_count = data["message_count"]
|
||||
self.member_count = data["member_count"]
|
||||
self._unroll_metadata(data["thread_metadata"])
|
||||
|
||||
try:
|
||||
member = data['member']
|
||||
member = data["member"]
|
||||
except KeyError:
|
||||
self.me = None
|
||||
else:
|
||||
self.me = ThreadMember(self, member)
|
||||
|
||||
def _unroll_metadata(self, data: ThreadMetadata):
|
||||
self.archived = data['archived']
|
||||
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
|
||||
self.auto_archive_duration = data['auto_archive_duration']
|
||||
self.archive_timestamp = parse_time(data['archive_timestamp'])
|
||||
self.locked = data.get('locked', False)
|
||||
self.invitable = data.get('invitable', True)
|
||||
self.archived = data["archived"]
|
||||
self.archiver_id = _get_as_snowflake(data, "archiver_id")
|
||||
self.auto_archive_duration = data["auto_archive_duration"]
|
||||
self.archive_timestamp = parse_time(data["archive_timestamp"])
|
||||
self.locked = data.get("locked", False)
|
||||
self.invitable = data.get("invitable", True)
|
||||
|
||||
def _update(self, data):
|
||||
try:
|
||||
self.name = data['name']
|
||||
self.name = data["name"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
self.slowmode_delay = data.get('rate_limit_per_user', 0)
|
||||
self.slowmode_delay = data.get("rate_limit_per_user", 0)
|
||||
|
||||
try:
|
||||
self._unroll_metadata(data['thread_metadata'])
|
||||
self._unroll_metadata(data["thread_metadata"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@ -225,7 +225,7 @@ class Thread(Messageable, Hashable):
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: The string that allows you to mention the thread."""
|
||||
return f'<#{self.id}>'
|
||||
return f"<#{self.id}>"
|
||||
|
||||
@property
|
||||
def members(self) -> List[ThreadMember]:
|
||||
@ -256,7 +256,11 @@ class Thread(Messageable, Hashable):
|
||||
Optional[:class:`Message`]
|
||||
The last message in this channel or ``None`` if not found.
|
||||
"""
|
||||
return self._state._get_message(self.last_message_id) if self.last_message_id else None
|
||||
return (
|
||||
self._state._get_message(self.last_message_id)
|
||||
if self.last_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def category(self) -> Optional[CategoryChannel]:
|
||||
@ -275,9 +279,9 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
parent = self.parent
|
||||
if parent is None:
|
||||
raise ClientException('Parent channel not found')
|
||||
raise ClientException("Parent channel not found")
|
||||
return parent.category
|
||||
|
||||
|
||||
@property
|
||||
def category_id(self) -> Optional[int]:
|
||||
"""The category channel ID the parent channel belongs to, if applicable.
|
||||
@ -295,7 +299,7 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
parent = self.parent
|
||||
if parent is None:
|
||||
raise ClientException('Parent channel not found')
|
||||
raise ClientException("Parent channel not found")
|
||||
return parent.category_id
|
||||
|
||||
def is_private(self) -> bool:
|
||||
@ -352,7 +356,7 @@ class Thread(Messageable, Hashable):
|
||||
|
||||
parent = self.parent
|
||||
if parent is None:
|
||||
raise ClientException('Parent channel not found')
|
||||
raise ClientException("Parent channel not found")
|
||||
return parent.permissions_for(obj)
|
||||
|
||||
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
|
||||
@ -402,7 +406,7 @@ class Thread(Messageable, Hashable):
|
||||
return
|
||||
|
||||
if len(messages) > 100:
|
||||
raise ClientException('Can only bulk delete messages up to 100 messages')
|
||||
raise ClientException("Can only bulk delete messages up to 100 messages")
|
||||
|
||||
message_ids: SnowflakeList = [m.id for m in messages]
|
||||
await self._state.http.delete_messages(self.id, message_ids)
|
||||
@ -477,11 +481,19 @@ class Thread(Messageable, Hashable):
|
||||
if check is MISSING:
|
||||
check = lambda m: True
|
||||
|
||||
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
|
||||
iterator = self.history(
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
oldest_first=oldest_first,
|
||||
around=around,
|
||||
)
|
||||
ret: List[Message] = []
|
||||
count = 0
|
||||
|
||||
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
|
||||
minimum_time = (
|
||||
int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
|
||||
)
|
||||
|
||||
async def _single_delete_strategy(messages: Iterable[Message]):
|
||||
for m in messages:
|
||||
@ -577,17 +589,17 @@ class Thread(Messageable, Hashable):
|
||||
"""
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
payload['name'] = str(name)
|
||||
payload["name"] = str(name)
|
||||
if archived is not MISSING:
|
||||
payload['archived'] = archived
|
||||
payload["archived"] = archived
|
||||
if auto_archive_duration is not MISSING:
|
||||
payload['auto_archive_duration'] = auto_archive_duration
|
||||
payload["auto_archive_duration"] = auto_archive_duration
|
||||
if locked is not MISSING:
|
||||
payload['locked'] = locked
|
||||
payload["locked"] = locked
|
||||
if invitable is not MISSING:
|
||||
payload['invitable'] = invitable
|
||||
payload["invitable"] = invitable
|
||||
if slowmode_delay is not MISSING:
|
||||
payload['rate_limit_per_user'] = slowmode_delay
|
||||
payload["rate_limit_per_user"] = slowmode_delay
|
||||
|
||||
data = await self._state.http.edit_channel(self.id, **payload)
|
||||
# The data payload will always be a Thread payload
|
||||
@ -773,12 +785,12 @@ class ThreadMember(Hashable):
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'id',
|
||||
'thread_id',
|
||||
'joined_at',
|
||||
'flags',
|
||||
'_state',
|
||||
'parent',
|
||||
"id",
|
||||
"thread_id",
|
||||
"joined_at",
|
||||
"flags",
|
||||
"_state",
|
||||
"parent",
|
||||
)
|
||||
|
||||
def __init__(self, parent: Thread, data: ThreadMemberPayload):
|
||||
@ -787,22 +799,22 @@ class ThreadMember(Hashable):
|
||||
self._from_data(data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
|
||||
return f"<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>"
|
||||
|
||||
def _from_data(self, data: ThreadMemberPayload):
|
||||
try:
|
||||
self.id = int(data['user_id'])
|
||||
self.id = int(data["user_id"])
|
||||
except KeyError:
|
||||
assert self._state.self_id is not None
|
||||
self.id = self._state.self_id
|
||||
|
||||
try:
|
||||
self.thread_id = int(data['id'])
|
||||
self.thread_id = int(data["id"])
|
||||
except KeyError:
|
||||
self.thread_id = self.parent.id
|
||||
|
||||
self.joined_at = parse_time(data['join_timestamp'])
|
||||
self.flags = data['flags']
|
||||
self.joined_at = parse_time(data["join_timestamp"])
|
||||
self.flags = data["flags"]
|
||||
|
||||
@property
|
||||
def thread(self) -> Thread:
|
||||
|
@ -29,7 +29,7 @@ from .user import PartialUser
|
||||
from .snowflake import Snowflake
|
||||
|
||||
|
||||
StatusType = Literal['idle', 'dnd', 'online', 'offline']
|
||||
StatusType = Literal["idle", "dnd", "online", "offline"]
|
||||
|
||||
|
||||
class PartialPresenceUpdate(TypedDict):
|
||||
|
@ -30,6 +30,7 @@ from .user import User
|
||||
from .team import Team
|
||||
from .snowflake import Snowflake
|
||||
|
||||
|
||||
class BaseAppInfo(TypedDict):
|
||||
id: Snowflake
|
||||
name: str
|
||||
@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict):
|
||||
summary: str
|
||||
description: str
|
||||
|
||||
|
||||
class _AppInfoOptional(TypedDict, total=False):
|
||||
team: Team
|
||||
guild_id: Snowflake
|
||||
@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False):
|
||||
hook: bool
|
||||
max_participants: int
|
||||
|
||||
|
||||
class AppInfo(BaseAppInfo, _AppInfoOptional):
|
||||
rpc_origins: List[str]
|
||||
owner: User
|
||||
bot_public: bool
|
||||
bot_require_code_grant: bool
|
||||
|
||||
|
||||
class _PartialAppInfoOptional(TypedDict, total=False):
|
||||
rpc_origins: List[str]
|
||||
cover_image: str
|
||||
@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
|
||||
max_participants: int
|
||||
flags: int
|
||||
|
||||
|
||||
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
|
||||
pass
|
||||
|
@ -26,7 +26,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, Optional, TypedDict, Union
|
||||
from .webhook import Webhook
|
||||
from .guild import MFALevel, VerificationLevel, ExplicitContentFilterLevel, DefaultMessageNotificationLevel
|
||||
from .guild import (
|
||||
MFALevel,
|
||||
VerificationLevel,
|
||||
ExplicitContentFilterLevel,
|
||||
DefaultMessageNotificationLevel,
|
||||
)
|
||||
from .integration import IntegrationExpireBehavior, PartialIntegration
|
||||
from .user import User
|
||||
from .snowflake import Snowflake
|
||||
@ -84,31 +89,47 @@ AuditLogEvent = Literal[
|
||||
|
||||
class _AuditLogChange_Str(TypedDict):
|
||||
key: Literal[
|
||||
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
|
||||
"name",
|
||||
"description",
|
||||
"preferred_locale",
|
||||
"vanity_url_code",
|
||||
"topic",
|
||||
"code",
|
||||
"allow",
|
||||
"deny",
|
||||
"permissions",
|
||||
"tags",
|
||||
]
|
||||
new_value: str
|
||||
old_value: str
|
||||
|
||||
|
||||
class _AuditLogChange_AssetHash(TypedDict):
|
||||
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
|
||||
key: Literal[
|
||||
"icon_hash",
|
||||
"splash_hash",
|
||||
"discovery_splash_hash",
|
||||
"banner_hash",
|
||||
"avatar_hash",
|
||||
"asset",
|
||||
]
|
||||
new_value: str
|
||||
old_value: str
|
||||
|
||||
|
||||
class _AuditLogChange_Snowflake(TypedDict):
|
||||
key: Literal[
|
||||
'id',
|
||||
'owner_id',
|
||||
'afk_channel_id',
|
||||
'rules_channel_id',
|
||||
'public_updates_channel_id',
|
||||
'widget_channel_id',
|
||||
'system_channel_id',
|
||||
'application_id',
|
||||
'channel_id',
|
||||
'inviter_id',
|
||||
'guild_id',
|
||||
"id",
|
||||
"owner_id",
|
||||
"afk_channel_id",
|
||||
"rules_channel_id",
|
||||
"public_updates_channel_id",
|
||||
"widget_channel_id",
|
||||
"system_channel_id",
|
||||
"application_id",
|
||||
"channel_id",
|
||||
"inviter_id",
|
||||
"guild_id",
|
||||
]
|
||||
new_value: Snowflake
|
||||
old_value: Snowflake
|
||||
@ -116,20 +137,20 @@ class _AuditLogChange_Snowflake(TypedDict):
|
||||
|
||||
class _AuditLogChange_Bool(TypedDict):
|
||||
key: Literal[
|
||||
'widget_enabled',
|
||||
'nsfw',
|
||||
'hoist',
|
||||
'mentionable',
|
||||
'temporary',
|
||||
'deaf',
|
||||
'mute',
|
||||
'nick',
|
||||
'enabled_emoticons',
|
||||
'region',
|
||||
'rtc_region',
|
||||
'available',
|
||||
'archived',
|
||||
'locked',
|
||||
"widget_enabled",
|
||||
"nsfw",
|
||||
"hoist",
|
||||
"mentionable",
|
||||
"temporary",
|
||||
"deaf",
|
||||
"mute",
|
||||
"nick",
|
||||
"enabled_emoticons",
|
||||
"region",
|
||||
"rtc_region",
|
||||
"available",
|
||||
"archived",
|
||||
"locked",
|
||||
]
|
||||
new_value: bool
|
||||
old_value: bool
|
||||
@ -137,72 +158,72 @@ class _AuditLogChange_Bool(TypedDict):
|
||||
|
||||
class _AuditLogChange_Int(TypedDict):
|
||||
key: Literal[
|
||||
'afk_timeout',
|
||||
'prune_delete_days',
|
||||
'position',
|
||||
'bitrate',
|
||||
'rate_limit_per_user',
|
||||
'color',
|
||||
'max_uses',
|
||||
'max_age',
|
||||
'user_limit',
|
||||
'auto_archive_duration',
|
||||
'default_auto_archive_duration',
|
||||
"afk_timeout",
|
||||
"prune_delete_days",
|
||||
"position",
|
||||
"bitrate",
|
||||
"rate_limit_per_user",
|
||||
"color",
|
||||
"max_uses",
|
||||
"max_age",
|
||||
"user_limit",
|
||||
"auto_archive_duration",
|
||||
"default_auto_archive_duration",
|
||||
]
|
||||
new_value: int
|
||||
old_value: int
|
||||
|
||||
|
||||
class _AuditLogChange_ListRole(TypedDict):
|
||||
key: Literal['$add', '$remove']
|
||||
key: Literal["$add", "$remove"]
|
||||
new_value: List[Role]
|
||||
old_value: List[Role]
|
||||
|
||||
|
||||
class _AuditLogChange_MFALevel(TypedDict):
|
||||
key: Literal['mfa_level']
|
||||
key: Literal["mfa_level"]
|
||||
new_value: MFALevel
|
||||
old_value: MFALevel
|
||||
|
||||
|
||||
class _AuditLogChange_VerificationLevel(TypedDict):
|
||||
key: Literal['verification_level']
|
||||
key: Literal["verification_level"]
|
||||
new_value: VerificationLevel
|
||||
old_value: VerificationLevel
|
||||
|
||||
|
||||
class _AuditLogChange_ExplicitContentFilter(TypedDict):
|
||||
key: Literal['explicit_content_filter']
|
||||
key: Literal["explicit_content_filter"]
|
||||
new_value: ExplicitContentFilterLevel
|
||||
old_value: ExplicitContentFilterLevel
|
||||
|
||||
|
||||
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
|
||||
key: Literal['default_message_notifications']
|
||||
key: Literal["default_message_notifications"]
|
||||
new_value: DefaultMessageNotificationLevel
|
||||
old_value: DefaultMessageNotificationLevel
|
||||
|
||||
|
||||
class _AuditLogChange_ChannelType(TypedDict):
|
||||
key: Literal['type']
|
||||
key: Literal["type"]
|
||||
new_value: ChannelType
|
||||
old_value: ChannelType
|
||||
|
||||
|
||||
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
|
||||
key: Literal['expire_behavior']
|
||||
key: Literal["expire_behavior"]
|
||||
new_value: IntegrationExpireBehavior
|
||||
old_value: IntegrationExpireBehavior
|
||||
|
||||
|
||||
class _AuditLogChange_VideoQualityMode(TypedDict):
|
||||
key: Literal['video_quality_mode']
|
||||
key: Literal["video_quality_mode"]
|
||||
new_value: VideoQualityMode
|
||||
old_value: VideoQualityMode
|
||||
|
||||
|
||||
class _AuditLogChange_Overwrites(TypedDict):
|
||||
key: Literal['permission_overwrites']
|
||||
key: Literal["permission_overwrites"]
|
||||
new_value: List[PermissionOverwrite]
|
||||
old_value: List[PermissionOverwrite]
|
||||
|
||||
@ -232,7 +253,7 @@ class AuditEntryInfo(TypedDict):
|
||||
message_id: Snowflake
|
||||
count: str
|
||||
id: Snowflake
|
||||
type: Literal['0', '1']
|
||||
type: Literal["0", "1"]
|
||||
role_name: str
|
||||
|
||||
|
||||
|
@ -128,7 +128,15 @@ class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
|
||||
thread_metadata: ThreadMetadata
|
||||
|
||||
|
||||
GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StoreChannel, StageChannel, ThreadChannel]
|
||||
GuildChannel = Union[
|
||||
TextChannel,
|
||||
NewsChannel,
|
||||
VoiceChannel,
|
||||
CategoryChannel,
|
||||
StoreChannel,
|
||||
StageChannel,
|
||||
ThreadChannel,
|
||||
]
|
||||
|
||||
|
||||
class DMChannel(_BaseChannel):
|
||||
|
@ -24,49 +24,60 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from typing import List, Literal, TypedDict
|
||||
|
||||
|
||||
class _EmbedFooterOptional(TypedDict, total=False):
|
||||
icon_url: str
|
||||
proxy_icon_url: str
|
||||
|
||||
|
||||
class EmbedFooter(_EmbedFooterOptional):
|
||||
text: str
|
||||
|
||||
|
||||
class _EmbedFieldOptional(TypedDict, total=False):
|
||||
inline: bool
|
||||
|
||||
|
||||
class EmbedField(_EmbedFieldOptional):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
class EmbedThumbnail(TypedDict, total=False):
|
||||
url: str
|
||||
proxy_url: str
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class EmbedVideo(TypedDict, total=False):
|
||||
url: str
|
||||
proxy_url: str
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class EmbedImage(TypedDict, total=False):
|
||||
url: str
|
||||
proxy_url: str
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class EmbedProvider(TypedDict, total=False):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class EmbedAuthor(TypedDict, total=False):
|
||||
name: str
|
||||
url: str
|
||||
icon_url: str
|
||||
proxy_icon_url: str
|
||||
|
||||
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
|
||||
|
||||
EmbedType = Literal["rich", "image", "video", "gifv", "article", "link"]
|
||||
|
||||
|
||||
class Embed(TypedDict, total=False):
|
||||
title: str
|
||||
|
@ -75,28 +75,28 @@ VerificationLevel = Literal[0, 1, 2, 3, 4]
|
||||
NSFWLevel = Literal[0, 1, 2, 3]
|
||||
PremiumTier = Literal[0, 1, 2, 3]
|
||||
GuildFeature = Literal[
|
||||
'ANIMATED_ICON',
|
||||
'BANNER',
|
||||
'COMMERCE',
|
||||
'COMMUNITY',
|
||||
'DISCOVERABLE',
|
||||
'FEATURABLE',
|
||||
'INVITE_SPLASH',
|
||||
'MEMBER_VERIFICATION_GATE_ENABLED',
|
||||
'MONETIZATION_ENABLED',
|
||||
'MORE_EMOJI',
|
||||
'MORE_STICKERS',
|
||||
'NEWS',
|
||||
'PARTNERED',
|
||||
'PREVIEW_ENABLED',
|
||||
'PRIVATE_THREADS',
|
||||
'SEVEN_DAY_THREAD_ARCHIVE',
|
||||
'THREE_DAY_THREAD_ARCHIVE',
|
||||
'TICKETED_EVENTS_ENABLED',
|
||||
'VANITY_URL',
|
||||
'VERIFIED',
|
||||
'VIP_REGIONS',
|
||||
'WELCOME_SCREEN_ENABLED',
|
||||
"ANIMATED_ICON",
|
||||
"BANNER",
|
||||
"COMMERCE",
|
||||
"COMMUNITY",
|
||||
"DISCOVERABLE",
|
||||
"FEATURABLE",
|
||||
"INVITE_SPLASH",
|
||||
"MEMBER_VERIFICATION_GATE_ENABLED",
|
||||
"MONETIZATION_ENABLED",
|
||||
"MORE_EMOJI",
|
||||
"MORE_STICKERS",
|
||||
"NEWS",
|
||||
"PARTNERED",
|
||||
"PREVIEW_ENABLED",
|
||||
"PRIVATE_THREADS",
|
||||
"SEVEN_DAY_THREAD_ARCHIVE",
|
||||
"THREE_DAY_THREAD_ARCHIVE",
|
||||
"TICKETED_EVENTS_ENABLED",
|
||||
"VANITY_URL",
|
||||
"VERIFIED",
|
||||
"VIP_REGIONS",
|
||||
"WELCOME_SCREEN_ENABLED",
|
||||
]
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ class PartialIntegration(TypedDict):
|
||||
account: IntegrationAccount
|
||||
|
||||
|
||||
IntegrationType = Literal['twitch', 'youtube', 'discord']
|
||||
IntegrationType = Literal["twitch", "youtube", "discord"]
|
||||
|
||||
|
||||
class BaseIntegration(PartialIntegration):
|
||||
|
@ -39,6 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
ApplicationCommandType = Literal[1, 2, 3]
|
||||
|
||||
|
||||
class _ApplicationCommandOptional(TypedDict, total=False):
|
||||
options: List[ApplicationCommandOption]
|
||||
type: ApplicationCommandType
|
||||
@ -100,32 +101,44 @@ class _ApplicationCommandInteractionDataOption(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionSubcommand(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionSubcommand(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[1, 2]
|
||||
options: List[ApplicationCommandInteractionDataOption]
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionString(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionString(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[3]
|
||||
value: str
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionInteger(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionInteger(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[4]
|
||||
value: int
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionBoolean(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionBoolean(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[5]
|
||||
value: bool
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionSnowflake(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[6, 7, 8, 9]
|
||||
value: Snowflake
|
||||
|
||||
|
||||
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
|
||||
class _ApplicationCommandInteractionDataOptionNumber(
|
||||
_ApplicationCommandInteractionDataOption
|
||||
):
|
||||
type: Literal[10]
|
||||
value: float
|
||||
|
||||
@ -222,9 +235,6 @@ class MessageInteraction(TypedDict):
|
||||
user: User
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class _EditApplicationCommandOptional(TypedDict, total=False):
|
||||
description: str
|
||||
options: Optional[List[ApplicationCommandOption]]
|
||||
|
@ -128,7 +128,7 @@ class Message(_MessageOptional):
|
||||
type: MessageType
|
||||
|
||||
|
||||
AllowedMentionType = Literal['roles', 'users', 'everyone']
|
||||
AllowedMentionType = Literal["roles", "users", "everyone"]
|
||||
|
||||
|
||||
class AllowedMentions(TypedDict):
|
||||
|
@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional
|
||||
from .user import PartialUser
|
||||
from .snowflake import Snowflake
|
||||
|
||||
|
||||
class TeamMember(TypedDict):
|
||||
user: PartialUser
|
||||
membership_state: int
|
||||
permissions: List[str]
|
||||
team_id: Snowflake
|
||||
|
||||
|
||||
class Team(TypedDict):
|
||||
id: Snowflake
|
||||
name: str
|
||||
|
@ -27,7 +27,9 @@ from .snowflake import Snowflake
|
||||
from .member import MemberWithUser
|
||||
|
||||
|
||||
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
|
||||
SupportedModes = Literal[
|
||||
"xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"
|
||||
]
|
||||
|
||||
|
||||
class _PartialVoiceStateOptional(TypedDict, total=False):
|
||||
|
@ -35,16 +35,16 @@ from ..partial_emoji import PartialEmoji, _EmojiTag
|
||||
from ..components import Button as ButtonComponent
|
||||
|
||||
__all__ = (
|
||||
'Button',
|
||||
'button',
|
||||
"Button",
|
||||
"button",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .view import View
|
||||
from ..emoji import Emoji
|
||||
|
||||
B = TypeVar('B', bound='Button')
|
||||
V = TypeVar('V', bound='View', covariant=True)
|
||||
B = TypeVar("B", bound="Button")
|
||||
V = TypeVar("V", bound="View", covariant=True)
|
||||
|
||||
|
||||
class Button(Item[V]):
|
||||
@ -76,12 +76,12 @@ class Button(Item[V]):
|
||||
"""
|
||||
|
||||
__item_repr_attributes__: Tuple[str, ...] = (
|
||||
'style',
|
||||
'url',
|
||||
'disabled',
|
||||
'label',
|
||||
'emoji',
|
||||
'row',
|
||||
"style",
|
||||
"url",
|
||||
"disabled",
|
||||
"label",
|
||||
"emoji",
|
||||
"row",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -97,7 +97,7 @@ class Button(Item[V]):
|
||||
):
|
||||
super().__init__()
|
||||
if custom_id is not None and url is not None:
|
||||
raise TypeError('cannot mix both url and custom_id with Button')
|
||||
raise TypeError("cannot mix both url and custom_id with Button")
|
||||
|
||||
self._provided_custom_id = custom_id is not None
|
||||
if url is None and custom_id is None:
|
||||
@ -112,7 +112,9 @@ class Button(Item[V]):
|
||||
elif isinstance(emoji, _EmojiTag):
|
||||
emoji = emoji._to_partial()
|
||||
else:
|
||||
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
|
||||
raise TypeError(
|
||||
f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}"
|
||||
)
|
||||
|
||||
self._underlying = ButtonComponent._raw_construct(
|
||||
type=ComponentType.button,
|
||||
@ -145,7 +147,7 @@ class Button(Item[V]):
|
||||
@custom_id.setter
|
||||
def custom_id(self, value: Optional[str]):
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('custom_id must be None or str')
|
||||
raise TypeError("custom_id must be None or str")
|
||||
|
||||
self._underlying.custom_id = value
|
||||
|
||||
@ -157,7 +159,7 @@ class Button(Item[V]):
|
||||
@url.setter
|
||||
def url(self, value: Optional[str]):
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('url must be None or str')
|
||||
raise TypeError("url must be None or str")
|
||||
self._underlying.url = value
|
||||
|
||||
@property
|
||||
@ -191,7 +193,9 @@ class Button(Item[V]):
|
||||
elif isinstance(value, _EmojiTag):
|
||||
self._underlying.emoji = value._to_partial()
|
||||
else:
|
||||
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
|
||||
raise TypeError(
|
||||
f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead"
|
||||
)
|
||||
else:
|
||||
self._underlying.emoji = None
|
||||
|
||||
@ -273,17 +277,17 @@ def button(
|
||||
|
||||
def decorator(func: ItemCallbackType) -> ItemCallbackType:
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise TypeError('button function must be a coroutine function')
|
||||
raise TypeError("button function must be a coroutine function")
|
||||
|
||||
func.__discord_ui_model_type__ = Button
|
||||
func.__discord_ui_model_kwargs__ = {
|
||||
'style': style,
|
||||
'custom_id': custom_id,
|
||||
'url': None,
|
||||
'disabled': disabled,
|
||||
'label': label,
|
||||
'emoji': emoji,
|
||||
'row': row,
|
||||
"style": style,
|
||||
"custom_id": custom_id,
|
||||
"url": None,
|
||||
"disabled": disabled,
|
||||
"label": label,
|
||||
"emoji": emoji,
|
||||
"row": row,
|
||||
}
|
||||
return func
|
||||
|
||||
|
@ -24,21 +24,30 @@ DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from ..interactions import Interaction
|
||||
|
||||
__all__ = (
|
||||
'Item',
|
||||
)
|
||||
__all__ = ("Item",)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..enums import ComponentType
|
||||
from .view import View
|
||||
from ..components import Component
|
||||
|
||||
I = TypeVar('I', bound='Item')
|
||||
V = TypeVar('V', bound='View', covariant=True)
|
||||
I = TypeVar("I", bound="Item")
|
||||
V = TypeVar("V", bound="View", covariant=True)
|
||||
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
@ -53,7 +62,7 @@ class Item(Generic[V]):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
__item_repr_attributes__: Tuple[str, ...] = ('row',)
|
||||
__item_repr_attributes__: Tuple[str, ...] = ("row",)
|
||||
|
||||
def __init__(self):
|
||||
self._view: Optional[V] = None
|
||||
@ -91,8 +100,10 @@ class Item(Generic[V]):
|
||||
return self._provided_custom_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
|
||||
return f'<{self.__class__.__name__} {attrs}>'
|
||||
attrs = " ".join(
|
||||
f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__
|
||||
)
|
||||
return f"<{self.__class__.__name__} {attrs}>"
|
||||
|
||||
@property
|
||||
def row(self) -> Optional[int]:
|
||||
@ -105,7 +116,7 @@ class Item(Generic[V]):
|
||||
elif 5 > value >= 0:
|
||||
self._row = value
|
||||
else:
|
||||
raise ValueError('row cannot be negative or greater than or equal to 5')
|
||||
raise ValueError("row cannot be negative or greater than or equal to 5")
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
|
@ -39,8 +39,8 @@ from ..components import (
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'Select',
|
||||
'select',
|
||||
"Select",
|
||||
"select",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -50,8 +50,8 @@ if TYPE_CHECKING:
|
||||
ComponentInteractionData,
|
||||
)
|
||||
|
||||
S = TypeVar('S', bound='Select')
|
||||
V = TypeVar('V', bound='View', covariant=True)
|
||||
S = TypeVar("S", bound="Select")
|
||||
V = TypeVar("V", bound="View", covariant=True)
|
||||
|
||||
|
||||
class Select(Item[V]):
|
||||
@ -89,11 +89,11 @@ class Select(Item[V]):
|
||||
"""
|
||||
|
||||
__item_repr_attributes__: Tuple[str, ...] = (
|
||||
'placeholder',
|
||||
'min_values',
|
||||
'max_values',
|
||||
'options',
|
||||
'disabled',
|
||||
"placeholder",
|
||||
"min_values",
|
||||
"max_values",
|
||||
"options",
|
||||
"disabled",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -131,7 +131,7 @@ class Select(Item[V]):
|
||||
@custom_id.setter
|
||||
def custom_id(self, value: str):
|
||||
if not isinstance(value, str):
|
||||
raise TypeError('custom_id must be None or str')
|
||||
raise TypeError("custom_id must be None or str")
|
||||
|
||||
self._underlying.custom_id = value
|
||||
|
||||
@ -143,7 +143,7 @@ class Select(Item[V]):
|
||||
@placeholder.setter
|
||||
def placeholder(self, value: Optional[str]):
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise TypeError('placeholder must be None or str')
|
||||
raise TypeError("placeholder must be None or str")
|
||||
|
||||
self._underlying.placeholder = value
|
||||
|
||||
@ -173,9 +173,9 @@ class Select(Item[V]):
|
||||
@options.setter
|
||||
def options(self, value: List[SelectOption]):
|
||||
if not isinstance(value, list):
|
||||
raise TypeError('options must be a list of SelectOption')
|
||||
raise TypeError("options must be a list of SelectOption")
|
||||
if not all(isinstance(obj, SelectOption) for obj in value):
|
||||
raise TypeError('all list items must subclass SelectOption')
|
||||
raise TypeError("all list items must subclass SelectOption")
|
||||
|
||||
self._underlying.options = value
|
||||
|
||||
@ -224,7 +224,6 @@ class Select(Item[V]):
|
||||
default=default,
|
||||
)
|
||||
|
||||
|
||||
self.append_option(option)
|
||||
|
||||
def append_option(self, option: SelectOption):
|
||||
@ -242,7 +241,7 @@ class Select(Item[V]):
|
||||
"""
|
||||
|
||||
if len(self._underlying.options) > 25:
|
||||
raise ValueError('maximum number of options already provided')
|
||||
raise ValueError("maximum number of options already provided")
|
||||
|
||||
self._underlying.options.append(option)
|
||||
|
||||
@ -272,7 +271,7 @@ class Select(Item[V]):
|
||||
|
||||
def refresh_state(self, interaction: Interaction) -> None:
|
||||
data: ComponentInteractionData = interaction.data # type: ignore
|
||||
self._selected_values = data.get('values', [])
|
||||
self._selected_values = data.get("values", [])
|
||||
|
||||
@classmethod
|
||||
def from_component(cls: Type[S], component: SelectMenu) -> S:
|
||||
@ -340,17 +339,17 @@ def select(
|
||||
|
||||
def decorator(func: ItemCallbackType) -> ItemCallbackType:
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
raise TypeError('select function must be a coroutine function')
|
||||
raise TypeError("select function must be a coroutine function")
|
||||
|
||||
func.__discord_ui_model_type__ = Select
|
||||
func.__discord_ui_model_kwargs__ = {
|
||||
'placeholder': placeholder,
|
||||
'custom_id': custom_id,
|
||||
'row': row,
|
||||
'min_values': min_values,
|
||||
'max_values': max_values,
|
||||
'options': options,
|
||||
'disabled': disabled,
|
||||
"placeholder": placeholder,
|
||||
"custom_id": custom_id,
|
||||
"row": row,
|
||||
"min_values": min_values,
|
||||
"max_values": max_values,
|
||||
"options": options,
|
||||
"disabled": disabled,
|
||||
}
|
||||
return func
|
||||
|
||||
|
@ -23,7 +23,18 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
)
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
|
||||
@ -41,9 +52,7 @@ from ..components import (
|
||||
SelectMenu as SelectComponent,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'View',
|
||||
)
|
||||
__all__ = ("View",)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -74,9 +83,7 @@ def _component_to_item(component: Component) -> Item:
|
||||
|
||||
|
||||
class _ViewWeights:
|
||||
__slots__ = (
|
||||
'weights',
|
||||
)
|
||||
__slots__ = ("weights",)
|
||||
|
||||
def __init__(self, children: List[Item]):
|
||||
self.weights: List[int] = [0, 0, 0, 0, 0]
|
||||
@ -92,13 +99,15 @@ class _ViewWeights:
|
||||
if weight + item.width <= 5:
|
||||
return index
|
||||
|
||||
raise ValueError('could not find open space for item')
|
||||
raise ValueError("could not find open space for item")
|
||||
|
||||
def add_item(self, item: Item) -> None:
|
||||
if item.row is not None:
|
||||
total = self.weights[item.row] + item.width
|
||||
if total > 5:
|
||||
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
|
||||
raise ValueError(
|
||||
f"item would not fit at row {item.row} ({total} > 5 width)"
|
||||
)
|
||||
self.weights[item.row] = total
|
||||
item._rendered_row = item.row
|
||||
else:
|
||||
@ -144,11 +153,11 @@ class View:
|
||||
children: List[ItemCallbackType] = []
|
||||
for base in reversed(cls.__mro__):
|
||||
for member in base.__dict__.values():
|
||||
if hasattr(member, '__discord_ui_model_type__'):
|
||||
if hasattr(member, "__discord_ui_model_type__"):
|
||||
children.append(member)
|
||||
|
||||
if len(children) > 25:
|
||||
raise TypeError('View cannot have more than 25 children')
|
||||
raise TypeError("View cannot have more than 25 children")
|
||||
|
||||
cls.__view_children_items__ = children
|
||||
|
||||
@ -156,7 +165,9 @@ class View:
|
||||
self.timeout = timeout
|
||||
self.children: List[Item] = []
|
||||
for func in self.__view_children_items__:
|
||||
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
|
||||
item: Item = func.__discord_ui_model_type__(
|
||||
**func.__discord_ui_model_kwargs__
|
||||
)
|
||||
item.callback = partial(func, self, item)
|
||||
item._view = self
|
||||
setattr(self, func.__name__, item)
|
||||
@ -171,7 +182,7 @@ class View:
|
||||
self.__stopped: asyncio.Future[bool] = loop.create_future()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
|
||||
return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>"
|
||||
|
||||
async def __timeout_task_impl(self) -> None:
|
||||
while True:
|
||||
@ -203,15 +214,17 @@ class View:
|
||||
|
||||
components.append(
|
||||
{
|
||||
'type': 1,
|
||||
'components': children,
|
||||
"type": 1,
|
||||
"components": children,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View:
|
||||
def from_message(
|
||||
cls, message: Message, /, *, timeout: Optional[float] = 180.0
|
||||
) -> View:
|
||||
"""Converts a message's components into a :class:`View`.
|
||||
|
||||
The :attr:`.Message.components` of a message are read-only
|
||||
@ -261,10 +274,10 @@ class View:
|
||||
"""
|
||||
|
||||
if len(self.children) > 25:
|
||||
raise ValueError('maximum number of children exceeded')
|
||||
raise ValueError("maximum number of children exceeded")
|
||||
|
||||
if not isinstance(item, Item):
|
||||
raise TypeError(f'expected Item not {item.__class__!r}')
|
||||
raise TypeError(f"expected Item not {item.__class__!r}")
|
||||
|
||||
self.__weights.add_item(item)
|
||||
|
||||
@ -327,7 +340,9 @@ class View:
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
|
||||
async def on_error(
|
||||
self, error: Exception, item: Item, interaction: Interaction
|
||||
) -> None:
|
||||
"""|coro|
|
||||
|
||||
A callback that is called when an item's callback or :meth:`interaction_check`
|
||||
@ -344,8 +359,10 @@ class View:
|
||||
interaction: :class:`~discord.Interaction`
|
||||
The interaction that led to the failure.
|
||||
"""
|
||||
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
|
||||
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
|
||||
print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
|
||||
traceback.print_exception(
|
||||
error.__class__, error, error.__traceback__, file=sys.stderr
|
||||
)
|
||||
|
||||
async def _scheduled_task(self, item: Item, interaction: Interaction):
|
||||
try:
|
||||
@ -377,13 +394,18 @@ class View:
|
||||
return
|
||||
|
||||
self.__stopped.set_result(True)
|
||||
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
|
||||
asyncio.create_task(
|
||||
self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}"
|
||||
)
|
||||
|
||||
def _dispatch_item(self, item: Item, interaction: Interaction):
|
||||
if self.__stopped.done():
|
||||
return
|
||||
|
||||
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
|
||||
asyncio.create_task(
|
||||
self._scheduled_task(item, interaction),
|
||||
name=f"discord-ui-view-dispatch-{self.id}",
|
||||
)
|
||||
|
||||
def refresh(self, components: List[Component]):
|
||||
# This is pretty hacky at the moment
|
||||
@ -437,7 +459,9 @@ class View:
|
||||
A persistent view has all their components with a set ``custom_id`` and
|
||||
a :attr:`timeout` set to ``None``.
|
||||
"""
|
||||
return self.timeout is None and all(item.is_persistent() for item in self.children)
|
||||
return self.timeout is None and all(
|
||||
item.is_persistent() for item in self.children
|
||||
)
|
||||
|
||||
async def wait(self) -> bool:
|
||||
"""Waits until the view has finished interacting.
|
||||
@ -509,7 +533,9 @@ class ViewStore:
|
||||
key = (component_type, message_id, custom_id)
|
||||
# Fallback to None message_id searches in case a persistent view
|
||||
# was added without an associated message_id
|
||||
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
|
||||
value = self._views.get(key) or self._views.get(
|
||||
(component_type, None, custom_id)
|
||||
)
|
||||
if value is None:
|
||||
return
|
||||
|
||||
|
@ -45,11 +45,11 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'User',
|
||||
'ClientUser',
|
||||
"User",
|
||||
"ClientUser",
|
||||
)
|
||||
|
||||
BU = TypeVar('BU', bound='BaseUser')
|
||||
BU = TypeVar("BU", bound="BaseUser")
|
||||
|
||||
|
||||
class _UserTag:
|
||||
@ -59,16 +59,16 @@ class _UserTag:
|
||||
|
||||
class BaseUser(_UserTag):
|
||||
__slots__ = (
|
||||
'name',
|
||||
'id',
|
||||
'discriminator',
|
||||
'_avatar',
|
||||
'_banner',
|
||||
'_accent_colour',
|
||||
'bot',
|
||||
'system',
|
||||
'_public_flags',
|
||||
'_state',
|
||||
"name",
|
||||
"id",
|
||||
"discriminator",
|
||||
"_avatar",
|
||||
"_banner",
|
||||
"_accent_colour",
|
||||
"bot",
|
||||
"system",
|
||||
"_public_flags",
|
||||
"_state",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -94,7 +94,7 @@ class BaseUser(_UserTag):
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name}#{self.discriminator}'
|
||||
return f"{self.name}#{self.discriminator}"
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.id
|
||||
@ -109,15 +109,15 @@ class BaseUser(_UserTag):
|
||||
return self.id >> 22
|
||||
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
self.name = data['username']
|
||||
self.id = int(data['id'])
|
||||
self.discriminator = data['discriminator']
|
||||
self._avatar = data['avatar']
|
||||
self._banner = data.get('banner', None)
|
||||
self._accent_colour = data.get('accent_color', None)
|
||||
self._public_flags = data.get('public_flags', 0)
|
||||
self.bot = data.get('bot', False)
|
||||
self.system = data.get('system', False)
|
||||
self.name = data["username"]
|
||||
self.id = int(data["id"])
|
||||
self.discriminator = data["discriminator"]
|
||||
self._avatar = data["avatar"]
|
||||
self._banner = data.get("banner", None)
|
||||
self._accent_colour = data.get("accent_color", None)
|
||||
self._public_flags = data.get("public_flags", 0)
|
||||
self.bot = data.get("bot", False)
|
||||
self.system = data.get("system", False)
|
||||
|
||||
@classmethod
|
||||
def _copy(cls: Type[BU], user: BU) -> BU:
|
||||
@ -137,11 +137,11 @@ class BaseUser(_UserTag):
|
||||
|
||||
def _to_minimal_user_json(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'username': self.name,
|
||||
'id': self.id,
|
||||
'avatar': self._avatar,
|
||||
'discriminator': self.discriminator,
|
||||
'bot': self.bot,
|
||||
"username": self.name,
|
||||
"id": self.id,
|
||||
"avatar": self._avatar,
|
||||
"discriminator": self.discriminator,
|
||||
"bot": self.bot,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -163,7 +163,9 @@ class BaseUser(_UserTag):
|
||||
@property
|
||||
def default_avatar(self) -> Asset:
|
||||
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
|
||||
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
|
||||
return Asset._from_default_avatar(
|
||||
self._state, int(self.discriminator) % len(DefaultAvatar)
|
||||
)
|
||||
|
||||
@property
|
||||
def display_avatar(self) -> Asset:
|
||||
@ -240,7 +242,7 @@ class BaseUser(_UserTag):
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: Returns a string that allows you to mention the given user."""
|
||||
return f'<@{self.id}>'
|
||||
return f"<@{self.id}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime:
|
||||
@ -324,7 +326,7 @@ class ClientUser(BaseUser):
|
||||
Specifies if the user has MFA turned on and working.
|
||||
"""
|
||||
|
||||
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
|
||||
__slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
verified: bool
|
||||
@ -337,19 +339,21 @@ class ClientUser(BaseUser):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
|
||||
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
|
||||
f"<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
|
||||
f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>"
|
||||
)
|
||||
|
||||
def _update(self, data: UserPayload) -> None:
|
||||
super()._update(data)
|
||||
# There's actually an Optional[str] phone field as well but I won't use it
|
||||
self.verified = data.get('verified', False)
|
||||
self.locale = data.get('locale')
|
||||
self._flags = data.get('flags', 0)
|
||||
self.mfa_enabled = data.get('mfa_enabled', False)
|
||||
self.verified = data.get("verified", False)
|
||||
self.locale = data.get("locale")
|
||||
self._flags = data.get("flags", 0)
|
||||
self.mfa_enabled = data.get("mfa_enabled", False)
|
||||
|
||||
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
|
||||
async def edit(
|
||||
self, *, username: str = MISSING, avatar: bytes = MISSING
|
||||
) -> ClientUser:
|
||||
"""|coro|
|
||||
|
||||
Edits the current profile of the client.
|
||||
@ -388,10 +392,10 @@ class ClientUser(BaseUser):
|
||||
"""
|
||||
payload: Dict[str, Any] = {}
|
||||
if username is not MISSING:
|
||||
payload['username'] = username
|
||||
payload["username"] = username
|
||||
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = _bytes_to_base64_data(avatar)
|
||||
payload["avatar"] = _bytes_to_base64_data(avatar)
|
||||
|
||||
data: UserPayload = await self._state.http.edit_profile(payload)
|
||||
return ClientUser(state=self._state, data=data)
|
||||
@ -436,14 +440,14 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
Specifies if the user is a system user (i.e. represents Discord officially).
|
||||
"""
|
||||
|
||||
__slots__ = ('_stored',)
|
||||
__slots__ = ("_stored",)
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
self._stored: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
|
||||
return f"<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>"
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
@ -481,7 +485,9 @@ class User(BaseUser, discord.abc.Messageable):
|
||||
|
||||
.. versionadded:: 1.7
|
||||
"""
|
||||
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
|
||||
return [
|
||||
guild for guild in self._state._guilds.values() if guild.get_member(self.id)
|
||||
]
|
||||
|
||||
async def create_dm(self) -> DMChannel:
|
||||
"""|coro|
|
||||
|
185
discord/utils.py
185
discord/utils.py
@ -72,18 +72,18 @@ else:
|
||||
|
||||
|
||||
__all__ = (
|
||||
'oauth_url',
|
||||
'snowflake_time',
|
||||
'time_snowflake',
|
||||
'find',
|
||||
'get',
|
||||
'sleep_until',
|
||||
'utcnow',
|
||||
'remove_markdown',
|
||||
'escape_markdown',
|
||||
'escape_mentions',
|
||||
'as_chunks',
|
||||
'format_dt',
|
||||
"oauth_url",
|
||||
"snowflake_time",
|
||||
"time_snowflake",
|
||||
"find",
|
||||
"get",
|
||||
"sleep_until",
|
||||
"utcnow",
|
||||
"remove_markdown",
|
||||
"escape_markdown",
|
||||
"escape_mentions",
|
||||
"as_chunks",
|
||||
"format_dt",
|
||||
)
|
||||
|
||||
DISCORD_EPOCH = 1420070400000
|
||||
@ -97,7 +97,7 @@ class _MissingSentinel:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return '...'
|
||||
return "..."
|
||||
|
||||
|
||||
MISSING: Any = _MissingSentinel()
|
||||
@ -106,7 +106,7 @@ MISSING: Any = _MissingSentinel()
|
||||
class _cached_property:
|
||||
def __init__(self, function):
|
||||
self.function = function
|
||||
self.__doc__ = getattr(function, '__doc__')
|
||||
self.__doc__ = getattr(function, "__doc__")
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
@ -131,15 +131,14 @@ if TYPE_CHECKING:
|
||||
class _RequestLike(Protocol):
|
||||
headers: Mapping[str, Any]
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
P = ParamSpec("P")
|
||||
|
||||
else:
|
||||
cached_property = _cached_property
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
_Iter = Union[Iterator[T], AsyncIterator[T]]
|
||||
|
||||
|
||||
@ -147,7 +146,7 @@ class CachedSlotProperty(Generic[T, T_co]):
|
||||
def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
|
||||
self.name = name
|
||||
self.function = function
|
||||
self.__doc__ = getattr(function, '__doc__')
|
||||
self.__doc__ = getattr(function, "__doc__")
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
|
||||
@ -177,10 +176,12 @@ class classproperty(Generic[T_co]):
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance, value) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
raise AttributeError("cannot set attribute")
|
||||
|
||||
|
||||
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
|
||||
def cached_slot_property(
|
||||
name: str,
|
||||
) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
|
||||
def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
|
||||
return CachedSlotProperty(name, func)
|
||||
|
||||
@ -245,18 +246,22 @@ def copy_doc(original: Callable) -> Callable[[T], T]:
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
def deprecated(
|
||||
instead: Optional[str] = None,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
warnings.simplefilter('always', DeprecationWarning) # turn off filter
|
||||
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
||||
if instead:
|
||||
fmt = "{0.__name__} is deprecated, use {1} instead."
|
||||
else:
|
||||
fmt = '{0.__name__} is deprecated.'
|
||||
fmt = "{0.__name__} is deprecated."
|
||||
|
||||
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
|
||||
warnings.simplefilter('default', DeprecationWarning) # reset filter
|
||||
warnings.warn(
|
||||
fmt.format(func, instead), stacklevel=3, category=DeprecationWarning
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@ -301,18 +306,18 @@ def oauth_url(
|
||||
:class:`str`
|
||||
The OAuth2 URL for inviting the bot into guilds.
|
||||
"""
|
||||
url = f'https://discord.com/oauth2/authorize?client_id={client_id}'
|
||||
url += '&scope=' + '+'.join(scopes or ('bot',))
|
||||
url = f"https://discord.com/oauth2/authorize?client_id={client_id}"
|
||||
url += "&scope=" + "+".join(scopes or ("bot",))
|
||||
if permissions is not MISSING:
|
||||
url += f'&permissions={permissions.value}'
|
||||
url += f"&permissions={permissions.value}"
|
||||
if guild is not MISSING:
|
||||
url += f'&guild_id={guild.id}'
|
||||
url += f"&guild_id={guild.id}"
|
||||
if redirect_uri is not MISSING:
|
||||
from urllib.parse import urlencode
|
||||
|
||||
url += '&response_type=code&' + urlencode({'redirect_uri': redirect_uri})
|
||||
url += "&response_type=code&" + urlencode({"redirect_uri": redirect_uri})
|
||||
if disable_guild_select:
|
||||
url += '&disable_guild_select=true'
|
||||
url += "&disable_guild_select=true"
|
||||
return url
|
||||
|
||||
|
||||
@ -435,13 +440,15 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
|
||||
# Special case the single element call
|
||||
if len(attrs) == 1:
|
||||
k, v = attrs.popitem()
|
||||
pred = attrget(k.replace('__', '.'))
|
||||
pred = attrget(k.replace("__", "."))
|
||||
for elem in iterable:
|
||||
if pred(elem) == v:
|
||||
return elem
|
||||
return None
|
||||
|
||||
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
|
||||
converted = [
|
||||
(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()
|
||||
]
|
||||
|
||||
for elem in iterable:
|
||||
if _all(pred(elem) == value for pred, value in converted):
|
||||
@ -463,46 +470,48 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
|
||||
|
||||
|
||||
def _get_mime_type_for_image(data: bytes):
|
||||
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
|
||||
return 'image/png'
|
||||
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
|
||||
return 'image/jpeg'
|
||||
elif data.startswith((b'\x47\x49\x46\x38\x37\x61', b'\x47\x49\x46\x38\x39\x61')):
|
||||
return 'image/gif'
|
||||
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP':
|
||||
return 'image/webp'
|
||||
if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"):
|
||||
return "image/png"
|
||||
elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
|
||||
return "image/jpeg"
|
||||
elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")):
|
||||
return "image/gif"
|
||||
elif data.startswith(b"RIFF") and data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
else:
|
||||
raise InvalidArgument('Unsupported image type given')
|
||||
raise InvalidArgument("Unsupported image type given")
|
||||
|
||||
|
||||
def _bytes_to_base64_data(data: bytes) -> str:
|
||||
fmt = 'data:{mime};base64,{data}'
|
||||
fmt = "data:{mime};base64,{data}"
|
||||
mime = _get_mime_type_for_image(data)
|
||||
b64 = b64encode(data).decode('ascii')
|
||||
b64 = b64encode(data).decode("ascii")
|
||||
return fmt.format(mime=mime, data=b64)
|
||||
|
||||
|
||||
if HAS_ORJSON:
|
||||
|
||||
def _to_json(obj: Any) -> str: # type: ignore
|
||||
return orjson.dumps(obj).decode('utf-8')
|
||||
return orjson.dumps(obj).decode("utf-8")
|
||||
|
||||
_from_json = orjson.loads # type: ignore
|
||||
|
||||
else:
|
||||
|
||||
def _to_json(obj: Any) -> str:
|
||||
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
|
||||
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
|
||||
|
||||
_from_json = json.loads
|
||||
|
||||
|
||||
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
|
||||
reset_after: Optional[str] = request.headers.get("X-Ratelimit-Reset-After")
|
||||
if use_clock or not reset_after:
|
||||
utc = datetime.timezone.utc
|
||||
now = datetime.datetime.now(utc)
|
||||
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
|
||||
reset = datetime.datetime.fromtimestamp(
|
||||
float(request.headers["X-Ratelimit-Reset"]), utc
|
||||
)
|
||||
return (reset - now).total_seconds()
|
||||
else:
|
||||
return float(reset_after)
|
||||
@ -527,7 +536,9 @@ async def async_all(gen, *, check=_isawaitable):
|
||||
|
||||
async def sane_wait_for(futures, *, timeout):
|
||||
ensured = [asyncio.ensure_future(fut) for fut in futures]
|
||||
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
|
||||
done, pending = await asyncio.wait(
|
||||
ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED
|
||||
)
|
||||
|
||||
if len(pending) != 0:
|
||||
raise asyncio.TimeoutError()
|
||||
@ -550,7 +561,9 @@ def compute_timedelta(dt: datetime.datetime):
|
||||
return max((dt - now).total_seconds(), 0)
|
||||
|
||||
|
||||
async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]:
|
||||
async def sleep_until(
|
||||
when: datetime.datetime, result: Optional[T] = None
|
||||
) -> Optional[T]:
|
||||
"""|coro|
|
||||
|
||||
Sleep until a specified time.
|
||||
@ -612,7 +625,7 @@ class SnowflakeList(array.array):
|
||||
...
|
||||
|
||||
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
|
||||
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
|
||||
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
|
||||
|
||||
def add(self, element: int) -> None:
|
||||
i = bisect_left(self, element)
|
||||
@ -627,7 +640,7 @@ class SnowflakeList(array.array):
|
||||
return i != len(self) and self[i] == element
|
||||
|
||||
|
||||
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
|
||||
_IS_ASCII = re.compile(r"^[\x00-\x7f]+$")
|
||||
|
||||
|
||||
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
|
||||
@ -636,7 +649,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
|
||||
if match:
|
||||
return match.endpos
|
||||
|
||||
UNICODE_WIDE_CHAR_TYPE = 'WFA'
|
||||
UNICODE_WIDE_CHAR_TYPE = "WFA"
|
||||
func = unicodedata.east_asian_width
|
||||
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
|
||||
|
||||
@ -660,7 +673,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
|
||||
if isinstance(invite, Invite):
|
||||
return invite.code
|
||||
else:
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
|
||||
rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)"
|
||||
m = re.match(rx, invite)
|
||||
if m:
|
||||
return m.group(1)
|
||||
@ -688,22 +701,27 @@ def resolve_template(code: Union[Template, str]) -> str:
|
||||
if isinstance(code, Template):
|
||||
return code.code
|
||||
else:
|
||||
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
|
||||
rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)"
|
||||
m = re.match(rx, code)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return code
|
||||
|
||||
|
||||
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|'))
|
||||
_MARKDOWN_ESCAPE_SUBREGEX = "|".join(
|
||||
r"\{0}(?=([\s\S]*((?<!\{0})\{0})))".format(c) for c in ("*", "`", "_", "~", "|")
|
||||
)
|
||||
|
||||
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
|
||||
_MARKDOWN_ESCAPE_COMMON = r"^>(?:>>)?\s|\[.+\]\(.+\)"
|
||||
|
||||
_MARKDOWN_ESCAPE_REGEX = re.compile(fr'(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})', re.MULTILINE)
|
||||
_MARKDOWN_ESCAPE_REGEX = re.compile(
|
||||
fr"(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])'
|
||||
_URL_REGEX = r"(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])"
|
||||
|
||||
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
|
||||
_MARKDOWN_STOCK_REGEX = fr"(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})"
|
||||
|
||||
|
||||
def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
|
||||
@ -732,15 +750,17 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
|
||||
|
||||
def replacement(match):
|
||||
groupdict = match.groupdict()
|
||||
return groupdict.get('url', '')
|
||||
return groupdict.get("url", "")
|
||||
|
||||
regex = _MARKDOWN_STOCK_REGEX
|
||||
if ignore_links:
|
||||
regex = f'(?:{_URL_REGEX}|{regex})'
|
||||
regex = f"(?:{_URL_REGEX}|{regex})"
|
||||
return re.sub(regex, replacement, text, 0, re.MULTILINE)
|
||||
|
||||
|
||||
def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str:
|
||||
def escape_markdown(
|
||||
text: str, *, as_needed: bool = False, ignore_links: bool = True
|
||||
) -> str:
|
||||
r"""A helper function that escapes Discord's markdown.
|
||||
|
||||
Parameters
|
||||
@ -769,18 +789,18 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
|
||||
|
||||
def replacement(match):
|
||||
groupdict = match.groupdict()
|
||||
is_url = groupdict.get('url')
|
||||
is_url = groupdict.get("url")
|
||||
if is_url:
|
||||
return is_url
|
||||
return '\\' + groupdict['markdown']
|
||||
return "\\" + groupdict["markdown"]
|
||||
|
||||
regex = _MARKDOWN_STOCK_REGEX
|
||||
if ignore_links:
|
||||
regex = f'(?:{_URL_REGEX}|{regex})'
|
||||
regex = f"(?:{_URL_REGEX}|{regex})"
|
||||
return re.sub(regex, replacement, text, 0, re.MULTILINE)
|
||||
else:
|
||||
text = re.sub(r'\\', r'\\\\', text)
|
||||
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
|
||||
text = re.sub(r"\\", r"\\\\", text)
|
||||
return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text)
|
||||
|
||||
|
||||
def escape_mentions(text: str) -> str:
|
||||
@ -806,7 +826,7 @@ def escape_mentions(text: str) -> str:
|
||||
:class:`str`
|
||||
The text with the mentions removed.
|
||||
"""
|
||||
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text)
|
||||
return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text)
|
||||
|
||||
|
||||
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
|
||||
@ -870,7 +890,7 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
|
||||
A new iterator which yields chunks of a given size.
|
||||
"""
|
||||
if max_size <= 0:
|
||||
raise ValueError('Chunk sizes must be greater than 0.')
|
||||
raise ValueError("Chunk sizes must be greater than 0.")
|
||||
|
||||
if isinstance(iterator, AsyncIterator):
|
||||
return _achunk(iterator, max_size)
|
||||
@ -916,11 +936,11 @@ def evaluate_annotation(
|
||||
cache[tp] = evaluated
|
||||
return evaluate_annotation(evaluated, globals, locals, cache)
|
||||
|
||||
if hasattr(tp, '__args__'):
|
||||
if hasattr(tp, "__args__"):
|
||||
implicit_str = True
|
||||
is_literal = False
|
||||
args = tp.__args__
|
||||
if not hasattr(tp, '__origin__'):
|
||||
if not hasattr(tp, "__origin__"):
|
||||
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
|
||||
converted = Union[args] # type: ignore
|
||||
return evaluate_annotation(converted, globals, locals, cache)
|
||||
@ -938,10 +958,17 @@ def evaluate_annotation(
|
||||
implicit_str = False
|
||||
is_literal = True
|
||||
|
||||
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
|
||||
evaluated_args = tuple(
|
||||
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
|
||||
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.')
|
||||
if is_literal and not all(
|
||||
isinstance(x, (str, int, bool, type(None))) for x in evaluated_args
|
||||
):
|
||||
raise TypeError(
|
||||
"Literal arguments must be of type str, int, bool, or NoneType."
|
||||
)
|
||||
|
||||
if evaluated_args == args:
|
||||
return tp
|
||||
@ -971,7 +998,7 @@ def resolve_annotation(
|
||||
return evaluate_annotation(annotation, globalns, locals, cache)
|
||||
|
||||
|
||||
TimestampStyle = Literal['f', 'F', 'd', 'D', 't', 'T', 'R']
|
||||
TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
|
||||
|
||||
|
||||
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
|
||||
@ -1015,5 +1042,5 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
|
||||
The formatted string.
|
||||
"""
|
||||
if style is None:
|
||||
return f'<t:{int(dt.timestamp())}>'
|
||||
return f'<t:{int(dt.timestamp())}:{style}>'
|
||||
return f"<t:{int(dt.timestamp())}>"
|
||||
return f"<t:{int(dt.timestamp())}:{style}>"
|
||||
|
@ -66,26 +66,26 @@ if TYPE_CHECKING:
|
||||
VoiceServerUpdate as VoiceServerUpdatePayload,
|
||||
SupportedModes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
has_nacl: bool
|
||||
|
||||
try:
|
||||
import nacl.secret # type: ignore
|
||||
|
||||
has_nacl = True
|
||||
except ImportError:
|
||||
has_nacl = False
|
||||
|
||||
__all__ = (
|
||||
'VoiceProtocol',
|
||||
'VoiceClient',
|
||||
"VoiceProtocol",
|
||||
"VoiceClient",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VoiceProtocol:
|
||||
"""A class that represents the Discord voice protocol.
|
||||
|
||||
@ -195,6 +195,7 @@ class VoiceProtocol:
|
||||
key_id, _ = self.channel._get_voice_client_key()
|
||||
self.client._connection._remove_voice_client(key_id)
|
||||
|
||||
|
||||
class VoiceClient(VoiceProtocol):
|
||||
"""Represents a Discord voice connection.
|
||||
|
||||
@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol):
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
The event loop that the voice client is running on.
|
||||
"""
|
||||
|
||||
endpoint_ip: str
|
||||
voice_port: int
|
||||
secret_key: List[int]
|
||||
ssrc: int
|
||||
|
||||
|
||||
def __init__(self, client: Client, channel: abc.Connectable):
|
||||
if not has_nacl:
|
||||
raise RuntimeError("PyNaCl library needed in order to use voice")
|
||||
@ -258,15 +259,15 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
warn_nacl = not has_nacl
|
||||
supported_modes: Tuple[SupportedModes, ...] = (
|
||||
'xsalsa20_poly1305_lite',
|
||||
'xsalsa20_poly1305_suffix',
|
||||
'xsalsa20_poly1305',
|
||||
"xsalsa20_poly1305_lite",
|
||||
"xsalsa20_poly1305_suffix",
|
||||
"xsalsa20_poly1305",
|
||||
)
|
||||
|
||||
@property
|
||||
def guild(self) -> Optional[Guild]:
|
||||
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
|
||||
return getattr(self.channel, 'guild', None)
|
||||
return getattr(self.channel, "guild", None)
|
||||
|
||||
@property
|
||||
def user(self) -> ClientUser:
|
||||
@ -283,8 +284,8 @@ class VoiceClient(VoiceProtocol):
|
||||
# connection related
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
self.session_id = data['session_id']
|
||||
channel_id = data['channel_id']
|
||||
self.session_id = data["session_id"]
|
||||
channel_id = data["channel_id"]
|
||||
|
||||
if not self._handshaking or self._potentially_reconnecting:
|
||||
# If we're done handshaking then we just need to update ourselves
|
||||
@ -301,20 +302,22 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
if self._voice_server_complete.is_set():
|
||||
_log.info('Ignoring extraneous voice server update.')
|
||||
_log.info("Ignoring extraneous voice server update.")
|
||||
return
|
||||
|
||||
self.token = data.get('token')
|
||||
self.server_id = int(data['guild_id'])
|
||||
endpoint = data.get('endpoint')
|
||||
self.token = data.get("token")
|
||||
self.server_id = int(data["guild_id"])
|
||||
endpoint = data.get("endpoint")
|
||||
|
||||
if endpoint is None or self.token is None:
|
||||
_log.warning('Awaiting endpoint... This requires waiting. ' \
|
||||
'If timeout occurred considering raising the timeout and reconnecting.')
|
||||
_log.warning(
|
||||
"Awaiting endpoint... This requires waiting. "
|
||||
"If timeout occurred considering raising the timeout and reconnecting."
|
||||
)
|
||||
return
|
||||
|
||||
self.endpoint, _, _ = endpoint.rpartition(':')
|
||||
if self.endpoint.startswith('wss://'):
|
||||
self.endpoint, _, _ = endpoint.rpartition(":")
|
||||
if self.endpoint.startswith("wss://"):
|
||||
# Just in case, strip it off since we're going to add it later
|
||||
self.endpoint = self.endpoint[6:]
|
||||
|
||||
@ -335,18 +338,24 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.channel.guild.change_voice_state(channel=self.channel)
|
||||
|
||||
async def voice_disconnect(self) -> None:
|
||||
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
_log.info(
|
||||
"The voice handshake is being terminated for Channel ID %s (Guild ID %s)",
|
||||
self.channel.id,
|
||||
self.guild.id,
|
||||
)
|
||||
await self.channel.guild.change_voice_state(channel=None)
|
||||
|
||||
def prepare_handshake(self) -> None:
|
||||
self._voice_state_complete.clear()
|
||||
self._voice_server_complete.clear()
|
||||
self._handshaking = True
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
_log.info(
|
||||
"Starting voice handshake... (connection attempt %d)", self._connections + 1
|
||||
)
|
||||
self._connections += 1
|
||||
|
||||
def finish_handshake(self) -> None:
|
||||
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
_log.info("Voice handshake complete. Endpoint found %s", self.endpoint)
|
||||
self._handshaking = False
|
||||
self._voice_server_complete.clear()
|
||||
self._voice_state_complete.clear()
|
||||
@ -359,8 +368,8 @@ class VoiceClient(VoiceProtocol):
|
||||
self._connected.set()
|
||||
return ws
|
||||
|
||||
async def connect(self, *, reconnect: bool, timeout: float) ->None:
|
||||
_log.info('Connecting to voice...')
|
||||
async def connect(self, *, reconnect: bool, timeout: float) -> None:
|
||||
_log.info("Connecting to voice...")
|
||||
self.timeout = timeout
|
||||
|
||||
for i in range(5):
|
||||
@ -388,7 +397,7 @@ class VoiceClient(VoiceProtocol):
|
||||
break
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
if reconnect:
|
||||
_log.exception('Failed to connect to voice... Retrying...')
|
||||
_log.exception("Failed to connect to voice... Retrying...")
|
||||
await asyncio.sleep(1 + i * 2.0)
|
||||
await self.voice_disconnect()
|
||||
continue
|
||||
@ -405,7 +414,9 @@ class VoiceClient(VoiceProtocol):
|
||||
self._potentially_reconnecting = True
|
||||
try:
|
||||
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
|
||||
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
|
||||
await asyncio.wait_for(
|
||||
self._voice_server_complete.wait(), timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self._potentially_reconnecting = False
|
||||
await self.disconnect(force=True)
|
||||
@ -453,14 +464,21 @@ class VoiceClient(VoiceProtocol):
|
||||
# 4014 - voice channel has been deleted.
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
_log.info(
|
||||
"Disconnecting from voice normally, close code %d.",
|
||||
exc.code,
|
||||
)
|
||||
await self.disconnect()
|
||||
break
|
||||
if exc.code == 4014:
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
_log.info(
|
||||
"Disconnected from voice by force... potentially reconnecting."
|
||||
)
|
||||
successful = await self.potential_reconnect()
|
||||
if not successful:
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
_log.info(
|
||||
"Reconnect was unsuccessful, disconnecting from voice normally..."
|
||||
)
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
@ -471,7 +489,9 @@ class VoiceClient(VoiceProtocol):
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
_log.exception(
|
||||
"Disconnected from voice... Reconnecting in %.2fs.", retry
|
||||
)
|
||||
self._connected.clear()
|
||||
await asyncio.sleep(retry)
|
||||
await self.voice_disconnect()
|
||||
@ -479,7 +499,7 @@ class VoiceClient(VoiceProtocol):
|
||||
await self.connect(reconnect=True, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
_log.warning("Could not connect to voice... Retrying...")
|
||||
continue
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
@ -527,11 +547,11 @@ class VoiceClient(VoiceProtocol):
|
||||
# Formulate rtp header
|
||||
header[0] = 0x80
|
||||
header[1] = 0x78
|
||||
struct.pack_into('>H', header, 2, self.sequence)
|
||||
struct.pack_into('>I', header, 4, self.timestamp)
|
||||
struct.pack_into('>I', header, 8, self.ssrc)
|
||||
struct.pack_into(">H", header, 2, self.sequence)
|
||||
struct.pack_into(">I", header, 4, self.timestamp)
|
||||
struct.pack_into(">I", header, 8, self.ssrc)
|
||||
|
||||
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
|
||||
encrypt_packet = getattr(self, "_encrypt_" + self.mode)
|
||||
return encrypt_packet(header, data)
|
||||
|
||||
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
|
||||
@ -551,12 +571,14 @@ class VoiceClient(VoiceProtocol):
|
||||
box = nacl.secret.SecretBox(bytes(self.secret_key))
|
||||
nonce = bytearray(24)
|
||||
|
||||
nonce[:4] = struct.pack('>I', self._lite_nonce)
|
||||
self.checked_add('_lite_nonce', 1, 4294967295)
|
||||
nonce[:4] = struct.pack(">I", self._lite_nonce)
|
||||
self.checked_add("_lite_nonce", 1, 4294967295)
|
||||
|
||||
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
|
||||
|
||||
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None:
|
||||
def play(
|
||||
self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None
|
||||
) -> None:
|
||||
"""Plays an :class:`AudioSource`.
|
||||
|
||||
The finalizer, ``after`` is called after the source has been exhausted
|
||||
@ -586,13 +608,15 @@ class VoiceClient(VoiceProtocol):
|
||||
"""
|
||||
|
||||
if not self.is_connected():
|
||||
raise ClientException('Not connected to voice.')
|
||||
raise ClientException("Not connected to voice.")
|
||||
|
||||
if self.is_playing():
|
||||
raise ClientException('Already playing audio.')
|
||||
raise ClientException("Already playing audio.")
|
||||
|
||||
if not isinstance(source, AudioSource):
|
||||
raise TypeError(f'source must be an AudioSource not {source.__class__.__name__}')
|
||||
raise TypeError(
|
||||
f"source must be an AudioSource not {source.__class__.__name__}"
|
||||
)
|
||||
|
||||
if not self.encoder and not source.is_opus():
|
||||
self.encoder = opus.Encoder()
|
||||
@ -635,10 +659,10 @@ class VoiceClient(VoiceProtocol):
|
||||
@source.setter
|
||||
def source(self, value: AudioSource) -> None:
|
||||
if not isinstance(value, AudioSource):
|
||||
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
|
||||
raise TypeError(f"expected AudioSource not {value.__class__.__name__}.")
|
||||
|
||||
if self._player is None:
|
||||
raise ValueError('Not playing anything.')
|
||||
raise ValueError("Not playing anything.")
|
||||
|
||||
self._player._set_source(value)
|
||||
|
||||
@ -662,7 +686,7 @@ class VoiceClient(VoiceProtocol):
|
||||
Encoding the data failed.
|
||||
"""
|
||||
|
||||
self.checked_add('sequence', 1, 65535)
|
||||
self.checked_add("sequence", 1, 65535)
|
||||
if encode:
|
||||
encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
|
||||
else:
|
||||
@ -671,6 +695,10 @@ class VoiceClient(VoiceProtocol):
|
||||
try:
|
||||
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
|
||||
except BlockingIOError:
|
||||
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
_log.warning(
|
||||
"A packet has been dropped (seq: %s, timestamp: %s)",
|
||||
self.sequence,
|
||||
self.timestamp,
|
||||
)
|
||||
|
||||
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
@ -30,13 +30,30 @@ import json
|
||||
import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, NamedTuple, Optional, TYPE_CHECKING, Tuple, Union, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .. import utils
|
||||
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
|
||||
from ..errors import (
|
||||
InvalidArgument,
|
||||
HTTPException,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
DiscordServerError,
|
||||
)
|
||||
from ..message import Message
|
||||
from ..enums import try_enum, WebhookType
|
||||
from ..user import BaseUser, User
|
||||
@ -46,10 +63,10 @@ from ..mixins import Hashable
|
||||
from ..channel import PartialMessageable
|
||||
|
||||
__all__ = (
|
||||
'Webhook',
|
||||
'WebhookMessage',
|
||||
'PartialWebhookChannel',
|
||||
'PartialWebhookGuild',
|
||||
"Webhook",
|
||||
"WebhookMessage",
|
||||
"PartialWebhookChannel",
|
||||
"PartialWebhookGuild",
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@ -120,14 +137,14 @@ class AsyncWebhookAdapter:
|
||||
self._locks[bucket] = lock = asyncio.Lock()
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
headers["Content-Type"] = "application/json"
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
headers["Authorization"] = f"Bot {auth_token}"
|
||||
|
||||
if reason is not None:
|
||||
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
|
||||
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
|
||||
|
||||
response: Optional[aiohttp.ClientResponse] = None
|
||||
data: Optional[Union[Dict[str, Any], str]] = None
|
||||
@ -147,23 +164,30 @@ class AsyncWebhookAdapter:
|
||||
to_send = form_data
|
||||
|
||||
try:
|
||||
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
|
||||
async with session.request(
|
||||
method, url, data=to_send, headers=headers, params=params
|
||||
) as response:
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
"Webhook ID %s with %s %s has returned status code %s",
|
||||
webhook_id,
|
||||
method,
|
||||
url,
|
||||
response.status,
|
||||
)
|
||||
data = (await response.text(encoding='utf-8')) or None
|
||||
if data and response.headers['Content-Type'] == 'application/json':
|
||||
data = (await response.text(encoding="utf-8")) or None
|
||||
if (
|
||||
data
|
||||
and response.headers["Content-Type"] == "application/json"
|
||||
):
|
||||
data = json.loads(data)
|
||||
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status != 429:
|
||||
remaining = response.headers.get("X-Ratelimit-Remaining")
|
||||
if remaining == "0" and response.status != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
|
||||
webhook_id,
|
||||
delta,
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
|
||||
@ -171,11 +195,15 @@ class AsyncWebhookAdapter:
|
||||
return data
|
||||
|
||||
if response.status == 429:
|
||||
if not response.headers.get('Via'):
|
||||
if not response.headers.get("Via"):
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
retry_after: float = data["retry_after"] # type: ignore
|
||||
_log.warning(
|
||||
"Webhook ID %s is rate limited. Retrying in %.2f seconds",
|
||||
webhook_id,
|
||||
retry_after,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@ -201,7 +229,7 @@ class AsyncWebhookAdapter:
|
||||
raise DiscordServerError(response, data)
|
||||
raise HTTPException(response, data)
|
||||
|
||||
raise RuntimeError('Unreachable code in HTTP handling.')
|
||||
raise RuntimeError("Unreachable code in HTTP handling.")
|
||||
|
||||
def delete_webhook(
|
||||
self,
|
||||
@ -211,7 +239,7 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, auth_token=token)
|
||||
|
||||
def delete_webhook_with_token(
|
||||
@ -222,7 +250,12 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[None]:
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"DELETE",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session, reason=reason)
|
||||
|
||||
def edit_webhook(
|
||||
@ -234,8 +267,10 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
|
||||
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(
|
||||
route, session, reason=reason, payload=payload, auth_token=token
|
||||
)
|
||||
|
||||
def edit_webhook_with_token(
|
||||
self,
|
||||
@ -246,7 +281,12 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
reason: Optional[str] = None,
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"PATCH",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session, reason=reason, payload=payload)
|
||||
|
||||
def execute_webhook(
|
||||
@ -261,11 +301,23 @@ class AsyncWebhookAdapter:
|
||||
thread_id: Optional[int] = None,
|
||||
wait: bool = False,
|
||||
) -> Response[Optional[MessagePayload]]:
|
||||
params = {'wait': int(wait)}
|
||||
params = {"wait": int(wait)}
|
||||
if thread_id:
|
||||
params['thread_id'] = thread_id
|
||||
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
|
||||
params["thread_id"] = thread_id
|
||||
route = Route(
|
||||
"POST",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(
|
||||
route,
|
||||
session,
|
||||
payload=payload,
|
||||
multipart=multipart,
|
||||
files=files,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def get_webhook_message(
|
||||
self,
|
||||
@ -276,8 +328,8 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[MessagePayload]:
|
||||
route = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"GET",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
@ -296,13 +348,15 @@ class AsyncWebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
) -> Response[Message]:
|
||||
route = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"PATCH",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
)
|
||||
return self.request(route, session, payload=payload, multipart=multipart, files=files)
|
||||
return self.request(
|
||||
route, session, payload=payload, multipart=multipart, files=files
|
||||
)
|
||||
|
||||
def delete_webhook_message(
|
||||
self,
|
||||
@ -313,8 +367,8 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[None]:
|
||||
route = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"DELETE",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
@ -328,7 +382,7 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(route, session=session, auth_token=token)
|
||||
|
||||
def fetch_webhook_with_token(
|
||||
@ -338,7 +392,12 @@ class AsyncWebhookAdapter:
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[WebhookPayload]:
|
||||
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"GET",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session=session)
|
||||
|
||||
def create_interaction_response(
|
||||
@ -351,15 +410,15 @@ class AsyncWebhookAdapter:
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Response[None]:
|
||||
payload: Dict[str, Any] = {
|
||||
'type': type,
|
||||
"type": type,
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
payload['data'] = data
|
||||
payload["data"] = data
|
||||
|
||||
route = Route(
|
||||
'POST',
|
||||
'/interactions/{webhook_id}/{webhook_token}/callback',
|
||||
"POST",
|
||||
"/interactions/{webhook_id}/{webhook_token}/callback",
|
||||
webhook_id=interaction_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
@ -374,8 +433,8 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
"GET",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
|
||||
webhook_id=application_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
@ -392,12 +451,14 @@ class AsyncWebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
) -> Response[MessagePayload]:
|
||||
r = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
|
||||
"PATCH",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
|
||||
webhook_id=application_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(r, session, payload=payload, multipart=multipart, files=files)
|
||||
return self.request(
|
||||
r, session, payload=payload, multipart=multipart, files=files
|
||||
)
|
||||
|
||||
def delete_original_interaction_response(
|
||||
self,
|
||||
@ -407,8 +468,8 @@ class AsyncWebhookAdapter:
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Response[None]:
|
||||
r = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
|
||||
"DELETE",
|
||||
"/webhooks/{webhook_id}/{wehook_token}/messages/@original",
|
||||
webhook_id=application_id,
|
||||
wehook_token=token,
|
||||
)
|
||||
@ -437,82 +498,86 @@ def handle_message_parameters(
|
||||
previous_allowed_mentions: Optional[AllowedMentions] = None,
|
||||
) -> ExecuteWebhookParameters:
|
||||
if files is not MISSING and file is not MISSING:
|
||||
raise TypeError('Cannot mix file and files keyword arguments.')
|
||||
raise TypeError("Cannot mix file and files keyword arguments.")
|
||||
if embeds is not MISSING and embed is not MISSING:
|
||||
raise TypeError('Cannot mix embed and embeds keyword arguments.')
|
||||
raise TypeError("Cannot mix embed and embeds keyword arguments.")
|
||||
|
||||
payload = {}
|
||||
if embeds is not MISSING:
|
||||
if len(embeds) > 10:
|
||||
raise InvalidArgument('embeds has a maximum of 10 elements.')
|
||||
payload['embeds'] = [e.to_dict() for e in embeds]
|
||||
raise InvalidArgument("embeds has a maximum of 10 elements.")
|
||||
payload["embeds"] = [e.to_dict() for e in embeds]
|
||||
|
||||
if embed is not MISSING:
|
||||
if embed is None:
|
||||
payload['embeds'] = []
|
||||
payload["embeds"] = []
|
||||
else:
|
||||
payload['embeds'] = [embed.to_dict()]
|
||||
payload["embeds"] = [embed.to_dict()]
|
||||
|
||||
if content is not MISSING:
|
||||
if content is not None:
|
||||
payload['content'] = str(content)
|
||||
payload["content"] = str(content)
|
||||
else:
|
||||
payload['content'] = None
|
||||
payload["content"] = None
|
||||
|
||||
if view is not MISSING:
|
||||
if view is not None:
|
||||
payload['components'] = view.to_components()
|
||||
payload["components"] = view.to_components()
|
||||
else:
|
||||
payload['components'] = []
|
||||
payload["components"] = []
|
||||
|
||||
payload['tts'] = tts
|
||||
payload["tts"] = tts
|
||||
if avatar_url:
|
||||
payload['avatar_url'] = str(avatar_url)
|
||||
payload["avatar_url"] = str(avatar_url)
|
||||
if username:
|
||||
payload['username'] = username
|
||||
payload["username"] = username
|
||||
if ephemeral:
|
||||
payload['flags'] = 64
|
||||
payload["flags"] = 64
|
||||
|
||||
if allowed_mentions:
|
||||
if previous_allowed_mentions is not None:
|
||||
payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
|
||||
payload["allowed_mentions"] = previous_allowed_mentions.merge(
|
||||
allowed_mentions
|
||||
).to_dict()
|
||||
else:
|
||||
payload['allowed_mentions'] = allowed_mentions.to_dict()
|
||||
payload["allowed_mentions"] = allowed_mentions.to_dict()
|
||||
elif previous_allowed_mentions is not None:
|
||||
payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
|
||||
payload["allowed_mentions"] = previous_allowed_mentions.to_dict()
|
||||
|
||||
multipart = []
|
||||
if file is not MISSING:
|
||||
files = [file]
|
||||
|
||||
if files:
|
||||
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
|
||||
multipart.append({"name": "payload_json", "value": utils._to_json(payload)})
|
||||
payload = None
|
||||
if len(files) == 1:
|
||||
file = files[0]
|
||||
multipart.append(
|
||||
{
|
||||
'name': 'file',
|
||||
'value': file.fp,
|
||||
'filename': file.filename,
|
||||
'content_type': 'application/octet-stream',
|
||||
"name": "file",
|
||||
"value": file.fp,
|
||||
"filename": file.filename,
|
||||
"content_type": "application/octet-stream",
|
||||
}
|
||||
)
|
||||
else:
|
||||
for index, file in enumerate(files):
|
||||
multipart.append(
|
||||
{
|
||||
'name': f'file{index}',
|
||||
'value': file.fp,
|
||||
'filename': file.filename,
|
||||
'content_type': 'application/octet-stream',
|
||||
"name": f"file{index}",
|
||||
"value": file.fp,
|
||||
"filename": file.filename,
|
||||
"content_type": "application/octet-stream",
|
||||
}
|
||||
)
|
||||
|
||||
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
|
||||
|
||||
|
||||
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
|
||||
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar(
|
||||
"async_webhook_context", default=AsyncWebhookAdapter()
|
||||
)
|
||||
|
||||
|
||||
class PartialWebhookChannel(Hashable):
|
||||
@ -530,14 +595,14 @@ class PartialWebhookChannel(Hashable):
|
||||
The partial channel's name.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'name')
|
||||
__slots__ = ("id", "name")
|
||||
|
||||
def __init__(self, *, data):
|
||||
self.id = int(data['id'])
|
||||
self.name = data['name']
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
|
||||
def __repr__(self):
|
||||
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
|
||||
return f"<PartialWebhookChannel name={self.name!r} id={self.id}>"
|
||||
|
||||
|
||||
class PartialWebhookGuild(Hashable):
|
||||
@ -555,16 +620,16 @@ class PartialWebhookGuild(Hashable):
|
||||
The partial guild's name.
|
||||
"""
|
||||
|
||||
__slots__ = ('id', 'name', '_icon', '_state')
|
||||
__slots__ = ("id", "name", "_icon", "_state")
|
||||
|
||||
def __init__(self, *, data, state):
|
||||
self._state = state
|
||||
self.id = int(data['id'])
|
||||
self.name = data['name']
|
||||
self._icon = data['icon']
|
||||
self.id = int(data["id"])
|
||||
self.name = data["name"]
|
||||
self._icon = data["icon"]
|
||||
|
||||
def __repr__(self):
|
||||
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
|
||||
return f"<PartialWebhookGuild name={self.name!r} id={self.id}>"
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional[Asset]:
|
||||
@ -578,13 +643,15 @@ class _FriendlyHttpAttributeErrorHelper:
|
||||
__slots__ = ()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
raise AttributeError('PartialWebhookState does not support http methods.')
|
||||
raise AttributeError("PartialWebhookState does not support http methods.")
|
||||
|
||||
|
||||
class _WebhookState:
|
||||
__slots__ = ('_parent', '_webhook')
|
||||
__slots__ = ("_parent", "_webhook")
|
||||
|
||||
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
|
||||
def __init__(
|
||||
self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]
|
||||
):
|
||||
self._webhook: Any = webhook
|
||||
|
||||
self._parent: Optional[ConnectionState]
|
||||
@ -621,7 +688,7 @@ class _WebhookState:
|
||||
if self._parent is not None:
|
||||
return getattr(self._parent, attr)
|
||||
|
||||
raise AttributeError(f'PartialWebhookState does not support {attr!r}.')
|
||||
raise AttributeError(f"PartialWebhookState does not support {attr!r}.")
|
||||
|
||||
|
||||
class WebhookMessage(Message):
|
||||
@ -750,47 +817,54 @@ class WebhookMessage(Message):
|
||||
|
||||
class BaseWebhook(Hashable):
|
||||
__slots__: Tuple[str, ...] = (
|
||||
'id',
|
||||
'type',
|
||||
'guild_id',
|
||||
'channel_id',
|
||||
'token',
|
||||
'auth_token',
|
||||
'user',
|
||||
'name',
|
||||
'_avatar',
|
||||
'source_channel',
|
||||
'source_guild',
|
||||
'_state',
|
||||
"id",
|
||||
"type",
|
||||
"guild_id",
|
||||
"channel_id",
|
||||
"token",
|
||||
"auth_token",
|
||||
"user",
|
||||
"name",
|
||||
"_avatar",
|
||||
"source_channel",
|
||||
"source_guild",
|
||||
"_state",
|
||||
)
|
||||
|
||||
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
token: Optional[str] = None,
|
||||
state: Optional[ConnectionState] = None,
|
||||
):
|
||||
self.auth_token: Optional[str] = token
|
||||
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
|
||||
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(
|
||||
self, parent=state
|
||||
)
|
||||
self._update(data)
|
||||
|
||||
def _update(self, data: WebhookPayload):
|
||||
self.id = int(data['id'])
|
||||
self.type = try_enum(WebhookType, int(data['type']))
|
||||
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
|
||||
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
|
||||
self.name = data.get('name')
|
||||
self._avatar = data.get('avatar')
|
||||
self.token = data.get('token')
|
||||
self.id = int(data["id"])
|
||||
self.type = try_enum(WebhookType, int(data["type"]))
|
||||
self.channel_id = utils._get_as_snowflake(data, "channel_id")
|
||||
self.guild_id = utils._get_as_snowflake(data, "guild_id")
|
||||
self.name = data.get("name")
|
||||
self._avatar = data.get("avatar")
|
||||
self.token = data.get("token")
|
||||
|
||||
user = data.get('user')
|
||||
user = data.get("user")
|
||||
self.user: Optional[Union[BaseUser, User]] = None
|
||||
if user is not None:
|
||||
# state parameter may be _WebhookState
|
||||
self.user = User(state=self._state, data=user) # type: ignore
|
||||
|
||||
source_channel = data.get('source_channel')
|
||||
source_channel = data.get("source_channel")
|
||||
if source_channel:
|
||||
source_channel = PartialWebhookChannel(data=source_channel)
|
||||
|
||||
self.source_channel: Optional[PartialWebhookChannel] = source_channel
|
||||
|
||||
source_guild = data.get('source_guild')
|
||||
source_guild = data.get("source_guild")
|
||||
if source_guild:
|
||||
source_guild = PartialWebhookGuild(data=source_guild, state=self._state)
|
||||
|
||||
@ -927,22 +1001,35 @@ class Webhook(BaseWebhook):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = ('session',)
|
||||
__slots__: Tuple[str, ...] = ("session",)
|
||||
|
||||
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
session: aiohttp.ClientSession,
|
||||
token: Optional[str] = None,
|
||||
state=None,
|
||||
):
|
||||
super().__init__(data, token, state)
|
||||
self.session = session
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
return f"<Webhook id={self.id!r}>"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
def partial(
|
||||
cls,
|
||||
id: int,
|
||||
token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
bot_token: Optional[str] = None,
|
||||
) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@ -970,15 +1057,21 @@ class Webhook(BaseWebhook):
|
||||
A partial webhook is just a webhook object with an ID and a token.
|
||||
"""
|
||||
data: WebhookPayload = {
|
||||
'id': id,
|
||||
'type': 1,
|
||||
'token': token,
|
||||
"id": id,
|
||||
"type": 1,
|
||||
"token": token,
|
||||
}
|
||||
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
|
||||
def from_url(
|
||||
cls,
|
||||
url: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
bot_token: Optional[str] = None,
|
||||
) -> Webhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@ -1008,24 +1101,32 @@ class Webhook(BaseWebhook):
|
||||
A partial :class:`Webhook`.
|
||||
A partial webhook is just a webhook object with an ID and a token.
|
||||
"""
|
||||
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
|
||||
m = re.search(
|
||||
r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})",
|
||||
url,
|
||||
)
|
||||
if m is None:
|
||||
raise InvalidArgument('Invalid webhook URL given.')
|
||||
raise InvalidArgument("Invalid webhook URL given.")
|
||||
|
||||
data: Dict[str, Any] = m.groupdict()
|
||||
data['type'] = 1
|
||||
data["type"] = 1
|
||||
return cls(data, session, token=bot_token) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _as_follower(cls, data, *, channel, user) -> Webhook:
|
||||
name = f"{channel.guild} #{channel}"
|
||||
feed: WebhookPayload = {
|
||||
'id': data['webhook_id'],
|
||||
'type': 2,
|
||||
'name': name,
|
||||
'channel_id': channel.id,
|
||||
'guild_id': channel.guild.id,
|
||||
'user': {'username': user.name, 'discriminator': user.discriminator, 'id': user.id, 'avatar': user._avatar},
|
||||
"id": data["webhook_id"],
|
||||
"type": 2,
|
||||
"name": name,
|
||||
"channel_id": channel.id,
|
||||
"guild_id": channel.guild.id,
|
||||
"user": {
|
||||
"username": user.name,
|
||||
"discriminator": user.discriminator,
|
||||
"id": user.id,
|
||||
"avatar": user._avatar,
|
||||
},
|
||||
}
|
||||
|
||||
state = channel._state
|
||||
@ -1075,11 +1176,17 @@ class Webhook(BaseWebhook):
|
||||
adapter = async_context.get()
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = await adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
|
||||
data = await adapter.fetch_webhook(
|
||||
self.id, self.auth_token, session=self.session
|
||||
)
|
||||
elif self.token:
|
||||
data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
|
||||
data = await adapter.fetch_webhook_with_token(
|
||||
self.id, self.token, session=self.session
|
||||
)
|
||||
else:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
return Webhook(data, self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
@ -1112,14 +1219,20 @@ class Webhook(BaseWebhook):
|
||||
This webhook does not have a token associated with it.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter = async_context.get()
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
await adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
|
||||
await adapter.delete_webhook(
|
||||
self.id, token=self.auth_token, session=self.session, reason=reason
|
||||
)
|
||||
elif self.token:
|
||||
await adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason)
|
||||
await adapter.delete_webhook_with_token(
|
||||
self.id, self.token, session=self.session, reason=reason
|
||||
)
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
@ -1165,14 +1278,18 @@ class Webhook(BaseWebhook):
|
||||
or it tried editing a channel without authentication.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
payload['name'] = str(name) if name is not None else None
|
||||
payload["name"] = str(name) if name is not None else None
|
||||
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
|
||||
payload["avatar"] = (
|
||||
utils._bytes_to_base64_data(avatar) if avatar is not None else None
|
||||
)
|
||||
|
||||
adapter = async_context.get()
|
||||
|
||||
@ -1180,27 +1297,45 @@ class Webhook(BaseWebhook):
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
raise InvalidArgument('Editing channel requires authenticated webhook')
|
||||
raise InvalidArgument("Editing channel requires authenticated webhook")
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
payload["channel_id"] = channel.id
|
||||
data = await adapter.edit_webhook(
|
||||
self.id,
|
||||
self.auth_token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
data = await adapter.edit_webhook(
|
||||
self.id,
|
||||
self.auth_token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
elif self.token:
|
||||
data = await adapter.edit_webhook_with_token(
|
||||
self.id, self.token, payload=payload, session=self.session, reason=reason
|
||||
self.id,
|
||||
self.token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
raise RuntimeError("Unreachable code hit: data was not assigned")
|
||||
|
||||
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
return Webhook(
|
||||
data=data, session=self.session, token=self.auth_token, state=self._state
|
||||
)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
|
||||
# state is artificial
|
||||
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@ -1350,22 +1485,30 @@ class Webhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(
|
||||
self._state, "allowed_mentions", None
|
||||
)
|
||||
if content is None:
|
||||
content = MISSING
|
||||
|
||||
application_webhook = self.type is WebhookType.application
|
||||
if ephemeral and not application_webhook:
|
||||
raise InvalidArgument('ephemeral messages can only be sent from application webhooks')
|
||||
raise InvalidArgument(
|
||||
"ephemeral messages can only be sent from application webhooks"
|
||||
)
|
||||
|
||||
if application_webhook:
|
||||
wait = True
|
||||
|
||||
if view is not MISSING:
|
||||
if isinstance(self._state, _WebhookState):
|
||||
raise InvalidArgument('Webhook views require an associated state with the webhook')
|
||||
raise InvalidArgument(
|
||||
"Webhook views require an associated state with the webhook"
|
||||
)
|
||||
if ephemeral is True and view.timeout is None:
|
||||
view.timeout = 15 * 60.0
|
||||
|
||||
@ -1439,7 +1582,9 @@ class Webhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter = async_context.get()
|
||||
data = await adapter.get_webhook_message(
|
||||
@ -1525,15 +1670,21 @@ class Webhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
if view is not MISSING:
|
||||
if isinstance(self._state, _WebhookState):
|
||||
raise InvalidArgument('This webhook does not have state associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have state associated with it"
|
||||
)
|
||||
|
||||
self._state.prevent_view_updates_for(message_id)
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(
|
||||
self._state, "allowed_mentions", None
|
||||
)
|
||||
params = handle_message_parameters(
|
||||
content=content,
|
||||
file=file,
|
||||
@ -1583,7 +1734,9 @@ class Webhook(BaseWebhook):
|
||||
Deleted a message that is not yours.
|
||||
"""
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter = async_context.get()
|
||||
await adapter.delete_webhook_message(
|
||||
|
@ -37,10 +37,28 @@ import time
|
||||
import re
|
||||
|
||||
from urllib.parse import quote as urlquote
|
||||
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from .. import utils
|
||||
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
|
||||
from ..errors import (
|
||||
InvalidArgument,
|
||||
HTTPException,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
DiscordServerError,
|
||||
)
|
||||
from ..message import Message
|
||||
from ..http import Route
|
||||
from ..channel import PartialMessageable
|
||||
@ -48,8 +66,8 @@ from ..channel import PartialMessageable
|
||||
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
|
||||
|
||||
__all__ = (
|
||||
'SyncWebhook',
|
||||
'SyncWebhookMessage',
|
||||
"SyncWebhook",
|
||||
"SyncWebhookMessage",
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@ -116,14 +134,14 @@ class WebhookAdapter:
|
||||
self._locks[bucket] = lock = threading.Lock()
|
||||
|
||||
if payload is not None:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
headers["Content-Type"] = "application/json"
|
||||
to_send = utils._to_json(payload)
|
||||
|
||||
if auth_token is not None:
|
||||
headers['Authorization'] = f'Bot {auth_token}'
|
||||
headers["Authorization"] = f"Bot {auth_token}"
|
||||
|
||||
if reason is not None:
|
||||
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
|
||||
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
|
||||
|
||||
response: Optional[Response] = None
|
||||
data: Optional[Union[Dict[str, Any], str]] = None
|
||||
@ -140,36 +158,50 @@ class WebhookAdapter:
|
||||
if multipart:
|
||||
file_data = {}
|
||||
for p in multipart:
|
||||
name = p['name']
|
||||
if name == 'payload_json':
|
||||
to_send = {'payload_json': p['value']}
|
||||
name = p["name"]
|
||||
if name == "payload_json":
|
||||
to_send = {"payload_json": p["value"]}
|
||||
else:
|
||||
file_data[name] = (p['filename'], p['value'], p['content_type'])
|
||||
file_data[name] = (
|
||||
p["filename"],
|
||||
p["value"],
|
||||
p["content_type"],
|
||||
)
|
||||
|
||||
try:
|
||||
with session.request(
|
||||
method, url, data=to_send, files=file_data, headers=headers, params=params
|
||||
method,
|
||||
url,
|
||||
data=to_send,
|
||||
files=file_data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
) as response:
|
||||
_log.debug(
|
||||
'Webhook ID %s with %s %s has returned status code %s',
|
||||
"Webhook ID %s with %s %s has returned status code %s",
|
||||
webhook_id,
|
||||
method,
|
||||
url,
|
||||
response.status_code,
|
||||
)
|
||||
response.encoding = 'utf-8'
|
||||
response.encoding = "utf-8"
|
||||
# Compatibility with aiohttp
|
||||
response.status = response.status_code # type: ignore
|
||||
|
||||
data = response.text or None
|
||||
if data and response.headers['Content-Type'] == 'application/json':
|
||||
if (
|
||||
data
|
||||
and response.headers["Content-Type"] == "application/json"
|
||||
):
|
||||
data = json.loads(data)
|
||||
|
||||
remaining = response.headers.get('X-Ratelimit-Remaining')
|
||||
if remaining == '0' and response.status_code != 429:
|
||||
remaining = response.headers.get("X-Ratelimit-Remaining")
|
||||
if remaining == "0" and response.status_code != 429:
|
||||
delta = utils._parse_ratelimit_header(response)
|
||||
_log.debug(
|
||||
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
|
||||
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
|
||||
webhook_id,
|
||||
delta,
|
||||
)
|
||||
lock.delay_by(delta)
|
||||
|
||||
@ -177,11 +209,15 @@ class WebhookAdapter:
|
||||
return data
|
||||
|
||||
if response.status_code == 429:
|
||||
if not response.headers.get('Via'):
|
||||
if not response.headers.get("Via"):
|
||||
raise HTTPException(response, data)
|
||||
|
||||
retry_after: float = data['retry_after'] # type: ignore
|
||||
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
|
||||
retry_after: float = data["retry_after"] # type: ignore
|
||||
_log.warning(
|
||||
"Webhook ID %s is rate limited. Retrying in %.2f seconds",
|
||||
webhook_id,
|
||||
retry_after,
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
continue
|
||||
|
||||
@ -207,7 +243,7 @@ class WebhookAdapter:
|
||||
raise DiscordServerError(response, data)
|
||||
raise HTTPException(response, data)
|
||||
|
||||
raise RuntimeError('Unreachable code in HTTP handling.')
|
||||
raise RuntimeError("Unreachable code in HTTP handling.")
|
||||
|
||||
def delete_webhook(
|
||||
self,
|
||||
@ -217,7 +253,7 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, auth_token=token)
|
||||
|
||||
def delete_webhook_with_token(
|
||||
@ -228,7 +264,12 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"DELETE",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session, reason=reason)
|
||||
|
||||
def edit_webhook(
|
||||
@ -240,8 +281,10 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
|
||||
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(
|
||||
route, session, reason=reason, payload=payload, auth_token=token
|
||||
)
|
||||
|
||||
def edit_webhook_with_token(
|
||||
self,
|
||||
@ -252,7 +295,12 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
reason: Optional[str] = None,
|
||||
):
|
||||
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"PATCH",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session, reason=reason, payload=payload)
|
||||
|
||||
def execute_webhook(
|
||||
@ -267,11 +315,23 @@ class WebhookAdapter:
|
||||
thread_id: Optional[int] = None,
|
||||
wait: bool = False,
|
||||
):
|
||||
params = {'wait': int(wait)}
|
||||
params = {"wait": int(wait)}
|
||||
if thread_id:
|
||||
params['thread_id'] = thread_id
|
||||
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
|
||||
params["thread_id"] = thread_id
|
||||
route = Route(
|
||||
"POST",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(
|
||||
route,
|
||||
session,
|
||||
payload=payload,
|
||||
multipart=multipart,
|
||||
files=files,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def get_webhook_message(
|
||||
self,
|
||||
@ -282,8 +342,8 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
):
|
||||
route = Route(
|
||||
'GET',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"GET",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
@ -302,13 +362,15 @@ class WebhookAdapter:
|
||||
files: Optional[List[File]] = None,
|
||||
):
|
||||
route = Route(
|
||||
'PATCH',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"PATCH",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
)
|
||||
return self.request(route, session, payload=payload, multipart=multipart, files=files)
|
||||
return self.request(
|
||||
route, session, payload=payload, multipart=multipart, files=files
|
||||
)
|
||||
|
||||
def delete_webhook_message(
|
||||
self,
|
||||
@ -319,8 +381,8 @@ class WebhookAdapter:
|
||||
session: Session,
|
||||
):
|
||||
route = Route(
|
||||
'DELETE',
|
||||
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
|
||||
"DELETE",
|
||||
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
message_id=message_id,
|
||||
@ -334,7 +396,7 @@ class WebhookAdapter:
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
|
||||
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
|
||||
return self.request(route, session=session, auth_token=token)
|
||||
|
||||
def fetch_webhook_with_token(
|
||||
@ -344,7 +406,12 @@ class WebhookAdapter:
|
||||
*,
|
||||
session: Session,
|
||||
):
|
||||
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
|
||||
route = Route(
|
||||
"GET",
|
||||
"/webhooks/{webhook_id}/{webhook_token}",
|
||||
webhook_id=webhook_id,
|
||||
webhook_token=token,
|
||||
)
|
||||
return self.request(route, session=session)
|
||||
|
||||
|
||||
@ -516,22 +583,35 @@ class SyncWebhook(BaseWebhook):
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = ('session',)
|
||||
__slots__: Tuple[str, ...] = ("session",)
|
||||
|
||||
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
data: WebhookPayload,
|
||||
session: Session,
|
||||
token: Optional[str] = None,
|
||||
state=None,
|
||||
):
|
||||
super().__init__(data, token, state)
|
||||
self.session = session
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Webhook id={self.id!r}>'
|
||||
return f"<Webhook id={self.id!r}>"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
""":class:`str` : Returns the webhook's url."""
|
||||
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
|
||||
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
|
||||
|
||||
@classmethod
|
||||
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
def partial(
|
||||
cls,
|
||||
id: int,
|
||||
token: str,
|
||||
*,
|
||||
session: Session = MISSING,
|
||||
bot_token: Optional[str] = None,
|
||||
) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook`.
|
||||
|
||||
Parameters
|
||||
@ -556,21 +636,23 @@ class SyncWebhook(BaseWebhook):
|
||||
A partial webhook is just a webhook object with an ID and a token.
|
||||
"""
|
||||
data: WebhookPayload = {
|
||||
'id': id,
|
||||
'type': 1,
|
||||
'token': token,
|
||||
"id": id,
|
||||
"type": 1,
|
||||
"token": token,
|
||||
}
|
||||
import requests
|
||||
|
||||
if session is not MISSING:
|
||||
if not isinstance(session, requests.Session):
|
||||
raise TypeError(f'expected requests.Session not {session.__class__!r}')
|
||||
raise TypeError(f"expected requests.Session not {session.__class__!r}")
|
||||
else:
|
||||
session = requests # type: ignore
|
||||
return cls(data, session, token=bot_token)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
|
||||
def from_url(
|
||||
cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None
|
||||
) -> SyncWebhook:
|
||||
"""Creates a partial :class:`Webhook` from a webhook URL.
|
||||
|
||||
Parameters
|
||||
@ -597,17 +679,20 @@ class SyncWebhook(BaseWebhook):
|
||||
A partial :class:`Webhook`.
|
||||
A partial webhook is just a webhook object with an ID and a token.
|
||||
"""
|
||||
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
|
||||
m = re.search(
|
||||
r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})",
|
||||
url,
|
||||
)
|
||||
if m is None:
|
||||
raise InvalidArgument('Invalid webhook URL given.')
|
||||
raise InvalidArgument("Invalid webhook URL given.")
|
||||
|
||||
data: Dict[str, Any] = m.groupdict()
|
||||
data['type'] = 1
|
||||
data["type"] = 1
|
||||
import requests
|
||||
|
||||
if session is not MISSING:
|
||||
if not isinstance(session, requests.Session):
|
||||
raise TypeError(f'expected requests.Session not {session.__class__!r}')
|
||||
raise TypeError(f"expected requests.Session not {session.__class__!r}")
|
||||
else:
|
||||
session = requests # type: ignore
|
||||
return cls(data, session, token=bot_token) # type: ignore
|
||||
@ -648,9 +733,13 @@ class SyncWebhook(BaseWebhook):
|
||||
if prefer_auth and self.auth_token:
|
||||
data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
|
||||
elif self.token:
|
||||
data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
|
||||
data = adapter.fetch_webhook_with_token(
|
||||
self.id, self.token, session=self.session
|
||||
)
|
||||
else:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
|
||||
|
||||
@ -679,14 +768,20 @@ class SyncWebhook(BaseWebhook):
|
||||
This webhook does not have a token associated with it.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
|
||||
adapter.delete_webhook(
|
||||
self.id, token=self.auth_token, session=self.session, reason=reason
|
||||
)
|
||||
elif self.token:
|
||||
adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason)
|
||||
adapter.delete_webhook_with_token(
|
||||
self.id, self.token, session=self.session, reason=reason
|
||||
)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
@ -731,14 +826,18 @@ class SyncWebhook(BaseWebhook):
|
||||
The newly edited webhook.
|
||||
"""
|
||||
if self.token is None and self.auth_token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
payload = {}
|
||||
if name is not MISSING:
|
||||
payload['name'] = str(name) if name is not None else None
|
||||
payload["name"] = str(name) if name is not None else None
|
||||
|
||||
if avatar is not MISSING:
|
||||
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
|
||||
payload["avatar"] = (
|
||||
utils._bytes_to_base64_data(avatar) if avatar is not None else None
|
||||
)
|
||||
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
|
||||
@ -746,25 +845,45 @@ class SyncWebhook(BaseWebhook):
|
||||
# If a channel is given, always use the authenticated endpoint
|
||||
if channel is not None:
|
||||
if self.auth_token is None:
|
||||
raise InvalidArgument('Editing channel requires authenticated webhook')
|
||||
raise InvalidArgument("Editing channel requires authenticated webhook")
|
||||
|
||||
payload['channel_id'] = channel.id
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
payload["channel_id"] = channel.id
|
||||
data = adapter.edit_webhook(
|
||||
self.id,
|
||||
self.auth_token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if prefer_auth and self.auth_token:
|
||||
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
|
||||
data = adapter.edit_webhook(
|
||||
self.id,
|
||||
self.auth_token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
elif self.token:
|
||||
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
|
||||
data = adapter.edit_webhook_with_token(
|
||||
self.id,
|
||||
self.token,
|
||||
payload=payload,
|
||||
session=self.session,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError('Unreachable code hit: data was not assigned')
|
||||
raise RuntimeError("Unreachable code hit: data was not assigned")
|
||||
|
||||
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
|
||||
return SyncWebhook(
|
||||
data=data, session=self.session, token=self.auth_token, state=self._state
|
||||
)
|
||||
|
||||
def _create_message(self, data):
|
||||
state = _WebhookState(self, parent=self._state)
|
||||
# state may be artificial (unlikely at this point...)
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
|
||||
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
|
||||
# state is artificial
|
||||
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
|
||||
|
||||
@ -887,9 +1006,13 @@ class SyncWebhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(
|
||||
self._state, "allowed_mentions", None
|
||||
)
|
||||
if content is None:
|
||||
content = MISSING
|
||||
|
||||
@ -951,7 +1074,9 @@ class SyncWebhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
data = adapter.get_webhook_message(
|
||||
@ -1015,9 +1140,13 @@ class SyncWebhook(BaseWebhook):
|
||||
"""
|
||||
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
|
||||
previous_mentions: Optional[AllowedMentions] = getattr(
|
||||
self._state, "allowed_mentions", None
|
||||
)
|
||||
params = handle_message_parameters(
|
||||
content=content,
|
||||
file=file,
|
||||
@ -1060,7 +1189,9 @@ class SyncWebhook(BaseWebhook):
|
||||
Deleted a message that is not yours.
|
||||
"""
|
||||
if self.token is None:
|
||||
raise InvalidArgument('This webhook does not have a token associated with it')
|
||||
raise InvalidArgument(
|
||||
"This webhook does not have a token associated with it"
|
||||
)
|
||||
|
||||
adapter: WebhookAdapter = _get_webhook_adapter()
|
||||
adapter.delete_webhook_message(
|
||||
|
@ -41,11 +41,12 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'WidgetChannel',
|
||||
'WidgetMember',
|
||||
'Widget',
|
||||
"WidgetChannel",
|
||||
"WidgetMember",
|
||||
"Widget",
|
||||
)
|
||||
|
||||
|
||||
class WidgetChannel:
|
||||
"""Represents a "partial" widget channel.
|
||||
|
||||
@ -76,7 +77,8 @@ class WidgetChannel:
|
||||
position: :class:`int`
|
||||
The channel's position
|
||||
"""
|
||||
__slots__ = ('id', 'name', 'position')
|
||||
|
||||
__slots__ = ("id", "name", "position")
|
||||
|
||||
def __init__(self, id: int, name: str, position: int) -> None:
|
||||
self.id: int = id
|
||||
@ -87,18 +89,19 @@ class WidgetChannel:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>'
|
||||
return f"<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>"
|
||||
|
||||
@property
|
||||
def mention(self) -> str:
|
||||
""":class:`str`: The string that allows you to mention the channel."""
|
||||
return f'<#{self.id}>'
|
||||
return f"<#{self.id}>"
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
|
||||
return snowflake_time(self.id)
|
||||
|
||||
|
||||
class WidgetMember(BaseUser):
|
||||
"""Represents a "partial" member of the widget's guild.
|
||||
|
||||
@ -147,9 +150,21 @@ class WidgetMember(BaseUser):
|
||||
connected_channel: Optional[:class:`WidgetChannel`]
|
||||
Which channel the member is connected to.
|
||||
"""
|
||||
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
|
||||
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted',
|
||||
'connected_channel')
|
||||
|
||||
__slots__ = (
|
||||
"name",
|
||||
"status",
|
||||
"nick",
|
||||
"avatar",
|
||||
"discriminator",
|
||||
"id",
|
||||
"bot",
|
||||
"activity",
|
||||
"deafened",
|
||||
"suppress",
|
||||
"muted",
|
||||
"connected_channel",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
activity: Optional[Union[BaseActivity, Spotify]]
|
||||
@ -159,17 +174,21 @@ class WidgetMember(BaseUser):
|
||||
*,
|
||||
state: ConnectionState,
|
||||
data: WidgetMemberPayload,
|
||||
connected_channel: Optional[WidgetChannel] = None
|
||||
connected_channel: Optional[WidgetChannel] = None,
|
||||
) -> None:
|
||||
super().__init__(state=state, data=data)
|
||||
self.nick: Optional[str] = data.get('nick')
|
||||
self.status: Status = try_enum(Status, data.get('status'))
|
||||
self.deafened: Optional[bool] = data.get('deaf', False) or data.get('self_deaf', False)
|
||||
self.muted: Optional[bool] = data.get('mute', False) or data.get('self_mute', False)
|
||||
self.suppress: Optional[bool] = data.get('suppress', False)
|
||||
self.nick: Optional[str] = data.get("nick")
|
||||
self.status: Status = try_enum(Status, data.get("status"))
|
||||
self.deafened: Optional[bool] = data.get("deaf", False) or data.get(
|
||||
"self_deaf", False
|
||||
)
|
||||
self.muted: Optional[bool] = data.get("mute", False) or data.get(
|
||||
"self_mute", False
|
||||
)
|
||||
self.suppress: Optional[bool] = data.get("suppress", False)
|
||||
|
||||
try:
|
||||
game = data['game']
|
||||
game = data["game"]
|
||||
except KeyError:
|
||||
activity = None
|
||||
else:
|
||||
@ -190,6 +209,7 @@ class WidgetMember(BaseUser):
|
||||
""":class:`str`: Returns the member's display name."""
|
||||
return self.nick or self.name
|
||||
|
||||
|
||||
class Widget:
|
||||
"""Represents a :class:`Guild` widget.
|
||||
|
||||
@ -227,27 +247,34 @@ class Widget:
|
||||
retrieved is capped.
|
||||
|
||||
"""
|
||||
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
|
||||
|
||||
__slots__ = ("_state", "channels", "_invite", "id", "members", "name")
|
||||
|
||||
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
|
||||
self._state = state
|
||||
self._invite = data['instant_invite']
|
||||
self.name: str = data['name']
|
||||
self.id: int = int(data['id'])
|
||||
self._invite = data["instant_invite"]
|
||||
self.name: str = data["name"]
|
||||
self.id: int = int(data["id"])
|
||||
|
||||
self.channels: List[WidgetChannel] = []
|
||||
for channel in data.get('channels', []):
|
||||
_id = int(channel['id'])
|
||||
self.channels.append(WidgetChannel(id=_id, name=channel['name'], position=channel['position']))
|
||||
for channel in data.get("channels", []):
|
||||
_id = int(channel["id"])
|
||||
self.channels.append(
|
||||
WidgetChannel(
|
||||
id=_id, name=channel["name"], position=channel["position"]
|
||||
)
|
||||
)
|
||||
|
||||
self.members: List[WidgetMember] = []
|
||||
channels = {channel.id: channel for channel in self.channels}
|
||||
for member in data.get('members', []):
|
||||
connected_channel = _get_as_snowflake(member, 'channel_id')
|
||||
for member in data.get("members", []):
|
||||
connected_channel = _get_as_snowflake(member, "channel_id")
|
||||
if connected_channel in channels:
|
||||
connected_channel = channels[connected_channel] # type: ignore
|
||||
elif connected_channel:
|
||||
connected_channel = WidgetChannel(id=connected_channel, name='', position=0)
|
||||
connected_channel = WidgetChannel(
|
||||
id=connected_channel, name="", position=0
|
||||
)
|
||||
|
||||
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore
|
||||
|
||||
@ -260,7 +287,9 @@ class Widget:
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>'
|
||||
return (
|
||||
f"<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>"
|
||||
)
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime.datetime:
|
||||
|
233
docs/conf.py
233
docs/conf.py
@ -18,45 +18,45 @@ import re
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
sys.path.append(os.path.abspath('extensions'))
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
sys.path.append(os.path.abspath("extensions"))
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#needs_sphinx = '1.0'
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'builder',
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.extlinks',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinxcontrib_trio',
|
||||
'details',
|
||||
'exception_hierarchy',
|
||||
'attributetable',
|
||||
'resourcelinks',
|
||||
'nitpick_file_ignorer',
|
||||
"builder",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.extlinks",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinxcontrib_trio",
|
||||
"details",
|
||||
"exception_hierarchy",
|
||||
"attributetable",
|
||||
"resourcelinks",
|
||||
"nitpick_file_ignorer",
|
||||
]
|
||||
|
||||
autodoc_member_order = 'bysource'
|
||||
autodoc_typehints = 'none'
|
||||
autodoc_member_order = "bysource"
|
||||
autodoc_typehints = "none"
|
||||
# maybe consider this?
|
||||
# napoleon_attr_annotations = False
|
||||
|
||||
extlinks = {
|
||||
'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'GH-'),
|
||||
"issue": ("https://github.com/Rapptz/discord.py/issues/%s", "GH-"),
|
||||
}
|
||||
|
||||
# Links used for cross-referencing stuff in other documentation
|
||||
intersphinx_mapping = {
|
||||
'py': ('https://docs.python.org/3', None),
|
||||
'aio': ('https://docs.aiohttp.org/en/stable/', None),
|
||||
'req': ('https://docs.python-requests.org/en/latest/', None)
|
||||
"py": ("https://docs.python.org/3", None),
|
||||
"aio": ("https://docs.aiohttp.org/en/stable/", None),
|
||||
"req": ("https://docs.python-requests.org/en/latest/", None),
|
||||
}
|
||||
|
||||
rst_prolog = """
|
||||
@ -67,20 +67,20 @@ rst_prolog = """
|
||||
"""
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# The suffix of source filenames.
|
||||
source_suffix = '.rst'
|
||||
source_suffix = ".rst"
|
||||
|
||||
# The encoding of source files.
|
||||
#source_encoding = 'utf-8-sig'
|
||||
# source_encoding = 'utf-8-sig'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
master_doc = "index"
|
||||
|
||||
# General information about the project.
|
||||
project = 'discord.py'
|
||||
copyright = '2015-present, Rapptz'
|
||||
project = "discord.py"
|
||||
copyright = "2015-present, Rapptz"
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
@ -88,15 +88,17 @@ copyright = '2015-present, Rapptz'
|
||||
#
|
||||
# The short X.Y version.
|
||||
|
||||
version = ''
|
||||
with open('../discord/__init__.py') as f:
|
||||
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
|
||||
version = ""
|
||||
with open("../discord/__init__.py") as f:
|
||||
version = re.search(
|
||||
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE
|
||||
).group(1)
|
||||
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = version
|
||||
|
||||
# This assumes a tag is available for final releases
|
||||
branch = 'master' if version.endswith('a') else 'v' + version
|
||||
branch = "master" if version.endswith("a") else "v" + version
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@ -105,49 +107,49 @@ branch = 'master' if version.endswith('a') else 'v' + version
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
locale_dirs = ['locale/']
|
||||
locale_dirs = ["locale/"]
|
||||
gettext_compact = False
|
||||
|
||||
# There are two options for replacing |today|: either, you set today to some
|
||||
# non-false value, then it is used:
|
||||
#today = ''
|
||||
# today = ''
|
||||
# Else, today_fmt is used as the format for a strftime call.
|
||||
#today_fmt = '%B %d, %Y'
|
||||
# today_fmt = '%B %d, %Y'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
exclude_patterns = ['_build']
|
||||
exclude_patterns = ["_build"]
|
||||
|
||||
# The reST default role (used for this markup: `text`) to use for all
|
||||
# documents.
|
||||
#default_role = None
|
||||
# default_role = None
|
||||
|
||||
# If true, '()' will be appended to :func: etc. cross-reference text.
|
||||
#add_function_parentheses = True
|
||||
# add_function_parentheses = True
|
||||
|
||||
# If true, the current module name will be prepended to all description
|
||||
# unit titles (such as .. function::).
|
||||
#add_module_names = True
|
||||
# add_module_names = True
|
||||
|
||||
# If true, sectionauthor and moduleauthor directives will be shown in the
|
||||
# output. They are ignored by default.
|
||||
#show_authors = False
|
||||
# show_authors = False
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'friendly'
|
||||
pygments_style = "friendly"
|
||||
|
||||
# A list of ignored prefixes for module index sorting.
|
||||
#modindex_common_prefix = []
|
||||
# modindex_common_prefix = []
|
||||
|
||||
# If true, keep warnings as "system message" paragraphs in the built documents.
|
||||
#keep_warnings = False
|
||||
# keep_warnings = False
|
||||
|
||||
|
||||
# Nitpicky mode options
|
||||
nitpick_ignore_files = [
|
||||
"migrating_to_async",
|
||||
"migrating",
|
||||
"whats_new",
|
||||
"migrating_to_async",
|
||||
"migrating",
|
||||
"whats_new",
|
||||
]
|
||||
|
||||
# -- Options for HTML output ----------------------------------------------
|
||||
@ -156,21 +158,21 @@ html_experimental_html5_writer = True
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
html_theme = 'basic'
|
||||
html_theme = "basic"
|
||||
|
||||
html_context = {
|
||||
'discord_invite': 'https://discord.gg/r3sSKJJ',
|
||||
'discord_extensions': [
|
||||
('discord.ext.commands', 'ext/commands'),
|
||||
('discord.ext.tasks', 'ext/tasks'),
|
||||
],
|
||||
"discord_invite": "https://discord.gg/r3sSKJJ",
|
||||
"discord_extensions": [
|
||||
("discord.ext.commands", "ext/commands"),
|
||||
("discord.ext.tasks", "ext/tasks"),
|
||||
],
|
||||
}
|
||||
|
||||
resource_links = {
|
||||
'discord': 'https://discord.gg/r3sSKJJ',
|
||||
'issues': 'https://github.com/Rapptz/discord.py/issues',
|
||||
'discussions': 'https://github.com/Rapptz/discord.py/discussions',
|
||||
'examples': f'https://github.com/Rapptz/discord.py/tree/{branch}/examples',
|
||||
"discord": "https://discord.gg/r3sSKJJ",
|
||||
"issues": "https://github.com/Rapptz/discord.py/issues",
|
||||
"discussions": "https://github.com/Rapptz/discord.py/discussions",
|
||||
"examples": f"https://github.com/Rapptz/discord.py/tree/{branch}/examples",
|
||||
}
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
@ -180,155 +182,143 @@ resource_links = {
|
||||
# }
|
||||
|
||||
# Add any paths that contain custom themes here, relative to this directory.
|
||||
#html_theme_path = []
|
||||
# html_theme_path = []
|
||||
|
||||
# The name for this set of Sphinx documents. If None, it defaults to
|
||||
# "<project> v<release> documentation".
|
||||
#html_title = None
|
||||
# html_title = None
|
||||
|
||||
# A shorter title for the navigation bar. Default is the same as html_title.
|
||||
#html_short_title = None
|
||||
# html_short_title = None
|
||||
|
||||
# The name of an image file (relative to this directory) to place at the top
|
||||
# of the sidebar.
|
||||
#html_logo = None
|
||||
# html_logo = None
|
||||
|
||||
# The name of an image file (within the static path) to use as favicon of the
|
||||
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
|
||||
# pixels large.
|
||||
html_favicon = './images/discord_py_logo.ico'
|
||||
html_favicon = "./images/discord_py_logo.ico"
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
# Add any extra paths that contain custom files (such as robots.txt or
|
||||
# .htaccess) here, relative to this directory. These files are copied
|
||||
# directly to the root of the documentation.
|
||||
#html_extra_path = []
|
||||
# html_extra_path = []
|
||||
|
||||
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
|
||||
# using the given strftime format.
|
||||
#html_last_updated_fmt = '%b %d, %Y'
|
||||
# html_last_updated_fmt = '%b %d, %Y'
|
||||
|
||||
# If true, SmartyPants will be used to convert quotes and dashes to
|
||||
# typographically correct entities.
|
||||
#html_use_smartypants = True
|
||||
# html_use_smartypants = True
|
||||
|
||||
# Custom sidebar templates, maps document names to template names.
|
||||
#html_sidebars = {}
|
||||
# html_sidebars = {}
|
||||
|
||||
# Additional templates that should be rendered to pages, maps page names to
|
||||
# template names.
|
||||
#html_additional_pages = {}
|
||||
# html_additional_pages = {}
|
||||
|
||||
# If false, no module index is generated.
|
||||
#html_domain_indices = True
|
||||
# html_domain_indices = True
|
||||
|
||||
# If false, no index is generated.
|
||||
#html_use_index = True
|
||||
# html_use_index = True
|
||||
|
||||
# If true, the index is split into individual pages for each letter.
|
||||
#html_split_index = False
|
||||
# html_split_index = False
|
||||
|
||||
# If true, links to the reST sources are added to the pages.
|
||||
#html_show_sourcelink = True
|
||||
# html_show_sourcelink = True
|
||||
|
||||
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
|
||||
#html_show_sphinx = True
|
||||
# html_show_sphinx = True
|
||||
|
||||
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
|
||||
#html_show_copyright = True
|
||||
# html_show_copyright = True
|
||||
|
||||
# If true, an OpenSearch description file will be output, and all pages will
|
||||
# contain a <link> tag referring to it. The value of this option must be the
|
||||
# base URL from which the finished HTML is served.
|
||||
#html_use_opensearch = ''
|
||||
# html_use_opensearch = ''
|
||||
|
||||
# This is the file name suffix for HTML files (e.g. ".xhtml").
|
||||
#html_file_suffix = None
|
||||
# html_file_suffix = None
|
||||
|
||||
# Language to be used for generating the HTML full-text search index.
|
||||
# Sphinx supports the following languages:
|
||||
# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
|
||||
# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
|
||||
#html_search_language = 'en'
|
||||
# html_search_language = 'en'
|
||||
|
||||
# A dictionary with options for the search language support, empty by default.
|
||||
# Now only 'ja' uses this config value
|
||||
#html_search_options = {'type': 'default'}
|
||||
# html_search_options = {'type': 'default'}
|
||||
|
||||
# The name of a javascript file (relative to the configuration directory) that
|
||||
# implements a search results scorer. If empty, the default will be used.
|
||||
html_search_scorer = '_static/scorer.js'
|
||||
html_search_scorer = "_static/scorer.js"
|
||||
|
||||
html_js_files = [
|
||||
'custom.js',
|
||||
'settings.js',
|
||||
'copy.js',
|
||||
'sidebar.js'
|
||||
]
|
||||
html_js_files = ["custom.js", "settings.js", "copy.js", "sidebar.js"]
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'discord.pydoc'
|
||||
htmlhelp_basename = "discord.pydoc"
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#'figure_align': 'htbp',
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#'papersize': 'letterpaper',
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#'pointsize': '10pt',
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#'preamble': '',
|
||||
# Latex figure (float) alignment
|
||||
#'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
('index', 'discord.py.tex', 'discord.py Documentation',
|
||||
'Rapptz', 'manual'),
|
||||
("index", "discord.py.tex", "discord.py Documentation", "Rapptz", "manual"),
|
||||
]
|
||||
|
||||
# The name of an image file (relative to this directory) to place at the top of
|
||||
# the title page.
|
||||
#latex_logo = None
|
||||
# latex_logo = None
|
||||
|
||||
# For "manual" documents, if this is true, then toplevel headings are parts,
|
||||
# not chapters.
|
||||
#latex_use_parts = False
|
||||
# latex_use_parts = False
|
||||
|
||||
# If true, show page references after internal links.
|
||||
#latex_show_pagerefs = False
|
||||
# latex_show_pagerefs = False
|
||||
|
||||
# If true, show URL addresses after external links.
|
||||
#latex_show_urls = False
|
||||
# latex_show_urls = False
|
||||
|
||||
# Documents to append as an appendix to all manuals.
|
||||
#latex_appendices = []
|
||||
# latex_appendices = []
|
||||
|
||||
# If false, no module index is generated.
|
||||
#latex_domain_indices = True
|
||||
# latex_domain_indices = True
|
||||
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
('index', 'discord.py', 'discord.py Documentation',
|
||||
['Rapptz'], 1)
|
||||
]
|
||||
man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
|
||||
|
||||
# If true, show URL addresses after external links.
|
||||
#man_show_urls = False
|
||||
# man_show_urls = False
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
@ -337,25 +327,32 @@ man_pages = [
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
('index', 'discord.py', 'discord.py Documentation',
|
||||
'Rapptz', 'discord.py', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
(
|
||||
"index",
|
||||
"discord.py",
|
||||
"discord.py Documentation",
|
||||
"Rapptz",
|
||||
"discord.py",
|
||||
"One line description of project.",
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
||||
# Documents to append as an appendix to all manuals.
|
||||
#texinfo_appendices = []
|
||||
# texinfo_appendices = []
|
||||
|
||||
# If false, no module index is generated.
|
||||
#texinfo_domain_indices = True
|
||||
# texinfo_domain_indices = True
|
||||
|
||||
# How to display URL addresses: 'footnote', 'no', or 'inline'.
|
||||
#texinfo_show_urls = 'footnote'
|
||||
# texinfo_show_urls = 'footnote'
|
||||
|
||||
# If true, do not generate a @detailmenu in the "Top" node's menu.
|
||||
#texinfo_no_detailmenu = False
|
||||
# texinfo_no_detailmenu = False
|
||||
|
||||
|
||||
def setup(app):
|
||||
if app.config.language == 'ja':
|
||||
app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None)
|
||||
app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg'
|
||||
app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg'
|
||||
if app.config.language == "ja":
|
||||
app.config.intersphinx_mapping["py"] = ("https://docs.python.org/ja/3", None)
|
||||
app.config.html_context["discord_invite"] = "https://discord.gg/nXzj3dg"
|
||||
app.config.resource_links["discord"] = "https://discord.gg/nXzj3dg"
|
||||
|
@ -9,60 +9,78 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class attributetable(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
class attributetablecolumn(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
class attributetabletitle(nodes.TextElement):
|
||||
pass
|
||||
|
||||
|
||||
class attributetableplaceholder(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
class attributetablebadge(nodes.TextElement):
|
||||
pass
|
||||
|
||||
|
||||
class attributetable_item(nodes.Part, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
def visit_attributetable_node(self, node):
|
||||
class_ = node["python-class"]
|
||||
self.body.append(f'<div class="py-attribute-table" data-move-to-id="{class_}">')
|
||||
|
||||
|
||||
def visit_attributetablecolumn_node(self, node):
|
||||
self.body.append(self.starttag(node, 'div', CLASS='py-attribute-table-column'))
|
||||
self.body.append(self.starttag(node, "div", CLASS="py-attribute-table-column"))
|
||||
|
||||
|
||||
def visit_attributetabletitle_node(self, node):
|
||||
self.body.append(self.starttag(node, 'span'))
|
||||
self.body.append(self.starttag(node, "span"))
|
||||
|
||||
|
||||
def visit_attributetablebadge_node(self, node):
|
||||
attributes = {
|
||||
'class': 'py-attribute-table-badge',
|
||||
'title': node['badge-type'],
|
||||
"class": "py-attribute-table-badge",
|
||||
"title": node["badge-type"],
|
||||
}
|
||||
self.body.append(self.starttag(node, 'span', **attributes))
|
||||
self.body.append(self.starttag(node, "span", **attributes))
|
||||
|
||||
|
||||
def visit_attributetable_item_node(self, node):
|
||||
self.body.append(self.starttag(node, 'li', CLASS='py-attribute-table-entry'))
|
||||
self.body.append(self.starttag(node, "li", CLASS="py-attribute-table-entry"))
|
||||
|
||||
|
||||
def depart_attributetable_node(self, node):
|
||||
self.body.append('</div>')
|
||||
self.body.append("</div>")
|
||||
|
||||
|
||||
def depart_attributetablecolumn_node(self, node):
|
||||
self.body.append('</div>')
|
||||
self.body.append("</div>")
|
||||
|
||||
|
||||
def depart_attributetabletitle_node(self, node):
|
||||
self.body.append('</span>')
|
||||
self.body.append("</span>")
|
||||
|
||||
|
||||
def depart_attributetablebadge_node(self, node):
|
||||
self.body.append('</span>')
|
||||
self.body.append("</span>")
|
||||
|
||||
|
||||
def depart_attributetable_item_node(self, node):
|
||||
self.body.append('</li>')
|
||||
self.body.append("</li>")
|
||||
|
||||
|
||||
_name_parser_regex = re.compile(r"(?P<module>[\w.]+\.)?(?P<name>\w+)")
|
||||
|
||||
_name_parser_regex = re.compile(r'(?P<module>[\w.]+\.)?(?P<name>\w+)')
|
||||
|
||||
class PyAttributeTable(SphinxDirective):
|
||||
has_content = False
|
||||
@ -74,13 +92,15 @@ class PyAttributeTable(SphinxDirective):
|
||||
def parse_name(self, content):
|
||||
path, name = _name_parser_regex.match(content).groups()
|
||||
if path:
|
||||
modulename = path.rstrip('.')
|
||||
modulename = path.rstrip(".")
|
||||
else:
|
||||
modulename = self.env.temp_data.get('autodoc:module')
|
||||
modulename = self.env.temp_data.get("autodoc:module")
|
||||
if not modulename:
|
||||
modulename = self.env.ref_context.get('py:module')
|
||||
modulename = self.env.ref_context.get("py:module")
|
||||
if modulename is None:
|
||||
raise RuntimeError('modulename somehow None for %s in %s.' % (content, self.env.docname))
|
||||
raise RuntimeError(
|
||||
"modulename somehow None for %s in %s." % (content, self.env.docname)
|
||||
)
|
||||
|
||||
return modulename, name
|
||||
|
||||
@ -112,29 +132,33 @@ class PyAttributeTable(SphinxDirective):
|
||||
replaced.
|
||||
"""
|
||||
content = self.arguments[0].strip()
|
||||
node = attributetableplaceholder('')
|
||||
node = attributetableplaceholder("")
|
||||
modulename, name = self.parse_name(content)
|
||||
node['python-doc'] = self.env.docname
|
||||
node['python-module'] = modulename
|
||||
node['python-class'] = name
|
||||
node['python-full-name'] = f'{modulename}.{name}'
|
||||
node["python-doc"] = self.env.docname
|
||||
node["python-module"] = modulename
|
||||
node["python-class"] = name
|
||||
node["python-full-name"] = f"{modulename}.{name}"
|
||||
return [node]
|
||||
|
||||
|
||||
def build_lookup_table(env):
|
||||
# Given an environment, load up a lookup table of
|
||||
# full-class-name: objects
|
||||
result = {}
|
||||
domain = env.domains['py']
|
||||
domain = env.domains["py"]
|
||||
|
||||
ignored = {
|
||||
'data', 'exception', 'module', 'class',
|
||||
"data",
|
||||
"exception",
|
||||
"module",
|
||||
"class",
|
||||
}
|
||||
|
||||
for (fullname, _, objtype, docname, _, _) in domain.get_objects():
|
||||
if objtype in ignored:
|
||||
continue
|
||||
|
||||
classname, _, child = fullname.rpartition('.')
|
||||
classname, _, child = fullname.rpartition(".")
|
||||
try:
|
||||
result[classname].append(child)
|
||||
except KeyError:
|
||||
@ -143,36 +167,46 @@ def build_lookup_table(env):
|
||||
return result
|
||||
|
||||
|
||||
TableElement = namedtuple('TableElement', 'fullname label badge')
|
||||
TableElement = namedtuple("TableElement", "fullname label badge")
|
||||
|
||||
|
||||
def process_attributetable(app, doctree, fromdocname):
|
||||
env = app.builder.env
|
||||
|
||||
lookup = build_lookup_table(env)
|
||||
for node in doctree.traverse(attributetableplaceholder):
|
||||
modulename, classname, fullname = node['python-module'], node['python-class'], node['python-full-name']
|
||||
modulename, classname, fullname = (
|
||||
node["python-module"],
|
||||
node["python-class"],
|
||||
node["python-full-name"],
|
||||
)
|
||||
groups = get_class_results(lookup, modulename, classname, fullname)
|
||||
table = attributetable('')
|
||||
table = attributetable("")
|
||||
for label, subitems in groups.items():
|
||||
if not subitems:
|
||||
continue
|
||||
table.append(class_results_to_node(label, sorted(subitems, key=lambda c: c.label)))
|
||||
table.append(
|
||||
class_results_to_node(label, sorted(subitems, key=lambda c: c.label))
|
||||
)
|
||||
|
||||
table['python-class'] = fullname
|
||||
table["python-class"] = fullname
|
||||
|
||||
if not table:
|
||||
node.replace_self([])
|
||||
else:
|
||||
node.replace_self([table])
|
||||
|
||||
|
||||
def get_class_results(lookup, modulename, name, fullname):
|
||||
module = importlib.import_module(modulename)
|
||||
cls = getattr(module, name)
|
||||
|
||||
groups = OrderedDict([
|
||||
(_('Attributes'), []),
|
||||
(_('Methods'), []),
|
||||
])
|
||||
groups = OrderedDict(
|
||||
[
|
||||
(_("Attributes"), []),
|
||||
(_("Methods"), []),
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
members = lookup[fullname]
|
||||
@ -180,8 +214,8 @@ def get_class_results(lookup, modulename, name, fullname):
|
||||
return groups
|
||||
|
||||
for attr in members:
|
||||
attrlookup = f'{fullname}.{attr}'
|
||||
key = _('Attributes')
|
||||
attrlookup = f"{fullname}.{attr}"
|
||||
key = _("Attributes")
|
||||
badge = None
|
||||
label = attr
|
||||
value = None
|
||||
@ -192,53 +226,73 @@ def get_class_results(lookup, modulename, name, fullname):
|
||||
break
|
||||
|
||||
if value is not None:
|
||||
doc = value.__doc__ or ''
|
||||
if inspect.iscoroutinefunction(value) or doc.startswith('|coro|'):
|
||||
key = _('Methods')
|
||||
badge = attributetablebadge('async', 'async')
|
||||
badge['badge-type'] = _('coroutine')
|
||||
doc = value.__doc__ or ""
|
||||
if inspect.iscoroutinefunction(value) or doc.startswith("|coro|"):
|
||||
key = _("Methods")
|
||||
badge = attributetablebadge("async", "async")
|
||||
badge["badge-type"] = _("coroutine")
|
||||
elif isinstance(value, classmethod):
|
||||
key = _('Methods')
|
||||
label = f'{name}.{attr}'
|
||||
badge = attributetablebadge('cls', 'cls')
|
||||
badge['badge-type'] = _('classmethod')
|
||||
key = _("Methods")
|
||||
label = f"{name}.{attr}"
|
||||
badge = attributetablebadge("cls", "cls")
|
||||
badge["badge-type"] = _("classmethod")
|
||||
elif inspect.isfunction(value):
|
||||
if doc.startswith(('A decorator', 'A shortcut decorator')):
|
||||
if doc.startswith(("A decorator", "A shortcut decorator")):
|
||||
# finicky but surprisingly consistent
|
||||
badge = attributetablebadge('@', '@')
|
||||
badge['badge-type'] = _('decorator')
|
||||
key = _('Methods')
|
||||
badge = attributetablebadge("@", "@")
|
||||
badge["badge-type"] = _("decorator")
|
||||
key = _("Methods")
|
||||
else:
|
||||
key = _('Methods')
|
||||
badge = attributetablebadge('def', 'def')
|
||||
badge['badge-type'] = _('method')
|
||||
key = _("Methods")
|
||||
badge = attributetablebadge("def", "def")
|
||||
badge["badge-type"] = _("method")
|
||||
|
||||
groups[key].append(TableElement(fullname=attrlookup, label=label, badge=badge))
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def class_results_to_node(key, elements):
|
||||
title = attributetabletitle(key, key)
|
||||
ul = nodes.bullet_list('')
|
||||
ul = nodes.bullet_list("")
|
||||
for element in elements:
|
||||
ref = nodes.reference('', '', internal=True,
|
||||
refuri='#' + element.fullname,
|
||||
anchorname='',
|
||||
*[nodes.Text(element.label)])
|
||||
para = addnodes.compact_paragraph('', '', ref)
|
||||
ref = nodes.reference(
|
||||
"",
|
||||
"",
|
||||
internal=True,
|
||||
refuri="#" + element.fullname,
|
||||
anchorname="",
|
||||
*[nodes.Text(element.label)],
|
||||
)
|
||||
para = addnodes.compact_paragraph("", "", ref)
|
||||
if element.badge is not None:
|
||||
ul.append(attributetable_item('', element.badge, para))
|
||||
ul.append(attributetable_item("", element.badge, para))
|
||||
else:
|
||||
ul.append(attributetable_item('', para))
|
||||
ul.append(attributetable_item("", para))
|
||||
|
||||
return attributetablecolumn("", title, ul)
|
||||
|
||||
return attributetablecolumn('', title, ul)
|
||||
|
||||
def setup(app):
|
||||
app.add_directive('attributetable', PyAttributeTable)
|
||||
app.add_node(attributetable, html=(visit_attributetable_node, depart_attributetable_node))
|
||||
app.add_node(attributetablecolumn, html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node))
|
||||
app.add_node(attributetabletitle, html=(visit_attributetabletitle_node, depart_attributetabletitle_node))
|
||||
app.add_node(attributetablebadge, html=(visit_attributetablebadge_node, depart_attributetablebadge_node))
|
||||
app.add_node(attributetable_item, html=(visit_attributetable_item_node, depart_attributetable_item_node))
|
||||
app.add_directive("attributetable", PyAttributeTable)
|
||||
app.add_node(
|
||||
attributetable, html=(visit_attributetable_node, depart_attributetable_node)
|
||||
)
|
||||
app.add_node(
|
||||
attributetablecolumn,
|
||||
html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node),
|
||||
)
|
||||
app.add_node(
|
||||
attributetabletitle,
|
||||
html=(visit_attributetabletitle_node, depart_attributetabletitle_node),
|
||||
)
|
||||
app.add_node(
|
||||
attributetablebadge,
|
||||
html=(visit_attributetablebadge_node, depart_attributetablebadge_node),
|
||||
)
|
||||
app.add_node(
|
||||
attributetable_item,
|
||||
html=(visit_attributetable_item_node, depart_attributetable_item_node),
|
||||
)
|
||||
app.add_node(attributetableplaceholder)
|
||||
app.connect('doctree-resolved', process_attributetable)
|
||||
app.connect("doctree-resolved", process_attributetable)
|
||||
|
@ -2,15 +2,15 @@ from sphinx.builders.html import StandaloneHTMLBuilder
|
||||
from sphinx.environment.adapters.indexentries import IndexEntries
|
||||
from sphinx.writers.html5 import HTML5Translator
|
||||
|
||||
|
||||
class DPYHTML5Translator(HTML5Translator):
|
||||
def visit_section(self, node):
|
||||
self.section_level += 1
|
||||
self.body.append(
|
||||
self.starttag(node, 'section'))
|
||||
self.body.append(self.starttag(node, "section"))
|
||||
|
||||
def depart_section(self, node):
|
||||
self.section_level -= 1
|
||||
self.body.append('</section>\n')
|
||||
self.body.append("</section>\n")
|
||||
|
||||
def visit_table(self, node):
|
||||
self.body.append('<div class="table-wrapper">')
|
||||
@ -18,7 +18,8 @@ class DPYHTML5Translator(HTML5Translator):
|
||||
|
||||
def depart_table(self, node):
|
||||
super().depart_table(node)
|
||||
self.body.append('</div>')
|
||||
self.body.append("</div>")
|
||||
|
||||
|
||||
class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
|
||||
# This is mostly copy pasted from Sphinx.
|
||||
@ -28,50 +29,56 @@ class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
|
||||
genindex = IndexEntries(self.env).create_index(self, group_entries=False)
|
||||
indexcounts = []
|
||||
for _k, entries in genindex:
|
||||
indexcounts.append(sum(1 + len(subitems)
|
||||
for _, (_, subitems, _) in entries))
|
||||
indexcounts.append(
|
||||
sum(1 + len(subitems) for _, (_, subitems, _) in entries)
|
||||
)
|
||||
|
||||
genindexcontext = {
|
||||
'genindexentries': genindex,
|
||||
'genindexcounts': indexcounts,
|
||||
'split_index': self.config.html_split_index,
|
||||
"genindexentries": genindex,
|
||||
"genindexcounts": indexcounts,
|
||||
"split_index": self.config.html_split_index,
|
||||
}
|
||||
|
||||
if self.config.html_split_index:
|
||||
self.handle_page('genindex', genindexcontext,
|
||||
'genindex-split.html')
|
||||
self.handle_page('genindex-all', genindexcontext,
|
||||
'genindex.html')
|
||||
self.handle_page("genindex", genindexcontext, "genindex-split.html")
|
||||
self.handle_page("genindex-all", genindexcontext, "genindex.html")
|
||||
for (key, entries), count in zip(genindex, indexcounts):
|
||||
ctx = {'key': key, 'entries': entries, 'count': count,
|
||||
'genindexentries': genindex}
|
||||
self.handle_page('genindex-' + key, ctx,
|
||||
'genindex-single.html')
|
||||
ctx = {
|
||||
"key": key,
|
||||
"entries": entries,
|
||||
"count": count,
|
||||
"genindexentries": genindex,
|
||||
}
|
||||
self.handle_page("genindex-" + key, ctx, "genindex-single.html")
|
||||
else:
|
||||
self.handle_page('genindex', genindexcontext, 'genindex.html')
|
||||
self.handle_page("genindex", genindexcontext, "genindex.html")
|
||||
|
||||
|
||||
def add_custom_jinja2(app):
|
||||
env = app.builder.templates.environment
|
||||
env.tests['prefixedwith'] = str.startswith
|
||||
env.tests['suffixedwith'] = str.endswith
|
||||
env.tests["prefixedwith"] = str.startswith
|
||||
env.tests["suffixedwith"] = str.endswith
|
||||
|
||||
|
||||
def add_builders(app):
|
||||
"""This is necessary because RTD injects their own for some reason."""
|
||||
app.set_translator('html', DPYHTML5Translator, override=True)
|
||||
app.set_translator("html", DPYHTML5Translator, override=True)
|
||||
app.add_builder(DPYStandaloneHTMLBuilder, override=True)
|
||||
|
||||
try:
|
||||
original = app.registry.builders['readthedocs']
|
||||
original = app.registry.builders["readthedocs"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
injected_mro = tuple(base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder
|
||||
for base in original.mro()[1:])
|
||||
new_builder = type(original.__name__, injected_mro, {'name': 'readthedocs'})
|
||||
app.set_translator('readthedocs', DPYHTML5Translator, override=True)
|
||||
injected_mro = tuple(
|
||||
base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder
|
||||
for base in original.mro()[1:]
|
||||
)
|
||||
new_builder = type(original.__name__, injected_mro, {"name": "readthedocs"})
|
||||
app.set_translator("readthedocs", DPYHTML5Translator, override=True)
|
||||
app.add_builder(new_builder, override=True)
|
||||
|
||||
|
||||
def setup(app):
|
||||
add_builders(app)
|
||||
app.connect('builder-inited', add_custom_jinja2)
|
||||
app.connect("builder-inited", add_custom_jinja2)
|
||||
|
@ -3,32 +3,43 @@ from docutils.parsers.rst import states, directives
|
||||
from docutils.parsers.rst.roles import set_classes
|
||||
from docutils import nodes
|
||||
|
||||
|
||||
class details(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
class summary(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
def visit_details_node(self, node):
|
||||
self.body.append(self.starttag(node, 'details', CLASS=node.attributes.get('class', '')))
|
||||
self.body.append(
|
||||
self.starttag(node, "details", CLASS=node.attributes.get("class", ""))
|
||||
)
|
||||
|
||||
|
||||
def visit_summary_node(self, node):
|
||||
self.body.append(self.starttag(node, 'summary', CLASS=node.attributes.get('summary-class', '')))
|
||||
self.body.append(
|
||||
self.starttag(node, "summary", CLASS=node.attributes.get("summary-class", ""))
|
||||
)
|
||||
self.body.append(node.rawsource)
|
||||
|
||||
|
||||
def depart_details_node(self, node):
|
||||
self.body.append('</details>\n')
|
||||
self.body.append("</details>\n")
|
||||
|
||||
|
||||
def depart_summary_node(self, node):
|
||||
self.body.append('</summary>')
|
||||
self.body.append("</summary>")
|
||||
|
||||
|
||||
class DetailsDirective(Directive):
|
||||
final_argument_whitespace = True
|
||||
optional_arguments = 1
|
||||
|
||||
option_spec = {
|
||||
'class': directives.class_option,
|
||||
'summary-class': directives.class_option,
|
||||
"class": directives.class_option,
|
||||
"summary-class": directives.class_option,
|
||||
}
|
||||
|
||||
has_content = True
|
||||
@ -37,19 +48,22 @@ class DetailsDirective(Directive):
|
||||
set_classes(self.options)
|
||||
self.assert_has_content()
|
||||
|
||||
text = '\n'.join(self.content)
|
||||
text = "\n".join(self.content)
|
||||
node = details(text, **self.options)
|
||||
|
||||
if self.arguments:
|
||||
summary_node = summary(self.arguments[0], **self.options)
|
||||
summary_node.source, summary_node.line = self.state_machine.get_source_and_line(self.lineno)
|
||||
(
|
||||
summary_node.source,
|
||||
summary_node.line,
|
||||
) = self.state_machine.get_source_and_line(self.lineno)
|
||||
node += summary_node
|
||||
|
||||
self.state.nested_parse(self.content, self.content_offset, node)
|
||||
return [node]
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_node(details, html=(visit_details_node, depart_details_node))
|
||||
app.add_node(summary, html=(visit_summary_node, depart_summary_node))
|
||||
app.add_directive('details', DetailsDirective)
|
||||
|
||||
app.add_directive("details", DetailsDirective)
|
||||
|
@ -4,24 +4,32 @@ from docutils.parsers.rst.roles import set_classes
|
||||
from docutils import nodes
|
||||
from sphinx.locale import _
|
||||
|
||||
|
||||
class exception_hierarchy(nodes.General, nodes.Element):
|
||||
pass
|
||||
|
||||
|
||||
def visit_exception_hierarchy_node(self, node):
|
||||
self.body.append(self.starttag(node, 'div', CLASS='exception-hierarchy-content'))
|
||||
self.body.append(self.starttag(node, "div", CLASS="exception-hierarchy-content"))
|
||||
|
||||
|
||||
def depart_exception_hierarchy_node(self, node):
|
||||
self.body.append('</div>\n')
|
||||
self.body.append("</div>\n")
|
||||
|
||||
|
||||
class ExceptionHierarchyDirective(Directive):
|
||||
has_content = True
|
||||
|
||||
def run(self):
|
||||
self.assert_has_content()
|
||||
node = exception_hierarchy('\n'.join(self.content))
|
||||
node = exception_hierarchy("\n".join(self.content))
|
||||
self.state.nested_parse(self.content, self.content_offset, node)
|
||||
return [node]
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_node(exception_hierarchy, html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node))
|
||||
app.add_directive('exception_hierarchy', ExceptionHierarchyDirective)
|
||||
app.add_node(
|
||||
exception_hierarchy,
|
||||
html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node),
|
||||
)
|
||||
app.add_directive("exception_hierarchy", ExceptionHierarchyDirective)
|
||||
|
@ -5,18 +5,20 @@ from sphinx.util import logging as sphinx_logging
|
||||
|
||||
|
||||
class NitpickFileIgnorer(logging.Filter):
|
||||
|
||||
def __init__(self, app: Sphinx) -> None:
|
||||
self.app = app
|
||||
super().__init__()
|
||||
|
||||
def filter(self, record: sphinx_logging.SphinxLogRecord) -> bool:
|
||||
if getattr(record, 'type', None) == 'ref':
|
||||
return record.location.get('refdoc') not in self.app.config.nitpick_ignore_files
|
||||
if getattr(record, "type", None) == "ref":
|
||||
return (
|
||||
record.location.get("refdoc")
|
||||
not in self.app.config.nitpick_ignore_files
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def setup(app: Sphinx):
|
||||
app.add_config_value('nitpick_ignore_files', [], '')
|
||||
app.add_config_value("nitpick_ignore_files", [], "")
|
||||
f = NitpickFileIgnorer(app)
|
||||
sphinx_logging.getLogger('sphinx.transforms.post_transforms').logger.addFilter(f)
|
||||
sphinx_logging.getLogger("sphinx.transforms.post_transforms").logger.addFilter(f)
|
||||
|
@ -22,7 +22,7 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
|
||||
lineno: int,
|
||||
inliner: Inliner,
|
||||
options: Dict = {},
|
||||
content: List[str] = []
|
||||
content: List[str] = [],
|
||||
) -> Tuple[List[Node], List[system_message]]:
|
||||
|
||||
text = utils.unescape(text)
|
||||
@ -32,13 +32,15 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
|
||||
title = full_url
|
||||
pnode = nodes.reference(title, title, internal=False, refuri=full_url)
|
||||
return [pnode], []
|
||||
|
||||
return role
|
||||
|
||||
|
||||
def add_link_role(app: Sphinx) -> None:
|
||||
app.add_role('resource', make_link_role(app.config.resource_links))
|
||||
app.add_role("resource", make_link_role(app.config.resource_links))
|
||||
|
||||
|
||||
def setup(app: Sphinx) -> Dict[str, Any]:
|
||||
app.add_config_value('resource_links', {}, 'env')
|
||||
app.connect('builder-inited', add_link_role)
|
||||
return {'version': sphinx.__display_version__, 'parallel_read_safe': True}
|
||||
app.add_config_value("resource_links", {}, "env")
|
||||
app.connect("builder-inited", add_link_role)
|
||||
return {"version": sphinx.__display_version__, "parallel_read_safe": True}
|
||||
|
@ -2,6 +2,7 @@ from discord.ext import tasks
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -13,18 +14,19 @@ class MyClient(discord.Client):
|
||||
self.my_background_task.start()
|
||||
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
@tasks.loop(seconds=60) # task runs every 60 seconds
|
||||
@tasks.loop(seconds=60) # task runs every 60 seconds
|
||||
async def my_background_task(self):
|
||||
channel = self.get_channel(1234567) # channel ID goes here
|
||||
channel = self.get_channel(1234567) # channel ID goes here
|
||||
self.counter += 1
|
||||
await channel.send(self.counter)
|
||||
|
||||
@my_background_task.before_loop
|
||||
async def before_my_task(self):
|
||||
await self.wait_until_ready() # wait until the bot logs in
|
||||
await self.wait_until_ready() # wait until the bot logs in
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -1,6 +1,7 @@
|
||||
import discord
|
||||
import asyncio
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -9,18 +10,18 @@ class MyClient(discord.Client):
|
||||
self.bg_task = self.loop.create_task(self.my_background_task())
|
||||
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def my_background_task(self):
|
||||
await self.wait_until_ready()
|
||||
counter = 0
|
||||
channel = self.get_channel(1234567) # channel ID goes here
|
||||
channel = self.get_channel(1234567) # channel ID goes here
|
||||
while not self.is_closed():
|
||||
counter += 1
|
||||
await channel.send(counter)
|
||||
await asyncio.sleep(60) # task runs every 60 seconds
|
||||
await asyncio.sleep(60) # task runs every 60 seconds
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -4,51 +4,58 @@ import discord
|
||||
from discord.ext import commands
|
||||
import random
|
||||
|
||||
description = '''An example bot to showcase the discord.ext.commands extension
|
||||
description = """An example bot to showcase the discord.ext.commands extension
|
||||
module.
|
||||
|
||||
There are a number of utility commands being showcased here.'''
|
||||
There are a number of utility commands being showcased here."""
|
||||
|
||||
intents = discord.Intents(guilds=True, messages=True, members=True)
|
||||
bot = commands.Bot(command_prefix='t-', description=description, intents=intents)
|
||||
bot = commands.Bot(command_prefix="t-", description=description, intents=intents)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
print(f'Logged in as {bot.user} (ID: {bot.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {bot.user} (ID: {bot.user.id})")
|
||||
print("------")
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def add(ctx, left: int, right: int):
|
||||
"""Adds two numbers together."""
|
||||
await ctx.send(left + right)
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def roll(ctx, dice: str):
|
||||
"""Rolls a dice in NdN format."""
|
||||
try:
|
||||
rolls, limit = map(int, dice.split('d'))
|
||||
rolls, limit = map(int, dice.split("d"))
|
||||
except Exception:
|
||||
await ctx.send('Format has to be in NdN!')
|
||||
await ctx.send("Format has to be in NdN!")
|
||||
return
|
||||
|
||||
result = ', '.join(str(random.randint(1, limit)) for r in range(rolls))
|
||||
result = ", ".join(str(random.randint(1, limit)) for r in range(rolls))
|
||||
await ctx.send(result)
|
||||
|
||||
@bot.command(description='For when you wanna settle the score some other way')
|
||||
|
||||
@bot.command(description="For when you wanna settle the score some other way")
|
||||
async def choose(ctx, *choices: str):
|
||||
"""Chooses between multiple choices."""
|
||||
await ctx.send(random.choice(choices))
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def repeat(ctx, times: int, content='repeating...'):
|
||||
async def repeat(ctx, times: int, content="repeating..."):
|
||||
"""Repeats a message multiple times."""
|
||||
for i in range(times):
|
||||
await ctx.send(content)
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def joined(ctx, member: discord.Member):
|
||||
"""Says when a member joined."""
|
||||
await ctx.send(f'{member.name} joined in {member.joined_at}')
|
||||
await ctx.send(f"{member.name} joined in {member.joined_at}")
|
||||
|
||||
|
||||
@bot.group()
|
||||
async def cool(ctx):
|
||||
@ -57,11 +64,13 @@ async def cool(ctx):
|
||||
In reality this just checks if a subcommand is being invoked.
|
||||
"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await ctx.send(f'No, {ctx.subcommand_passed} is not cool')
|
||||
await ctx.send(f"No, {ctx.subcommand_passed} is not cool")
|
||||
|
||||
@cool.command(name='bot')
|
||||
|
||||
@cool.command(name="bot")
|
||||
async def _bot(ctx):
|
||||
"""Is the bot cool?"""
|
||||
await ctx.send('Yes, the bot is cool.')
|
||||
await ctx.send("Yes, the bot is cool.")
|
||||
|
||||
bot.run('token')
|
||||
|
||||
bot.run("token")
|
||||
|
@ -6,26 +6,24 @@ import youtube_dl
|
||||
from discord.ext import commands
|
||||
|
||||
# Suppress noise about console usage from errors
|
||||
youtube_dl.utils.bug_reports_message = lambda: ''
|
||||
youtube_dl.utils.bug_reports_message = lambda: ""
|
||||
|
||||
|
||||
ytdl_format_options = {
|
||||
'format': 'bestaudio/best',
|
||||
'outtmpl': '%(extractor)s-%(id)s-%(title)s.%(ext)s',
|
||||
'restrictfilenames': True,
|
||||
'noplaylist': True,
|
||||
'nocheckcertificate': True,
|
||||
'ignoreerrors': False,
|
||||
'logtostderr': False,
|
||||
'quiet': True,
|
||||
'no_warnings': True,
|
||||
'default_search': 'auto',
|
||||
'source_address': '0.0.0.0' # bind to ipv4 since ipv6 addresses cause issues sometimes
|
||||
"format": "bestaudio/best",
|
||||
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
|
||||
"restrictfilenames": True,
|
||||
"noplaylist": True,
|
||||
"nocheckcertificate": True,
|
||||
"ignoreerrors": False,
|
||||
"logtostderr": False,
|
||||
"quiet": True,
|
||||
"no_warnings": True,
|
||||
"default_search": "auto",
|
||||
"source_address": "0.0.0.0", # bind to ipv4 since ipv6 addresses cause issues sometimes
|
||||
}
|
||||
|
||||
ffmpeg_options = {
|
||||
'options': '-vn'
|
||||
}
|
||||
ffmpeg_options = {"options": "-vn"}
|
||||
|
||||
ytdl = youtube_dl.YoutubeDL(ytdl_format_options)
|
||||
|
||||
@ -36,19 +34,21 @@ class YTDLSource(discord.PCMVolumeTransformer):
|
||||
|
||||
self.data = data
|
||||
|
||||
self.title = data.get('title')
|
||||
self.url = data.get('url')
|
||||
self.title = data.get("title")
|
||||
self.url = data.get("url")
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url, *, loop=None, stream=False):
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
data = await loop.run_in_executor(None, lambda: ytdl.extract_info(url, download=not stream))
|
||||
data = await loop.run_in_executor(
|
||||
None, lambda: ytdl.extract_info(url, download=not stream)
|
||||
)
|
||||
|
||||
if 'entries' in data:
|
||||
if "entries" in data:
|
||||
# take first item from a playlist
|
||||
data = data['entries'][0]
|
||||
data = data["entries"][0]
|
||||
|
||||
filename = data['url'] if stream else ytdl.prepare_filename(data)
|
||||
filename = data["url"] if stream else ytdl.prepare_filename(data)
|
||||
return cls(discord.FFmpegPCMAudio(filename, **ffmpeg_options), data=data)
|
||||
|
||||
|
||||
@ -70,9 +70,11 @@ class Music(commands.Cog):
|
||||
"""Plays a file from the local filesystem"""
|
||||
|
||||
source = discord.PCMVolumeTransformer(discord.FFmpegPCMAudio(query))
|
||||
ctx.voice_client.play(source, after=lambda e: print(f'Player error: {e}') if e else None)
|
||||
ctx.voice_client.play(
|
||||
source, after=lambda e: print(f"Player error: {e}") if e else None
|
||||
)
|
||||
|
||||
await ctx.send(f'Now playing: {query}')
|
||||
await ctx.send(f"Now playing: {query}")
|
||||
|
||||
@commands.command()
|
||||
async def yt(self, ctx, *, url):
|
||||
@ -80,9 +82,11 @@ class Music(commands.Cog):
|
||||
|
||||
async with ctx.typing():
|
||||
player = await YTDLSource.from_url(url, loop=self.bot.loop)
|
||||
ctx.voice_client.play(player, after=lambda e: print(f'Player error: {e}') if e else None)
|
||||
ctx.voice_client.play(
|
||||
player, after=lambda e: print(f"Player error: {e}") if e else None
|
||||
)
|
||||
|
||||
await ctx.send(f'Now playing: {player.title}')
|
||||
await ctx.send(f"Now playing: {player.title}")
|
||||
|
||||
@commands.command()
|
||||
async def stream(self, ctx, *, url):
|
||||
@ -90,9 +94,11 @@ class Music(commands.Cog):
|
||||
|
||||
async with ctx.typing():
|
||||
player = await YTDLSource.from_url(url, loop=self.bot.loop, stream=True)
|
||||
ctx.voice_client.play(player, after=lambda e: print(f'Player error: {e}') if e else None)
|
||||
ctx.voice_client.play(
|
||||
player, after=lambda e: print(f"Player error: {e}") if e else None
|
||||
)
|
||||
|
||||
await ctx.send(f'Now playing: {player.title}')
|
||||
await ctx.send(f"Now playing: {player.title}")
|
||||
|
||||
@commands.command()
|
||||
async def volume(self, ctx, volume: int):
|
||||
@ -123,16 +129,19 @@ class Music(commands.Cog):
|
||||
elif ctx.voice_client.is_playing():
|
||||
ctx.voice_client.stop()
|
||||
|
||||
|
||||
bot = commands.Bot(
|
||||
command_prefix=commands.when_mentioned_or("!"),
|
||||
description='Relatively simple music bot example',
|
||||
intents=discord.Intents(guilds=True, guild_messages=True, voice_states=True)
|
||||
description="Relatively simple music bot example",
|
||||
intents=discord.Intents(guilds=True, guild_messages=True, voice_states=True),
|
||||
)
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
print(f'Logged in as {bot.user} (ID: {bot.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {bot.user} (ID: {bot.user.id})")
|
||||
print("------")
|
||||
|
||||
|
||||
bot.add_cog(Music(bot))
|
||||
bot.run('token')
|
||||
bot.run("token")
|
||||
|
@ -7,7 +7,7 @@ from discord.ext import commands
|
||||
|
||||
|
||||
intents = discord.Intents(guilds=True, messages=True, members=True)
|
||||
bot = commands.Bot('!', intents=intents)
|
||||
bot = commands.Bot("!", intents=intents)
|
||||
|
||||
|
||||
@bot.command()
|
||||
@ -28,14 +28,16 @@ async def userinfo(ctx: commands.Context, user: discord.User):
|
||||
user_id = user.id
|
||||
username = user.name
|
||||
avatar = user.avatar.url
|
||||
await ctx.send(f'User found: {user_id} -- {username}\n{avatar}')
|
||||
await ctx.send(f"User found: {user_id} -- {username}\n{avatar}")
|
||||
|
||||
|
||||
@userinfo.error
|
||||
async def userinfo_error(ctx: commands.Context, error: commands.CommandError):
|
||||
# if the conversion above fails for any reason, it will raise `commands.BadArgument`
|
||||
# so we handle this in this error handler:
|
||||
if isinstance(error, commands.BadArgument):
|
||||
return await ctx.send('Couldn\'t find that user.')
|
||||
return await ctx.send("Couldn't find that user.")
|
||||
|
||||
|
||||
# Custom Converter here
|
||||
class ChannelOrMemberConverter(commands.Converter):
|
||||
@ -69,21 +71,25 @@ class ChannelOrMemberConverter(commands.Converter):
|
||||
# If the value could not be converted we can raise an error
|
||||
# so our error handlers can deal with it in one place.
|
||||
# The error has to be CommandError derived, so BadArgument works fine here.
|
||||
raise commands.BadArgument(f'No Member or TextChannel could be converted from "{argument}"')
|
||||
|
||||
raise commands.BadArgument(
|
||||
f'No Member or TextChannel could be converted from "{argument}"'
|
||||
)
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def notify(ctx: commands.Context, target: ChannelOrMemberConverter):
|
||||
# This command signature utilises the custom converter written above
|
||||
# What will happen during command invocation is that the `target` above will be passed to
|
||||
# the `argument` parameter of the `ChannelOrMemberConverter.convert` method and
|
||||
# the `argument` parameter of the `ChannelOrMemberConverter.convert` method and
|
||||
# the conversion will go through the process defined there.
|
||||
|
||||
await target.send(f'Hello, {target.name}!')
|
||||
await target.send(f"Hello, {target.name}!")
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]):
|
||||
async def ignore(
|
||||
ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]
|
||||
):
|
||||
# This command signature utilises the `typing.Union` typehint.
|
||||
# The `commands` framework attempts a conversion of each type in this Union *in order*.
|
||||
# So, it will attempt to convert whatever is passed to `target` to a `discord.Member` instance.
|
||||
@ -94,9 +100,16 @@ async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, dis
|
||||
|
||||
# To check the resulting type, `isinstance` is used
|
||||
if isinstance(target, discord.Member):
|
||||
await ctx.send(f'Member found: {target.mention}, adding them to the ignore list.')
|
||||
elif isinstance(target, discord.TextChannel): # this could be an `else` but for completeness' sake.
|
||||
await ctx.send(f'Channel found: {target.mention}, adding it to the ignore list.')
|
||||
await ctx.send(
|
||||
f"Member found: {target.mention}, adding them to the ignore list."
|
||||
)
|
||||
elif isinstance(
|
||||
target, discord.TextChannel
|
||||
): # this could be an `else` but for completeness' sake.
|
||||
await ctx.send(
|
||||
f"Channel found: {target.mention}, adding it to the ignore list."
|
||||
)
|
||||
|
||||
|
||||
# Built-in type converters.
|
||||
@bot.command()
|
||||
@ -109,4 +122,5 @@ async def multiply(ctx: commands.Context, number: int, maybe: bool):
|
||||
return await ctx.send(number * 2)
|
||||
await ctx.send(number * 5)
|
||||
|
||||
bot.run('token')
|
||||
|
||||
bot.run("token")
|
||||
|
@ -10,7 +10,7 @@ class MyContext(commands.Context):
|
||||
# depending on whether value is True or False
|
||||
# if its True, it'll add a green check mark
|
||||
# otherwise, it'll add a red cross mark
|
||||
emoji = '\N{WHITE HEAVY CHECK MARK}' if value else '\N{CROSS MARK}'
|
||||
emoji = "\N{WHITE HEAVY CHECK MARK}" if value else "\N{CROSS MARK}"
|
||||
try:
|
||||
# this will react to the command author's message
|
||||
await self.message.add_reaction(emoji)
|
||||
@ -27,9 +27,10 @@ class MyBot(commands.Bot):
|
||||
# subclass to the super() method, which tells the bot to
|
||||
# use the new MyContext class
|
||||
return await super().get_context(message, cls=cls)
|
||||
|
||||
|
||||
bot = MyBot(command_prefix='!', intents=discord.Intents(guilds=True, messages=True))
|
||||
|
||||
bot = MyBot(command_prefix="!", intents=discord.Intents(guilds=True, messages=True))
|
||||
|
||||
|
||||
@bot.command()
|
||||
async def guess(ctx, number: int):
|
||||
@ -42,8 +43,9 @@ async def guess(ctx, number: int):
|
||||
# or a red cross mark if it wasn't
|
||||
await ctx.tick(number == value)
|
||||
|
||||
|
||||
# IMPORTANT: You shouldn't hard code your token
|
||||
# these are very important, and leaking them can
|
||||
# these are very important, and leaking them can
|
||||
# let people do very malicious things with your
|
||||
# bot. Try to use a file or something to keep
|
||||
# them private, and don't commit it to GitHub
|
||||
|
@ -1,21 +1,23 @@
|
||||
import discord
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def on_message(self, message):
|
||||
if message.content.startswith('!deleteme'):
|
||||
msg = await message.channel.send('I will delete myself now...')
|
||||
if message.content.startswith("!deleteme"):
|
||||
msg = await message.channel.send("I will delete myself now...")
|
||||
await msg.delete()
|
||||
|
||||
# this also works
|
||||
await message.channel.send('Goodbye in 3 seconds...', delete_after=3.0)
|
||||
await message.channel.send("Goodbye in 3 seconds...", delete_after=3.0)
|
||||
|
||||
async def on_message_delete(self, message):
|
||||
msg = f'{message.author} has deleted the message: {message.content}'
|
||||
msg = f"{message.author} has deleted the message: {message.content}"
|
||||
await message.channel.send(msg)
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -1,20 +1,22 @@
|
||||
import discord
|
||||
import asyncio
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def on_message(self, message):
|
||||
if message.content.startswith('!editme'):
|
||||
msg = await message.channel.send('10')
|
||||
if message.content.startswith("!editme"):
|
||||
msg = await message.channel.send("10")
|
||||
await asyncio.sleep(3.0)
|
||||
await msg.edit(content='40')
|
||||
await msg.edit(content="40")
|
||||
|
||||
async def on_message_edit(self, before, after):
|
||||
msg = f'**{before.author}** edited their message:\n{before.content} -> {after.content}'
|
||||
msg = f"**{before.author}** edited their message:\n{before.content} -> {after.content}"
|
||||
await before.channel.send(msg)
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -2,18 +2,19 @@ import discord
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def on_message(self, message):
|
||||
# we do not want the bot to reply to itself
|
||||
if message.author.id == self.user.id:
|
||||
return
|
||||
|
||||
if message.content.startswith('$guess'):
|
||||
await message.channel.send('Guess a number between 1 and 10.')
|
||||
if message.content.startswith("$guess"):
|
||||
await message.channel.send("Guess a number between 1 and 10.")
|
||||
|
||||
def is_correct(m):
|
||||
return m.author == message.author and m.content.isdigit()
|
||||
@ -21,14 +22,17 @@ class MyClient(discord.Client):
|
||||
answer = random.randint(1, 10)
|
||||
|
||||
try:
|
||||
guess = await self.wait_for('message', check=is_correct, timeout=5.0)
|
||||
guess = await self.wait_for("message", check=is_correct, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
return await message.channel.send(f'Sorry, you took too long it was {answer}.')
|
||||
return await message.channel.send(
|
||||
f"Sorry, you took too long it was {answer}."
|
||||
)
|
||||
|
||||
if int(guess.content) == answer:
|
||||
await message.channel.send('You are right!')
|
||||
await message.channel.send("You are right!")
|
||||
else:
|
||||
await message.channel.send(f'Oops. It is actually {answer}.')
|
||||
await message.channel.send(f"Oops. It is actually {answer}.")
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -2,17 +2,18 @@
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def on_member_join(self, member):
|
||||
guild = member.guild
|
||||
if guild.system_channel is not None:
|
||||
to_send = f'Welcome {member.mention} to {guild.name}!'
|
||||
to_send = f"Welcome {member.mention} to {guild.name}!"
|
||||
await guild.system_channel.send(to_send)
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True, members=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -2,15 +2,24 @@
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.role_message_id = 0 # ID of the message that can be reacted to to add/remove a role.
|
||||
self.role_message_id = (
|
||||
0 # ID of the message that can be reacted to to add/remove a role.
|
||||
)
|
||||
self.emoji_to_role = {
|
||||
discord.PartialEmoji(name='🔴'): 0, # ID of the role associated with unicode emoji '🔴'.
|
||||
discord.PartialEmoji(name='🟡'): 0, # ID of the role associated with unicode emoji '🟡'.
|
||||
discord.PartialEmoji(name='green', id=0): 0, # ID of the role associated with a partial emoji's ID.
|
||||
discord.PartialEmoji(
|
||||
name="🔴"
|
||||
): 0, # ID of the role associated with unicode emoji '🔴'.
|
||||
discord.PartialEmoji(
|
||||
name="🟡"
|
||||
): 0, # ID of the role associated with unicode emoji '🟡'.
|
||||
discord.PartialEmoji(
|
||||
name="green", id=0
|
||||
): 0, # ID of the role associated with a partial emoji's ID.
|
||||
}
|
||||
|
||||
async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent):
|
||||
@ -78,6 +87,7 @@ class MyClient(discord.Client):
|
||||
# If we want to do something in case of errors we'd do it here.
|
||||
pass
|
||||
|
||||
|
||||
intents = discord.Intents(guilds=True, members=True, guild_reactions=True)
|
||||
client = MyClient(intents=intents)
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -1,17 +1,19 @@
|
||||
import discord
|
||||
|
||||
|
||||
class MyClient(discord.Client):
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
async def on_message(self, message):
|
||||
# we do not want the bot to reply to itself
|
||||
if message.author.id == self.user.id:
|
||||
return
|
||||
|
||||
if message.content.startswith('!hello'):
|
||||
await message.reply('Hello!', mention_author=True)
|
||||
if message.content.startswith("!hello"):
|
||||
await message.reply("Hello!", mention_author=True)
|
||||
|
||||
|
||||
client = MyClient(intents=discord.Intents(guilds=True, messages=True))
|
||||
client.run('token')
|
||||
client.run("token")
|
||||
|
@ -6,18 +6,19 @@ from discord.ext import commands
|
||||
bot = commands.Bot(
|
||||
command_prefix=commands.when_mentioned,
|
||||
description="Nothing to see here!",
|
||||
intents=discord.Intents(guilds=True, messages=True)
|
||||
intents=discord.Intents(guilds=True, messages=True),
|
||||
)
|
||||
|
||||
# the `hidden` keyword argument hides it from the help command.
|
||||
# the `hidden` keyword argument hides it from the help command.
|
||||
@bot.group(hidden=True)
|
||||
async def secret(ctx: commands.Context):
|
||||
"""What is this "secret" you speak of?"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await ctx.send('Shh!', delete_after=5)
|
||||
await ctx.send("Shh!", delete_after=5)
|
||||
|
||||
|
||||
def create_overwrites(ctx, *objects):
|
||||
"""This is just a helper function that creates the overwrites for the
|
||||
"""This is just a helper function that creates the overwrites for the
|
||||
voice/text channels.
|
||||
|
||||
A `discord.PermissionOverwrite` allows you to determine the permissions
|
||||
@ -31,40 +32,51 @@ def create_overwrites(ctx, *objects):
|
||||
# a dict comprehension is being utilised here to set the same permission overwrites
|
||||
# for each `discord.Role` or `discord.Member`.
|
||||
overwrites = {
|
||||
obj: discord.PermissionOverwrite(view_channel=True)
|
||||
for obj in objects
|
||||
obj: discord.PermissionOverwrite(view_channel=True) for obj in objects
|
||||
}
|
||||
|
||||
# prevents the default role (@everyone) from viewing the channel
|
||||
# if it isn't already allowed to view the channel.
|
||||
overwrites.setdefault(ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False))
|
||||
overwrites.setdefault(
|
||||
ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False)
|
||||
)
|
||||
|
||||
# makes sure the client is always allowed to view the channel.
|
||||
overwrites[ctx.guild.me] = discord.PermissionOverwrite(view_channel=True)
|
||||
|
||||
return overwrites
|
||||
|
||||
|
||||
# since these commands rely on guild related features,
|
||||
# it is best to lock it to be guild-only.
|
||||
@secret.command()
|
||||
@commands.guild_only()
|
||||
async def text(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]):
|
||||
"""This makes a text channel with a specified name
|
||||
async def text(
|
||||
ctx: commands.Context,
|
||||
name: str,
|
||||
*objects: typing.Union[discord.Role, discord.Member]
|
||||
):
|
||||
"""This makes a text channel with a specified name
|
||||
that is only visible to roles or members that are specified.
|
||||
"""
|
||||
|
||||
|
||||
overwrites = create_overwrites(ctx, *objects)
|
||||
|
||||
await ctx.guild.create_text_channel(
|
||||
name,
|
||||
overwrites=overwrites,
|
||||
topic='Top secret text channel. Any leakage of this channel may result in serious trouble.',
|
||||
reason='Very secret business.',
|
||||
topic="Top secret text channel. Any leakage of this channel may result in serious trouble.",
|
||||
reason="Very secret business.",
|
||||
)
|
||||
|
||||
|
||||
@secret.command()
|
||||
@commands.guild_only()
|
||||
async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]):
|
||||
async def voice(
|
||||
ctx: commands.Context,
|
||||
name: str,
|
||||
*objects: typing.Union[discord.Role, discord.Member]
|
||||
):
|
||||
"""This does the same thing as the `text` subcommand
|
||||
but instead creates a voice channel.
|
||||
"""
|
||||
@ -72,14 +84,15 @@ async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord
|
||||
overwrites = create_overwrites(ctx, *objects)
|
||||
|
||||
await ctx.guild.create_voice_channel(
|
||||
name,
|
||||
overwrites=overwrites,
|
||||
reason='Very secret business.'
|
||||
name, overwrites=overwrites, reason="Very secret business."
|
||||
)
|
||||
|
||||
|
||||
@secret.command()
|
||||
@commands.guild_only()
|
||||
async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role):
|
||||
async def emoji(
|
||||
ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role
|
||||
):
|
||||
"""This clones a specified emoji that only specified roles
|
||||
are allowed to use.
|
||||
"""
|
||||
@ -90,11 +103,8 @@ async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: disc
|
||||
# the key parameter here is `roles`, which controls
|
||||
# what roles are able to use the emoji.
|
||||
await ctx.guild.create_custom_emoji(
|
||||
name=emoji.name,
|
||||
image=emoji_bytes,
|
||||
roles=roles,
|
||||
reason='Very secret business.'
|
||||
name=emoji.name, image=emoji_bytes, roles=roles, reason="Very secret business."
|
||||
)
|
||||
|
||||
|
||||
bot.run('token')
|
||||
bot.run("token")
|
||||
|
@ -6,13 +6,13 @@ import discord
|
||||
class Bot(commands.Bot):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
command_prefix=commands.when_mentioned_or('$'),
|
||||
intents=discord.Intents(guilds=True, messages=True)
|
||||
command_prefix=commands.when_mentioned_or("$"),
|
||||
intents=discord.Intents(guilds=True, messages=True),
|
||||
)
|
||||
|
||||
async def on_ready(self):
|
||||
print(f'Logged in as {self.user} (ID: {self.user.id})')
|
||||
print('------')
|
||||
print(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
print("------")
|
||||
|
||||
|
||||
# Define a simple View that gives us a confirmation menu
|
||||
@ -24,16 +24,18 @@ class Confirm(discord.ui.View):
|
||||
# When the confirm button is pressed, set the inner value to `True` and
|
||||
# stop the View from listening to more input.
|
||||
# We also send the user an ephemeral message that we're confirming their choice.
|
||||
@discord.ui.button(label='Confirm', style=discord.ButtonStyle.green)
|
||||
async def confirm(self, button: discord.ui.Button, interaction: discord.Interaction):
|
||||
await interaction.response.send_message('Confirming', ephemeral=True)
|
||||
@discord.ui.button(label="Confirm", style=discord.ButtonStyle.green)
|
||||
async def confirm(
|
||||
self, button: discord.ui.Button, interaction: discord.Interaction
|
||||
):
|
||||
await interaction.response.send_message("Confirming", ephemeral=True)
|
||||
self.value = True
|
||||
self.stop()
|
||||
|
||||
# This one is similar to the confirmation button except sets the inner value to `False`
|
||||
@discord.ui.button(label='Cancel', style=discord.ButtonStyle.grey)
|
||||
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.grey)
|
||||
async def cancel(self, button: discord.ui.Button, interaction: discord.Interaction):
|
||||
await interaction.response.send_message('Cancelling', ephemeral=True)
|
||||
await interaction.response.send_message("Cancelling", ephemeral=True)
|
||||
self.value = False
|
||||
self.stop()
|
||||
|
||||
@ -46,15 +48,15 @@ async def ask(ctx: commands.Context):
|
||||
"""Asks the user a question to confirm something."""
|
||||
# We create the view and assign it to a variable so we can wait for it later.
|
||||
view = Confirm()
|
||||
await ctx.send('Do you want to continue?', view=view)
|
||||
await ctx.send("Do you want to continue?", view=view)
|
||||
# Wait for the View to stop listening for input...
|
||||
await view.wait()
|
||||
if view.value is None:
|
||||
print('Timed out...')
|
||||
print("Timed out...")
|
||||
elif view.value:
|
||||
print('Confirmed...')
|
||||
print("Confirmed...")
|
||||
else:
|
||||
print('Cancelled...')
|
||||
print("Cancelled...")
|
||||
|
||||
|
||||
bot.run('token')
|
||||
bot.run("token")
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user