46 Commits

Author SHA1 Message Date
Sami Altamimi
e48493b36c Actually block out Python 3.9.7
reference PR https://github.com/iDevision/enhanced-discord.py/pull/93
2021-10-05 21:12:40 -05:00
Sami Altamimi
4931100b44 Advise version incompatibility with Python 3.9.7
Due to a bug in Python 3.9.7, code situations where we call an __init__ function in a subclass will fail with a TypeError. 

This is only a bug within the Python language and was patched out with Python 3.10 and this doesn't affect Python 3.9.6. 

Here, we are advising the incompatibility.
2021-10-05 21:03:05 -05:00
Ian Webster
ec1e2add21 Update user-agent (#92) 2021-10-04 21:11:10 +01:00
Gnome
4277f65051 Implement _FakeSlashMessage.clean_content
Closes #83
2021-10-03 21:05:00 +01:00
Gnome!
3260ec6643 Add improved docs for slash commands (#77)
* Fix command checks actually working

* Current progress on slash command docs

* Improve docs for slash commands further
2021-09-27 01:14:07 -07:00
Chiggy-Playz
d16d2d856f Sort subcommand names (#68) 2021-09-25 22:43:23 -07:00
Gnome!
456d71d228 Add better support for MENTIONABLE (#74) 2021-09-25 22:41:43 -07:00
Gnome!
093a38527d Fix slash command prefix to / (#75) 2021-09-25 22:40:35 -07:00
NORXND
163d8e6586 Merge pull request #76
* Fix docs in BadInviteArgument class
2021-09-25 22:39:09 -07:00
Tom
0637a628ca update workflows (#73)
* modify workflows to fit into one file, fix pyright workflow

* remove redundant pip install

* add check flag to black

* use psf/black for black checker
2021-09-21 14:51:46 -07:00
Gnome!
02c6b07834 Merge pull request #72
* Fix command checks actually working
2021-09-21 14:34:54 -07:00
Gnome!
b810848273 Merge pull request #70
* Fix embed image/thumbnail property
2021-09-21 12:10:16 -07:00
Astrea
cd4bb296f3 Merge pull request #58
* FIxed `userinfo` command not returning an avatar...

* Quick merge conflict fix

* Merge branch '2.0' into converter-example-fix

* Fix code style issues with Black
2021-09-21 11:52:55 -07:00
Arnav Jindal
6a63ce2ed7 Add typechecking for PRS/Commits (#59)
* Create ci.yml

* Create .python-black

* Remove linting
2021-09-21 11:52:03 -07:00
Gnome!
fba7ca420c Merge pull request #63
* Add ephemeral attachment field

* I did not miss a comma
2021-09-21 11:51:23 -07:00
Gnome!
e65415d3c8 Merge pull request #60
* Rework how checks add attributes to Commmand

* Merge remote-tracking branch 'upstream/2.0' into command-attrs-checks
2021-09-21 11:47:28 -07:00
Astrea
2ecf755372 Merge pull request #57
* FIx _accent_colour being improperly typehinted
2021-09-21 11:37:28 -07:00
Gnome!
00ae8bb18c Fix all invites to devision server invite (#69) 2021-09-20 21:25:48 +02:00
iDutchy
0638bda719 Fix docs invite
Invite link on docs was still set to dpy, this changes it to edpy
2021-09-19 02:42:43 +02:00
Gnome!
1957fa6011 Implement a least breaking approach to slash commands (#39)
* Most slash command support completed, needs some debugging (and reindent)

* Implement a ctx.send helper for slash commands

* Add group command support

* Add Option converter, fix default optional, fix help command

* Add client.setup and move readying commands to that

* Implement _FakeSlashMessage.from_interaction

* Rename normmal_command to message_command

* Add docs for added params

* Add slash_command_guilds to bot and decos

* Fix merge conflict

* Remove name from commands.Option, wasn't used

* Move slash command processing to BotBase.process_slash_commands

* Create slash_only.py

Basic example for slash commands

* Create slash_and_message.py

Basic example for mixed commands

* Fix slash_command and normal_command bools

* Add some basic error handling for registration

* Fixed converter upload errors

* Fix some logic and make an actual example

* Thanks Safety Jim

* docstrings, *args, and error changes

* Add proper literal support

* Add basic documentation on slash commands

* Fix non-slash command interactions

* Fix ctx.reply in slash command context

* Fix typing on Context.reply

* Fix multiple optional argument sorting

* Update ctx.message docs to mention error instead of warning

* Move slash command creation to BotBase

* Fix code style issues with Black

* Rearrange some stuff and add flag support

* Change some errors and fix interaction.channel fixing

* Fix slash command quoting for *args

Co-authored-by: iDutchy <42503862+iDutchy@users.noreply.github.com>
Co-authored-by: Lint Action <lint-action@samuelmeuli.com>
2021-09-19 01:28:11 +02:00
Astrea
75a23351c4 Revert #42 (#61) 2021-09-09 00:02:02 +02:00
Lint Action
7513c2138f Fix code style issues with Black 2021-09-05 21:34:20 +00:00
IAmTomahawkx
a23dae8604 Merge branch '2.0' of https://github.com/IDevision/enhanced-discord.py into 2.0 2021-09-05 14:33:00 -07:00
IAmTomahawkx
1833e984ce add black workflow, change our code formats. closes #43 2021-09-05 14:32:51 -07:00
Gnome!
53a6b2cb45 Revert "Merge pull request #12" (#56)
This reverts commit 42c0a8d8a5.
2021-09-05 10:37:51 -07:00
iDutchy
65640ddfc7 Merge pull request #55 from TheMoksej/patch-2
remove unnecessary await
2021-09-05 15:24:45 +02:00
Moksej
14b3188bb8 remove unnecessary await 2021-09-05 13:58:10 +02:00
Arthur
3ffe134895 Merge pull request #44
* Typehint gateway.py

* Add relevant typehints to gateway.py to voice_client.py

* Change EventListener to subclass NamedTuple

* Add return type for DiscordWebSocket.wait_for

* Correct deque typehint

* Remove unnecessary typehints for literals

* Use type aliases

* Merge branch '2.0' into pr7422
2021-09-02 13:50:19 -07:00
Arthur
1032728311 Merge pull request #32
* Add get/fetch_member to ThreadMember objects
2021-09-02 13:43:19 -07:00
Arthur
33470ff196 Merge pull request #31
* Add bots and humans to TextChannel
2021-09-02 13:41:26 -07:00
Arthur
47e42d1648 Merge pull request #42
* implement WelcomeScreen

* copy over the kwargs issue.

* readable variable names

* modernise code

* modernise pt2

* Update discord/welcome_screen.py

* make pylance not cry from my onions

* type http.py

* remove extraneous import
2021-09-02 13:40:11 -07:00
Astrea
4055bafaa5 Merge pull request #47
* Added `on_raw_typing` event
2021-09-02 13:34:39 -07:00
IAmTomahawkx
152b61aabb fix recursionerror caused by a Pull Request 2021-09-02 12:49:38 -07:00
Ahmad Ansori Palembani
f37be7961a Merge pull request #41
* Fixed `TypeError`

* Handles `EmptyEmbed` inside setter instead of set_

* Remove return and setter docstring
2021-09-02 12:46:56 -07:00
NightSlasher35
0f6db99c59 Merge pull request #22
* add nitro booster color

* Update discord/colour.py
2021-09-02 12:34:41 -07:00
chillymosh
42c0a8d8a5 Merge pull request #12
* Clean up python

* Clean up bot python

* revert lists

* revert commands.bot completely

* extract raise_expected_coro further

* add new lines

* removed erroneous import

* remove hashed line
2021-09-02 12:32:46 -07:00
Arthur
092fbca08f Merge pull request #21
* [BREAKING] Make case_insensitive default to True on groups and commands
2021-09-02 12:28:03 -07:00
Arthur
13834d1147 Merge pull request #7
* Add try_user to get a user from cache or from the gateway.

* Extract populate_owners into a new coroutine.

* Add a try_owners coroutine to get a list of owners of the bot.

* Fix coding-style.

* Fix a bug where None would be returned in try_owners if the cache was…

* Fix docstring

* Add spacing in the code
2021-09-02 12:24:52 -07:00
Arthur
5d10384576 Merge pull request #27
* Add author_permissions to the Context object as a shortcut to return …
2021-09-02 12:18:26 -07:00
Daud
fc0188d7bc Merge pull request #49
* Change README title to enhanced-discord.py
2021-09-02 12:17:19 -07:00
iDutchy
a62c0ff0d1 Merge pull request #16 from paris-ci/alias_administrator_to_admin
Alias admin to administrators in permissions. This needs to be tested…
2021-09-02 03:06:30 +02:00
iDutchy
b75be64044 Update permissions.py
A better implementation :)
2021-09-02 03:05:49 +02:00
NightSlasher35
630a842556 Update CONTRIBUTING.md correctly (#29)
* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

Co-authored-by: Tom <47765953+IAmTomahawkx@users.noreply.github.com>
2021-09-01 17:49:41 -07:00
Arthur
c485e08ea0 Add try_member to guild. (#14)
* Add try_member to guild.

This also fix an omission in the fetch_member docs. fetch_member raises NotFound if the given user isn't in the guild.

* Optimize imports.
2021-09-01 17:47:15 -07:00
classerase
dba9a8abb9 Update README.rst (#48) 2021-09-01 15:30:15 -07:00
Arthur Jovart
8cdc1f4ad9 Alias admin to administrators in permissions. This needs to be tested, but should be working. 2021-08-29 00:26:26 +02:00
122 changed files with 6713 additions and 5052 deletions

View File

@@ -1,5 +1,7 @@
## Contributing to discord.py
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
@@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a
Generally speaking questions are better suited in our resources below.
- The official support server: https://discord.gg/r3sSKJJ
- The official support server: https://discord.gg/TvqYBrGXEm
- The Discord API server under #python_discord-py: https://discord.gg/discord-api
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html)
- [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.
@@ -32,13 +34,13 @@ If the bug report is missing this information then it'll take us longer to fix t
## Submitting a Pull Request
Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows PEP-8 guidelines (mostly) with a column limit of 125.
Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep, and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows the black code format, with a line length limit of `120`
### Git Commit Guidelines
- Use present tense (e.g. "Add feature" not "Added feature")
- Limit all lines to 72 characters or less.
- Reference issues or pull requests outside of the first line.
- Limit all lines to 120 characters or fewer.
- Reference issues or pull requests outside the first line.
- Please use the shorthand `#123` and not the full URL.
- Commits regarding the commands extension must be prefixed with `[commands]`

View File

@@ -6,7 +6,7 @@ body:
attributes:
value: >
Thanks for taking the time to fill out a bug.
If you want real-time support, consider joining our Discord at https://discord.gg/r3sSKJJ instead.
If you want real-time support, consider joining our Discord at https://discord.gg/TvqYBrGXEm instead.
Please note that this form is for bugs only!
- type: input

View File

@@ -5,4 +5,4 @@ contact_links:
url: https://github.com/Rapptz/discord.py/discussions
- name: Discord Server
about: Use our official Discord server to ask for help and questions as well.
url: https://discord.gg/r3sSKJJ
url: https://discord.gg/TvqYBrGXEm

View File

@@ -1,3 +1,5 @@
<!-- Pull requests that do not fill this information in will likely be closed -->
## Summary
<!-- What is this pull request for? Does it fix any issues? -->

38
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: CI
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Setup node.js (for pyright)
uses: actions/setup-node@v1
with:
node-version: "14"
- name: Run type checking
run: |
npm install -g pyright
pip install .
pyright --lib --verifytypes discord --ignoreexternal
black:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Run linter
uses: psf/black@stable
with:
options: "--line-length 120 --check"
src: "./discord"

1
.python-black Normal file
View File

@@ -0,0 +1 @@

View File

@@ -1,8 +1,8 @@
discord.py
==========
enhanced-discord.py
===================
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png
:target: https://discord.gg/PYAfZzpsjG
:target: https://discord.gg/TvqYBrGXEm
:alt: Discord server invite
.. image:: https://img.shields.io/pypi/v/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy
@@ -41,7 +41,9 @@ Key Features
Installing
----------
**Python 3.8 or higher is required**
**Python 3.8 or higher is required***
***Do not use 3.9.7 as Python 3.9.7 has a bug that causes a TypeError with __init__ methods in subclasses.** View more `here <https://bugs.python.org/issue45121/>`_.
To install the library without full voice support, you can just run the following command:
@@ -59,7 +61,7 @@ To install the development version, do the following:
.. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py
$ cd discord.py
$ cd enhanced-discord.py
$ python3 -m pip install -U .[voice]
@@ -117,5 +119,5 @@ Links
------
- `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_
- `Official Discord Server <https://discord.gg/PYAfZzpsjG>`_
- `Official Discord Server <https://discord.gg/TvqYBrGXEm>`_
- `Discord API <https://discord.gg/discord-api>`_

View File

@@ -9,13 +9,13 @@ A basic wrapper for the Discord API.
"""
__title__ = 'discord'
__author__ = 'Rapptz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2015-present Rapptz'
__version__ = '2.0.0a'
__title__ = "discord"
__author__ = "Rapptz"
__license__ = "MIT"
__copyright__ = "Copyright 2015-present Rapptz"
__version__ = "2.0.0a"
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
import logging
from typing import NamedTuple, Literal
@@ -69,6 +69,6 @@ class VersionInfo(NamedTuple):
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="alpha", serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@@ -31,26 +31,29 @@ import pkg_resources
import aiohttp
import platform
def show_version():
entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
entries.append("- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info))
version_info = discord.version_info
entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info))
if version_info.releaselevel != 'final':
pkg = pkg_resources.get_distribution('discord.py')
entries.append("- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info))
if version_info.releaselevel != "final":
pkg = pkg_resources.get_distribution("discord.py")
if pkg:
entries.append(f' - discord.py pkg_resources: v{pkg.version}')
entries.append(f" - discord.py pkg_resources: v{pkg.version}")
entries.append(f'- aiohttp v{aiohttp.__version__}')
entries.append(f"- aiohttp v{aiohttp.__version__}")
uname = platform.uname()
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries))
entries.append("- system info: {0.system} {0.release} {0.version}".format(uname))
print("\n".join(entries))
def core(parser, args):
if args.version:
show_version()
_bot_template = """#!/usr/bin/env python3
from discord.ext import commands
@@ -120,7 +123,7 @@ def setup(bot):
bot.add_cog({name}(bot))
'''
_cog_extras = '''
_cog_extras = """
def cog_unload(self):
# clean up logic goes here
pass
@@ -149,22 +152,22 @@ _cog_extras = '''
# called after a command is called here
pass
'''
"""
# certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent
_base_table = {
'<': '-',
'>': '-',
':': '-',
'"': '-',
"<": "-",
">": "-",
":": "-",
'"': "-",
# '/': '-', these are fine
# '\\': '-',
'|': '-',
'?': '-',
'*': '-',
"|": "-",
"?": "-",
"*": "-",
}
# NUL (0) and 1-31 are disallowed
@@ -172,21 +175,45 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path):
return name
if sys.platform == 'win32':
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9')
if sys.platform == "win32":
forbidden = (
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
)
if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one')
parser.error("invalid directory name given, use a different one")
name = name.translate(_translation_table)
if replace_spaces:
name = name.replace(' ', '-')
name = name.replace(" ", "-")
return Path(name)
def newbot(parser, args):
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
@@ -195,106 +222,114 @@ def newbot(parser, args):
try:
new_directory.mkdir(exist_ok=True, parents=True)
except OSError as exc:
parser.error(f'could not create our bot directory ({exc})')
parser.error(f"could not create our bot directory ({exc})")
cogs = new_directory / 'cogs'
cogs = new_directory / "cogs"
try:
cogs.mkdir(exist_ok=True)
init = cogs / '__init__.py'
init = cogs / "__init__.py"
init.touch()
except OSError as exc:
print(f'warning: could not create cogs directory ({exc})')
print(f"warning: could not create cogs directory ({exc})")
try:
with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp:
with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp:
fp.write('token = "place your token here"\ncogs = []\n')
except OSError as exc:
parser.error(f'could not create config file ({exc})')
parser.error(f"could not create config file ({exc})")
try:
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot'
with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp:
base = "Bot" if not args.sharded else "AutoShardedBot"
fp.write(_bot_template.format(base=base, prefix=args.prefix))
except OSError as exc:
parser.error(f'could not create bot file ({exc})')
parser.error(f"could not create bot file ({exc})")
if not args.no_git:
try:
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp:
fp.write(_gitignore_template)
except OSError as exc:
print(f'warning: could not create .gitignore file ({exc})')
print(f"warning: could not create .gitignore file ({exc})")
print("successfully made bot at", new_directory)
print('successfully made bot at', new_directory)
def newcog(parser, args):
cog_dir = to_path(parser, args.directory)
try:
cog_dir.mkdir(exist_ok=True)
except OSError as exc:
print(f'warning: could not create cogs directory ({exc})')
print(f"warning: could not create cogs directory ({exc})")
directory = cog_dir / to_path(parser, args.name)
directory = directory.with_suffix('.py')
directory = directory.with_suffix(".py")
try:
with open(str(directory), 'w', encoding='utf-8') as fp:
attrs = ''
extra = _cog_extras if args.full else ''
with open(str(directory), "w", encoding="utf-8") as fp:
attrs = ""
extra = _cog_extras if args.full else ""
if args.class_name:
name = args.class_name
else:
name = str(directory.stem)
if '-' in name or '_' in name:
translation = str.maketrans('-_', ' ')
name = name.translate(translation).title().replace(' ', '')
if "-" in name or "_" in name:
translation = str.maketrans("-_", " ")
name = name.translate(translation).title().replace(" ", "")
else:
name = name.title()
if args.display_name:
attrs += f', name="{args.display_name}"'
if args.hide_commands:
attrs += ', command_attrs=dict(hidden=True)'
attrs += ", command_attrs=dict(hidden=True)"
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc:
parser.error(f'could not create cog file ({exc})')
parser.error(f"could not create cog file ({exc})")
else:
print('successfully made cog at', directory)
print("successfully made cog at", directory)
def add_newbot_args(subparser):
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser = subparser.add_parser("newbot", help="creates a command bot project quickly")
parser.set_defaults(func=newbot)
parser.add_argument('name', help='the bot project name')
parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd())
parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='<prefix>')
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true')
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
parser.add_argument("name", help="the bot project name")
parser.add_argument("directory", help="the directory to place it in (default: .)", nargs="?", default=Path.cwd())
parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>")
parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true")
parser.add_argument("--no-git", help="do not create a .gitignore file", action="store_true", dest="no_git")
def add_newcog_args(subparser):
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser = subparser.add_parser("newcog", help="creates a new cog template quickly")
parser.set_defaults(func=newcog)
parser.add_argument('name', help='the cog name')
parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs'))
parser.add_argument('--class-name', help='the class name of the cog (default: <name>)', dest='class_name')
parser.add_argument('--display-name', help='the cog name (default: <name>)')
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true')
parser.add_argument('--full', help='add all special methods as well', action='store_true')
parser.add_argument("name", help="the cog name")
parser.add_argument(
"directory", help="the directory to place it in (default: cogs)", nargs="?", default=Path("cogs")
)
parser.add_argument("--class-name", help="the class name of the cog (default: <name>)", dest="class_name")
parser.add_argument("--display-name", help="the cog name (default: <name>)")
parser.add_argument("--hide-commands", help="whether to hide all commands in the cog", action="store_true")
parser.add_argument("--full", help="add all special methods as well", action="store_true")
def parse_args():
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with discord.py")
parser.add_argument("-v", "--version", action="store_true", help="shows the library version")
parser.set_defaults(func=core)
subparser = parser.add_subparsers(dest='subcommand', title='subcommands')
subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
add_newbot_args(subparser)
add_newcog_args(subparser)
return parser, parser.parse_args()
def main():
parser, args = parse_args()
args.func(parser, args)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem
from . import utils
__all__ = (
'Snowflake',
'User',
'PrivateChannel',
'GuildChannel',
'Messageable',
'Connectable',
"Snowflake",
"User",
"PrivateChannel",
"GuildChannel",
"Messageable",
"Connectable",
)
T = TypeVar('T', bound=VoiceProtocol)
T = TypeVar("T", bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
@@ -98,7 +98,7 @@ MISSING = utils.MISSING
class _Undefined:
def __repr__(self) -> str:
return 'see-below'
return "see-below"
_undefined: Any = _Undefined()
@@ -189,23 +189,23 @@ class PrivateChannel(Snowflake, Protocol):
class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type')
__slots__ = ("id", "allow", "deny", "type")
ROLE = 0
MEMBER = 1
def __init__(self, data: PermissionOverwritePayload):
self.id: int = int(data['id'])
self.allow: int = int(data.get('allow', 0))
self.deny: int = int(data.get('deny', 0))
self.type: OverwriteType = data['type']
self.id: int = int(data["id"])
self.allow: int = int(data.get("allow", 0))
self.deny: int = int(data.get("deny", 0))
self.type: OverwriteType = data["type"]
def _asdict(self) -> PermissionOverwritePayload:
return {
'id': self.id,
'allow': str(self.allow),
'deny': str(self.deny),
'type': self.type,
"id": self.id,
"allow": str(self.allow),
"deny": str(self.deny),
"type": self.type,
}
def is_role(self) -> bool:
@@ -215,7 +215,7 @@ class _Overwrites:
return self.type == 1
GCH = TypeVar('GCH', bound='GuildChannel')
GCH = TypeVar("GCH", bound="GuildChannel")
class GuildChannel:
@@ -276,7 +276,7 @@ class GuildChannel:
reason: Optional[str],
) -> None:
if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.')
raise InvalidArgument("Channel position cannot be less than 0.")
http = self._state.http
bucket = self._sorting_bucket
@@ -297,7 +297,7 @@ class GuildChannel:
payload = []
for index, c in enumerate(channels):
d: Dict[str, Any] = {'id': c.id, 'position': index}
d: Dict[str, Any] = {"id": c.id, "position": index}
if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
@@ -306,81 +306,81 @@ class GuildChannel:
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
try:
parent = options.pop('category')
parent = options.pop("category")
except KeyError:
parent_id = _undefined
else:
parent_id = parent and parent.id
try:
options['rate_limit_per_user'] = options.pop('slowmode_delay')
options["rate_limit_per_user"] = options.pop("slowmode_delay")
except KeyError:
pass
try:
rtc_region = options.pop('rtc_region')
rtc_region = options.pop("rtc_region")
except KeyError:
pass
else:
options['rtc_region'] = None if rtc_region is None else str(rtc_region)
options["rtc_region"] = None if rtc_region is None else str(rtc_region)
try:
video_quality_mode = options.pop('video_quality_mode')
video_quality_mode = options.pop("video_quality_mode")
except KeyError:
pass
else:
options['video_quality_mode'] = int(video_quality_mode)
options["video_quality_mode"] = int(video_quality_mode)
lock_permissions = options.pop('sync_permissions', False)
lock_permissions = options.pop("sync_permissions", False)
try:
position = options.pop('position')
position = options.pop("position")
except KeyError:
if parent_id is not _undefined:
if lock_permissions:
category = self.guild.get_channel(parent_id)
if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
options["parent_id"] = parent_id
elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id)
if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
overwrites = options.get('overwrites', None)
overwrites = options.get("overwrites", None)
if overwrites is not None:
perms = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}')
raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")
allow, deny = perm.pair()
payload = {
'allow': allow.value,
'deny': deny.value,
'id': target.id,
"allow": allow.value,
"deny": deny.value,
"id": target.id,
}
if isinstance(target, Role):
payload['type'] = _Overwrites.ROLE
payload["type"] = _Overwrites.ROLE
else:
payload['type'] = _Overwrites.MEMBER
payload["type"] = _Overwrites.MEMBER
perms.append(payload)
options['permission_overwrites'] = perms
options["permission_overwrites"] = perms
try:
ch_type = options['type']
ch_type = options["type"]
except KeyError:
pass
else:
if not isinstance(ch_type, ChannelType):
raise InvalidArgument('type field must be of type ChannelType')
options['type'] = ch_type.value
raise InvalidArgument("type field must be of type ChannelType")
options["type"] = ch_type.value
if options:
return await self._state.http.edit_channel(self.id, reason=reason, **options)
@@ -390,7 +390,7 @@ class GuildChannel:
everyone_index = 0
everyone_id = self.guild.id
for index, overridden in enumerate(data.get('permission_overwrites', [])):
for index, overridden in enumerate(data.get("permission_overwrites", [])):
overwrite = _Overwrites(overridden)
self._overwrites.append(overwrite)
@@ -429,7 +429,7 @@ class GuildChannel:
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime:
@@ -779,18 +779,18 @@ class GuildChannel:
elif isinstance(target, Role):
perm_type = _Overwrites.ROLE
else:
raise InvalidArgument('target parameter must be either Member or Role')
raise InvalidArgument("target parameter must be either Member or Role")
if overwrite is _undefined:
if len(permissions) == 0:
raise InvalidArgument('No overwrite provided.')
raise InvalidArgument("No overwrite provided.")
try:
overwrite = PermissionOverwrite(**permissions)
except (ValueError, TypeError):
raise InvalidArgument('Invalid permissions given to keyword arguments.')
raise InvalidArgument("Invalid permissions given to keyword arguments.")
else:
if len(permissions) > 0:
raise InvalidArgument('Cannot mix overwrite and keyword arguments.')
raise InvalidArgument("Cannot mix overwrite and keyword arguments.")
# TODO: wait for event
@@ -800,7 +800,7 @@ class GuildChannel:
(allow, deny) = overwrite.pair()
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
else:
raise InvalidArgument('Invalid overwrite type provided.')
raise InvalidArgument("Invalid overwrite type provided.")
async def _clone_impl(
self: GCH,
@@ -809,9 +809,9 @@ class GuildChannel:
name: Optional[str] = None,
reason: Optional[str] = None,
) -> GCH:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name
base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites]
base_attrs["parent_id"] = self.category_id
base_attrs["name"] = name or self.name
guild_id = self.guild.id
cls = self.__class__
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
@@ -964,14 +964,14 @@ class GuildChannel:
if not kwargs:
return
beginning, end = kwargs.get('beginning'), kwargs.get('end')
before, after = kwargs.get('before'), kwargs.get('after')
offset = kwargs.get('offset', 0)
beginning, end = kwargs.get("beginning"), kwargs.get("end")
before, after = kwargs.get("before"), kwargs.get("after")
offset = kwargs.get("offset", 0)
if sum(bool(a) for a in (beginning, end, before, after)) > 1:
raise InvalidArgument('Only one of [before, after, end, beginning] can be used.')
raise InvalidArgument("Only one of [before, after, end, beginning] can be used.")
bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING)
parent_id = kwargs.get("category", MISSING)
# fmt: off
channels: List[GuildChannel]
if parent_id not in (MISSING, None):
@@ -1011,14 +1011,14 @@ class GuildChannel:
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)
if index is None:
raise InvalidArgument('Could not resolve appropriate move position')
raise InvalidArgument("Could not resolve appropriate move position")
channels.insert(max((index + offset), 0), self)
payload = []
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
lock_permissions = kwargs.get("sync_permissions", False)
reason = kwargs.get("reason")
for index, channel in enumerate(channels):
d = {'id': channel.id, 'position': index}
d = {"id": channel.id, "position": index}
if parent_id is not MISSING and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
@@ -1332,14 +1332,14 @@ class Messageable:
content = str(content) if content is not None else None
if embed is not None and embeds is not None:
raise InvalidArgument('cannot pass both embed and embeds parameter to send()')
raise InvalidArgument("cannot pass both embed and embeds parameter to send()")
if embed is not None:
embed = embed.to_dict()
elif embeds is not None:
if len(embeds) > 10:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
raise InvalidArgument("embeds parameter must be a list of up to 10 elements")
embeds = [embed.to_dict() for embed in embeds]
if stickers is not None:
@@ -1355,28 +1355,30 @@ class Messageable:
if mention_author is not None:
allowed_mentions = allowed_mentions or AllowedMentions().to_dict()
allowed_mentions['replied_user'] = bool(mention_author)
allowed_mentions["replied_user"] = bool(mention_author)
if reference is not None:
try:
reference = reference.to_message_reference_dict()
except AttributeError:
raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None
raise InvalidArgument(
"reference parameter must be Message, MessageReference, or PartialMessage"
) from None
if view:
if not hasattr(view, '__discord_ui_view__'):
raise InvalidArgument(f'view parameter must be View not {view.__class__!r}')
if not hasattr(view, "__discord_ui_view__"):
raise InvalidArgument(f"view parameter must be View not {view.__class__!r}")
components = view.to_components()
else:
components = None
if file is not None and files is not None:
raise InvalidArgument('cannot pass both file and files parameter to send()')
raise InvalidArgument("cannot pass both file and files parameter to send()")
if file is not None:
if not isinstance(file, File):
raise InvalidArgument('file parameter must be File')
raise InvalidArgument("file parameter must be File")
try:
data = await state.http.send_files(
@@ -1397,9 +1399,9 @@ class Messageable:
elif files is not None:
if len(files) > 10:
raise InvalidArgument('files parameter must be a list of up to 10 elements')
raise InvalidArgument("files parameter must be a list of up to 10 elements")
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument('files parameter must be a list of File')
raise InvalidArgument("files parameter must be a list of File")
try:
data = await state.http.send_files(
@@ -1666,13 +1668,13 @@ class Connectable(Protocol):
state = self._state
if state._get_voice_client(key_id):
raise ClientException('Already connected to a voice channel.')
raise ClientException("Already connected to a voice channel.")
client = state._get_client()
voice = cls(client, self)
if not isinstance(voice, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class.')
raise TypeError("Type must meet VoiceProtocol abstract base class.")
state._add_voice_client(key_id, voice)

View File

@@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji
from .utils import _get_as_snowflake
__all__ = (
'BaseActivity',
'Activity',
'Streaming',
'Game',
'Spotify',
'CustomActivity',
"BaseActivity",
"Activity",
"Streaming",
"Game",
"Spotify",
"CustomActivity",
)
"""If curious, this is the current schema for an activity.
@@ -119,10 +119,10 @@ class BaseActivity:
.. versionadded:: 1.3
"""
__slots__ = ('_created_at',)
__slots__ = ("_created_at",)
def __init__(self, **kwargs):
self._created_at: Optional[float] = kwargs.pop('created_at', None)
self._created_at: Optional[float] = kwargs.pop("created_at", None)
@property
def created_at(self) -> Optional[datetime.datetime]:
@@ -199,58 +199,58 @@ class Activity(BaseActivity):
"""
__slots__ = (
'state',
'details',
'_created_at',
'timestamps',
'assets',
'party',
'flags',
'sync_id',
'session_id',
'type',
'name',
'url',
'application_id',
'emoji',
'buttons',
"state",
"details",
"_created_at",
"timestamps",
"assets",
"party",
"flags",
"sync_id",
"session_id",
"type",
"name",
"url",
"application_id",
"emoji",
"buttons",
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {})
self.assets: ActivityAssets = kwargs.pop('assets', {})
self.party: ActivityParty = kwargs.pop('party', {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id')
self.name: Optional[str] = kwargs.pop('name', None)
self.url: Optional[str] = kwargs.pop('url', None)
self.flags: int = kwargs.pop('flags', 0)
self.sync_id: Optional[str] = kwargs.pop('sync_id', None)
self.session_id: Optional[str] = kwargs.pop('session_id', None)
self.buttons: List[ActivityButton] = kwargs.pop('buttons', [])
self.state: Optional[str] = kwargs.pop("state", None)
self.details: Optional[str] = kwargs.pop("details", None)
self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {})
self.assets: ActivityAssets = kwargs.pop("assets", {})
self.party: ActivityParty = kwargs.pop("party", {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id")
self.name: Optional[str] = kwargs.pop("name", None)
self.url: Optional[str] = kwargs.pop("url", None)
self.flags: int = kwargs.pop("flags", 0)
self.sync_id: Optional[str] = kwargs.pop("sync_id", None)
self.session_id: Optional[str] = kwargs.pop("session_id", None)
self.buttons: List[ActivityButton] = kwargs.pop("buttons", [])
activity_type = kwargs.pop('type', -1)
activity_type = kwargs.pop("type", -1)
self.type: ActivityType = (
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
)
emoji = kwargs.pop('emoji', None)
emoji = kwargs.pop("emoji", None)
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
def __repr__(self) -> str:
attrs = (
('type', self.type),
('name', self.name),
('url', self.url),
('details', self.details),
('application_id', self.application_id),
('session_id', self.session_id),
('emoji', self.emoji),
("type", self.type),
("name", self.name),
("url", self.url),
("details", self.details),
("application_id", self.application_id),
("session_id", self.session_id),
("emoji", self.emoji),
)
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Activity {inner}>'
inner = " ".join("%s=%r" % t for t in attrs)
return f"<Activity {inner}>"
def to_dict(self) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
@@ -263,16 +263,16 @@ class Activity(BaseActivity):
continue
ret[attr] = value
ret['type'] = int(self.type)
ret["type"] = int(self.type)
if self.emoji:
ret['emoji'] = self.emoji.to_dict()
ret["emoji"] = self.emoji.to_dict()
return ret
@property
def start(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['start'] / 1000
timestamp = self.timestamps["start"] / 1000
except KeyError:
return None
else:
@@ -282,7 +282,7 @@ class Activity(BaseActivity):
def end(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps['end'] / 1000
timestamp = self.timestamps["end"] / 1000
except KeyError:
return None
else:
@@ -295,11 +295,11 @@ class Activity(BaseActivity):
return None
try:
large_image = self.assets['large_image']
large_image = self.assets["large_image"]
except KeyError:
return None
else:
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png'
return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png"
@property
def small_image_url(self) -> Optional[str]:
@@ -308,21 +308,21 @@ class Activity(BaseActivity):
return None
try:
small_image = self.assets['small_image']
small_image = self.assets["small_image"]
except KeyError:
return None
else:
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png'
return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png"
@property
def large_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
return self.assets.get('large_text', None)
return self.assets.get("large_text", None)
@property
def small_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
return self.assets.get('small_text', None)
return self.assets.get("small_text", None)
class Game(BaseActivity):
@@ -359,20 +359,20 @@ class Game(BaseActivity):
The game's name.
"""
__slots__ = ('name', '_end', '_start')
__slots__ = ("name", "_end", "_start")
def __init__(self, name: str, **extra):
super().__init__(**extra)
self.name: str = name
try:
timestamps: ActivityTimestamps = extra['timestamps']
timestamps: ActivityTimestamps = extra["timestamps"]
except KeyError:
self._start = 0
self._end = 0
else:
self._start = timestamps.get('start', 0)
self._end = timestamps.get('end', 0)
self._start = timestamps.get("start", 0)
self._end = timestamps.get("end", 0)
@property
def type(self) -> ActivityType:
@@ -400,15 +400,15 @@ class Game(BaseActivity):
return str(self.name)
def __repr__(self) -> str:
return f'<Game name={self.name!r}>'
return f"<Game name={self.name!r}>"
def to_dict(self) -> Dict[str, Any]:
timestamps: Dict[str, Any] = {}
if self._start:
timestamps['start'] = self._start
timestamps["start"] = self._start
if self._end:
timestamps['end'] = self._end
timestamps["end"] = self._end
# fmt: off
return {
@@ -473,16 +473,16 @@ class Streaming(BaseActivity):
A dictionary comprising of similar keys than those in :attr:`Activity.assets`.
"""
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
__slots__ = ("platform", "name", "game", "url", "details", "assets")
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
super().__init__(**extra)
self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name)
self.game: Optional[str] = extra.pop('state', None)
self.name: Optional[str] = extra.pop("details", name)
self.game: Optional[str] = extra.pop("state", None)
self.url: str = url
self.details: Optional[str] = extra.pop('details', self.name) # compatibility
self.assets: ActivityAssets = extra.pop('assets', {})
self.details: Optional[str] = extra.pop("details", self.name) # compatibility
self.assets: ActivityAssets = extra.pop("assets", {})
@property
def type(self) -> ActivityType:
@@ -496,7 +496,7 @@ class Streaming(BaseActivity):
return str(self.name)
def __repr__(self) -> str:
return f'<Streaming name={self.name!r}>'
return f"<Streaming name={self.name!r}>"
@property
def twitch_name(self):
@@ -507,11 +507,11 @@ class Streaming(BaseActivity):
"""
try:
name = self.assets['large_image']
name = self.assets["large_image"]
except KeyError:
return None
else:
return name[7:] if name[:7] == 'twitch:' else None
return name[7:] if name[:7] == "twitch:" else None
def to_dict(self) -> Dict[str, Any]:
# fmt: off
@@ -523,7 +523,7 @@ class Streaming(BaseActivity):
}
# fmt: on
if self.details:
ret['details'] = self.details
ret["details"] = self.details
return ret
def __eq__(self, other: Any) -> bool:
@@ -559,17 +559,17 @@ class Spotify:
Returns the string 'Spotify'.
"""
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
__slots__ = ("_state", "_details", "_timestamps", "_assets", "_party", "_sync_id", "_session_id", "_created_at")
def __init__(self, **data):
self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id')
self._session_id: str = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None)
self._state: str = data.pop("state", "")
self._details: str = data.pop("details", "")
self._timestamps: Dict[str, int] = data.pop("timestamps", {})
self._assets: ActivityAssets = data.pop("assets", {})
self._party: ActivityParty = data.pop("party", {})
self._sync_id: str = data.pop("sync_id")
self._session_id: str = data.pop("session_id")
self._created_at: Optional[float] = data.pop("created_at", None)
@property
def type(self) -> ActivityType:
@@ -604,21 +604,21 @@ class Spotify:
def to_dict(self) -> Dict[str, Any]:
return {
'flags': 48, # SYNC | PLAY
'name': 'Spotify',
'assets': self._assets,
'party': self._party,
'sync_id': self._sync_id,
'session_id': self._session_id,
'timestamps': self._timestamps,
'details': self._details,
'state': self._state,
"flags": 48, # SYNC | PLAY
"name": "Spotify",
"assets": self._assets,
"party": self._party,
"sync_id": self._sync_id,
"session_id": self._session_id,
"timestamps": self._timestamps,
"details": self._details,
"state": self._state,
}
@property
def name(self) -> str:
""":class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify'
return "Spotify"
def __eq__(self, other: Any) -> bool:
return (
@@ -635,10 +635,10 @@ class Spotify:
return hash(self._session_id)
def __str__(self) -> str:
return 'Spotify'
return "Spotify"
def __repr__(self) -> str:
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>'
return f"<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>"
@property
def title(self) -> str:
@@ -648,7 +648,7 @@ class Spotify:
@property
def artists(self) -> List[str]:
"""List[:class:`str`]: The artists of the song being played."""
return self._state.split('; ')
return self._state.split("; ")
@property
def artist(self) -> str:
@@ -662,16 +662,16 @@ class Spotify:
@property
def album(self) -> str:
""":class:`str`: The album that the song being played belongs to."""
return self._assets.get('large_text', '')
return self._assets.get("large_text", "")
@property
def album_cover_url(self) -> str:
""":class:`str`: The album cover image URL from Spotify's CDN."""
large_image = self._assets.get('large_image', '')
if large_image[:8] != 'spotify:':
return ''
large_image = self._assets.get("large_image", "")
if large_image[:8] != "spotify:":
return ""
album_image_id = large_image[8:]
return 'https://i.scdn.co/image/' + album_image_id
return "https://i.scdn.co/image/" + album_image_id
@property
def track_id(self) -> str:
@@ -684,17 +684,17 @@ class Spotify:
.. versionadded:: 2.0
"""
return f'https://open.spotify.com/track/{self.track_id}'
return f"https://open.spotify.com/track/{self.track_id}"
@property
def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc)
@property
def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc)
@property
def duration(self) -> datetime.timedelta:
@@ -704,7 +704,7 @@ class Spotify:
@property
def party_id(self) -> str:
""":class:`str`: The party ID of the listening party."""
return self._party.get('id', '')
return self._party.get("id", "")
class CustomActivity(BaseActivity):
@@ -738,13 +738,13 @@ class CustomActivity(BaseActivity):
The emoji to pass to the activity, if any.
"""
__slots__ = ('name', 'emoji', 'state')
__slots__ = ("name", "emoji", "state")
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
super().__init__(**extra)
self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None)
if self.name == 'Custom Status':
self.state: Optional[str] = extra.pop("state", None)
if self.name == "Custom Status":
self.name = self.state
self.emoji: Optional[PartialEmoji]
@@ -757,7 +757,7 @@ class CustomActivity(BaseActivity):
elif isinstance(emoji, PartialEmoji):
self.emoji = emoji
else:
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.')
raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.")
@property
def type(self) -> ActivityType:
@@ -770,18 +770,18 @@ class CustomActivity(BaseActivity):
def to_dict(self) -> Dict[str, Any]:
if self.name == self.state:
o = {
'type': ActivityType.custom.value,
'state': self.name,
'name': 'Custom Status',
"type": ActivityType.custom.value,
"state": self.name,
"name": "Custom Status",
}
else:
o = {
'type': ActivityType.custom.value,
'name': self.name,
"type": ActivityType.custom.value,
"name": self.name,
}
if self.emoji:
o['emoji'] = self.emoji.to_dict()
o["emoji"] = self.emoji.to_dict()
return o
def __eq__(self, other: Any) -> bool:
@@ -796,47 +796,50 @@ class CustomActivity(BaseActivity):
def __str__(self) -> str:
if self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return f"{self.emoji} {self.name}"
return str(self.emoji)
else:
return str(self.name)
def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
return f"<CustomActivity name={self.name!r} emoji={self.emoji!r}>"
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data:
return None
game_type = try_enum(ActivityType, data.get('type', -1))
game_type = try_enum(ActivityType, data.get("type", -1))
if game_type is ActivityType.playing:
if 'application_id' in data or 'session_id' in data:
if "application_id" in data or "session_id" in data:
return Activity(**data)
return Game(**data)
elif game_type is ActivityType.custom:
try:
name = data.pop('name')
name = data.pop("name")
except KeyError:
return Activity(**data)
else:
# we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming:
if 'url' in data:
if "url" in data:
# the url won't be None here
return Streaming(**data) # type: ignore
return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data:
return Spotify(**data)
return Activity(**data)

View File

@@ -40,8 +40,8 @@ if TYPE_CHECKING:
from .state import ConnectionState
__all__ = (
'AppInfo',
'PartialAppInfo',
"AppInfo",
"PartialAppInfo",
)
@@ -115,58 +115,58 @@ class AppInfo:
"""
__slots__ = (
'_state',
'description',
'id',
'name',
'rpc_origins',
'bot_public',
'bot_require_code_grant',
'owner',
'_icon',
'summary',
'verify_key',
'team',
'guild_id',
'primary_sku_id',
'slug',
'_cover_image',
'terms_of_service_url',
'privacy_policy_url',
"_state",
"description",
"id",
"name",
"rpc_origins",
"bot_public",
"bot_require_code_grant",
"owner",
"_icon",
"summary",
"verify_key",
"team",
"guild_id",
"primary_sku_id",
"slug",
"_cover_image",
"terms_of_service_url",
"privacy_policy_url",
)
def __init__(self, state: ConnectionState, data: AppInfoPayload):
from .team import Team
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self._icon: Optional[str] = data['icon']
self.rpc_origins: List[str] = data['rpc_origins']
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.create_user(data['owner'])
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self._icon: Optional[str] = data["icon"]
self.rpc_origins: List[str] = data["rpc_origins"]
self.bot_public: bool = data["bot_public"]
self.bot_require_code_grant: bool = data["bot_require_code_grant"]
self.owner: User = state.create_user(data["owner"])
team: Optional[TeamPayload] = data.get('team')
team: Optional[TeamPayload] = data.get("team")
self.team: Optional[Team] = Team(state, team) if team else None
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id')
self.slug: Optional[str] = data.get('slug')
self._cover_image: Optional[str] = data.get('cover_image')
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, "primary_sku_id")
self.slug: Optional[str] = data.get("slug")
self._cover_image: Optional[str] = data.get("cover_image")
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'description={self.description!r} public={self.bot_public} '
f'owner={self.owner!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"description={self.description!r} public={self.bot_public} "
f"owner={self.owner!r}>"
)
@property
@@ -174,7 +174,7 @@ class AppInfo:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
return Asset._from_icon(self._state, self.id, self._icon, path="app")
@property
def cover_image(self) -> Optional[Asset]:
@@ -195,6 +195,7 @@ class AppInfo:
"""
return self._state._get_guild(self.guild_id)
class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
@@ -222,26 +223,37 @@ class PartialAppInfo:
The application's privacy policy URL, if set.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon')
__slots__ = (
"_state",
"id",
"name",
"description",
"rpc_origins",
"summary",
"verify_key",
"terms_of_service_url",
"privacy_policy_url",
"_icon",
)
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data.get('icon')
self.description: str = data['description']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data.get("icon")
self.description: str = data["description"]
self.rpc_origins: Optional[List[str]] = data.get("rpc_origins")
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>"
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
return Asset._from_icon(self._state, self.id, self._icon, path="app")

View File

@@ -33,13 +33,11 @@ from . import utils
import yarl
__all__ = (
'Asset',
)
__all__ = ("Asset",)
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
@@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
MISSING = utils.MISSING
class AssetMixin:
url: str
_state: Optional[Any]
@@ -71,7 +70,7 @@ class AssetMixin:
The content of the asset.
"""
if self._state is None:
raise DiscordException('Invalid state (no ConnectionState provided)')
raise DiscordException("Invalid state (no ConnectionState provided)")
return await self._state.http.get_from_cdn(self.url)
@@ -112,7 +111,7 @@ class AssetMixin:
fp.seek(0)
return written
else:
with open(fp, 'wb') as f:
with open(fp, "wb") as f:
return f.write(data)
@@ -143,13 +142,13 @@ class Asset(AssetMixin):
"""
__slots__: Tuple[str, ...] = (
'_state',
'_url',
'_animated',
'_key',
"_state",
"_url",
"_animated",
"_key",
)
BASE = 'https://cdn.discordapp.com'
BASE = "https://cdn.discordapp.com"
def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
@@ -161,26 +160,26 @@ class Asset(AssetMixin):
def _from_default_avatar(cls, state, index: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/embed/avatars/{index}.png',
url=f"{cls.BASE}/embed/avatars/{index}.png",
key=str(index),
animated=False,
)
@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024',
url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024",
key=avatar,
animated=animated,
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
@@ -192,7 +191,7 @@ class Asset(AssetMixin):
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
key=icon_hash,
animated=False,
)
@@ -201,7 +200,7 @@ class Asset(AssetMixin):
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
key=cover_image_hash,
animated=False,
)
@@ -210,18 +209,18 @@ class Asset(AssetMixin):
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
return cls(
state,
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024',
url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
key=image,
animated=False,
)
@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png'
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024',
url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024",
key=icon_hash,
animated=animated,
)
@@ -230,20 +229,20 @@ class Asset(AssetMixin):
def _from_sticker_banner(cls, state, banner: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
key=str(banner),
animated=False,
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
animated = banner_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512",
key=banner_hash,
animated=animated
animated=animated,
)
def __str__(self) -> str:
@@ -253,8 +252,8 @@ class Asset(AssetMixin):
return len(self._url)
def __repr__(self):
shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>'
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"
def __eq__(self, other):
return isinstance(other, Asset) and self._url == other._url
@@ -312,21 +311,21 @@ class Asset(AssetMixin):
if format is not MISSING:
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
url = url.with_path(f'{path}.{format}')
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
url = url.with_path(f"{path}.{format}")
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f"{path}.{format}")
if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{static_format}')
raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f"{path}.{static_format}")
if size is not MISSING:
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = url.with_query(size=size)
else:
url = url.with_query(url.raw_query_string)
@@ -353,7 +352,7 @@ class Asset(AssetMixin):
The new updated asset.
"""
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
@@ -379,14 +378,14 @@ class Asset(AssetMixin):
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
else:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:

View File

@@ -35,9 +35,9 @@ from .object import Object
from .permissions import PermissionOverwrite, Permissions
__all__ = (
'AuditLogDiff',
'AuditLogChanges',
'AuditLogEntry',
"AuditLogDiff",
"AuditLogChanges",
"AuditLogEntry",
)
@@ -85,6 +85,7 @@ def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Uni
return None
return entry._get_member(int(data))
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
if data is None:
return None
@@ -96,16 +97,16 @@ def _transform_overwrites(
) -> List[Tuple[Object, PermissionOverwrite]]:
overwrites = []
for elem in data:
allow = Permissions(int(elem['allow']))
deny = Permissions(int(elem['deny']))
allow = Permissions(int(elem["allow"]))
deny = Permissions(int(elem["deny"]))
ow = PermissionOverwrite.from_pair(allow, deny)
ow_type = elem['type']
ow_id = int(elem['id'])
ow_type = elem["type"]
ow_id = int(elem["id"])
target = None
if ow_type == '0':
if ow_type == "0":
target = entry.guild.get_role(ow_id)
elif ow_type == '1':
elif ow_type == "1":
target = entry._get_member(ow_id)
if target is None:
@@ -137,7 +138,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
return _transform
T = TypeVar('T', bound=enums.Enum)
T = TypeVar("T", bound=enums.Enum)
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
@@ -146,12 +147,14 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'):
if entry.action.name.startswith("sticker_"):
return enums.try_enum(enums.StickerType, data)
else:
return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff:
def __len__(self) -> int:
return len(self.__dict__)
@@ -160,8 +163,8 @@ class AuditLogDiff:
yield from self.__dict__.items()
def __repr__(self) -> str:
values = ' '.join('%s=%r' % item for item in self.__dict__.items())
return f'<AuditLogDiff {values}>'
values = " ".join("%s=%r" % item for item in self.__dict__.items())
return f"<AuditLogDiff {values}>"
if TYPE_CHECKING:
@@ -217,14 +220,14 @@ class AuditLogChanges:
self.after = AuditLogDiff()
for elem in data:
attr = elem['key']
attr = elem["key"]
# special cases for role add/remove
if attr == '$add':
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore
if attr == "$add":
self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore
continue
elif attr == '$remove':
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore
elif attr == "$remove":
self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore
continue
try:
@@ -238,7 +241,7 @@ class AuditLogChanges:
transformer: Optional[Transformer]
try:
before = elem['old_value']
before = elem["old_value"]
except KeyError:
before = None
else:
@@ -248,7 +251,7 @@ class AuditLogChanges:
setattr(self.before, attr, before)
try:
after = elem['new_value']
after = elem["new_value"]
except KeyError:
after = None
else:
@@ -258,34 +261,36 @@ class AuditLogChanges:
setattr(self.after, attr, after)
# add an alias
if hasattr(self.after, 'colour'):
if hasattr(self.after, "colour"):
self.after.color = self.after.colour
self.before.color = self.before.colour
if hasattr(self.after, 'expire_behavior'):
if hasattr(self.after, "expire_behavior"):
self.after.expire_behaviour = self.after.expire_behavior
self.before.expire_behaviour = self.before.expire_behavior
def __repr__(self) -> str:
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>'
return f"<AuditLogChanges before={self.before!r} after={self.after!r}>"
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None:
if not hasattr(first, 'roles'):
setattr(first, 'roles', [])
def _handle_role(
self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]
) -> None:
if not hasattr(first, "roles"):
setattr(first, "roles", [])
data = []
g: Guild = entry.guild # type: ignore
for e in elem:
role_id = int(e['id'])
role_id = int(e["id"])
role = g.get_role(role_id)
if role is None:
role = Object(id=role_id)
role.name = e['name'] # type: ignore
role.name = e["name"] # type: ignore
data.append(role)
setattr(second, 'roles', data)
setattr(second, "roles", data)
class _AuditLogProxyMemberPrune:
@@ -365,56 +370,56 @@ class AuditLogEntry(Hashable):
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id'])
self.action = enums.try_enum(enums.AuditLogAction, data["action_type"])
self.id = int(data["id"])
# this key is technically not usually present
self.reason = data.get('reason')
self.extra = data.get('options')
self.reason = data.get("reason")
self.extra = data.get("options")
if isinstance(self.action, enums.AuditLogAction) and self.extra:
if self.action is enums.AuditLogAction.member_prune:
# member prune has two keys with useful information
self.extra: _AuditLogProxyMemberPrune = type(
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()}
"_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()}
)()
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
channel_id = int(self.extra['channel_id'])
channel_id = int(self.extra["channel_id"])
elems = {
'count': int(self.extra['count']),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
"count": int(self.extra["count"]),
"channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
}
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)()
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)()
elif self.action is enums.AuditLogAction.member_disconnect:
# The member disconnect action has a dict with some information
elems = {
'count': int(self.extra['count']),
"count": int(self.extra["count"]),
}
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)()
elif self.action.name.endswith('pin'):
self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)()
elif self.action.name.endswith("pin"):
# the pin actions have a dict with some information
channel_id = int(self.extra['channel_id'])
channel_id = int(self.extra["channel_id"])
elems = {
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
'message_id': int(self.extra['message_id']),
"channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
"message_id": int(self.extra["message_id"]),
}
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith('overwrite_'):
self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)()
elif self.action.name.startswith("overwrite_"):
# the overwrite_ actions have a dict with some information
instance_id = int(self.extra['id'])
the_type = self.extra.get('type')
if the_type == '1':
instance_id = int(self.extra["id"])
the_type = self.extra.get("type")
if the_type == "1":
self.extra = self._get_member(instance_id)
elif the_type == '0':
elif the_type == "0":
role = self.guild.get_role(instance_id)
if role is None:
role = Object(id=instance_id)
role.name = self.extra.get('role_name') # type: ignore
role.name = self.extra.get("role_name") # type: ignore
self.extra: Role = role
elif self.action.name.startswith('stage_instance'):
channel_id = int(self.extra['channel_id'])
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith("stage_instance"):
channel_id = int(self.extra["channel_id"])
elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)()
# fmt: off
self.extra: Union[
@@ -433,16 +438,16 @@ class AuditLogEntry(Hashable):
# where new_value and old_value are not guaranteed to be there depending
# on the action type, so let's just fetch it for now and only turn it
# into meaningful data when requested
self._changes = data.get('changes', [])
self._changes = data.get("changes", [])
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore
self._target_id = utils._get_as_snowflake(data, 'target_id')
self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
self._target_id = utils._get_as_snowflake(data, "target_id")
def _get_member(self, user_id: int) -> Union[Member, User, None]:
return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str:
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>'
return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
@utils.cached_property
def created_at(self) -> datetime.datetime:
@@ -450,9 +455,13 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id)
@utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]:
def target(
self,
) -> Union[
Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None
]:
try:
converter = getattr(self, '_convert_target_' + self.action.target_type)
converter = getattr(self, "_convert_target_" + self.action.target_type)
except AttributeError:
return Object(id=self._target_id)
else:
@@ -498,11 +507,11 @@ class AuditLogEntry(Hashable):
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
fake_payload = {
'max_age': changeset.max_age,
'max_uses': changeset.max_uses,
'code': changeset.code,
'temporary': changeset.temporary,
'uses': changeset.uses,
"max_age": changeset.max_age,
"max_uses": changeset.max_uses,
"code": changeset.code,
"temporary": changeset.temporary,
"uses": changeset.uses,
}
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore

View File

@@ -29,11 +29,10 @@ import time
import random
from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar('T', bool, Literal[True], Literal[False])
T = TypeVar("T", bool, Literal[True], Literal[False])
__all__ = ("ExponentialBackoff",)
__all__ = (
'ExponentialBackoff',
)
class ExponentialBackoff(Generic[T]):
"""An implementation of the exponential backoff algorithm

View File

@@ -57,14 +57,14 @@ from .threads import Thread
from .iterators import ArchivedThreadIterator
__all__ = (
'TextChannel',
'VoiceChannel',
'StageChannel',
'DMChannel',
'CategoryChannel',
'StoreChannel',
'GroupChannel',
'PartialMessageable',
"TextChannel",
"VoiceChannel",
"StageChannel",
"DMChannel",
"CategoryChannel",
"StoreChannel",
"GroupChannel",
"PartialMessageable",
)
if TYPE_CHECKING:
@@ -155,51 +155,51 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'topic',
'_state',
'nsfw',
'category_id',
'position',
'slowmode_delay',
'_overwrites',
'_type',
'last_message_id',
'default_auto_archive_duration',
"name",
"id",
"guild",
"topic",
"_state",
"nsfw",
"category_id",
"position",
"slowmode_delay",
"_overwrites",
"_type",
"last_message_id",
"default_auto_archive_duration",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self._type: int = data['type']
self.id: int = int(data["id"])
self._type: int = data["type"]
self._update(guild, data)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('position', self.position),
('nsfw', self.nsfw),
('news', self.is_news()),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("position", self.position),
("nsfw", self.nsfw),
("news", self.is_news()),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: TextChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.topic: Optional[str] = data.get('topic')
self.position: int = data['position']
self.nsfw: bool = data.get('nsfw', False)
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.topic: Optional[str] = data.get("topic")
self.position: int = data["position"]
self.nsfw: bool = data.get("nsfw", False)
# Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440)
self._type: int = data.get('type', self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self.slowmode_delay: int = data.get("rate_limit_per_user", 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440)
self._type: int = data.get("type", self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id")
self._fill_overwrites(data)
async def _get_channel(self):
@@ -228,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""List[:class:`Member`]: Returns all members that can see this channel."""
return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: Returns all bots that can see this channel."""
return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages]
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: Returns all human members that can see this channel."""
return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages]
@property
def threads(self) -> List[Thread]:
"""List[:class:`Thread`]: Returns all the threads that you can see.
@@ -361,7 +371,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
return await self._clone_impl(
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
{"topic": self.topic, "nsfw": self.nsfw, "rate_limit_per_user": self.slowmode_delay},
name=name,
reason=reason,
)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@@ -408,7 +420,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
@@ -548,7 +560,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
data = await self._state.http.channel_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]
async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook:
async def create_webhook(
self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None
) -> Webhook:
"""|coro|
Creates a webhook for this channel.
@@ -625,10 +639,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
if not self.is_news():
raise ClientException('The channel must be a news channel.')
raise ClientException("The channel must be a news channel.")
if not isinstance(destination, TextChannel):
raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}')
raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}")
from .webhook import Webhook
@@ -792,40 +806,40 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
__slots__ = (
'name',
'id',
'guild',
'bitrate',
'user_limit',
'_state',
'position',
'_overwrites',
'category_id',
'rtc_region',
'video_quality_mode',
"name",
"id",
"guild",
"bitrate",
"user_limit",
"_state",
"position",
"_overwrites",
"category_id",
"rtc_region",
"video_quality_mode",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def _get_voice_client_key(self) -> Tuple[int, str]:
return self.guild.id, 'guild_id'
return self.guild.id, "guild_id"
def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild
self.name: str = data['name']
rtc = data.get('rtc_region')
self.name: str = data["name"]
rtc = data.get("rtc_region")
self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.position: int = data['position']
self.bitrate: int = data.get('bitrate')
self.user_limit: int = data.get('user_limit')
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1))
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data["position"]
self.bitrate: int = data.get("bitrate")
self.user_limit: int = data.get("user_limit")
self._fill_overwrites(data)
@property
@@ -933,17 +947,17 @@ class VoiceChannel(VocalGuildChannel):
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("rtc_region", self.rtc_region),
("position", self.position),
("bitrate", self.bitrate),
("video_quality_mode", self.video_quality_mode),
("user_limit", self.user_limit),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
@property
def type(self) -> ChannelType:
@@ -952,7 +966,9 @@ class VoiceChannel(VocalGuildChannel):
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason)
return await self._clone_impl(
{"bitrate": self.bitrate, "user_limit": self.user_limit}, name=name, reason=reason
)
@overload
async def edit(
@@ -1093,26 +1109,26 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
__slots__ = ('topic',)
__slots__ = ("topic",)
def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('topic', self.topic),
('rtc_region', self.rtc_region),
('position', self.position),
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
('category_id', self.category_id),
("id", self.id),
("name", self.name),
("topic", self.topic),
("rtc_region", self.rtc_region),
("position", self.position),
("bitrate", self.bitrate),
("video_quality_mode", self.video_quality_mode),
("user_limit", self.user_limit),
("category_id", self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
joined = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data)
self.topic = data.get('topic')
self.topic = data.get("topic")
@property
def requesting_to_speak(self) -> List[Member]:
@@ -1201,13 +1217,13 @@ class StageChannel(VocalGuildChannel):
The newly created stage instance.
"""
payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic}
payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic}
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
raise InvalidArgument("privacy_level field must be of type PrivacyLevel")
payload['privacy_level'] = privacy_level.value
payload["privacy_level"] = privacy_level.value
data = await self._state.http.create_stage_instance(**payload, reason=reason)
return StageInstance(guild=self.guild, state=self._state, data=data)
@@ -1361,22 +1377,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
"""
__slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id')
__slots__ = ("name", "id", "guild", "nsfw", "_state", "position", "_overwrites", "category_id")
def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def __repr__(self) -> str:
return f'<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
return f"<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.nsfw: bool = data.get('nsfw', False)
self.position: int = data['position']
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.nsfw: bool = data.get("nsfw", False)
self.position: int = data["position"]
self._fill_overwrites(data)
@property
@@ -1394,7 +1410,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload
async def edit(
@@ -1463,7 +1479,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs):
kwargs.pop('category', None)
kwargs.pop("category", None)
await super().move(**kwargs)
@property
@@ -1590,30 +1606,30 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'_state',
'nsfw',
'category_id',
'position',
'_overwrites',
"name",
"id",
"guild",
"_state",
"nsfw",
"category_id",
"position",
"_overwrites",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(guild, data)
def __repr__(self) -> str:
return f'<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
return f"<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: StoreChannelPayload) -> None:
self.guild: Guild = guild
self.name: str = data['name']
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
self.position: int = data['position']
self.nsfw: bool = data.get('nsfw', False)
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data["position"]
self.nsfw: bool = data.get("nsfw", False)
self._fill_overwrites(data)
@property
@@ -1640,7 +1656,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload
async def edit(
@@ -1716,7 +1732,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
DMC = TypeVar('DMC', bound='DMChannel')
DMC = TypeVar("DMC", bound="DMChannel")
class DMChannel(discord.abc.Messageable, Hashable):
@@ -1756,24 +1772,24 @@ class DMChannel(discord.abc.Messageable, Hashable):
The direct message channel ID.
"""
__slots__ = ('id', 'recipient', 'me', '_state')
__slots__ = ("id", "recipient", "me", "_state")
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
self._state: ConnectionState = state
self.recipient: Optional[User] = state.store_user(data['recipients'][0])
self.recipient: Optional[User] = state.store_user(data["recipients"][0])
self.me: ClientUser = me
self.id: int = int(data['id'])
self.id: int = int(data["id"])
async def _get_channel(self):
return self
def __str__(self) -> str:
if self.recipient:
return f'Direct Message with {self.recipient}'
return 'Direct Message with Unknown User'
return f"Direct Message with {self.recipient}"
return "Direct Message with Unknown User"
def __repr__(self) -> str:
return f'<DMChannel id={self.id} recipient={self.recipient!r}>'
return f"<DMChannel id={self.id} recipient={self.recipient!r}>"
@classmethod
def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC:
@@ -1892,19 +1908,19 @@ class GroupChannel(discord.abc.Messageable, Hashable):
The group channel's name if provided.
"""
__slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state')
__slots__ = ("id", "recipients", "owner_id", "owner", "_icon", "name", "me", "_state")
def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self.me: ClientUser = me
self._update_group(data)
def _update_group(self, data: GroupChannelPayload) -> None:
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id')
self._icon: Optional[str] = data.get('icon')
self.name: Optional[str] = data.get('name')
self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id")
self._icon: Optional[str] = data.get("icon")
self.name: Optional[str] = data.get("name")
self.recipients: List[User] = [self._state.store_user(u) for u in data.get("recipients", [])]
self.owner: Optional[BaseUser]
if self.owner_id == self.me.id:
@@ -1920,12 +1936,12 @@ class GroupChannel(discord.abc.Messageable, Hashable):
return self.name
if len(self.recipients) == 0:
return 'Unnamed'
return "Unnamed"
return ', '.join(map(lambda x: x.name, self.recipients))
return ", ".join(map(lambda x: x.name, self.recipients))
def __repr__(self) -> str:
return f'<GroupChannel id={self.id} name={self.name!r}>'
return f"<GroupChannel id={self.id} name={self.name!r}>"
@property
def type(self) -> ChannelType:
@@ -1937,7 +1953,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
"""Optional[:class:`Asset`]: Returns the channel's icon asset if available."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='channel')
return Asset._from_icon(self._state, self.id, self._icon, path="channel")
@property
def created_at(self) -> datetime.datetime:

View File

@@ -29,7 +29,20 @@ import logging
import signal
import sys
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
)
import aiohttp
@@ -69,46 +82,49 @@ if TYPE_CHECKING:
from .member import Member
from .voice_client import VoiceProtocol
__all__ = (
'Client',
)
__all__ = ("Client",)
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
_log.info("Cleaning up after %d tasks.", len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
_log.info("All tasks finished cancelling.")
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task
})
loop.call_exception_handler(
{
"message": "Unhandled exception during Client.run shutdown.",
"exception": task.exception(),
"task": task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
_log.info("Closing the event loop.")
loop.close()
class Client:
r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
@@ -199,6 +215,7 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
@@ -212,24 +229,22 @@ class Client:
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id')
self.shard_count: Optional[int] = options.get('shard_count')
self.shard_id: Optional[int] = options.get("shard_id")
self.shard_count: Optional[int] = options.get("shard_count")
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop)
connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None)
proxy: Optional[str] = options.pop("proxy", None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None)
unsync_clock: bool = options.pop("assume_unsync_clock", True)
self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
)
self._handlers: Dict[str, Callable] = {
'ready': self._handle_ready
}
self._handlers: Dict[str, Callable] = {"ready": self._handle_ready}
self._hooks: Dict[str, Callable] = {
'before_identify': self._call_before_identify_hook
}
self._hooks: Dict[str, Callable] = {"before_identify": self._call_before_identify_hook}
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._enable_debug_events: bool = options.pop("enable_debug_events", False)
self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
@@ -247,8 +262,14 @@ class Client:
return self.ws
def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
return ConnectionState(
dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
def _handle_ready(self) -> None:
self._ready.set()
@@ -260,7 +281,7 @@ class Client:
This could be referred to as the Discord WebSocket protocol latency.
"""
ws = self.ws
return float('nan') if not ws else ws.latency
return float("nan") if not ws else ws.latency
def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited.
@@ -348,7 +369,9 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None:
async def _run_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> None:
try:
await coro(*args, **kwargs)
except asyncio.CancelledError:
@@ -359,14 +382,16 @@ class Client:
except asyncio.CancelledError:
pass
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task:
def _schedule_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
return asyncio.create_task(wrapped, name=f"discord.py: {event_name}")
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s', event)
method = 'on_' + event
_log.debug("Dispatching event %s", event)
method = "on_" + event
listeners = self._listeners.get(event)
if listeners:
@@ -413,7 +438,7 @@ class Client:
overridden to have a different implementation.
Check :func:`~discord.on_error` for more details.
"""
print(f'Ignoring exception in {event_method}', file=sys.stderr)
print(f"Ignoring exception in {event_method}", file=sys.stderr)
traceback.print_exc()
# hooks
@@ -470,7 +495,7 @@ class Client:
passing status code.
"""
_log.info('logging in using static token')
_log.info("logging in using static token")
data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)
@@ -502,29 +527,31 @@ class Client:
backoff = ExponentialBackoff()
ws_params = {
'initial': True,
'shard_id': self.shard_id,
"initial": True,
"shard_id": self.shard_id,
}
while not self.is_closed():
try:
coro = DiscordWebSocket.from_client(self, **ws_params)
self.ws = await asyncio.wait_for(coro, timeout=60.0)
ws_params['initial'] = False
ws_params["initial"] = False
while True:
await self.ws.poll_event()
except ReconnectWebSocket as e:
_log.info('Got a request to %s the websocket.', e.op)
self.dispatch('disconnect')
_log.info("Got a request to %s the websocket.", e.op)
self.dispatch("disconnect")
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue
except (OSError,
except (
OSError,
HTTPException,
GatewayNotFound,
ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError) as exc:
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect')
self.dispatch("disconnect")
if not reconnect:
await self.close()
if isinstance(exc, ConnectionClosed) and exc.code == 1000:
@@ -597,7 +624,7 @@ class Client:
async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro|
A shorthand coroutine for :meth:`login` + :meth:`connect`.
A shorthand coroutine for :meth:`login` + :meth:`setup` + :meth:`connect`.
Raises
-------
@@ -605,8 +632,21 @@ class Client:
An unexpected keyword argument was received.
"""
await self.login(token)
await self.setup()
await self.connect(reconnect=reconnect)
async def setup(self) -> Any:
"""|coro|
A coroutine to be called to setup the bot, by default this is blank.
To perform asynchronous setup after the bot is logged in but before
it has connected to the Websocket, overwrite this coroutine.
.. versionadded:: 2.0
"""
pass
def run(self, *args: Any, **kwargs: Any) -> None:
"""A blocking call that abstracts away the event loop
initialisation from you.
@@ -654,10 +694,10 @@ class Client:
try:
loop.run_forever()
except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.')
_log.info("Received signal to terminate bot and event loop.")
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.')
_log.info("Cleaning up tasks.")
_cleanup_loop(loop)
if not future.cancelled():
@@ -688,14 +728,14 @@ class Client:
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
self._connection._activity = value.to_dict() # type: ignore
else:
raise TypeError('activity must derive from BaseActivity.')
raise TypeError("activity must derive from BaseActivity.")
@property
def status(self):
""":class:`.Status`:
The status being used upon logging on to Discord.
.. versionadded: 2.0
.. versionadded:: 2.0
"""
if self._connection._status in set(state.value for state in Status):
return Status(self._connection._status)
@@ -704,11 +744,11 @@ class Client:
@status.setter
def status(self, value):
if value is Status.offline:
self._connection._status = 'invisible'
self._connection._status = "invisible"
elif isinstance(value, Status):
self._connection._status = str(value)
else:
raise TypeError('status must derive from Status.')
raise TypeError("status must derive from Status.")
@property
def allowed_mentions(self) -> Optional[AllowedMentions]:
@@ -723,7 +763,7 @@ class Client:
if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
else:
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}')
raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}")
@property
def intents(self) -> Intents:
@@ -829,6 +869,38 @@ class Client:
"""
return self._connection.get_user(id)
async def try_user(self, id: int, /) -> Optional[User]:
"""|coro|
Returns a user with the given ID. If not from cache, the user will be requested from the API.
You do not have to share any guilds with the user to get this information from the API,
however many operations do require that you do.
.. note::
This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead.
.. versionadded:: 2.0
Parameters
-----------
id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""
maybe_user = self.get_user(id)
if maybe_user is not None:
return maybe_user
try:
return await self.fetch_user(id)
except NotFound:
return None
def get_emoji(self, id: int, /) -> Optional[Emoji]:
"""Returns an emoji with the given ID.
@@ -1001,8 +1073,10 @@ class Client:
future = self.loop.create_future()
if check is None:
def _check(*args):
return True
check = _check
ev = event.lower()
@@ -1040,10 +1114,10 @@ class Client:
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('event registered must be a coroutine function')
raise TypeError("event registered must be a coroutine function")
setattr(self, coro.__name__, coro)
_log.debug('%s has successfully been registered as an event', coro.__name__)
_log.debug("%s has successfully been registered as an event", coro.__name__)
return coro
async def change_presence(
@@ -1082,10 +1156,10 @@ class Client:
"""
if status is None:
status_str = 'online'
status_str = "online"
status = Status.online
elif status is Status.offline:
status_str = 'invisible'
status_str = "invisible"
status = Status.offline
else:
status_str = str(status)
@@ -1107,11 +1181,7 @@ class Client:
# Guild stuff
def fetch_guilds(
self,
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
self, *, limit: Optional[int] = 100, before: SnowflakeTime = None, after: SnowflakeTime = None
) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@@ -1307,12 +1377,14 @@ class Client:
The stage instance from the stage channel ID.
"""
data = await self.http.get_stage_instance(channel_id)
guild = self.get_guild(int(data['guild_id']))
guild = self.get_guild(int(data["guild_id"]))
return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore
# Invite management
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite:
async def fetch_invite(
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
"""|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID.
@@ -1428,8 +1500,8 @@ class Client:
The bot's application information.
"""
data = await self.http.application_info()
if 'rpc_origins' not in data:
data['rpc_origins'] = None
if "rpc_origins" not in data:
data["rpc_origins"] = None
return AppInfo(self._connection, data)
async def fetch_user(self, user_id: int, /) -> User:
@@ -1492,16 +1564,16 @@ class Client:
"""
data = await self.http.get_channel(channel_id)
factory, ch_type = _threaded_channel_factory(data['type'])
factory, ch_type = _threaded_channel_factory(data["type"])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
else:
# the factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore
guild_id = int(data["guild_id"]) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id)
# GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
@@ -1550,7 +1622,7 @@ class Client:
The sticker you requested.
"""
data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore
cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
@@ -1571,7 +1643,7 @@ class Client:
All available premium sticker packs.
"""
data = await self.http.list_premium_sticker_packs()
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']]
return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]]
async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
@@ -1628,10 +1700,10 @@ class Client:
"""
if not isinstance(view, View):
raise TypeError(f'expected an instance of View not {view.__class__!r}')
raise TypeError(f"expected an instance of View not {view.__class__!r}")
if not view.is_persistent():
raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout')
raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout")
self._connection.store_view(view, message_id)

View File

@@ -35,11 +35,11 @@ from typing import (
)
__all__ = (
'Colour',
'Color',
"Colour",
"Color",
)
CT = TypeVar('CT', bound='Colour')
CT = TypeVar("CT", bound="Colour")
class Colour:
@@ -76,16 +76,16 @@ class Colour:
The raw integer colour value.
"""
__slots__ = ('value',)
__slots__ = ("value",)
def __init__(self, value: int):
if not isinstance(value, int):
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.")
self.value: int = value
def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xff
return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool:
return isinstance(other, Colour) and self.value == other.value
@@ -94,13 +94,13 @@ class Colour:
return not self.__eq__(other)
def __str__(self) -> str:
return f'#{self.value:0>6x}'
return f"#{self.value:0>6x}"
def __int__(self) -> int:
return self.value
def __repr__(self) -> str:
return f'<Colour value={self.value}>'
return f"<Colour value={self.value}>"
def __hash__(self) -> int:
return hash(self.value)
@@ -164,12 +164,12 @@ class Colour:
@classmethod
def teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
return cls(0x1abc9c)
return cls(0x1ABC9C)
@classmethod
def dark_teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806a)
return cls(0x11806A)
@classmethod
def brand_green(cls: Type[CT]) -> CT:
@@ -182,17 +182,17 @@ class Colour:
@classmethod
def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
return cls(0x2ecc71)
return cls(0x2ECC71)
@classmethod
def dark_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
return cls(0x1f8b4c)
return cls(0x1F8B4C)
@classmethod
def blue(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
return cls(0x3498db)
return cls(0x3498DB)
@classmethod
def dark_blue(cls: Type[CT]) -> CT:
@@ -202,42 +202,42 @@ class Colour:
@classmethod
def purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
return cls(0x9b59b6)
return cls(0x9B59B6)
@classmethod
def dark_purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
return cls(0x71368a)
return cls(0x71368A)
@classmethod
def magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
return cls(0xe91e63)
return cls(0xE91E63)
@classmethod
def dark_magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
return cls(0xad1457)
return cls(0xAD1457)
@classmethod
def gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
return cls(0xf1c40f)
return cls(0xF1C40F)
@classmethod
def dark_gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
return cls(0xc27c0e)
return cls(0xC27C0E)
@classmethod
def orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
return cls(0xe67e22)
return cls(0xE67E22)
@classmethod
def dark_orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xa84300)
return cls(0xA84300)
@classmethod
def brand_red(cls: Type[CT]) -> CT:
@@ -250,45 +250,52 @@ class Colour:
@classmethod
def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c)
return cls(0xE74C3C)
@classmethod
def nitro_booster(cls):
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
.. versionadded:: 2.0"""
return cls(0xF47FFF)
@classmethod
def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
return cls(0x992d22)
return cls(0x992D22)
@classmethod
def lighter_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95a5a6)
return cls(0x95A5A6)
lighter_gray = lighter_grey
@classmethod
def dark_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607d8b)
return cls(0x607D8B)
dark_gray = dark_grey
@classmethod
def light_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979c9f)
return cls(0x979C9F)
light_gray = light_grey
@classmethod
def darker_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546e7a)
return cls(0x546E7A)
darker_gray = darker_grey
@classmethod
def og_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
return cls(0x7289da)
return cls(0x7289DA)
@classmethod
def blurple(cls: Type[CT]) -> CT:
@@ -298,7 +305,7 @@ class Colour:
@classmethod
def greyple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
return cls(0x99aab5)
return cls(0x99AAB5)
@classmethod
def dark_theme(cls: Type[CT]) -> CT:

View File

@@ -41,14 +41,14 @@ if TYPE_CHECKING:
__all__ = (
'Component',
'ActionRow',
'Button',
'SelectMenu',
'SelectOption',
"Component",
"ActionRow",
"Button",
"SelectMenu",
"SelectOption",
)
C = TypeVar('C', bound='Component')
C = TypeVar("C", bound="Component")
class Component:
@@ -70,14 +70,14 @@ class Component:
The type of component.
"""
__slots__: Tuple[str, ...] = ('type',)
__slots__: Tuple[str, ...] = ("type",)
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__)
return f"<{self.__class__.__name__} {attrs}>"
@classmethod
def _raw_construct(cls: Type[C], **kwargs) -> C:
@@ -112,18 +112,18 @@ class ActionRow(Component):
The children components that this holds, if any.
"""
__slots__: Tuple[str, ...] = ('children',)
__slots__: Tuple[str, ...] = ("children",)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.children: List[Component] = [_component_factory(d) for d in data.get("components", [])]
def to_dict(self) -> ActionRowPayload:
return {
'type': int(self.type),
'components': [child.to_dict() for child in self.children],
"type": int(self.type),
"components": [child.to_dict() for child in self.children],
} # type: ignore
@@ -157,44 +157,44 @@ class Button(Component):
"""
__slots__: Tuple[str, ...] = (
'style',
'custom_id',
'url',
'disabled',
'label',
'emoji',
"style",
"custom_id",
"url",
"disabled",
"label",
"emoji",
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
self.disabled: bool = data.get('disabled', False)
self.label: Optional[str] = data.get('label')
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
self.custom_id: Optional[str] = data.get("custom_id")
self.url: Optional[str] = data.get("url")
self.disabled: bool = data.get("disabled", False)
self.label: Optional[str] = data.get("label")
self.emoji: Optional[PartialEmoji]
try:
self.emoji = PartialEmoji.from_dict(data['emoji'])
self.emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError:
self.emoji = None
def to_dict(self) -> ButtonComponentPayload:
payload = {
'type': 2,
'style': int(self.style),
'label': self.label,
'disabled': self.disabled,
"type": 2,
"style": int(self.style),
"label": self.label,
"disabled": self.disabled,
}
if self.custom_id:
payload['custom_id'] = self.custom_id
payload["custom_id"] = self.custom_id
if self.url:
payload['url'] = self.url
payload["url"] = self.url
if self.emoji:
payload['emoji'] = self.emoji.to_dict()
payload["emoji"] = self.emoji.to_dict()
return payload # type: ignore
@@ -231,37 +231,37 @@ class SelectMenu(Component):
"""
__slots__: Tuple[str, ...] = (
'custom_id',
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
"custom_id",
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
self.max_values: int = data.get('max_values', 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
self.custom_id: str = data["custom_id"]
self.placeholder: Optional[str] = data.get("placeholder")
self.min_values: int = data.get("min_values", 1)
self.max_values: int = data.get("max_values", 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])]
self.disabled: bool = data.get("disabled", False)
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
'type': self.type.value,
'custom_id': self.custom_id,
'min_values': self.min_values,
'max_values': self.max_values,
'options': [op.to_dict() for op in self.options],
'disabled': self.disabled,
"type": self.type.value,
"custom_id": self.custom_id,
"min_values": self.min_values,
"max_values": self.max_values,
"options": [op.to_dict() for op in self.options],
"disabled": self.disabled,
}
if self.placeholder:
payload['placeholder'] = self.placeholder
payload["placeholder"] = self.placeholder
return payload
@@ -292,11 +292,11 @@ class SelectOption:
"""
__slots__: Tuple[str, ...] = (
'label',
'value',
'description',
'emoji',
'default',
"label",
"value",
"description",
"emoji",
"default",
)
def __init__(
@@ -318,60 +318,60 @@ class SelectOption:
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
self.emoji = emoji
self.default = default
def __repr__(self) -> str:
return (
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} '
f'emoji={self.emoji!r} default={self.default!r}>'
f"<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} "
f"emoji={self.emoji!r} default={self.default!r}>"
)
def __str__(self) -> str:
if self.emoji:
base = f'{self.emoji} {self.label}'
base = f"{self.emoji} {self.label}"
else:
base = self.label
if self.description:
return f'{base}\n{self.description}'
return f"{base}\n{self.description}"
return base
@classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try:
emoji = PartialEmoji.from_dict(data['emoji'])
emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError:
emoji = None
return cls(
label=data['label'],
value=data['value'],
description=data.get('description'),
label=data["label"],
value=data["value"],
description=data.get("description"),
emoji=emoji,
default=data.get('default', False),
default=data.get("default", False),
)
def to_dict(self) -> SelectOptionPayload:
payload: SelectOptionPayload = {
'label': self.label,
'value': self.value,
'default': self.default,
"label": self.label,
"value": self.value,
"default": self.default,
}
if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore
payload["emoji"] = self.emoji.to_dict() # type: ignore
if self.description:
payload['description'] = self.description
payload["description"] = self.description
return payload
def _component_factory(data: ComponentPayload) -> Component:
component_type = data['type']
component_type = data["type"]
if component_type == 1:
return ActionRow(data)
elif component_type == 2:

View File

@@ -32,11 +32,10 @@ if TYPE_CHECKING:
from types import TracebackType
TypingT = TypeVar('TypingT', bound='Typing')
TypingT = TypeVar("TypingT", bound="Typing")
__all__ = ("Typing",)
__all__ = (
'Typing',
)
def _typing_done_callback(fut: asyncio.Future) -> None:
# just retrieve any exception and call it a day
@@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
except (asyncio.CancelledError, Exception):
pass
class Typing:
def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
@@ -67,7 +67,8 @@ class Typing:
self.task.add_done_callback(_typing_done_callback)
return self
def __exit__(self,
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
@@ -79,7 +80,8 @@ class Typing:
await channel._state.http.send_typing(channel.id)
return self.__enter__()
async def __aexit__(self,
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],

View File

@@ -30,9 +30,7 @@ from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Typ
from . import utils
from .colour import Colour
__all__ = (
'Embed',
)
__all__ = ("Embed",)
class _EmptyEmbed:
@@ -40,7 +38,7 @@ class _EmptyEmbed:
return False
def __repr__(self) -> str:
return 'Embed.Empty'
return "Embed.Empty"
def __len__(self) -> int:
return 0
@@ -57,51 +55,45 @@ class EmbedProxy:
return len(self.__dict__)
def __repr__(self) -> str:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
return f'EmbedProxy({inner})'
inner = ", ".join((f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")))
return f"EmbedProxy({inner})"
def __getattr__(self, attr: str) -> _EmptyEmbed:
return EmptyEmbed
E = TypeVar('E', bound='Embed')
E = TypeVar("E", bound="Embed")
if TYPE_CHECKING:
from discord.types.embed import Embed as EmbedData, EmbedType
T = TypeVar('T')
T = TypeVar("T")
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str]
value: MaybeEmpty[str]
inline: bool
class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
@@ -163,19 +155,19 @@ class Embed:
"""
__slots__ = (
'title',
'url',
'type',
'_timestamp',
'_colour',
'_footer',
'_image',
'_thumbnail',
'_video',
'_provider',
'_author',
'_fields',
'description',
"title",
"url",
"type",
"_timestamp",
"_colour",
"_footer",
"_image",
"_thumbnail",
"_video",
"_provider",
"_author",
"_fields",
"description",
)
Empty: Final = EmptyEmbed
@@ -186,7 +178,7 @@ class Embed:
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich',
type: EmbedType = "rich",
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
@@ -231,10 +223,10 @@ class Embed:
# fill in the basic fields
self.title = data.get('title', EmptyEmbed)
self.type = data.get('type', EmptyEmbed)
self.description = data.get('description', EmptyEmbed)
self.url = data.get('url', EmptyEmbed)
self.title = data.get("title", EmptyEmbed)
self.type = data.get("type", EmptyEmbed)
self.description = data.get("description", EmptyEmbed)
self.url = data.get("url", EmptyEmbed)
if self.title is not EmptyEmbed:
self.title = str(self.title)
@@ -248,22 +240,22 @@ class Embed:
# try to fill in the more rich fields
try:
self._colour = Colour(value=data['color'])
self._colour = Colour(value=data["color"])
except KeyError:
pass
try:
self._timestamp = utils.parse_time(data['timestamp'])
self._timestamp = utils.parse_time(data["timestamp"])
except KeyError:
pass
for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'):
for attr in ("thumbnail", "video", "provider", "author", "fields", "image", "footer"):
try:
value = data[attr]
except KeyError:
continue
else:
setattr(self, '_' + attr, value)
setattr(self, "_" + attr, value)
return self
@@ -273,11 +265,11 @@ class Embed:
def __len__(self) -> int:
total = len(self.title) + len(self.description)
for field in getattr(self, '_fields', []):
total += len(field['name']) + len(field['value'])
for field in getattr(self, "_fields", []):
total += len(field["name"]) + len(field["value"])
try:
footer_text = self._footer['text']
footer_text = self._footer["text"]
except (AttributeError, KeyError):
pass
else:
@@ -288,7 +280,7 @@ class Embed:
except AttributeError:
pass
else:
total += len(author['name'])
total += len(author["name"])
return total
@@ -312,7 +304,7 @@ class Embed:
@property
def colour(self) -> MaybeEmpty[Colour]:
return getattr(self, '_colour', EmptyEmbed)
return getattr(self, "_colour", EmptyEmbed)
@colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore
@@ -321,13 +313,15 @@ class Embed:
elif isinstance(value, int):
self._colour = Colour(value=value)
else:
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.')
raise TypeError(
f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead."
)
color = colour
@property
def timestamp(self) -> MaybeEmpty[datetime.datetime]:
return getattr(self, '_timestamp', EmptyEmbed)
return getattr(self, "_timestamp", EmptyEmbed)
@timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
@@ -348,7 +342,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
return EmbedProxy(getattr(self, "_footer", {})) # type: ignore
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
"""Sets the footer for the embed content.
@@ -366,10 +360,10 @@ class Embed:
self._footer = {}
if text is not EmptyEmbed:
self._footer['text'] = str(text)
self._footer["text"] = str(text)
if icon_url is not EmptyEmbed:
self._footer['icon_url'] = str(icon_url)
self._footer["icon_url"] = str(icon_url)
return self
@@ -401,16 +395,17 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
return EmbedProxy(getattr(self, "_image", {})) # type: ignore
@image.setter
def image(self: E, *, url: Any):
self._image = {
'url': str(url),
}
def image(self, url: Any):
if url is EmptyEmbed:
del self.image
else:
self._image = {"url": str(url)}
@image.deleter
def image(self: E):
def image(self):
try:
del self._image
except AttributeError:
@@ -431,11 +426,7 @@ class Embed:
The source URL for the image. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
del self.image
else:
self.image = url
return self
@property
@@ -451,27 +442,23 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore
@thumbnail.setter
def thumbnail(self: E, *, url: Any):
"""Sets the thumbnail for the embed content.
"""
self._thumbnail = {
'url': str(url),
}
return
def thumbnail(self, url: Any):
if url is EmptyEmbed:
del self.thumbnail
else:
self._thumbnail = {"url": str(url)}
@thumbnail.deleter
def thumbnail(self):
try:
del self.thumbnail
del self._thumbnail
except AttributeError:
pass
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]):
def set_thumbnail(self, *, url: MaybeEmpty[Any]):
"""Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style
@@ -485,11 +472,8 @@ class Embed:
url: :class:`str`
The source URL for the thumbnail. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
del self.thumbnail
else:
self.thumbnail = url
self.thumbnail = url
return self
@property
@@ -504,7 +488,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
return EmbedProxy(getattr(self, "_video", {})) # type: ignore
@property
def provider(self) -> _EmbedProviderProxy:
@@ -514,7 +498,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
return EmbedProxy(getattr(self, "_provider", {})) # type: ignore
@property
def author(self) -> _EmbedAuthorProxy:
@@ -524,9 +508,11 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, '_author', {})) # type: ignore
return EmbedProxy(getattr(self, "_author", {})) # type: ignore
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
def set_author(
self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed
) -> E:
"""Sets the author for the embed content.
This function returns the class instance to allow for fluent-style
@@ -543,14 +529,14 @@ class Embed:
"""
self._author = {
'name': str(name),
"name": str(name),
}
if url is not EmptyEmbed:
self._author['url'] = str(url)
self._author["url"] = str(url)
if icon_url is not EmptyEmbed:
self._author['icon_url'] = str(icon_url)
self._author["icon_url"] = str(icon_url)
return self
@@ -577,7 +563,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore
return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E:
"""Adds a field to the embed object.
@@ -596,9 +582,9 @@ class Embed:
"""
field = {
'inline': inline,
'name': str(name),
'value': str(value),
"inline": inline,
"name": str(name),
"value": str(value),
}
try:
@@ -629,9 +615,9 @@ class Embed:
"""
field = {
'inline': inline,
'name': str(name),
'value': str(value),
"inline": inline,
"name": str(name),
"value": str(value),
}
try:
@@ -697,11 +683,11 @@ class Embed:
try:
field = self._fields[index]
except (TypeError, IndexError, AttributeError):
raise IndexError('field index out of range')
raise IndexError("field index out of range")
field['name'] = str(name)
field['value'] = str(value)
field['inline'] = inline
field["name"] = str(name)
field["value"] = str(value)
field["inline"] = inline
return self
def to_dict(self) -> EmbedData:
@@ -719,35 +705,35 @@ class Embed:
# deal with basic convenience wrappers
try:
colour = result.pop('colour')
colour = result.pop("colour")
except KeyError:
pass
else:
if colour:
result['color'] = colour.value
result["color"] = colour.value
try:
timestamp = result.pop('timestamp')
timestamp = result.pop("timestamp")
except KeyError:
pass
else:
if timestamp:
if timestamp.tzinfo:
result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
else:
result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
# add in the non raw attribute ones
if self.type:
result['type'] = self.type
result["type"] = self.type
if self.description:
result['description'] = self.description
result["description"] = self.description
if self.url:
result['url'] = self.url
result["url"] = self.url
if self.title:
result['title'] = self.title
result["title"] = self.title
return result # type: ignore

View File

@@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User
__all__ = (
'Emoji',
)
__all__ = ("Emoji",)
if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload
@@ -98,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin):
"""
__slots__: Tuple[str, ...] = (
'require_colons',
'animated',
'managed',
'id',
'name',
'_roles',
'guild_id',
'_state',
'user',
'available',
"require_colons",
"animated",
"managed",
"id",
"name",
"_roles",
"guild_id",
"_state",
"user",
"available",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
@@ -116,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin):
self._from_data(data)
def _from_data(self, emoji: EmojiPayload):
self.require_colons: bool = emoji.get('require_colons', False)
self.managed: bool = emoji.get('managed', False)
self.id: int = int(emoji['id']) # type: ignore
self.name: str = emoji['name'] # type: ignore
self.animated: bool = emoji.get('animated', False)
self.available: bool = emoji.get('available', True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', [])))
user = emoji.get('user')
self.require_colons: bool = emoji.get("require_colons", False)
self.managed: bool = emoji.get("managed", False)
self.id: int = int(emoji["id"]) # type: ignore
self.name: str = emoji["name"] # type: ignore
self.animated: bool = emoji.get("animated", False)
self.available: bool = emoji.get("available", True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", [])))
user = emoji.get("user")
self.user: Optional[User] = User(state=self._state, data=user) if user else None
def _to_partial(self) -> PartialEmoji:
@@ -131,21 +129,21 @@ class Emoji(_EmojiTag, AssetMixin):
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for attr in self.__slots__:
if attr[0] != '_':
if attr[0] != "_":
value = getattr(self, attr, None)
if value is not None:
yield (attr, value)
def __str__(self) -> str:
if self.animated:
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
return f"<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>"
def __eq__(self, other: Any) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id
@@ -164,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin):
@property
def url(self) -> str:
""":class:`str`: Returns the URL of the emoji."""
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
@property
def roles(self) -> List[Role]:
@@ -219,7 +217,9 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
async def edit(
self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None
) -> Emoji:
r"""|coro|
Edits the custom emoji.
@@ -254,9 +254,9 @@ class Emoji(_EmojiTag, AssetMixin):
payload = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if roles is not MISSING:
payload['roles'] = [role.id for role in roles]
payload["roles"] = [role.id for role in roles]
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state)

View File

@@ -27,41 +27,41 @@ from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
__all__ = (
'Enum',
'ChannelType',
'MessageType',
'VoiceRegion',
'SpeakingState',
'VerificationLevel',
'ContentFilter',
'Status',
'DefaultAvatar',
'AuditLogAction',
'AuditLogActionCategory',
'UserFlags',
'ActivityType',
'NotificationLevel',
'TeamMembershipState',
'WebhookType',
'ExpireBehaviour',
'ExpireBehavior',
'StickerType',
'StickerFormatType',
'InviteTarget',
'VideoQualityMode',
'ComponentType',
'ButtonStyle',
'StagePrivacyLevel',
'InteractionType',
'InteractionResponseType',
'NSFWLevel',
"Enum",
"ChannelType",
"MessageType",
"VoiceRegion",
"SpeakingState",
"VerificationLevel",
"ContentFilter",
"Status",
"DefaultAvatar",
"AuditLogAction",
"AuditLogActionCategory",
"UserFlags",
"ActivityType",
"NotificationLevel",
"TeamMembershipState",
"WebhookType",
"ExpireBehaviour",
"ExpireBehavior",
"StickerType",
"StickerFormatType",
"InviteTarget",
"VideoQualityMode",
"ComponentType",
"ButtonStyle",
"StagePrivacyLevel",
"InteractionType",
"InteractionResponseType",
"NSFWLevel",
)
def _create_value_cls(name, comparable):
cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}'
cls = namedtuple("_EnumValue_" + name, "name value")
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
@@ -69,8 +69,9 @@ def _create_value_cls(name, comparable):
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls
def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
class EnumMeta(type):
@@ -88,7 +89,7 @@ class EnumMeta(type):
value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == '_' and not is_descriptor:
if key[0] == "_" and not is_descriptor:
continue
# Special case classmethod to just pass through
@@ -110,10 +111,10 @@ class EnumMeta(type):
member_mapping[key] = new_value
attrs[key] = new_value
attrs['_enum_value_map_'] = value_mapping
attrs['_enum_member_map_'] = member_mapping
attrs['_enum_member_names_'] = member_names
attrs['_enum_value_cls_'] = value_cls
attrs["_enum_value_map_"] = value_mapping
attrs["_enum_member_map_"] = member_mapping
attrs["_enum_member_names_"] = member_names
attrs["_enum_value_cls_"] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore
return actual_cls
@@ -128,7 +129,7 @@ class EnumMeta(type):
return len(cls._enum_member_names_)
def __repr__(cls):
return f'<enum {cls.__name__}>'
return f"<enum {cls.__name__}>"
@property
def __members__(cls):
@@ -144,10 +145,10 @@ class EnumMeta(type):
return cls._enum_member_map_[key]
def __setattr__(cls, name, value):
raise TypeError('Enums are immutable.')
raise TypeError("Enums are immutable.")
def __delattr__(cls, attr):
raise TypeError('Enums are immutable')
raise TypeError("Enums are immutable")
def __instancecheck__(self, instance):
# isinstance(x, Y)
@@ -215,29 +216,29 @@ class MessageType(Enum):
class VoiceRegion(Enum):
us_west = 'us-west'
us_east = 'us-east'
us_south = 'us-south'
us_central = 'us-central'
eu_west = 'eu-west'
eu_central = 'eu-central'
singapore = 'singapore'
london = 'london'
sydney = 'sydney'
amsterdam = 'amsterdam'
frankfurt = 'frankfurt'
brazil = 'brazil'
hongkong = 'hongkong'
russia = 'russia'
japan = 'japan'
southafrica = 'southafrica'
south_korea = 'south-korea'
india = 'india'
europe = 'europe'
dubai = 'dubai'
vip_us_east = 'vip-us-east'
vip_us_west = 'vip-us-west'
vip_amsterdam = 'vip-amsterdam'
us_west = "us-west"
us_east = "us-east"
us_south = "us-south"
us_central = "us-central"
eu_west = "eu-west"
eu_central = "eu-central"
singapore = "singapore"
london = "london"
sydney = "sydney"
amsterdam = "amsterdam"
frankfurt = "frankfurt"
brazil = "brazil"
hongkong = "hongkong"
russia = "russia"
japan = "japan"
southafrica = "southafrica"
south_korea = "south-korea"
india = "india"
europe = "europe"
dubai = "dubai"
vip_us_east = "vip-us-east"
vip_us_west = "vip-us-west"
vip_amsterdam = "vip-amsterdam"
def __str__(self):
return self.value
@@ -277,12 +278,12 @@ class ContentFilter(Enum, comparable=True):
class Status(Enum):
online = 'online'
offline = 'offline'
idle = 'idle'
dnd = 'dnd'
do_not_disturb = 'dnd'
invisible = 'invisible'
online = "online"
offline = "offline"
idle = "idle"
dnd = "dnd"
do_not_disturb = "dnd"
invisible = "invisible"
def __str__(self):
return self.value
@@ -415,33 +416,33 @@ class AuditLogAction(Enum):
def target_type(self) -> Optional[str]:
v = self.value
if v == -1:
return 'all'
return "all"
elif v < 10:
return 'guild'
return "guild"
elif v < 20:
return 'channel'
return "channel"
elif v < 30:
return 'user'
return "user"
elif v < 40:
return 'role'
return "role"
elif v < 50:
return 'invite'
return "invite"
elif v < 60:
return 'webhook'
return "webhook"
elif v < 70:
return 'emoji'
return "emoji"
elif v == 73:
return 'channel'
return "channel"
elif v < 80:
return 'message'
return "message"
elif v < 83:
return 'integration'
return "integration"
elif v < 90:
return 'stage_instance'
return "stage_instance"
elif v < 93:
return 'sticker'
return "sticker"
elif v < 113:
return 'thread'
return "thread"
class UserFlags(Enum):
@@ -589,12 +590,12 @@ class NSFWLevel(Enum, comparable=True):
age_restricted = 3
T = TypeVar('T')
T = TypeVar("T")
def create_unknown_value(cls: Type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore
name = f'unknown_{val}'
name = f"unknown_{val}"
return value_cls(name=name, value=val)

View File

@@ -38,20 +38,20 @@ if TYPE_CHECKING:
from .interactions import Interaction
__all__ = (
'DiscordException',
'ClientException',
'NoMoreItems',
'GatewayNotFound',
'HTTPException',
'Forbidden',
'NotFound',
'DiscordServerError',
'InvalidData',
'InvalidArgument',
'LoginFailure',
'ConnectionClosed',
'PrivilegedIntentsRequired',
'InteractionResponded',
"DiscordException",
"ClientException",
"NoMoreItems",
"GatewayNotFound",
"HTTPException",
"Forbidden",
"NotFound",
"DiscordServerError",
"InvalidData",
"InvalidArgument",
"LoginFailure",
"ConnectionClosed",
"PrivilegedIntentsRequired",
"InteractionResponded",
)
@@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException):
"""An exception that is raised when the gateway for Discord could not be found"""
def __init__(self):
message = 'The gateway to connect to discord was not found.'
message = "The gateway to connect to discord was not found."
super().__init__(message)
def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]:
def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]:
items: List[Tuple[str, str]] = []
for k, v in d.items():
new_key = key + '.' + k if key else k
new_key = key + "." + k if key else k
if isinstance(v, dict):
try:
_errors: List[Dict[str, Any]] = v['_errors']
_errors: List[Dict[str, Any]] = v["_errors"]
except KeyError:
items.extend(_flatten_error_dict(v, new_key).items())
else:
items.append((new_key, ' '.join(x.get('message', '') for x in _errors)))
items.append((new_key, " ".join(x.get("message", "") for x in _errors)))
else:
items.append((new_key, v))
@@ -129,22 +129,22 @@ class HTTPException(DiscordException):
self.code: int
self.text: str
if isinstance(message, dict):
self.code = message.get('code', 0)
base = message.get('message', '')
errors = message.get('errors')
self.code = message.get("code", 0)
base = message.get("message", "")
errors = message.get("errors")
if errors:
errors = _flatten_error_dict(errors)
helpful = '\n'.join('In %s: %s' % t for t in errors.items())
self.text = base + '\n' + helpful
helpful = "\n".join("In %s: %s" % t for t in errors.items())
self.text = base + "\n" + helpful
else:
self.text = base
else:
self.text = message or ''
self.text = message or ""
self.code = 0
fmt = '{0.status} {0.reason} (error code: {1})'
fmt = "{0.status} {0.reason} (error code: {1})"
if len(self.text):
fmt += ': {2}'
fmt += ": {2}"
super().__init__(fmt.format(self.response, self.code, self.text))
@@ -226,9 +226,9 @@ class ConnectionClosed(ClientException):
# reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1
# aiohttp doesn't seem to consistently provide close reason
self.reason: str = ''
self.reason: str = ""
self.shard_id: Optional[int] = shard_id
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}')
super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}")
class PrivilegedIntentsRequired(ClientException):
@@ -250,10 +250,10 @@ class PrivilegedIntentsRequired(ClientException):
def __init__(self, shard_id: Optional[int]):
self.shard_id: Optional[int] = shard_id
msg = (
'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the '
'developer portal. It is recommended to go to https://discord.com/developers/applications/ '
'and explicitly enable the privileged intents within your application\'s page. If this is not '
'possible, then consider disabling the privileged intents instead.'
"Shard ID %s is requesting privileged intents that have not been explicitly enabled in the "
"developer portal. It is recommended to go to https://discord.com/developers/applications/ "
"and explicitly enable the privileged intents within your application's page. If this is not "
"possible, then consider disabling the privileged intents instead."
)
super().__init__(msg % shard_id)
@@ -274,4 +274,4 @@ class InteractionResponded(ClientException):
def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction
super().__init__('This interaction has already been responded to before')
super().__init__("This interaction has already been responded to before")

View File

@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
@@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
]
# This is merely a tag type to avoid circular import issues.

View File

@@ -28,18 +28,44 @@ from __future__ import annotations
import asyncio
import collections
import collections.abc
from functools import cached_property
import inspect
import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
from collections import defaultdict
from discord.http import HTTPClient
from typing import (
Any,
Callable,
Iterable,
Tuple,
cast,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord
from discord.types.interactions import (
ApplicationCommandInteractionData,
ApplicationCommandInteractionDataOption,
EditApplicationCommand,
_ApplicationCommandInteractionDataOptionString,
)
from .core import GroupMixin
from .view import StringView
from .converter import Greedy
from .view import StringView, supported_quotes
from .context import Context
from .flags import FlagConverter
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
@@ -47,24 +73,67 @@ from .cog import Cog
if TYPE_CHECKING:
import importlib.machinery
from discord.role import Role
from discord.message import Message
from discord.abc import PartialMessageableChannel
from ._types import (
Check,
CoroFunc,
)
__all__ = (
'when_mentioned',
'when_mentioned_or',
'Bot',
'AutoShardedBot',
"when_mentioned",
"when_mentioned_or",
"Bot",
"AutoShardedBot",
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
class _FakeSlashMessage(discord.PartialMessage):
activity = application = edited_at = reference = webhook_id = None
attachments = components = reactions = stickers = []
tts = False
raw_mentions = discord.Message.raw_mentions
clean_content = discord.Message.clean_content
channel_mentions = discord.Message.channel_mentions
raw_role_mentions = discord.Message.raw_role_mentions
raw_channel_mentions = discord.Message.raw_channel_mentions
author: Union[discord.User, discord.Member]
@classmethod
def from_interaction(
cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread]
):
self = cls(channel=channel, id=interaction.id)
assert interaction.user is not None
self.author = interaction.user
return self
@cached_property
def mentions(self) -> List[Union[discord.Member, discord.User]]:
client = self._state._get_client()
if self.guild:
ensure_user = lambda id: self.guild.get_member(id) or client.get_user(id) # type: ignore
else:
ensure_user = client.get_user
return discord.utils._unique(filter(None, map(ensure_user, self.raw_mentions)))
@cached_property
def role_mentions(self) -> List[Role]:
if self.guild is None:
return []
return discord.utils._unique(filter(None, map(self.guild.get_role, self.raw_role_mentions)))
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
@@ -72,7 +141,8 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
@@ -103,6 +173,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
----------
:func:`.when_mentioned`
"""
def inner(bot, msg):
r = list(prefixes)
r = when_mentioned(bot, msg) + r
@@ -110,19 +181,65 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
def _unwrap_slash_groups(
data: ApplicationCommandInteractionData,
) -> Tuple[str, List[ApplicationCommandInteractionDataOption]]:
command_name = data["name"]
command_options = data.get("options") or []
while any(o["type"] in {1, 2} for o in command_options): # type: ignore
for option in command_options: # type: ignore
if option["type"] in {1, 2}: # type: ignore
command_name += f' {option["name"]}' # type: ignore
command_options = option.get("options") or []
return command_name, command_options
def _quote_string_safe(string: str) -> str:
# we need to quote this string otherwise we may spill into
# other parameters and cause all kinds of trouble, as many
# quotes are supported and some may be in the option, we
# loop through all supported quotes and if neither open or
# close are in the string, we add them
for open, close in supported_quotes.items():
if open not in string and close not in string:
return f"{open}{string}{close}"
# all supported quotes are in the message and we cannot add any
# safely, very unlikely but still got to be covered
raise errors.UnexpectedQuoteError(string)
class _DefaultRepr:
def __repr__(self):
return '<default-help-command>'
return "<default-help-command>"
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options):
def __init__(
self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
message_commands: bool = True,
slash_commands: bool = False,
**options,
):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix
self.slash_commands = slash_commands
self.message_commands = message_commands
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
@@ -131,16 +248,20 @@ class BotBase(GroupMixin):
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.')
raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
if not (message_commands or slash_commands):
raise ValueError("Both message_commands and slash_commands are disabled.")
if help_command is _default:
self.help_command = DefaultHelpCommand()
@@ -152,10 +273,59 @@ class BotBase(GroupMixin):
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
ev = "on_" + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def setup(self):
await self.create_slash_commands()
async def create_slash_commands(self):
commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list)
for command in self.commands:
if command.hidden or (command.slash_command is None and not self.slash_commands):
continue
try:
payload = command.to_application_command()
except Exception:
raise errors.ApplicationCommandRegistrationError(command)
if payload is None:
continue
guilds = command.slash_command_guilds or self.slash_command_guilds
if guilds is None:
commands[None].append(payload)
else:
for guild in guilds:
commands[guild].append(payload)
http: HTTPClient = self.http # type: ignore
global_commands = commands.pop(None, None)
application_id = self.application_id or (await self.application_info()).id # type: ignore
if global_commands is not None:
if self.slash_command_guilds is None:
await http.bulk_upsert_global_commands(
payload=global_commands,
application_id=application_id,
)
else:
for guild in self.slash_command_guilds:
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=global_commands,
application_id=application_id,
)
for guild, guild_commands in commands.items():
assert guild is not None
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=guild_commands,
application_id=application_id,
)
@discord.utils.copy_doc(discord.Client.close)
async def close(self) -> None:
for extension in tuple(self.__extensions):
@@ -182,7 +352,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get('on_command_error', None):
if self.extra_events.get("on_command_error", None):
return
command = context.command
@@ -193,7 +363,7 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler():
return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
# global check registration
@@ -344,14 +514,59 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
await self.populate_owners()
return await self.is_owner(user)
async def try_owners(self) -> List[discord.User]:
"""|coro|
Returns a list of :class:`~discord.User` representing the owners of the bot.
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
.. versionadded:: 2.0
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
Returns
--------
List[:class:`~discord.User`]
List of owners of the bot.
"""
if self.owner_id:
owner = await self.try_user(self.owner_id)
if owner:
return [owner]
else:
return []
elif self.owner_ids:
owners = []
for owner_id in self.owner_ids:
owner = await self.try_user(owner_id)
if owner:
owners.append(owner)
return owners
else:
# We didn't have owners cached yet, cache them and retry.
await self.populate_owners()
return await self.try_owners()
async def populate_owners(self):
"""|coro|
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
.. versionadded:: 2.0
"""
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
self.owner_ids = {m.id for m in app.team.members}
else:
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
self.owner_id = app.owner.id
def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
@@ -380,7 +595,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro
return coro
@@ -413,7 +628,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro
return coro
@@ -445,7 +660,7 @@ class BotBase(GroupMixin):
name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
raise TypeError("Listeners must be coroutines")
if name in self.extra_events:
self.extra_events[name].append(func)
@@ -541,14 +756,14 @@ class BotBase(GroupMixin):
"""
if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog')
raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name)
cog = cog._inject(self)
@@ -636,7 +851,7 @@ class BotBase(GroupMixin):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
func = getattr(lib, "teardown")
except AttributeError:
pass
else:
@@ -663,7 +878,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, 'setup')
setup = getattr(lib, "setup")
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
@@ -813,11 +1028,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)}
try:
# Unload and then load the module...
@@ -850,7 +1061,7 @@ class BotBase(GroupMixin):
def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
raise TypeError("help_command must be a subclass of HelpCommand")
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
@@ -880,6 +1091,9 @@ class BotBase(GroupMixin):
A list of prefixes or a single prefix that the bot is
listening for.
"""
if isinstance(message, _FakeSlashMessage):
return "/"
prefix = ret = self.command_prefix
if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message)
@@ -893,8 +1107,10 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}")
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
@@ -954,14 +1170,18 @@ class BotBase(GroupMixin):
except TypeError:
if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}")
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}")
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen
raise
@@ -988,19 +1208,19 @@ class BotBase(GroupMixin):
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch('command', ctx)
self.dispatch("command", ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure('The global check once functions failed.')
raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch('command_completion', ctx)
self.dispatch("command_completion", ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None:
"""|coro|
@@ -1030,9 +1250,95 @@ class BotBase(GroupMixin):
ctx = await self.get_context(message)
await self.invoke(ctx)
async def process_slash_commands(self, interaction: discord.Interaction):
"""|coro|
This function processes a slash command interaction into a usable
message and calls :meth:`.process_commands` based on it. Without this
coroutine slash commands will not be triggered.
By default, this coroutine is called inside the :func:`.on_interaction`
event. If you choose to override the :func:`.on_interaction` event,
then you should invoke this coroutine as well.
.. versionadded:: 2.0
Parameters
-----------
interaction: :class:`discord.Interaction`
The interaction to process slash commands for.
"""
if interaction.type != discord.InteractionType.application_command:
return
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
command_name, command_options = _unwrap_slash_groups(interaction.data)
command = self.get_command(command_name)
if command is None:
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
# Ensure the interaction channel is usable
channel = interaction.channel
if channel is None or isinstance(channel, discord.PartialMessageable):
if interaction.guild is None:
assert interaction.user is not None
channel = await interaction.user.create_dm()
elif interaction.channel_id is not None:
channel = await interaction.guild.fetch_channel(interaction.channel_id)
else:
return # cannot do anything without stable channel
# Make our fake message so we can pass it to ext.commands
message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore
message.content = f"/{command_name} "
# Add arguments to fake message content, in the right order
ignore_params: List[inspect.Parameter] = []
for name, param in command.clean_params.items():
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
for name, flag in param.annotation.get_flags().items():
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if flag.required:
raise errors.MissingRequiredFlag(flag)
else:
prefix = param.annotation.__commands_flag_prefix__
delimiter = param.annotation.__commands_flag_delimiter__
message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore
continue
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
else:
ignore_params.append(param)
elif (
option["type"] == 3
and not isinstance(param.annotation, Greedy)
and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}
):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
message.content += f"{_quote_string_safe(option['value'])} "
else:
message.content += f'{option.get("value", "")} '
ctx = await self.get_context(message)
ctx._ignored_params = ignore_params
ctx.interaction = interaction
await self.invoke(ctx)
async def on_message(self, message):
await self.process_commands(message)
async def on_interaction(self, interaction: discord.Interaction):
await self.process_slash_commands(interaction)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@@ -1075,7 +1381,7 @@ class Bot(BotBase, discord.Client):
when passing an empty string, it should always be last as no prefix
after it will be matched.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This
Whether the commands should be case insensitive. Defaults to ``True``. This
attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well.
description: :class:`str`
@@ -1102,11 +1408,36 @@ class Bot(BotBase, discord.Client):
the ``command_prefix`` is set to ``!``. Defaults to ``False``.
.. versionadded:: 1.7
message_commands: Optional[:class:`bool`]
Whether to process commands based on messages.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``message_command`` parameter
.. versionadded:: 2.0
slash_commands: Optional[:class:`bool`]
Whether to upload and process slash commands.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command`` parameter
.. versionadded:: 2.0
slash_command_guilds: Optional[:class:`List[int]`]
If this is set, only upload slash commands to these guild IDs.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command_guilds`` parameter
.. versionadded:: 2.0
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.
"""
pass

View File

@@ -36,15 +36,16 @@ if TYPE_CHECKING:
from .core import Command
__all__ = (
'CogMeta',
'Cog',
"CogMeta",
"Cog",
)
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@@ -104,6 +105,7 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
@@ -111,17 +113,17 @@ class CogMeta(type):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
description = kwargs.pop('description', None)
description = kwargs.pop("description", None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
@@ -136,17 +138,17 @@ class CogMeta(type):
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
if elem.startswith(('cog_', 'bot_')):
raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.")
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
getattr(value, "__cog_listener__")
except AttributeError:
continue
else:
if elem.startswith(('cog_', 'bot_')):
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
@@ -169,10 +171,12 @@ class CogMeta(type):
def qualified_name(cls) -> str:
return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
@@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
@@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
@@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
@@ -274,7 +277,7 @@ class Cog(metaclass=CogMeta):
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
return getattr(method.__func__, "__cog_special_method__", method)
@classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
@@ -296,14 +299,14 @@ class Cog(metaclass=CogMeta):
"""
if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.")
def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
@@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
@@ -322,7 +326,7 @@ class Cog(metaclass=CogMeta):
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
@_cog_special_method
def cog_unload(self) -> None:

View File

@@ -25,13 +25,14 @@ from __future__ import annotations
import inspect
import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from datetime import timedelta
from typing import Any, Dict, Generic, List, Literal, Optional, TYPE_CHECKING, TypeVar, Union, overload
import discord.abc
import discord.utils
from discord.message import Message
from discord import Permissions
if TYPE_CHECKING:
from typing_extensions import ParamSpec
@@ -41,6 +42,8 @@ if TYPE_CHECKING:
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.webhook import WebhookMessage
from discord.interactions import Interaction
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
@@ -49,21 +52,19 @@ if TYPE_CHECKING:
from .help import HelpCommand
from .view import StringView
__all__ = (
'Context',
)
__all__ = ("Context",)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
P = ParamSpec("P")
else:
P = TypeVar('P')
P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]):
@@ -121,8 +122,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
A boolean that indicates if the command failed to be parsed, checked,
or invoked.
"""
interaction: Optional[Interaction] = None
def __init__(self,
def __init__(
self,
*,
message: Message,
bot: BotT,
@@ -151,6 +154,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._ignored_params: List[inspect.Parameter] = []
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
@@ -219,7 +223,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command
view = self.view
if cmd is None:
raise ValueError('This context is not valid.')
raise ValueError("This context is not valid.")
# some state to revert to when we're done
index, previous = view.index, view.previous
@@ -230,7 +234,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '')
view.index = len(self.prefix or "")
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
@@ -263,7 +267,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0
"""
if self.prefix is None:
return ''
return ""
user = self.me
# this breaks if the prefix mention is not the bot itself but I
@@ -271,7 +275,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
@property
def cog(self) -> Optional[Cog]:
@@ -314,6 +318,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
g = self.guild
return g.voice_client if g else None
def author_permissions(self) -> Permissions:
"""Returns the author permissions in the given channel.
.. versionadded:: 2.0
"""
return self.channel.permissions_for(self.author)
async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>)
@@ -381,7 +392,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, '__cog_commands__'):
if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):
@@ -395,6 +406,97 @@ class Context(discord.abc.Messageable, Generic[BotT]):
except CommandError as e:
await cmd.on_help_command_error(self, e)
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[False] = False,
ephemeral: bool = False,
**kwargs: Any,
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[True] = True,
ephemeral: bool = False,
**kwargs: Any,
) -> Union[Message, WebhookMessage]:
...
async def send(
self, content: Optional[str] = None, return_message: bool = True, ephemeral: bool = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
"""
|coro|
A shortcut method to :meth:`.abc.Messageable.send` with interaction helpers.
This function takes all the parameters of :meth:`.abc.Messageable.send` plus the following:
Parameters
------------
return_message: :class:`bool`
Ignored if not in a slash command context.
If this is set to False more native interaction methods will be used.
ephemeral: :class:`bool`
Ignored if not in a slash command context.
Indicates if the message should only be visible to the user who started the interaction.
If a view is sent with an ephemeral message and it has no timeout set then the timeout
is set to 15 minutes.
Returns
--------
Optional[Union[:class:`.Message`, :class:`.WebhookMessage`]]
In a slash command context, the message that was sent if return_message is True.
In a normal context, it always returns a :class:`.Message`
"""
if self.interaction is None or (
self.interaction.response.responded_at is not None
and discord.utils.utcnow() - self.interaction.response.responded_at >= timedelta(minutes=15)
):
return await super().send(content, **kwargs)
# Remove unsupported arguments from kwargs
kwargs.pop("nonce", None)
kwargs.pop("stickers", None)
kwargs.pop("reference", None)
kwargs.pop("delete_after", None)
kwargs.pop("mention_author", None)
if not (
return_message
or self.interaction.response.is_done()
or any(arg in kwargs for arg in ("file", "files", "allowed_mentions"))
):
send = self.interaction.response.send_message
else:
# We have to defer in order to use the followup webhook
if not self.interaction.response.is_done():
await self.interaction.response.defer(ephemeral=ephemeral)
send = self.interaction.followup.send
return await send(content, ephemeral=ephemeral, **kwargs) # type: ignore
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[False] = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[True] = True, **kwargs: Any
) -> Union[Message, WebhookMessage]:
...
@discord.utils.copy_doc(Message.reply)
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
return await self.message.reply(content, **kwargs)
async def reply(
self, content: Optional[str] = None, return_message: bool = True, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
return await self.send(content, return_message=return_message, reference=self.message, **kwargs) # type: ignore

View File

@@ -52,32 +52,33 @@ if TYPE_CHECKING:
__all__ = (
'Converter',
'ObjectConverter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'clean_content',
'Greedy',
'run_converters',
"Converter",
"ObjectConverter",
"MemberConverter",
"UserConverter",
"MessageConverter",
"PartialMessageConverter",
"TextChannelConverter",
"InviteConverter",
"GuildConverter",
"RoleConverter",
"GameConverter",
"ColourConverter",
"ColorConverter",
"VoiceChannelConverter",
"StageChannelConverter",
"EmojiConverter",
"PartialEmojiConverter",
"CategoryChannelConverter",
"IDConverter",
"StoreChannelConverter",
"ThreadConverter",
"GuildChannelConverter",
"GuildStickerConverter",
"clean_content",
"Greedy",
"Option",
"run_converters",
)
@@ -91,10 +92,12 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar("TT", bound=discord.Thread)
DT = TypeVar("DT", bound=str)
@runtime_checkable
@@ -132,10 +135,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError('Derived classes need to implement this.')
raise NotImplementedError("Derived classes need to implement this.")
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
class IDConverter(Converter[T_co]):
@@ -158,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
if match is None:
raise ObjectNotFound(argument)
@@ -192,8 +195,8 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
else:
@@ -223,7 +226,7 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
guild = ctx.guild
result = None
user_id = None
@@ -232,13 +235,13 @@ class MemberConverter(IDConverter[discord.Member]):
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, 'get_member_named', argument)
result = _get_from_guilds(bot, "get_member_named", argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
else:
result = _get_from_guilds(bot, 'get_member', user_id)
result = _get_from_guilds(bot, "get_member", user_id)
if result is None:
if guild is None:
@@ -276,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
result = None
state = ctx._state
@@ -294,12 +297,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
if arg[0] == "@":
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
if len(arg) > 5 and arg[-5] == "#":
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
@@ -330,22 +333,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?P<guild_id>[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data["message_id"])
guild_id = data.get("guild_id")
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me':
elif guild_id == "@me":
guild_id = None
else:
guild_id = int(guild_id)
@@ -417,13 +420,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None
guild = ctx.guild
@@ -443,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
result = _get_from_guilds(bot, "get_channel", channel_id)
if not isinstance(result, type):
raise ChannelNotFound(argument)
@@ -454,7 +457,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None
guild = ctx.guild
@@ -491,7 +494,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@@ -511,7 +514,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@@ -530,7 +533,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@@ -550,7 +553,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@@ -569,7 +572,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
class ThreadConverter(IDConverter[discord.Thread]):
@@ -583,11 +586,11 @@ class ThreadConverter(IDConverter[discord.Thread]):
2. Lookup by mention.
3. Lookup by name.
.. versionadded: 2.0
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
class ColourConverter(Converter[discord.Colour]):
@@ -616,10 +619,10 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
@@ -630,7 +633,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == '%':
if number[-1] == "%":
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
@@ -646,29 +649,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group("b"))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == '#':
if argument[0] == "#":
return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x':
if argument[0:2] == "0x":
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith('#'):
if rest.startswith("#"):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == 'rgb':
if arg[0:3] == "rgb":
return self.parse_rgb(arg)
arg = arg.replace(' ', '_')
arg = arg.replace(" ", "_")
method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
@@ -697,7 +700,7 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
if match:
result = guild.get_role(int(match.group(1)))
else:
@@ -776,7 +779,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
result = None
bot = ctx.bot
guild = ctx.guild
@@ -810,7 +813,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
if match:
emoji_animated = bool(match.group(1))
@@ -903,37 +906,37 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user'
return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role'
return f"@{r.name}" if r else "@deleted-role"
else:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user'
return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
return '@deleted-role'
return "@deleted-role"
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel'
return f"#{c.name}" if c else "#deleted-channel"
else:
def resolve_channel(id: int) -> str:
return f'<#{id}>'
return f"<#{id}>"
transforms = {
'@': resolve_member,
'@!': resolve_member,
'#': resolve_channel,
'@&': resolve_role,
"@": resolve_member,
"@!": resolve_member,
"#": resolve_channel,
"@&": resolve_role,
}
def repl(match: re.Match) -> str:
@@ -942,7 +945,7 @@ class clean_content(Converter[str]):
transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
@@ -974,42 +977,86 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ('converter',)
__slots__ = ("converter",)
def __init__(self, *, converter: T):
self.converter = converter
def __repr__(self):
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
converter = getattr(self.converter, "__name__", repr(self.converter))
return f"Greedy[{converter}]"
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
raise TypeError("Greedy[...] only takes a single argument")
converter = params[0]
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
raise TypeError(f"Greedy[{converter!r}] is invalid.")
return cls(converter=converter)
class Option(Generic[T, DT]): # type: ignore
"""A special 'converter' to apply a description to slash command options.
For example in the following code:
.. code-block:: python3
@bot.command()
async def ban(ctx,
member: discord.Member, *,
reason: str = commands.Option('no reason', description='the reason to ban this member')
):
await member.ban(reason=reason)
The description would be ``the reason to ban this member`` and the default would be ``no reason``
.. versionadded:: 2.0
Attributes
------------
default: Optional[Any]
The default for this option, overwrites Option during parsing.
description: :class:`str`
The description for this option, is unpacked to :attr:`.Command.option_descriptions`
"""
description: DT
default: Union[T, inspect._empty]
__slots__ = (
"default",
"description",
)
def __init__(self, default: T = inspect.Parameter.empty, *, description: DT) -> None:
self.description = description
self.default = default
if TYPE_CHECKING:
# Terrible workaround for type checking reasons
def Option(default: T = inspect.Parameter.empty, *, description: str) -> T:
...
def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
return False
else:
raise BadBoolArgument(lowered)
@@ -1065,7 +1112,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@@ -1124,7 +1171,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any
The resulting conversion.
"""
origin = getattr(converter, '__origin__', None)
origin = getattr(converter, "__origin__", None)
if origin is Union:
errors = []

View File

@@ -38,15 +38,16 @@ if TYPE_CHECKING:
from ...message import Message
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
)
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum):
default = 0
@@ -90,7 +91,7 @@ class Cooldown:
The length of the cooldown period in seconds.
"""
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
@@ -190,7 +191,8 @@ class Cooldown:
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping:
def __init__(
@@ -199,7 +201,7 @@ class CooldownMapping:
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
@@ -256,13 +258,9 @@ class CooldownMapping:
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
@@ -278,6 +276,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.
@@ -291,7 +290,7 @@ class _Semaphore:
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
__slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None:
self.value: int = number
@@ -299,7 +298,7 @@ class _Semaphore:
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool:
return self.value == 0
@@ -337,8 +336,9 @@ class _Semaphore:
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
__slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
@@ -347,16 +347,16 @@ class MaxConcurrency:
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def get_key(self, message: Message) -> Any:
return self.per.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@@ -33,6 +33,7 @@ if TYPE_CHECKING:
from .converter import Converter
from .context import Context
from .core import Command
from .cooldowns import Cooldown, BucketType
from .flags import Flag
from discord.abc import GuildChannel
@@ -41,65 +42,67 @@ if TYPE_CHECKING:
__all__ = (
'CommandError',
'MissingRequiredArgument',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
'CheckFailure',
'CheckAnyFailure',
'CommandNotFound',
'DisabledCommand',
'CommandInvokeError',
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner',
'MessageNotFound',
'ObjectNotFound',
'MemberNotFound',
'GuildNotFound',
'UserNotFound',
'ChannelNotFound',
'ThreadNotFound',
'ChannelNotReadable',
'BadColourArgument',
'BadColorArgument',
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
'BotMissingRole',
'MissingAnyRole',
'BotMissingAnyRole',
'MissingPermissions',
'BotMissingPermissions',
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
'ExpectedClosingQuoteError',
'ExtensionError',
'ExtensionAlreadyLoaded',
'ExtensionNotLoaded',
'NoEntryPointError',
'ExtensionFailed',
'ExtensionNotFound',
'CommandRegistrationError',
'FlagError',
'BadFlagArgument',
'MissingFlagArgument',
'TooManyFlags',
'MissingRequiredFlag',
"CommandError",
"MissingRequiredArgument",
"BadArgument",
"PrivateMessageOnly",
"NoPrivateMessage",
"CheckFailure",
"CheckAnyFailure",
"CommandNotFound",
"DisabledCommand",
"CommandInvokeError",
"TooManyArguments",
"UserInputError",
"CommandOnCooldown",
"MaxConcurrencyReached",
"NotOwner",
"MessageNotFound",
"ObjectNotFound",
"MemberNotFound",
"GuildNotFound",
"UserNotFound",
"ChannelNotFound",
"ThreadNotFound",
"ChannelNotReadable",
"BadColourArgument",
"BadColorArgument",
"RoleNotFound",
"BadInviteArgument",
"EmojiNotFound",
"GuildStickerNotFound",
"PartialEmojiConversionFailure",
"BadBoolArgument",
"MissingRole",
"BotMissingRole",
"MissingAnyRole",
"BotMissingAnyRole",
"MissingPermissions",
"BotMissingPermissions",
"NSFWChannelRequired",
"ConversionError",
"BadUnionArgument",
"BadLiteralArgument",
"ArgumentParsingError",
"UnexpectedQuoteError",
"InvalidEndOfQuotedStringError",
"ExpectedClosingQuoteError",
"ExtensionError",
"ExtensionAlreadyLoaded",
"ExtensionNotLoaded",
"NoEntryPointError",
"ExtensionFailed",
"ExtensionNotFound",
"CommandRegistrationError",
"ApplicationCommandRegistrationError",
"FlagError",
"BadFlagArgument",
"MissingFlagArgument",
"TooManyFlags",
"MissingRequiredFlag",
)
class CommandError(DiscordException):
r"""The base exception type for all command related errors.
@@ -109,14 +112,16 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`.
"""
def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args)
else:
super().__init__(*args)
class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError.
@@ -130,18 +135,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
self.original: Exception = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
regarding user input.
This inherits from :exc:`CommandError`.
"""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked
but no command under that name is found.
@@ -151,8 +160,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`.
"""
pass
class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter
that is required is not encountered.
@@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param: Parameter) -> None:
self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.')
super().__init__(f"{param.name} is a required argument that is missing.")
class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
@@ -174,23 +187,29 @@ class TooManyArguments(UserInputError):
This inherits from :exc:`UserInputError`
"""
pass
class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command.
This inherits from :exc:`UserInputError`
"""
pass
class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError`
"""
pass
class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
@@ -209,7 +228,8 @@ class CheckAnyFailure(CheckFailure):
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.')
super().__init__("You do not have permission to run this command.")
class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private
@@ -217,8 +237,10 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.')
super().__init__(message or "This command can only be used in private messages.")
class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message
@@ -228,15 +250,18 @@ class NoPrivateMessage(CheckFailure):
"""
def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.')
super().__init__(message or "This command cannot be used in private messages.")
class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure`
"""
pass
class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format
of an ID or a mention.
@@ -250,9 +275,11 @@ class ObjectNotFound(BadArgument):
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's
@@ -267,10 +294,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache.
@@ -283,10 +312,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's
cache.
@@ -300,10 +331,12 @@ class UserNotFound(BadArgument):
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel.
@@ -316,10 +349,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages
in the channel.
@@ -333,10 +368,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable
"""
def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel.
@@ -349,10 +386,12 @@ class ChannelNotFound(BadArgument):
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread.
@@ -365,10 +404,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid.
@@ -381,12 +422,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role.
@@ -399,21 +443,30 @@ class RoleNotFound(BadArgument):
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The invite supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji.
@@ -426,10 +479,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct
format.
@@ -443,10 +498,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
@@ -459,10 +516,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@@ -475,17 +534,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option')
super().__init__(f"{argument} is not a recognised boolean option")
class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError`
"""
pass
class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception.
@@ -497,9 +560,11 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e: Exception) -> None:
self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}')
super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown.
@@ -516,11 +581,13 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
@@ -539,10 +606,11 @@ class MaxConcurrencyReached(CommandError):
self.number: int = number
self.per: BucketType = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
suffix = "per %s" % name if per.name != "default" else "globally"
plural = "%s times %s" if number > 1 else "%s time %s"
fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.')
super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.")
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.
@@ -557,11 +625,13 @@ class MissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.'
message = f"Role {missing_role!r} is required to run this command."
super().__init__(message)
class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command.
@@ -575,11 +645,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command'
message = f"Bot requires the role {missing_role!r} to run this command"
super().__init__(message)
class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of
the roles specified to run a command.
@@ -594,15 +666,16 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"You are missing at least one of the required roles: {fmt}"
super().__init__(message)
@@ -623,19 +696,21 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
fmt = " or ".join(missing)
message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message)
class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting.
@@ -648,10 +723,12 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled.
"""
def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a
command.
@@ -663,18 +740,20 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'You are missing {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"You are missing {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a
command.
@@ -686,18 +765,20 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`]
The required permissions that are missing.
"""
def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions]
missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = f'Bot requires {fmt} permission(s) to run this command.'
fmt = " and ".join(missing)
message = f"Bot requires {fmt} permission(s) to run this command."
super().__init__(message, *args)
class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all
its associated types.
@@ -713,6 +794,7 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
@@ -722,18 +804,19 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
if hasattr(x, '__origin__'):
if hasattr(x, "__origin__"):
return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
@@ -751,6 +834,7 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
@@ -758,12 +842,13 @@ class BadLiteralArgument(UserInputError):
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.
@@ -772,8 +857,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for
i18n purposes.
"""
pass
class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
@@ -784,9 +871,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote: str) -> None:
self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string
@@ -799,9 +888,11 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char: str) -> None:
self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}')
super().__init__(f"Expected space after closing quotation but received {char!r}")
class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found.
@@ -816,7 +907,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.')
super().__init__(f"Expected closing {close_quote}.")
class ExtensionError(DiscordException):
"""Base exception for extension related errors.
@@ -828,37 +920,45 @@ class ExtensionError(DiscordException):
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name
message = message or f'Extension {name!r} had an error.'
message = message or f"Extension {name!r} had an error."
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name)
super().__init__(f"Extension {name!r} is already loaded.", name=name)
class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name)
super().__init__(f"Extension {name!r} has not been loaded.", name=name)
class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@@ -872,11 +972,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}'
msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found.
@@ -890,10 +992,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str`
The extension that had the error.
"""
def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.'
msg = f"Extension {name!r} could not be loaded."
super().__init__(msg, name=name)
class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added
because the name is already taken by a different command.
@@ -909,11 +1013,32 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name
self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.')
type_ = "alias" if alias_conflict else "command"
super().__init__(f"The {type_} {name} is already an existing command or alias.")
class ApplicationCommandRegistrationError(ClientException):
"""An exception raised when a command cannot be converted to an
application command.
This inherits from :exc:`discord.ClientException`
.. versionadded:: 2.0
Attributes
----------
command: :class:`Command`
The command that failed to be converted.
"""
def __init__(self, command: Command, msg: str = None) -> None:
self.command = command
super().__init__(msg or f"{command.qualified_name} failed to converted to an application command.")
class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors.
@@ -922,8 +1047,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0
"""
pass
class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values.
@@ -938,10 +1065,12 @@ class TooManyFlags(FlagError):
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.")
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
@@ -955,6 +1084,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
@@ -962,7 +1092,8 @@ class BadFlagArgument(FlagError):
except AttributeError:
name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given.
@@ -976,9 +1107,11 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
super().__init__(f"Flag {flag.name!r} is required and missing")
class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value.
@@ -992,6 +1125,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')
super().__init__(f"Flag {flag.name!r} does not have an argument")

View File

@@ -59,9 +59,9 @@ import sys
import re
__all__ = (
'Flag',
'flag',
'FlagConverter',
"Flag",
"flag",
"FlagConverter",
)
@@ -148,20 +148,20 @@ def flag(
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError('flag names should not be empty')
raise ValueError("flag names should not be empty")
for ch in name:
if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
@@ -178,7 +178,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default
if flag.aliases is MISSING:
@@ -229,7 +233,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag")
if flag.override is MISSING:
flag.override = False
@@ -237,7 +241,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
else:
names.add(name)
@@ -245,7 +249,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.")
else:
names.add(alias)
@@ -274,10 +278,10 @@ class FlagsMeta(type):
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs['__commands_is_flag__'] = True
attrs["__commands_is_flag__"] = True
try:
global_ns = sys.modules[attrs['__module__']].__dict__
global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError:
global_ns = {}
@@ -296,26 +300,26 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"]
if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"]
if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive
attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter
attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix
attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
@@ -337,11 +341,11 @@ class FlagsMeta(type):
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
joined = "|".join(keys)
pattern = re.compile(f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs)
@@ -432,7 +436,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
F = TypeVar("F", bound="FlagConverter")
class FlagConverter(metaclass=FlagsMeta):
@@ -493,8 +497,8 @@ class FlagConverter(metaclass=FlagsMeta):
return self
def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()])
return f"<{self.__class__.__name__} {pairs}>"
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@@ -507,7 +511,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
key = match.group("flag")
if case_insensitive:
key = key.casefold()

View File

@@ -39,10 +39,10 @@ if TYPE_CHECKING:
from .context import Context
__all__ = (
'Paginator',
'HelpCommand',
'DefaultHelpCommand',
'MinimalHelpCommand',
"Paginator",
"HelpCommand",
"DefaultHelpCommand",
"MinimalHelpCommand",
)
# help -> shows info of bot on top/bottom and lists subcommands
@@ -89,7 +89,7 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
@@ -118,7 +118,7 @@ class Paginator:
def _linesep_len(self):
return len(self.linesep)
def add_line(self, line='', *, empty=False):
def add_line(self, line="", *, empty=False):
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@@ -138,7 +138,7 @@ class Paginator:
"""
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
if len(line) > max_page_size:
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
self.close_page()
@@ -147,7 +147,7 @@ class Paginator:
self._current_page.append(line)
if empty:
self._current_page.append('')
self._current_page.append("")
self._count += self._linesep_len
def close_page(self):
@@ -176,7 +176,7 @@ class Paginator:
return self._pages
def __repr__(self):
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
return fmt.format(self)
@@ -197,7 +197,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback
on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overriden__'):
if not hasattr(on_error, "__help_command_not_overriden__"):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
@@ -224,7 +224,7 @@ class _HelpCommandImpl(Command):
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError('Missing context parameter') from None
raise ValueError("Missing context parameter") from None
else:
return result
@@ -296,13 +296,13 @@ class HelpCommand:
"""
MENTION_TRANSFORMS = {
'@everyone': '@\u200beveryone',
'@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role',
"@everyone": "@\u200beveryone",
"@here": "@\u200bhere",
r"<@!?[0-9]{17,22}>": "@deleted-user",
r"<@&[0-9]{17,22}>": "@deleted-role",
}
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing
@@ -321,11 +321,11 @@ class HelpCommand:
return self
def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault("name", "help")
attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
@@ -422,20 +422,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + ' ' + parent.signature)
entries.append(parent.name + " " + parent.signature)
parent = parent.parent
parent_sig = ' '.join(reversed(entries))
parent_sig = " ".join(reversed(entries))
if len(command.aliases) > 0:
aliases = '|'.join(command.aliases)
fmt = f'[{command.name}|{aliases}]'
aliases = "|".join(command.aliases)
fmt = f"[{command.name}|{aliases}]"
if parent_sig:
fmt = parent_sig + ' ' + fmt
fmt = parent_sig + " " + fmt
alias = fmt
else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
alias = command.name if not parent_sig else parent_sig + " " + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}'
return f"{self.context.clean_prefix}{alias} {command.signature}"
def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse.
@@ -449,7 +449,7 @@ class HelpCommand:
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), '@invalid')
return transforms.get(obj.group(0), "@invalid")
return self.MENTION_PATTERN.sub(replace, string)
@@ -615,7 +615,7 @@ class HelpCommand:
:class:`.abc.Messageable`
The destination where the help command will be output.
"""
return self.context.channel
return self.context
async def send_error_message(self, error):
"""|coro|
@@ -846,7 +846,7 @@ class HelpCommand:
# Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through
# the command group chain ourselves.
keys = command.split(' ')
keys = command.split(" ")
cmd = bot.all_commands.get(keys[0])
if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
@@ -907,14 +907,14 @@ class DefaultHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.width = options.pop("width", 80)
self.indent = options.pop("indent", 2)
self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator()
@@ -924,7 +924,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...'
return text[: self.width - 3].rstrip() + "..."
return text
def get_ending_note(self):
@@ -977,6 +977,10 @@ class DefaultHelpCommand(HelpCommand):
for page in self.paginator.pages:
await destination.send(page)
interaction = self.context.interaction
if interaction is not None and destination == self.context.author and not interaction.response.is_done():
await interaction.response.send_message("Sent help to your DMs!", ephemeral=True)
def add_command_formatting(self, command):
"""A utility function to format the non-indented block of commands and groups.
@@ -1007,7 +1011,7 @@ class DefaultHelpCommand(HelpCommand):
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
return ctx.author
else:
return ctx.channel
return ctx
async def prepare_help_command(self, ctx, command):
self.paginator.clear()
@@ -1021,11 +1025,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion
self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:'
no_category = f"\u200b{self.no_category}:"
def get_category(command, *, no_category=no_category):
cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category
return cog.qualified_name + ":" if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered)
@@ -1110,13 +1114,13 @@ class MinimalHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
@@ -1149,7 +1153,7 @@ class MinimalHelpCommand(HelpCommand):
)
def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
@@ -1180,8 +1184,8 @@ class MinimalHelpCommand(HelpCommand):
"""
if commands:
# U+2002 Middle Dot
joined = '\u2002'.join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__')
joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f"__**{heading}**__")
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
@@ -1197,7 +1201,7 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command`
The command to show information of.
"""
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases):
@@ -1268,7 +1272,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}'
no_category = f"\u200b{self.no_category}"
def get_category(command, *, no_category=no_category):
cog = command.cog
@@ -1302,7 +1306,7 @@ class MinimalHelpCommand(HelpCommand):
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
if filtered:
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)
@@ -1322,7 +1326,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
self.paginator.add_line(f'**{self.commands_heading}**')
self.paginator.add_line(f"**{self.commands_heading}**")
for command in filtered:
self.add_subcommand_formatting(command)

View File

@@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
_quotes = {
supported_quotes = {
'"': '"',
"": "",
"": "",
@@ -44,7 +44,8 @@ _quotes = {
"": "",
"": "",
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
_all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values())
class StringView:
def __init__(self, buffer):
@@ -129,7 +130,7 @@ class StringView:
if current is None:
return None
close_quote = _quotes.get(current)
close_quote = supported_quotes.get(current)
is_quoted = bool(close_quote)
if is_quoted:
result = []
@@ -144,11 +145,11 @@ class StringView:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return ''.join(result)
return "".join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\':
if current == "\\":
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
@@ -156,7 +157,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return ''.join(result)
return "".join(result)
if next_char in _escaped_quotes:
# escaped quote
@@ -179,14 +180,13 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return ''.join(result)
return "".join(result)
if current.isspace() and not is_quoted:
# end of word found
return ''.join(result)
return "".join(result)
result.append(current)
def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"

View File

@@ -48,19 +48,17 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
__all__ = (
'loop',
)
__all__ = ("loop",)
T = TypeVar('T')
T = TypeVar("T")
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
__slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
@@ -124,7 +122,7 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
@@ -132,10 +130,10 @@ class Loop(Generic[LF]):
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name)
coro = getattr(self, "_" + name)
if coro is None:
return
@@ -150,7 +148,7 @@ class Loop(Generic[LF]):
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
await self._call_loop_function("before_loop")
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
@@ -193,10 +191,10 @@ class Loop(Generic[LF]):
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function('error', exc)
await self._call_loop_function("error", exc)
raise exc
finally:
await self._call_loop_function('after_loop')
await self._call_loop_function("after_loop")
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
@@ -323,7 +321,7 @@ class Loop(Generic[LF]):
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None:
args = (self._injected, *args)
@@ -410,9 +408,9 @@ class Loop(Generic[LF]):
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f'{exc!r} must be a class.')
raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException):
raise TypeError(f'{exc!r} must inherit from BaseException.')
raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions)
@@ -466,7 +464,7 @@ class Loop(Generic[LF]):
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro: FT) -> FT:
@@ -489,7 +487,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._before_loop = coro
return coro
@@ -517,7 +515,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._after_loop = coro
return coro
@@ -543,7 +541,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._error = coro # type: ignore
return coro
@@ -601,16 +599,16 @@ class Loop(Generic[LF]):
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
@@ -663,7 +661,7 @@ class Loop(Generic[LF]):
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep
self._seconds = float(seconds)
@@ -672,7 +670,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time')
raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING

View File

@@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union
import os
import io
__all__ = (
'File',
)
__all__ = ("File",)
class File:
@@ -64,7 +62,7 @@ class File:
Whether the attachment is a spoiler.
"""
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer')
__slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer")
if TYPE_CHECKING:
fp: io.BufferedIOBase
@@ -80,12 +78,12 @@ class File:
):
if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()):
raise ValueError(f'File buffer {fp!r} must be seekable and readable')
raise ValueError(f"File buffer {fp!r} must be seekable and readable")
self.fp = fp
self._original_pos = fp.tell()
self._owner = False
else:
self.fp = open(fp, 'rb')
self.fp = open(fp, "rb")
self._original_pos = 0
self._owner = True
@@ -100,14 +98,14 @@ class File:
if isinstance(fp, str):
_, self.filename = os.path.split(fp)
else:
self.filename = getattr(fp, 'name', None)
self.filename = getattr(fp, "name", None)
else:
self.filename = filename
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
self.filename = 'SPOILER_' + self.filename
if spoiler and self.filename is not None and not self.filename.startswith("SPOILER_"):
self.filename = "SPOILER_" + self.filename
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_'))
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith("SPOILER_"))
def reset(self, *, seek: Union[int, bool] = True) -> None:
# The `seek` parameter is needed because

View File

@@ -29,16 +29,16 @@ from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optio
from .enums import UserFlags
__all__ = (
'SystemChannelFlags',
'MessageFlags',
'PublicUserFlags',
'Intents',
'MemberCacheFlags',
'ApplicationFlags',
"SystemChannelFlags",
"MessageFlags",
"PublicUserFlags",
"Intents",
"MemberCacheFlags",
"ApplicationFlags",
)
FV = TypeVar('FV', bound='flag_value')
BF = TypeVar('BF', bound='BaseFlags')
FV = TypeVar("FV", bound="flag_value")
BF = TypeVar("BF", bound="BaseFlags")
class flag_value:
@@ -63,7 +63,7 @@ class flag_value:
instance._set_flag(self.flag, value)
def __repr__(self):
return f'<flag_value flag={self.flag!r}>'
return f"<flag_value flag={self.flag!r}>"
class alias_flag_value(flag_value):
@@ -98,13 +98,13 @@ class BaseFlags:
value: int
__slots__ = ('value',)
__slots__ = ("value",)
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@@ -123,7 +123,7 @@ class BaseFlags:
return hash(self.value)
def __repr__(self) -> str:
return f'<{self.__class__.__name__} value={self.value}>'
return f"<{self.__class__.__name__} value={self.value}>"
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items():
@@ -142,7 +142,7 @@ class BaseFlags:
elif toggle is False:
self.value &= ~o
else:
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.')
raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.")
@fill_with_flags(inverted=True)
@@ -196,7 +196,7 @@ class SystemChannelFlags(BaseFlags):
elif toggle is False:
self.value |= o
else:
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
raise TypeError("Value to set for SystemChannelFlags must be a bool.")
@flag_value
def join_notifications(self):
@@ -461,7 +461,7 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@@ -907,7 +907,7 @@ class MemberCacheFlags(BaseFlags):
self.value = (1 << bits) - 1
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)
@classmethod
@@ -977,10 +977,10 @@ class MemberCacheFlags(BaseFlags):
def _verify_intents(self, intents: Intents):
if self.voice and not intents.voice_states:
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states')
raise ValueError("MemberCacheFlags.voice requires Intents.voice_states")
if self.joined and not intents.members:
raise ValueError('MemberCacheFlags.joined requires Intents.members')
raise ValueError("MemberCacheFlags.joined requires Intents.members")
@property
def _voice_only(self):

View File

@@ -22,8 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
TypedDict,
Any,
Optional,
List,
TypeVar,
Type,
Dict,
Callable,
Coroutine,
NamedTuple,
Deque,
)
import asyncio
from collections import namedtuple, deque
from collections import deque
import concurrent.futures
import logging
import struct
@@ -40,46 +57,77 @@ from .activity import BaseActivity
from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
T = TypeVar("T")
DWS = TypeVar("DWS", bound="DiscordWebSocket")
DVWS = TypeVar("DVWS", bound="DiscordVoiceWebSocket")
Coro = Callable[..., Coroutine[Any, Any, Any]]
Predicate = Callable[[Dict[str, Any]], bool]
DataCallable = Callable[[Dict[str, Any]], T]
Result = Optional[DataCallable[Any]]
_log: logging.Logger = logging.getLogger(__name__)
__all__ = (
'DiscordWebSocket',
'KeepAliveHandler',
'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket',
'ReconnectWebSocket',
"DiscordWebSocket",
"KeepAliveHandler",
"VoiceKeepAliveHandler",
"DiscordVoiceWebSocket",
"ReconnectWebSocket",
)
class Heartbeat(TypedDict):
op: int
d: int
class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True):
self.shard_id = shard_id
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id: Optional[int] = shard_id
self.resume: bool = resume
self.op = "RESUME" if resume else "IDENTIFY"
class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class EventListener(NamedTuple):
predicate: Predicate
event: str
result: Result
future: asyncio.Future
class GatewayRatelimiter:
def __init__(self, count=110, per=60.0):
def __init__(self, count: int = 110, per: float = 60.0) -> None:
# The default is 110 to give room for at least 10 heartbeats per minute
self.max = count
self.remaining = count
self.window = 0.0
self.per = per
self.lock = asyncio.Lock()
self.shard_id = None
self.max: int = count
self.remaining: int = count
self.window: float = 0.0
self.per: float = per
self.lock: asyncio.Lock = asyncio.Lock()
self.shard_id: Optional[int] = None
def is_ratelimited(self):
def is_ratelimited(self) -> bool:
current = time.time()
if current > self.window + self.per:
return False
return self.remaining == 0
def get_delay(self):
def get_delay(self) -> float:
current = time.time()
if current > self.window + self.per:
@@ -97,52 +145,54 @@ class GatewayRatelimiter:
return 0.0
async def block(self):
async def block(self) -> None:
async with self.lock:
delta = self.get_delay()
if delta:
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
_log.warning("WebSocket in shard ID %s is ratelimited, waiting %.2f seconds", self.shard_id, delta)
await asyncio.sleep(delta)
class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs):
ws = kwargs.pop('ws', None)
interval = kwargs.pop('interval', None)
shard_id = kwargs.pop('shard_id', None)
def __init__(self, *args: Any, **kwargs: Any) -> None:
ws = kwargs.pop("ws")
interval = kwargs.pop("interval", None)
shard_id = kwargs.pop("shard_id", None)
threading.Thread.__init__(self, *args, **kwargs)
self.ws = ws
self._main_thread_id = ws.thread_id
self.interval = interval
self.daemon = True
self.shard_id = shard_id
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self._stop_ev = threading.Event()
self._last_ack = time.perf_counter()
self._last_send = time.perf_counter()
self._last_recv = time.perf_counter()
self.latency = float('inf')
self.heartbeat_timeout = ws._max_heartbeat_timeout
self.ws: DiscordWebSocket = ws
self._main_thread_id: int = ws.thread_id
self.interval: Optional[float] = interval
self.daemon: bool = True
self.shard_id: Optional[int] = shard_id
self.msg: str = "Keeping shard ID %s websocket alive with sequence %s."
self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds."
self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind."
self._stop_ev: threading.Event = threading.Event()
self._last_ack: float = time.perf_counter()
self._last_send: float = time.perf_counter()
self._last_recv: float = time.perf_counter()
self.latency: float = float("inf")
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
def run(self):
def run(self) -> None:
while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
_log.warning(
"Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id
)
coro = self.ws.close(4000)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
f.result()
except Exception:
_log.exception('An error occurred while stopping the gateway. Ignoring.')
_log.exception("An error occurred while stopping the gateway. Ignoring.")
finally:
self.stop()
return
data = self.get_payload()
_log.debug(self.msg, self.shard_id, data['d'])
_log.debug(self.msg, self.shard_id, data["d"])
coro = self.ws.send_heartbeat(data)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
@@ -159,8 +209,8 @@ class KeepAliveHandler(threading.Thread):
except KeyError:
msg = self.block_msg
else:
stack = ''.join(traceback.format_stack(frame))
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
stack = "".join(traceback.format_stack(frame))
msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}"
_log.warning(msg, self.shard_id, total)
except Exception:
@@ -168,50 +218,51 @@ class KeepAliveHandler(threading.Thread):
else:
self._last_send = time.perf_counter()
def get_payload(self):
def get_payload(self) -> Heartbeat:
return {
'op': self.ws.HEARTBEAT,
'd': self.ws.sequence
"op": self.ws.HEARTBEAT,
# the websocket's sequence won't be None here
"d": self.ws.sequence, # type: ignore
}
def stop(self):
def stop(self) -> None:
self._stop_ev.set()
def tick(self):
def tick(self) -> None:
self._last_recv = time.perf_counter()
def ack(self):
def ack(self) -> None:
ack_time = time.perf_counter()
self._last_ack = ack_time
self.latency = ack_time - self._last_send
if self.latency > 10:
_log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s."
self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds"
self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind"
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
def get_payload(self) -> Heartbeat:
return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)}
def ack(self):
def ack(self) -> None:
ack_time = time.perf_counter()
self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
async def close(self, *, code: int = 4000, message: bytes = b"") -> bool:
return await super().close(code=code, message=message)
class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6.
@@ -266,41 +317,63 @@ class DiscordWebSocket:
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
def __init__(self, socket, *, loop):
self.socket = socket
self.loop = loop
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
# an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None
# generic event listeners
self._dispatch_listeners = []
self._dispatch_listeners: List[EventListener] = []
# the keep alive
self._keep_alive = None
self.thread_id = threading.get_ident()
self._keep_alive: Optional[KeepAliveHandler] = None
self.thread_id: int = threading.get_ident()
# ws related stuff
self.session_id = None
self.sequence = None
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib = zlib.decompressobj()
self._buffer = bytearray()
self._close_code = None
self._rate_limiter = GatewayRatelimiter()
self._buffer: bytearray = bytearray()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
# attributes that get set in from_client
self.token: str = utils.MISSING
self._connection: ConnectionState = utils.MISSING
self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING
self.gateway: str = utils.MISSING
self.call_hooks: Coro = utils.MISSING
self._initial_identify: bool = utils.MISSING
self.shard_id: Optional[int] = utils.MISSING
self.shard_count: Optional[int] = utils.MISSING
self.session_id: Optional[str] = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
@property
def open(self):
def open(self) -> bool:
return not self.socket.closed
def is_ratelimited(self):
def is_ratelimited(self) -> bool:
return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /):
self._dispatch('socket_raw_receive', data)
def debug_log_receive(self, data, /) -> None:
self._dispatch("socket_raw_receive", data)
def log_receive(self, _, /):
def log_receive(self, _, /) -> None:
pass
@classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
async def from_client(
cls: Type[DWS],
client: Client,
*,
initial: bool = False,
gateway: Optional[str] = None,
shard_id: Optional[int] = None,
session: Optional[str] = None,
sequence: Optional[int] = None,
resume: bool = False,
) -> DWS:
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@@ -310,7 +383,9 @@ class DiscordWebSocket:
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
ws.token = client.http.token
# the token won't be None here
ws.token = client.http.token # type: ignore
ws._connection = client._connection
ws._discord_parsers = client._connection.parsers
ws._dispatch = client.dispatch
@@ -330,7 +405,7 @@ class DiscordWebSocket:
client._connection._update_references(ws)
_log.debug('Created websocket connected to %s', gateway)
_log.debug("Created websocket connected to %s", gateway)
# poll event for OP Hello
await ws.poll_event()
@@ -342,7 +417,7 @@ class DiscordWebSocket:
await ws.resume()
return ws
def wait_for(self, event, predicate, result=None):
def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future:
"""Waits for a DISPATCH'd event that meets the predicate.
Parameters
@@ -367,79 +442,67 @@ class DiscordWebSocket:
self._dispatch_listeners.append(entry)
return future
async def identify(self):
async def identify(self) -> None:
"""Sends the IDENTIFY packet."""
payload = {
'op': self.IDENTIFY,
'd': {
'token': self.token,
'properties': {
'$os': sys.platform,
'$browser': 'discord.py',
'$device': 'discord.py',
'$referrer': '',
'$referring_domain': ''
"op": self.IDENTIFY,
"d": {
"token": self.token,
"properties": {
"$os": sys.platform,
"$browser": "discord.py",
"$device": "discord.py",
"$referrer": "",
"$referring_domain": "",
},
"compress": True,
"large_threshold": 250,
"v": 3,
},
'compress': True,
'large_threshold': 250,
'v': 3
}
}
if self.shard_id is not None and self.shard_count is not None:
payload['d']['shard'] = [self.shard_id, self.shard_count]
payload["d"]["shard"] = [self.shard_id, self.shard_count]
state = self._connection
if state._activity is not None or state._status is not None:
payload['d']['presence'] = {
'status': state._status,
'game': state._activity,
'since': 0,
'afk': False
}
payload["d"]["presence"] = {"status": state._status, "game": state._activity, "since": 0, "afk": False}
if state._intents is not None:
payload['d']['intents'] = state._intents.value
payload["d"]["intents"] = state._intents.value
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.call_hooks("before_identify", self.shard_id, initial=self._initial_identify)
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
_log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id)
async def resume(self):
async def resume(self) -> None:
"""Sends the RESUME packet."""
payload = {
'op': self.RESUME,
'd': {
'seq': self.sequence,
'session_id': self.session_id,
'token': self.token
}
}
payload = {"op": self.RESUME, "d": {"seq": self.sequence, "session_id": self.session_id, "token": self.token}}
await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
_log.info("Shard ID %s has sent the RESUME payload.", self.shard_id)
async def received_message(self, msg, /):
async def received_message(self, msg, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff":
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
msg = msg.decode("utf-8")
self._buffer = bytearray()
self.log_receive(msg)
msg = utils._from_json(msg)
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
event = msg.get('t')
_log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg)
event = msg.get("t")
if event:
self._dispatch('socket_event_type', event)
self._dispatch("socket_event_type", event)
op = msg.get('op')
data = msg.get('d')
seq = msg.get('s')
op = msg.get("op")
data = msg.get("d")
seq = msg.get("s")
if seq is not None:
self.sequence = seq
@@ -451,7 +514,7 @@ class DiscordWebSocket:
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
_log.debug('Received RECONNECT opcode.')
_log.debug("Received RECONNECT opcode.")
await self.close()
raise ReconnectWebSocket(self.shard_id)
@@ -467,7 +530,7 @@ class DiscordWebSocket:
return
if op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
interval = data["heartbeat_interval"] / 1000.0
self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id)
# send a heartbeat immediately
await self.send_as_json(self._keep_alive.get_payload())
@@ -481,33 +544,41 @@ class DiscordWebSocket:
self.sequence = None
self.session_id = None
_log.info('Shard ID %s session has been invalidated.', self.shard_id)
_log.info("Shard ID %s session has been invalidated.", self.shard_id)
await self.close(code=1000)
raise ReconnectWebSocket(self.shard_id, resume=False)
_log.warning('Unknown OP code %s.', op)
_log.warning("Unknown OP code %s.", op)
return
if event == 'READY':
self._trace = trace = data.get('_trace', [])
self.sequence = msg['s']
self.session_id = data['session_id']
if event == "READY":
self._trace = trace = data.get("_trace", [])
self.sequence = msg["s"]
self.session_id = data["session_id"]
# pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id, ', '.join(trace), self.session_id)
data["__shard_id__"] = self.shard_id
_log.info(
"Shard ID %s has connected to Gateway: %s (Session ID: %s).",
self.shard_id,
", ".join(trace),
self.session_id,
)
elif event == 'RESUMED':
self._trace = trace = data.get('_trace', [])
elif event == "RESUMED":
self._trace = trace = data.get("_trace", [])
# pass back the shard ID to the resumed handler
data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
self.shard_id, self.session_id, ', '.join(trace))
data["__shard_id__"] = self.shard_id
_log.info(
"Shard ID %s has successfully RESUMED session %s under trace %s.",
self.shard_id,
self.session_id,
", ".join(trace),
)
try:
func = self._discord_parsers[event]
except KeyError:
_log.debug('Unknown event %s.', event)
_log.debug("Unknown event %s.", event)
else:
func(data)
@@ -537,16 +608,16 @@ class DiscordWebSocket:
del self._dispatch_listeners[index]
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
return float("inf") if heartbeat is None else heartbeat.latency
def _can_handle_close(self):
def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self):
async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop.
Raises
@@ -561,10 +632,10 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise WebSocketClosure
except (asyncio.TimeoutError, WebSocketClosure) as e:
# Ensure the keep alive handler is closed
@@ -573,34 +644,34 @@ class DiscordWebSocket:
self._keep_alive = None
if isinstance(e, asyncio.TimeoutError):
_log.info('Timed out receiving packet. Attempting a reconnect.')
_log.info("Timed out receiving packet. Attempting a reconnect.")
raise ReconnectWebSocket(self.shard_id) from None
code = self._close_code or self.socket.close_code
if self._can_handle_close():
_log.info('Websocket closed with %s, attempting a reconnect.', code)
_log.info("Websocket closed with %s, attempting a reconnect.", code)
raise ReconnectWebSocket(self.shard_id) from None
else:
_log.info('Websocket closed with %s, cannot reconnect.', code)
_log.info("Websocket closed with %s, cannot reconnect.", code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /):
async def debug_send(self, data, /) -> None:
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data)
self._dispatch("socket_raw_send", data)
await self.socket.send_str(data)
async def send(self, data, /):
async def send(self, data, /) -> None:
await self._rate_limiter.block()
await self.socket.send_str(data)
async def send_as_json(self, data):
async def send_as_json(self, data) -> None:
try:
await self.send(utils._to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data):
async def send_heartbeat(self, data: Heartbeat) -> None:
# This bypasses the rate limit handling code since it has a higher priority
try:
await self.socket.send_str(utils._to_json(data))
@@ -608,68 +679,60 @@ class DiscordWebSocket:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0):
async def change_presence(
self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0
) -> None:
if activity is not None:
if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()]
raise InvalidArgument("activity must derive from BaseActivity.")
activities = [activity.to_dict()]
else:
activity = []
activities = []
if status == 'idle':
if status == "idle":
since = int(time.time() * 1000)
payload = {
'op': self.PRESENCE,
'd': {
'activities': activity,
'afk': False,
'since': since,
'status': status
}
}
payload = {"op": self.PRESENCE, "d": {"activities": activities, "afk": False, "since": since, "status": status}}
sent = utils._to_json(payload)
_log.debug('Sending "%s" to change status', sent)
await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
payload = {
'op': self.REQUEST_MEMBERS,
'd': {
'guild_id': guild_id,
'presences': presences,
'limit': limit
}
}
async def request_chunks(
self,
guild_id: int,
query: Optional[str] = None,
*,
limit: int,
user_ids: Optional[List[int]] = None,
presences: bool = False,
nonce: Optional[int] = None,
) -> None:
payload = {"op": self.REQUEST_MEMBERS, "d": {"guild_id": guild_id, "presences": presences, "limit": limit}}
if nonce:
payload['d']['nonce'] = nonce
payload["d"]["nonce"] = nonce
if user_ids:
payload['d']['user_ids'] = user_ids
payload["d"]["user_ids"] = user_ids
if query is not None:
payload['d']['query'] = query
payload["d"]["query"] = query
await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
async def voice_state(
self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False
) -> None:
payload = {
'op': self.VOICE_STATE,
'd': {
'guild_id': guild_id,
'channel_id': channel_id,
'self_mute': self_mute,
'self_deaf': self_deaf
}
"op": self.VOICE_STATE,
"d": {"guild_id": guild_id, "channel_id": channel_id, "self_mute": self_mute, "self_deaf": self_deaf},
}
_log.debug('Updating our voice state to %s.', payload)
_log.debug("Updating our voice state to %s.", payload)
await self.send_as_json(payload)
async def close(self, code=4000):
async def close(self, code: int = 4000) -> None:
if self._keep_alive:
self._keep_alive.stop()
self._keep_alive = None
@@ -677,6 +740,7 @@ class DiscordWebSocket:
self._close_code = code
await self.socket.close(code=code)
class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections.
@@ -721,53 +785,58 @@ class DiscordVoiceWebSocket:
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None):
self.ws = socket
self.loop = loop
self._keep_alive = None
self._close_code = None
self.secret_key = None
def __init__(
self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None
) -> None:
self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: VoiceKeepAliveHandler = utils.MISSING
self._close_code: Optional[int] = None
self.secret_key: Optional[List[int]] = None
self.gateway: str = utils.MISSING
self._connection: VoiceClient = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
self.thread_id: int = utils.MISSING
if hook:
self._hook = hook
# we want to redeclare self._hook
self._hook = hook # type: ignore
async def _hook(self, *args):
async def _hook(self, *args: Any) -> Any:
pass
async def send_as_json(self, data):
_log.debug('Sending voice websocket frame: %s.', data)
async def send_as_json(self, data) -> None:
_log.debug("Sending voice websocket frame: %s.", data)
await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json
async def resume(self):
async def resume(self) -> None:
state = self._connection
payload = {
'op': self.RESUME,
'd': {
'token': state.token,
'server_id': str(state.server_id),
'session_id': state.session_id
}
"op": self.RESUME,
"d": {"token": state.token, "server_id": str(state.server_id), "session_id": state.session_id},
}
await self.send_as_json(payload)
async def identify(self):
state = self._connection
payload = {
'op': self.IDENTIFY,
'd': {
'server_id': str(state.server_id),
'user_id': str(state.user.id),
'session_id': state.session_id,
'token': state.token
}
"op": self.IDENTIFY,
"d": {
"server_id": str(state.server_id),
"user_id": str(state.user.id),
"session_id": state.session_id,
"token": state.token,
},
}
await self.send_as_json(payload)
@classmethod
async def from_client(cls, client, *, resume=False, hook=None):
async def from_client(
cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None
) -> DVWS:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = "wss://" + client.endpoint + "/?v=4"
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
@@ -783,127 +852,109 @@ class DiscordVoiceWebSocket:
return ws
async def select_protocol(self, ip, port, mode):
async def select_protocol(self, ip, port, mode) -> None:
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
'protocol': 'udp',
'data': {
'address': ip,
'port': port,
'mode': mode
}
}
"op": self.SELECT_PROTOCOL,
"d": {"protocol": "udp", "data": {"address": ip, "port": port, "mode": mode}},
}
await self.send_as_json(payload)
async def client_connect(self):
payload = {
'op': self.CLIENT_CONNECT,
'd': {
'audio_ssrc': self._connection.ssrc
}
}
async def client_connect(self) -> None:
payload = {"op": self.CLIENT_CONNECT, "d": {"audio_ssrc": self._connection.ssrc}}
await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice):
payload = {
'op': self.SPEAKING,
'd': {
'speaking': int(state),
'delay': 0
}
}
async def speak(self, state=SpeakingState.voice) -> None:
payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}}
await self.send_as_json(payload)
async def received_message(self, msg):
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
data = msg.get('d')
async def received_message(self, msg) -> None:
_log.debug("Voice websocket frame received: %s", msg)
op = msg["op"]
data = msg.get("d")
if op == self.READY:
await self.initial_connection(data)
elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack()
elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.')
_log.info("Voice RESUME succeeded.")
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
self._connection.mode = data["mode"]
await self.load_secret_key(data)
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
interval = data["heartbeat_interval"] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start()
await self._hook(self, msg)
async def initial_connection(self, data):
async def initial_connection(self, data) -> None:
state = self._connection
state.ssrc = data['ssrc']
state.voice_port = data['port']
state.endpoint_ip = data['ip']
state.ssrc = data["ssrc"]
state.voice_port = data["port"]
state.endpoint_ip = data["ip"]
packet = bytearray(70)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
struct.pack_into(">H", packet, 0, 1) # 1 = Send
struct.pack_into(">H", packet, 2, 70) # 70 = Length
struct.pack_into(">I", packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70)
_log.debug('received packet in initial_connection: %s', recv)
_log.debug("received packet in initial_connection: %s", recv)
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
state.ip = recv[ip_start:ip_end].decode("ascii")
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0]
_log.debug("detected ip: %s port: %s", state.ip, state.port)
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))
modes = [mode for mode in data["modes"] if mode in self._connection.supported_modes]
_log.debug("received supported encryption modes: %s", ", ".join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.info('selected the voice protocol for use (%s)', mode)
_log.info("selected the voice protocol for use (%s)", mode)
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
return float("inf") if heartbeat is None else heartbeat.latency
@property
def average_latency(self):
def average_latency(self) -> float:
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies:
return float('inf')
return float("inf")
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
_log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key')
async def load_secret_key(self, data) -> None:
_log.info("received secret key for voice connection")
self.secret_key = self._connection.secret_key = data.get("secret_key")
await self.speak()
await self.speak(False)
async def poll_event(self):
async def poll_event(self) -> None:
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
_log.debug("Received %s", msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000):
async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None:
self._keep_alive.stop()

View File

@@ -46,7 +46,7 @@ from . import utils, abc
from .role import Role
from .member import Member, VoiceState
from .emoji import Emoji
from .errors import InvalidData
from .errors import InvalidData, NotFound
from .permissions import PermissionOverwrite
from .colour import Colour
from .errors import InvalidArgument, ClientException
@@ -78,9 +78,7 @@ from .sticker import GuildSticker
from .file import File
__all__ = (
'Guild',
)
__all__ = ("Guild",)
MISSING = utils.MISSING
@@ -239,45 +237,45 @@ class Guild(Hashable):
"""
__slots__ = (
'afk_timeout',
'afk_channel',
'name',
'id',
'unavailable',
'region',
'owner_id',
'mfa_level',
'emojis',
'stickers',
'features',
'verification_level',
'explicit_content_filter',
'default_notifications',
'description',
'max_presences',
'max_members',
'max_video_channel_users',
'premium_tier',
'premium_subscription_count',
'preferred_locale',
'nsfw_level',
'_members',
'_channels',
'_icon',
'_banner',
'_state',
'_roles',
'_member_count',
'_large',
'_splash',
'_voice_states',
'_system_channel_id',
'_system_channel_flags',
'_discovery_splash',
'_rules_channel_id',
'_public_updates_channel_id',
'_stage_instances',
'_threads',
"afk_timeout",
"afk_channel",
"name",
"id",
"unavailable",
"region",
"owner_id",
"mfa_level",
"emojis",
"stickers",
"features",
"verification_level",
"explicit_content_filter",
"default_notifications",
"description",
"max_presences",
"max_members",
"max_video_channel_users",
"premium_tier",
"premium_subscription_count",
"preferred_locale",
"nsfw_level",
"_members",
"_channels",
"_icon",
"_banner",
"_state",
"_roles",
"_member_count",
"_large",
"_splash",
"_voice_states",
"_system_channel_id",
"_system_channel_flags",
"_discovery_splash",
"_rules_channel_id",
"_public_updates_channel_id",
"_stage_instances",
"_threads",
)
_PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = {
@@ -337,21 +335,23 @@ class Guild(Hashable):
return to_remove
def __str__(self) -> str:
return self.name or ''
return self.name or ""
def __repr__(self) -> str:
attrs = (
('id', self.id),
('name', self.name),
('shard_id', self.shard_id),
('chunked', self.chunked),
('member_count', getattr(self, '_member_count', None)),
("id", self.id),
("name", self.name),
("shard_id", self.shard_id),
("chunked", self.chunked),
("member_count", getattr(self, "_member_count", None)),
)
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Guild {inner}>'
inner = " ".join("%s=%r" % t for t in attrs)
return f"<Guild {inner}>"
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]:
user_id = int(data['user_id'])
def _update_voice_state(
self, data: GuildVoiceState, channel_id: int
) -> Tuple[Optional[Member], VoiceState, VoiceState]:
user_id = int(data["user_id"])
channel = self.get_channel(channel_id)
try:
# check if we should remove the voice state from cache
@@ -371,7 +371,7 @@ class Guild(Hashable):
member = self.get_member(user_id)
if member is None:
try:
member = Member(data=data['member'], state=self._state, guild=self)
member = Member(data=data["member"], state=self._state, guild=self)
except KeyError:
member = None
@@ -403,57 +403,57 @@ class Guild(Hashable):
def _from_data(self, guild: GuildPayload) -> None:
# according to Stan, this is always available even if the guild is unavailable
# I don't have this guarantee when someone updates the guild.
member_count = guild.get('member_count', None)
member_count = guild.get("member_count", None)
if member_count is not None:
self._member_count: int = member_count
self.name: str = guild.get('name')
self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region'))
self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level'))
self.name: str = guild.get("name")
self.region: VoiceRegion = try_enum(VoiceRegion, guild.get("region"))
self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get("verification_level"))
self.default_notifications: NotificationLevel = try_enum(
NotificationLevel, guild.get('default_message_notifications')
NotificationLevel, guild.get("default_message_notifications")
)
self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0))
self.afk_timeout: int = guild.get('afk_timeout')
self._icon: Optional[str] = guild.get('icon')
self._banner: Optional[str] = guild.get('banner')
self.unavailable: bool = guild.get('unavailable', False)
self.id: int = int(guild['id'])
self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get("explicit_content_filter", 0))
self.afk_timeout: int = guild.get("afk_timeout")
self._icon: Optional[str] = guild.get("icon")
self._banner: Optional[str] = guild.get("banner")
self.unavailable: bool = guild.get("unavailable", False)
self.id: int = int(guild["id"])
self._roles: Dict[int, Role] = {}
state = self._state # speed up attribute access
for r in guild.get('roles', []):
for r in guild.get("roles", []):
role = Role(guild=self, data=r, state=state)
self._roles[role.id] = role
self.mfa_level: MFALevel = guild.get('mfa_level')
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', [])))
self.mfa_level: MFALevel = guild.get("mfa_level")
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", [])))
self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get('stickers', []))
map(lambda d: state.store_sticker(self, d), guild.get("stickers", []))
)
self.features: List[GuildFeature] = guild.get('features', [])
self._splash: Optional[str] = guild.get('splash')
self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id')
self.description: Optional[str] = guild.get('description')
self.max_presences: Optional[int] = guild.get('max_presences')
self.max_members: Optional[int] = guild.get('max_members')
self.max_video_channel_users: Optional[int] = guild.get('max_video_channel_users')
self.premium_tier: int = guild.get('premium_tier', 0)
self.premium_subscription_count: int = guild.get('premium_subscription_count') or 0
self._system_channel_flags: int = guild.get('system_channel_flags', 0)
self.preferred_locale: Optional[str] = guild.get('preferred_locale')
self._discovery_splash: Optional[str] = guild.get('discovery_splash')
self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'rules_channel_id')
self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'public_updates_channel_id')
self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get('nsfw_level', 0))
self.features: List[GuildFeature] = guild.get("features", [])
self._splash: Optional[str] = guild.get("splash")
self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id")
self.description: Optional[str] = guild.get("description")
self.max_presences: Optional[int] = guild.get("max_presences")
self.max_members: Optional[int] = guild.get("max_members")
self.max_video_channel_users: Optional[int] = guild.get("max_video_channel_users")
self.premium_tier: int = guild.get("premium_tier", 0)
self.premium_subscription_count: int = guild.get("premium_subscription_count") or 0
self._system_channel_flags: int = guild.get("system_channel_flags", 0)
self.preferred_locale: Optional[str] = guild.get("preferred_locale")
self._discovery_splash: Optional[str] = guild.get("discovery_splash")
self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, "rules_channel_id")
self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, "public_updates_channel_id")
self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0))
self._stage_instances: Dict[int, StageInstance] = {}
for s in guild.get('stage_instances', []):
for s in guild.get("stage_instances", []):
stage_instance = StageInstance(guild=self, data=s, state=state)
self._stage_instances[stage_instance.id] = stage_instance
cache_joined = self._state.member_cache_flags.joined
self_id = self._state.self_id
for mdata in guild.get('members', []):
for mdata in guild.get("members", []):
member = Member(data=mdata, guild=self, state=state)
if cache_joined or member.id == self_id:
self._add_member(member)
@@ -461,35 +461,35 @@ class Guild(Hashable):
self._sync(guild)
self._large: Optional[bool] = None if member_count is None else self._member_count >= 250
self.owner_id: Optional[int] = utils._get_as_snowflake(guild, 'owner_id')
self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) # type: ignore
self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id")
self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore
for obj in guild.get('voice_states', []):
self._update_voice_state(obj, int(obj['channel_id']))
for obj in guild.get("voice_states", []):
self._update_voice_state(obj, int(obj["channel_id"]))
# TODO: refactor/remove?
def _sync(self, data: GuildPayload) -> None:
try:
self._large = data['large']
self._large = data["large"]
except KeyError:
pass
empty_tuple = tuple()
for presence in data.get('presences', []):
user_id = int(presence['user']['id'])
for presence in data.get("presences", []):
user_id = int(presence["user"]["id"])
member = self.get_member(user_id)
if member is not None:
member._presence_update(presence, empty_tuple) # type: ignore
if 'channels' in data:
channels = data['channels']
if "channels" in data:
channels = data["channels"]
for c in channels:
factory, ch_type = _guild_channel_factory(c['type'])
factory, ch_type = _guild_channel_factory(c["type"])
if factory:
self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore
if 'threads' in data:
threads = data['threads']
if "threads" in data:
threads = data["threads"]
for thread in threads:
self._add_thread(Thread(guild=self, state=self._state, data=thread))
@@ -712,7 +712,7 @@ class Guild(Hashable):
@property
def emoji_limit(self) -> int:
""":class:`int`: The maximum number of emoji slots this guild has."""
more_emoji = 200 if 'MORE_EMOJI' in self.features else 50
more_emoji = 200 if "MORE_EMOJI" in self.features else 50
return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji)
@property
@@ -721,13 +721,13 @@ class Guild(Hashable):
.. versionadded:: 2.0
"""
more_stickers = 60 if 'MORE_STICKERS' in self.features else 0
more_stickers = 60 if "MORE_STICKERS" in self.features else 0
return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers)
@property
def bitrate_limit(self) -> float:
""":class:`float`: The maximum bitrate for voice channels this guild can have."""
vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if 'VIP_REGIONS' in self.features else 96e3
vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if "VIP_REGIONS" in self.features else 96e3
return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate)
@property
@@ -871,21 +871,21 @@ class Guild(Hashable):
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None:
return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners')
return Asset._from_guild_image(self._state, self.id, self._banner, path="banners")
@property
def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None:
return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes")
@property
def discovery_splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available."""
if self._discovery_splash is None:
return None
return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes')
return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path="discovery-splashes")
@property
def member_count(self) -> int:
@@ -909,7 +909,7 @@ class Guild(Hashable):
If this value returns ``False``, then you should request for
offline members.
"""
count = getattr(self, '_member_count', None)
count = getattr(self, "_member_count", None)
if count is None:
return False
return count == len(self._members)
@@ -956,7 +956,7 @@ class Guild(Hashable):
result = None
members = self.members
if len(name) > 5 and name[-5] == '#':
if len(name) > 5 and name[-5] == "#":
# The 5 length is checking to see if #0000 is in the string,
# as a#0000 has a length of 6, the minimum for a potential
# discriminator lookup.
@@ -984,20 +984,20 @@ class Guild(Hashable):
if overwrites is MISSING:
overwrites = {}
elif not isinstance(overwrites, dict):
raise InvalidArgument('overwrites parameter expects a dict.')
raise InvalidArgument("overwrites parameter expects a dict.")
perms = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}')
raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")
allow, deny = perm.pair()
payload = {'allow': allow.value, 'deny': deny.value, 'id': target.id}
payload = {"allow": allow.value, "deny": deny.value, "id": target.id}
if isinstance(target, Role):
payload['type'] = abc._Overwrites.ROLE
payload["type"] = abc._Overwrites.ROLE
else:
payload['type'] = abc._Overwrites.MEMBER
payload["type"] = abc._Overwrites.MEMBER
perms.append(payload)
@@ -1098,16 +1098,16 @@ class Guild(Hashable):
options = {}
if position is not MISSING:
options['position'] = position
options["position"] = position
if topic is not MISSING:
options['topic'] = topic
options["topic"] = topic
if slowmode_delay is not MISSING:
options['rate_limit_per_user'] = slowmode_delay
options["rate_limit_per_user"] = slowmode_delay
if nsfw is not MISSING:
options['nsfw'] = nsfw
options["nsfw"] = nsfw
data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options
@@ -1182,19 +1182,19 @@ class Guild(Hashable):
"""
options = {}
if position is not MISSING:
options['position'] = position
options["position"] = position
if bitrate is not MISSING:
options['bitrate'] = bitrate
options["bitrate"] = bitrate
if user_limit is not MISSING:
options['user_limit'] = user_limit
options["user_limit"] = user_limit
if rtc_region is not MISSING:
options['rtc_region'] = None if rtc_region is None else str(rtc_region)
options["rtc_region"] = None if rtc_region is None else str(rtc_region)
if video_quality_mode is not MISSING:
options['video_quality_mode'] = video_quality_mode.value
options["video_quality_mode"] = video_quality_mode.value
data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options
@@ -1257,13 +1257,18 @@ class Guild(Hashable):
"""
options: Dict[str, Any] = {
'topic': topic,
"topic": topic,
}
if position is not MISSING:
options['position'] = position
options["position"] = position
data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.stage_voice, category=category, reason=reason, **options
name,
overwrites=overwrites,
channel_type=ChannelType.stage_voice,
category=category,
reason=reason,
**options,
)
channel = StageChannel(state=self._state, guild=self, data=data)
@@ -1304,7 +1309,7 @@ class Guild(Hashable):
"""
options: Dict[str, Any] = {}
if position is not MISSING:
options['position'] = position
options["position"] = position
data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options
@@ -1479,108 +1484,108 @@ class Guild(Hashable):
fields: Dict[str, Any] = {}
if name is not MISSING:
fields['name'] = name
fields["name"] = name
if description is not MISSING:
fields['description'] = description
fields["description"] = description
if preferred_locale is not MISSING:
fields['preferred_locale'] = preferred_locale
fields["preferred_locale"] = preferred_locale
if afk_timeout is not MISSING:
fields['afk_timeout'] = afk_timeout
fields["afk_timeout"] = afk_timeout
if icon is not MISSING:
if icon is None:
fields['icon'] = icon
fields["icon"] = icon
else:
fields['icon'] = utils._bytes_to_base64_data(icon)
fields["icon"] = utils._bytes_to_base64_data(icon)
if banner is not MISSING:
if banner is None:
fields['banner'] = banner
fields["banner"] = banner
else:
fields['banner'] = utils._bytes_to_base64_data(banner)
fields["banner"] = utils._bytes_to_base64_data(banner)
if splash is not MISSING:
if splash is None:
fields['splash'] = splash
fields["splash"] = splash
else:
fields['splash'] = utils._bytes_to_base64_data(splash)
fields["splash"] = utils._bytes_to_base64_data(splash)
if discovery_splash is not MISSING:
if discovery_splash is None:
fields['discovery_splash'] = discovery_splash
fields["discovery_splash"] = discovery_splash
else:
fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash)
fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash)
if default_notifications is not MISSING:
if not isinstance(default_notifications, NotificationLevel):
raise InvalidArgument('default_notifications field must be of type NotificationLevel')
fields['default_message_notifications'] = default_notifications.value
raise InvalidArgument("default_notifications field must be of type NotificationLevel")
fields["default_message_notifications"] = default_notifications.value
if afk_channel is not MISSING:
if afk_channel is None:
fields['afk_channel_id'] = afk_channel
fields["afk_channel_id"] = afk_channel
else:
fields['afk_channel_id'] = afk_channel.id
fields["afk_channel_id"] = afk_channel.id
if system_channel is not MISSING:
if system_channel is None:
fields['system_channel_id'] = system_channel
fields["system_channel_id"] = system_channel
else:
fields['system_channel_id'] = system_channel.id
fields["system_channel_id"] = system_channel.id
if rules_channel is not MISSING:
if rules_channel is None:
fields['rules_channel_id'] = rules_channel
fields["rules_channel_id"] = rules_channel
else:
fields['rules_channel_id'] = rules_channel.id
fields["rules_channel_id"] = rules_channel.id
if public_updates_channel is not MISSING:
if public_updates_channel is None:
fields['public_updates_channel_id'] = public_updates_channel
fields["public_updates_channel_id"] = public_updates_channel
else:
fields['public_updates_channel_id'] = public_updates_channel.id
fields["public_updates_channel_id"] = public_updates_channel.id
if owner is not MISSING:
if self.owner_id != self._state.self_id:
raise InvalidArgument('To transfer ownership you must be the owner of the guild.')
raise InvalidArgument("To transfer ownership you must be the owner of the guild.")
fields['owner_id'] = owner.id
fields["owner_id"] = owner.id
if region is not MISSING:
fields['region'] = str(region)
fields["region"] = str(region)
if verification_level is not MISSING:
if not isinstance(verification_level, VerificationLevel):
raise InvalidArgument('verification_level field must be of type VerificationLevel')
raise InvalidArgument("verification_level field must be of type VerificationLevel")
fields['verification_level'] = verification_level.value
fields["verification_level"] = verification_level.value
if explicit_content_filter is not MISSING:
if not isinstance(explicit_content_filter, ContentFilter):
raise InvalidArgument('explicit_content_filter field must be of type ContentFilter')
raise InvalidArgument("explicit_content_filter field must be of type ContentFilter")
fields['explicit_content_filter'] = explicit_content_filter.value
fields["explicit_content_filter"] = explicit_content_filter.value
if system_channel_flags is not MISSING:
if not isinstance(system_channel_flags, SystemChannelFlags):
raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags')
raise InvalidArgument("system_channel_flags field must be of type SystemChannelFlags")
fields['system_channel_flags'] = system_channel_flags.value
fields["system_channel_flags"] = system_channel_flags.value
if community is not MISSING:
features = []
if community:
if 'rules_channel_id' in fields and 'public_updates_channel_id' in fields:
features.append('COMMUNITY')
if "rules_channel_id" in fields and "public_updates_channel_id" in fields:
features.append("COMMUNITY")
else:
raise InvalidArgument(
'community field requires both rules_channel and public_updates_channel fields to be provided'
"community field requires both rules_channel and public_updates_channel fields to be provided"
)
fields['features'] = features
fields["features"] = features
data = await http.edit_guild(self.id, reason=reason, **fields)
return Guild(data=data, state=self._state)
@@ -1611,9 +1616,9 @@ class Guild(Hashable):
data = await self._state.http.get_all_guild_channels(self.id)
def convert(d):
factory, ch_type = _guild_channel_factory(d['type'])
factory, ch_type = _guild_channel_factory(d["type"])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(d))
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d))
channel = factory(guild=self, state=self._state, data=d)
return channel
@@ -1640,10 +1645,10 @@ class Guild(Hashable):
The active threads
"""
data = await self._state.http.get_active_threads(self.id)
threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])]
threads = [Thread(guild=self, state=self._state, data=d) for d in data.get("threads", [])]
thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads}
for member in data.get('members', []):
thread = thread_lookup.get(int(member['id']))
for member in data.get("members", []):
thread = thread_lookup.get(int(member["id"]))
if thread is not None:
thread._add_member(ThreadMember(parent=thread, data=member))
@@ -1699,7 +1704,7 @@ class Guild(Hashable):
"""
if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.')
raise ClientException("Intents.members must be enabled to use this.")
return MemberIterator(self, limit=limit, after=after)
@@ -1723,6 +1728,8 @@ class Guild(Hashable):
You do not have access to the guild.
HTTPException
Fetching the member failed.
NotFound
A member with that ID does not exist.
Returns
--------
@@ -1732,6 +1739,34 @@ class Guild(Hashable):
data = await self._state.http.get_member(self.id, member_id)
return Member(data=data, state=self._state, guild=self)
async def try_member(self, member_id: int, /) -> Optional[Member]:
"""|coro|
Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`.
.. note::
This method might result in an API call.
Parameters
-----------
member_id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
member = self.get_member(member_id)
if member:
return member
else:
try:
return await self.fetch_member(member_id)
except NotFound:
return None
async def fetch_ban(self, user: Snowflake) -> BanEntry:
"""|coro|
@@ -1760,7 +1795,7 @@ class Guild(Hashable):
The :class:`BanEntry` object for the specified user.
"""
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"])
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
"""|coro|
@@ -1793,16 +1828,16 @@ class Guild(Hashable):
"""
data = await self._state.http.get_channel(channel_id)
factory, ch_type = _threaded_guild_channel_factory(data['type'])
factory, ch_type = _threaded_guild_channel_factory(data["type"])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
if ch_type in (ChannelType.group, ChannelType.private):
raise InvalidData('Channel ID resolved to a private channel')
raise InvalidData("Channel ID resolved to a private channel")
guild_id = int(data['guild_id'])
guild_id = int(data["guild_id"])
if self.id != guild_id:
raise InvalidData('Guild ID resolved to a different guild')
raise InvalidData("Guild ID resolved to a different guild")
channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore
return channel
@@ -1829,7 +1864,7 @@ class Guild(Hashable):
"""
data: List[BanPayload] = await self._state.http.get_bans(self.id)
return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data]
return [BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) for e in data]
async def prune_members(
self,
@@ -1889,7 +1924,7 @@ class Guild(Hashable):
"""
if not isinstance(days, int):
raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.')
raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.")
if roles:
role_ids = [str(role.id) for role in roles]
@@ -1899,7 +1934,7 @@ class Guild(Hashable):
data = await self._state.http.prune_members(
self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason
)
return data['pruned']
return data["pruned"]
async def templates(self) -> List[Template]:
"""|coro|
@@ -1981,7 +2016,7 @@ class Guild(Hashable):
"""
if not isinstance(days, int):
raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.')
raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.")
if roles:
role_ids = [str(role.id) for role in roles]
@@ -1989,7 +2024,7 @@ class Guild(Hashable):
role_ids = []
data = await self._state.http.estimate_pruned_members(self.id, days, role_ids)
return data['pruned']
return data["pruned"]
async def invites(self) -> List[Invite]:
"""|coro|
@@ -2015,7 +2050,7 @@ class Guild(Hashable):
data = await self._state.http.invites_from(self.id)
result = []
for invite in data:
channel = self.get_channel(int(invite['channel']['id']))
channel = self.get_channel(int(invite["channel"]["id"]))
result.append(Invite(state=self._state, data=invite, guild=self, channel=channel))
return result
@@ -2039,10 +2074,10 @@ class Guild(Hashable):
"""
from .template import Template
payload = {'name': name}
payload = {"name": name}
if description:
payload['description'] = description
payload["description"] = description
data = await self._state.http.create_template(self.id, payload)
@@ -2099,9 +2134,9 @@ class Guild(Hashable):
data = await self._state.http.get_all_integrations(self.id)
def convert(d):
factory, _ = _integration_factory(d['type'])
factory, _ = _integration_factory(d["type"])
if factory is None:
raise InvalidData('Unknown integration type {type!r} for integration ID {id}'.format_map(d))
raise InvalidData("Unknown integration type {type!r} for integration ID {id}".format_map(d))
return factory(guild=self, data=d)
return [convert(d) for d in data]
@@ -2206,20 +2241,20 @@ class Guild(Hashable):
The created sticker.
"""
payload = {
'name': name,
"name": name,
}
if description:
payload['description'] = description
payload["description"] = description
try:
emoji = unicodedata.name(emoji)
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
emoji = emoji.replace(" ", "_")
payload['tags'] = emoji
payload["tags"] = emoji
data = await self._state.http.create_guild_sticker(self.id, payload, file, reason)
return self._state.store_sticker(self, data)
@@ -2486,24 +2521,24 @@ class Guild(Hashable):
"""
fields: Dict[str, Any] = {}
if permissions is not MISSING:
fields['permissions'] = str(permissions.value)
fields["permissions"] = str(permissions.value)
else:
fields['permissions'] = '0'
fields["permissions"] = "0"
actual_colour = colour or color or Colour.default()
if isinstance(actual_colour, int):
fields['color'] = actual_colour
fields["color"] = actual_colour
else:
fields['color'] = actual_colour.value
fields["color"] = actual_colour.value
if hoist is not MISSING:
fields['hoist'] = hoist
fields["hoist"] = hoist
if mentionable is not MISSING:
fields['mentionable'] = mentionable
fields["mentionable"] = mentionable
if name is not MISSING:
fields['name'] = name
fields["name"] = name
data = await self._state.http.create_role(self.id, reason=reason, **fields)
role = Role(guild=self, data=data, state=self._state)
@@ -2556,12 +2591,12 @@ class Guild(Hashable):
A list of all the roles in the guild.
"""
if not isinstance(positions, dict):
raise InvalidArgument('positions parameter expects a dict.')
raise InvalidArgument("positions parameter expects a dict.")
role_positions: List[Dict[str, Any]] = []
for role, position in positions.items():
payload = {'id': role.id, 'position': position}
payload = {"id": role.id, "position": position}
role_positions.append(payload)
@@ -2687,19 +2722,19 @@ class Guild(Hashable):
# we start with { code: abc }
payload = await self._state.http.get_vanity_code(self.id)
if not payload['code']:
if not payload["code"]:
return None
# get the vanity URL channel since default channels aren't
# reliable or a thing anymore
data = await self._state.http.get_invite(payload['code'])
data = await self._state.http.get_invite(payload["code"])
channel = self.get_channel(int(data['channel']['id']))
payload['revoked'] = False
payload['temporary'] = False
payload['max_uses'] = 0
payload['max_age'] = 0
payload['uses'] = payload.get('uses', 0)
channel = self.get_channel(int(data["channel"]["id"]))
payload["revoked"] = False
payload["temporary"] = False
payload["max_uses"] = 0
payload["max_age"] = 0
payload["uses"] = payload.get("uses", 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel)
# TODO: use MISSING when async iterators get refactored
@@ -2776,7 +2811,13 @@ class Guild(Hashable):
action = action.value
return AuditLogIterator(
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action
self,
before=before,
after=after,
limit=limit,
oldest_first=oldest_first,
user_id=user_id,
action_type=action,
)
async def widget(self) -> Widget:
@@ -2830,9 +2871,9 @@ class Guild(Hashable):
"""
payload = {}
if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id
payload["channel_id"] = None if channel is None else channel.id
if enabled is not MISSING:
payload['enabled'] = enabled
payload["enabled"] = enabled
await self._state.http.edit_widget(self.id, payload=payload)
@@ -2858,7 +2899,7 @@ class Guild(Hashable):
"""
if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.')
raise ClientException("Intents.members must be enabled to use this.")
if not self._state.is_guild_evicted(self):
return await self._state.chunk_guild(self, cache=cache)
@@ -2919,20 +2960,20 @@ class Guild(Hashable):
"""
if presences and not self._state._intents.presences:
raise ClientException('Intents.presences must be enabled to use this.')
raise ClientException("Intents.presences must be enabled to use this.")
if query is None:
if query == '':
raise ValueError('Cannot pass empty query string.')
if query == "":
raise ValueError("Cannot pass empty query string.")
if user_ids is None:
raise ValueError('Must pass either query or user_ids')
raise ValueError("Must pass either query or user_ids")
if user_ids is not None and query is not None:
raise ValueError('Cannot pass both query and user_ids')
raise ValueError("Cannot pass both query and user_ids")
if user_ids is not None and not user_ids:
raise ValueError('user_ids must contain at least 1 value')
raise ValueError("user_ids must contain at least 1 value")
limit = min(100, limit or 5)
return await self._state.query_members(

File diff suppressed because it is too large Load Diff

View File

@@ -32,11 +32,11 @@ from .errors import InvalidArgument
from .enums import try_enum, ExpireBehaviour
__all__ = (
'IntegrationAccount',
'IntegrationApplication',
'Integration',
'StreamIntegration',
'BotIntegration',
"IntegrationAccount",
"IntegrationApplication",
"Integration",
"StreamIntegration",
"BotIntegration",
)
if TYPE_CHECKING:
@@ -65,14 +65,14 @@ class IntegrationAccount:
The account name.
"""
__slots__ = ('id', 'name')
__slots__ = ("id", "name")
def __init__(self, data: IntegrationAccountPayload) -> None:
self.id: str = data['id']
self.name: str = data['name']
self.id: str = data["id"]
self.name: str = data["name"]
def __repr__(self) -> str:
return f'<IntegrationAccount id={self.id} name={self.name!r}>'
return f"<IntegrationAccount id={self.id} name={self.name!r}>"
class Integration:
@@ -99,14 +99,14 @@ class Integration:
"""
__slots__ = (
'guild',
'id',
'_state',
'type',
'name',
'account',
'user',
'enabled',
"guild",
"id",
"_state",
"type",
"name",
"account",
"user",
"enabled",
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
@@ -118,14 +118,14 @@ class Integration:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
self.id: int = int(data['id'])
self.type: IntegrationType = data['type']
self.name: str = data['name']
self.account: IntegrationAccount = IntegrationAccount(data['account'])
self.id: int = int(data["id"])
self.type: IntegrationType = data["type"]
self.name: str = data["name"]
self.account: IntegrationAccount = IntegrationAccount(data["account"])
user = data.get('user')
user = data.get("user")
self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
self.enabled: bool = data["enabled"]
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
@@ -186,26 +186,26 @@ class StreamIntegration(Integration):
"""
__slots__ = (
'revoked',
'expire_behaviour',
'expire_grace_period',
'synced_at',
'_role_id',
'syncing',
'enable_emoticons',
'subscriber_count',
"revoked",
"expire_behaviour",
"expire_grace_period",
"synced_at",
"_role_id",
"syncing",
"enable_emoticons",
"subscriber_count",
)
def _from_data(self, data: StreamIntegrationPayload) -> None:
super()._from_data(data)
self.revoked: bool = data['revoked']
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior'])
self.expire_grace_period: int = data['expire_grace_period']
self.synced_at: datetime.datetime = parse_time(data['synced_at'])
self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id')
self.syncing: bool = data['syncing']
self.enable_emoticons: bool = data['enable_emoticons']
self.subscriber_count: int = data['subscriber_count']
self.revoked: bool = data["revoked"]
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"])
self.expire_grace_period: int = data["expire_grace_period"]
self.synced_at: datetime.datetime = parse_time(data["synced_at"])
self._role_id: Optional[int] = _get_as_snowflake(data, "role_id")
self.syncing: bool = data["syncing"]
self.enable_emoticons: bool = data["enable_emoticons"]
self.subscriber_count: int = data["subscriber_count"]
@property
def expire_behavior(self) -> ExpireBehaviour:
@@ -252,15 +252,15 @@ class StreamIntegration(Integration):
payload: Dict[str, Any] = {}
if expire_behaviour is not MISSING:
if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour')
raise InvalidArgument("expire_behaviour field must be of type ExpireBehaviour")
payload['expire_behavior'] = expire_behaviour.value
payload["expire_behavior"] = expire_behaviour.value
if expire_grace_period is not MISSING:
payload['expire_grace_period'] = expire_grace_period
payload["expire_grace_period"] = expire_grace_period
if enable_emoticons is not MISSING:
payload['enable_emoticons'] = enable_emoticons
payload["enable_emoticons"] = enable_emoticons
# This endpoint is undocumented.
# Unsure if it returns the data or not as a result
@@ -307,21 +307,21 @@ class IntegrationApplication:
"""
__slots__ = (
'id',
'name',
'icon',
'description',
'summary',
'user',
"id",
"name",
"icon",
"description",
"summary",
"user",
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
self.description: str = data['description']
self.summary: str = data['summary']
user = data.get('bot')
self.id: int = int(data["id"])
self.name: str = data["name"]
self.icon: Optional[str] = data["icon"]
self.description: str = data["description"]
self.summary: str = data["summary"]
user = data.get("bot")
self.user: Optional[User] = User(state=state, data=user) if user else None
@@ -350,17 +350,17 @@ class BotIntegration(Integration):
The application tied to this integration.
"""
__slots__ = ('application',)
__slots__ = ("application",)
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state)
self.application = IntegrationApplication(data=data["application"], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
if value == 'discord':
if value == "discord":
return BotIntegration, value
elif value in ('twitch', 'youtube'):
elif value in ("twitch", "youtube"):
return StreamIntegration, value
else:
return Integration, value

View File

@@ -41,12 +41,14 @@ from .permissions import Permissions
from .webhook.async_ import async_context, Webhook, handle_message_parameters
__all__ = (
'Interaction',
'InteractionMessage',
'InteractionResponse',
"Interaction",
"InteractionMessage",
"InteractionResponse",
)
if TYPE_CHECKING:
from datetime import datetime
from .types.interactions import (
Interaction as InteractionPayload,
InteractionData,
@@ -58,12 +60,10 @@ if TYPE_CHECKING:
from aiohttp import ClientSession
from .embeds import Embed
from .ui.view import View
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .channel import TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .threads import Thread
InteractionChannel = Union[
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
]
InteractionChannel = Union[TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable]
MISSING: Any = utils.MISSING
@@ -100,23 +100,23 @@ class Interaction:
"""
__slots__: Tuple[str, ...] = (
'id',
'type',
'guild_id',
'channel_id',
'data',
'application_id',
'message',
'user',
'token',
'version',
'_permissions',
'_state',
'_session',
'_original_message',
'_cs_response',
'_cs_followup',
'_cs_channel',
"id",
"type",
"guild_id",
"channel_id",
"data",
"application_id",
"message",
"user",
"token",
"version",
"_permissions",
"_state",
"_session",
"_original_message",
"_cs_response",
"_cs_followup",
"_cs_channel",
)
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
@@ -126,18 +126,18 @@ class Interaction:
self._from_data(data)
def _from_data(self, data: InteractionPayload):
self.id: int = int(data['id'])
self.type: InteractionType = try_enum(InteractionType, data['type'])
self.data: Optional[InteractionData] = data.get('data')
self.token: str = data['token']
self.version: int = data['version']
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.application_id: int = int(data['application_id'])
self.id: int = int(data["id"])
self.type: InteractionType = try_enum(InteractionType, data["type"])
self.data: Optional[InteractionData] = data.get("data")
self.token: str = data["token"]
self.version: int = data["version"]
self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id")
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.application_id: int = int(data["application_id"])
self.message: Optional[Message]
try:
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore
except KeyError:
self.message = None
@@ -148,15 +148,15 @@ class Interaction:
if self.guild_id:
guild = self.guild or Object(id=self.guild_id)
try:
member = data['member'] # type: ignore
member = data["member"] # type: ignore
except KeyError:
pass
else:
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
self._permissions = int(member.get('permissions', 0))
self._permissions = int(member.get("permissions", 0))
else:
try:
self.user = User(state=self._state, data=data['user'])
self.user = User(state=self._state, data=data["user"])
except KeyError:
pass
@@ -165,7 +165,7 @@ class Interaction:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id)
@utils.cached_slot_property('_cs_channel')
@utils.cached_slot_property("_cs_channel")
def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
@@ -179,7 +179,7 @@ class Interaction:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
return None
return channel
return channel # type: ignore
@property
def permissions(self) -> Permissions:
@@ -189,7 +189,7 @@ class Interaction:
"""
return Permissions(self._permissions)
@utils.cached_slot_property('_cs_response')
@utils.cached_slot_property("_cs_response")
def response(self) -> InteractionResponse:
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
@@ -198,13 +198,13 @@ class Interaction:
"""
return InteractionResponse(self)
@utils.cached_slot_property('_cs_followup')
@utils.cached_slot_property("_cs_followup")
def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = {
'id': self.application_id,
'type': 3,
'token': self.token,
"id": self.application_id,
"type": 3,
"token": self.token,
}
return Webhook.from_state(data=payload, state=self._state)
@@ -238,7 +238,7 @@ class Interaction:
# TODO: fix later to not raise?
channel = self.channel
if channel is None:
raise ClientException('Channel for message could not be resolved')
raise ClientException("Channel for message could not be resolved")
adapter = async_context.get()
data = await adapter.get_original_interaction_response(
@@ -369,20 +369,20 @@ class InteractionResponse:
"""
__slots__: Tuple[str, ...] = (
'_responded',
'_parent',
"responded_at",
"_parent",
)
def __init__(self, parent: Interaction):
self.responded_at: Optional[datetime] = None
self._parent: Interaction = parent
self._responded: bool = False
def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before.
An interaction can only be responded to once.
"""
return self._responded
return self.responded_at is not None
async def defer(self, *, ephemeral: bool = False) -> None:
"""|coro|
@@ -405,7 +405,7 @@ class InteractionResponse:
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
if self.is_done():
raise InteractionResponded(self._parent)
defer_type: int = 0
@@ -416,14 +416,15 @@ class InteractionResponse:
elif parent.type is InteractionType.application_command:
defer_type = InteractionResponseType.deferred_channel_message.value
if ephemeral:
data = {'flags': 64}
data = {"flags": 64}
if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=defer_type, data=data
)
self._responded = True
self.responded_at = utils.utcnow()
async def pong(self) -> None:
"""|coro|
@@ -439,7 +440,7 @@ class InteractionResponse:
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
if self.is_done():
raise InteractionResponded(self._parent)
parent = self._parent
@@ -448,7 +449,7 @@ class InteractionResponse:
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
)
self._responded = True
self.responded_at = utils.utcnow()
async def send_message(
self,
@@ -494,32 +495,32 @@ class InteractionResponse:
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
if self.is_done():
raise InteractionResponded(self._parent)
payload: Dict[str, Any] = {
'tts': tts,
"tts": tts,
}
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix embed and embeds keyword arguments')
raise TypeError("cannot mix embed and embeds keyword arguments")
if embed is not MISSING:
embeds = [embed]
if embeds:
if len(embeds) > 10:
raise ValueError('embeds cannot exceed maximum of 10 elements')
payload['embeds'] = [e.to_dict() for e in embeds]
raise ValueError("embeds cannot exceed maximum of 10 elements")
payload["embeds"] = [e.to_dict() for e in embeds]
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
if ephemeral:
payload['flags'] = 64
payload["flags"] = 64
if view is not MISSING:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
parent = self._parent
adapter = async_context.get()
@@ -537,7 +538,7 @@ class InteractionResponse:
self._parent._state.store_view(view)
self._responded = True
self.responded_at = utils.utcnow()
async def edit_message(
self,
@@ -578,7 +579,7 @@ class InteractionResponse:
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
if self.is_done():
raise InteractionResponded(self._parent)
parent = self._parent
@@ -591,12 +592,12 @@ class InteractionResponse:
payload = {}
if content is not MISSING:
if content is None:
payload['content'] = None
payload["content"] = None
else:
payload['content'] = str(content)
payload["content"] = str(content)
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix both embed and embeds keyword arguments')
raise TypeError("cannot mix both embed and embeds keyword arguments")
if embed is not MISSING:
if embed is None:
@@ -605,17 +606,17 @@ class InteractionResponse:
embeds = [embed]
if embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds]
payload["embeds"] = [e.to_dict() for e in embeds]
if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments]
payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING:
state.prevent_view_updates_for(message_id)
if view is None:
payload['components'] = []
payload["components"] = []
else:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
@@ -629,11 +630,11 @@ class InteractionResponse:
if view and not view.is_finished():
state.store_view(view, message_id)
self._responded = True
self.responded_at = utils.utcnow()
class _InteractionMessageState:
__slots__ = ('_parent', '_interaction')
__slots__ = ("_parent", "_interaction")
def __init__(self, interaction: Interaction, parent: ConnectionState):
self._interaction: Interaction = interaction

View File

@@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum
from .appinfo import PartialAppInfo
__all__ = (
'PartialInviteChannel',
'PartialInviteGuild',
'Invite',
"PartialInviteChannel",
"PartialInviteGuild",
"Invite",
)
if TYPE_CHECKING:
@@ -52,8 +52,8 @@ if TYPE_CHECKING:
from .abc import GuildChannel
from .user import User
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object]
InviteGuildType = Union[Guild, "PartialInviteGuild", Object]
InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object]
import datetime
@@ -92,23 +92,23 @@ class PartialInviteChannel:
The partial channel's type.
"""
__slots__ = ('id', 'name', 'type')
__slots__ = ("id", "name", "type")
def __init__(self, data: InviteChannelPayload):
self.id: int = int(data['id'])
self.name: str = data['name']
self.type: ChannelType = try_enum(ChannelType, data['type'])
self.id: int = int(data["id"])
self.name: str = data["name"]
self.type: ChannelType = try_enum(ChannelType, data["type"])
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>'
return f"<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>"
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime.datetime:
@@ -154,26 +154,26 @@ class PartialInviteGuild:
The partial guild's description.
"""
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description')
__slots__ = ("_state", "features", "_icon", "_banner", "id", "name", "_splash", "verification_level", "description")
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state
self.id: int = id
self.name: str = data['name']
self.features: List[str] = data.get('features', [])
self._icon: Optional[str] = data.get('icon')
self._banner: Optional[str] = data.get('banner')
self._splash: Optional[str] = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description')
self.name: str = data["name"]
self.features: List[str] = data.get("features", [])
self._icon: Optional[str] = data.get("icon")
self._banner: Optional[str] = data.get("banner")
self._splash: Optional[str] = data.get("splash")
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get("verification_level"))
self.description: Optional[str] = data.get("description")
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} '
f'description={self.description!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} "
f"description={self.description!r}>"
)
@property
@@ -193,17 +193,17 @@ class PartialInviteGuild:
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None:
return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners')
return Asset._from_guild_image(self._state, self.id, self._banner, path="banners")
@property
def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None:
return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes")
I = TypeVar('I', bound='Invite')
I = TypeVar("I", bound="Invite")
class Invite(Hashable):
@@ -308,26 +308,26 @@ class Invite(Hashable):
"""
__slots__ = (
'max_age',
'code',
'guild',
'revoked',
'created_at',
'uses',
'temporary',
'max_uses',
'inviter',
'channel',
'target_user',
'target_type',
'_state',
'approximate_member_count',
'approximate_presence_count',
'target_application',
'expires_at',
"max_age",
"code",
"guild",
"revoked",
"created_at",
"uses",
"temporary",
"max_uses",
"inviter",
"channel",
"target_user",
"target_type",
"_state",
"approximate_member_count",
"approximate_presence_count",
"target_application",
"expires_at",
)
BASE = 'https://discord.gg'
BASE = "https://discord.gg"
def __init__(
self,
@@ -338,31 +338,33 @@ class Invite(Hashable):
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
):
self._state: ConnectionState = state
self.max_age: Optional[int] = data.get('max_age')
self.code: str = data['code']
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
self.revoked: Optional[bool] = data.get('revoked')
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.temporary: Optional[bool] = data.get('temporary')
self.uses: Optional[int] = data.get('uses')
self.max_uses: Optional[int] = data.get('max_uses')
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
self.max_age: Optional[int] = data.get("max_age")
self.code: str = data["code"]
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild)
self.revoked: Optional[bool] = data.get("revoked")
self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.temporary: Optional[bool] = data.get("temporary")
self.uses: Optional[int] = data.get("uses")
self.max_uses: Optional[int] = data.get("max_uses")
self.approximate_presence_count: Optional[int] = data.get("approximate_presence_count")
self.approximate_member_count: Optional[int] = data.get("approximate_member_count")
expires_at = data.get('expires_at', None)
expires_at = data.get("expires_at", None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
inviter_data = data.get('inviter')
inviter_data = data.get("inviter")
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get("channel"), channel)
target_user_data = data.get('target_user')
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
target_user_data = data.get("target_user")
self.target_user: Optional[User] = (
None if target_user_data is None else self._state.create_user(target_user_data)
)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
application = data.get('target_application')
application = data.get("target_application")
self.target_application: Optional[PartialAppInfo] = (
PartialAppInfo(data=application, state=state) if application else None
)
@@ -371,12 +373,12 @@ class Invite(Hashable):
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I:
guild: Optional[Union[Guild, PartialInviteGuild]]
try:
guild_data = data['guild']
guild_data = data["guild"]
except KeyError:
# If we're here, then this is a group DM
guild = None
else:
guild_id = int(guild_data['id'])
guild_id = int(guild_data["id"])
guild = state._get_guild(guild_id)
if guild is None:
# If it's not cached, then it has to be a partial guild
@@ -384,7 +386,7 @@ class Invite(Hashable):
# As far as I know, invites always need a channel
# So this should never raise.
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel'])
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data["channel"])
if guild is not None and not isinstance(guild, PartialInviteGuild):
# Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel
@@ -393,9 +395,9 @@ class Invite(Hashable):
@classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id'])
channel_id = int(data["channel_id"])
if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
else:
@@ -415,7 +417,7 @@ class Invite(Hashable):
if data is None:
return None
guild_id = int(data['id'])
guild_id = int(data["id"])
return PartialInviteGuild(self._state, data, guild_id)
def _resolve_channel(
@@ -439,9 +441,9 @@ class Invite(Hashable):
def __repr__(self) -> str:
return (
f'<Invite code={self.code!r} guild={self.guild!r} '
f'online={self.approximate_presence_count} '
f'members={self.approximate_member_count}>'
f"<Invite code={self.code!r} guild={self.guild!r} "
f"online={self.approximate_presence_count} "
f"members={self.approximate_member_count}>"
)
def __hash__(self) -> int:
@@ -455,7 +457,7 @@ class Invite(Hashable):
@property
def url(self) -> str:
""":class:`str`: A property that retrieves the invite URL."""
return self.BASE + '/' + self.code
return self.BASE + "/" + self.code
async def delete(self, *, reason: Optional[str] = None):
"""|coro|

View File

@@ -34,11 +34,11 @@ from .object import Object
from .audit_logs import AuditLogEntry
__all__ = (
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
"ReactionIterator",
"HistoryIterator",
"AuditLogIterator",
"GuildIterator",
"MemberIterator",
)
if TYPE_CHECKING:
@@ -67,8 +67,8 @@ if TYPE_CHECKING:
from .threads import Thread
from .abc import Snowflake
T = TypeVar('T')
OT = TypeVar('OT')
T = TypeVar("T")
OT = TypeVar("OT")
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
@@ -83,7 +83,7 @@ class _AsyncIterator(AsyncIterator[T]):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split('__')
nested = attr.split("__")
obj = elem
for attribute in nested:
obj = getattr(obj, attribute)
@@ -107,7 +107,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.')
raise ValueError("async iterator chunk sizes must be greater than 0.")
return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@@ -182,7 +182,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@@ -218,14 +218,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
if data:
self.limit -= retrieve
self.after = Object(id=int(data[-1]['id']))
self.after = Object(id=int(data[-1]["id"]))
if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data):
await self.users.put(User(state=self.state, data=element))
else:
for element in reversed(data):
member_id = int(element['id'])
member_id = int(element["id"])
member = self.guild.get_member(member_id)
if member is not None:
await self.users.put(member)
@@ -233,7 +233,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']):
class HistoryIterator(_AsyncIterator["Message"]):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@@ -295,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self.around:
if self.limit is None:
raise ValueError('history does not support around with limit=None')
raise ValueError("history does not support around with limit=None")
if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101:
@@ -303,20 +303,20 @@ class HistoryIterator(_AsyncIterator['Message']):
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
elif self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < int(m['id'])
self._filter = lambda m: self.after.id < int(m["id"])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def next(self) -> Message:
if self.messages.empty():
@@ -337,7 +337,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return r > 0
async def fill_messages(self):
if not hasattr(self, 'channel'):
if not hasattr(self, "channel"):
# do the required set up
channel = await self.messageable._get_channel()
self.channel = channel
@@ -367,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_messages_after_strategy(self, retrieve):
@@ -377,7 +377,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
async def _retrieve_messages_around_strategy(self, retrieve):
@@ -390,7 +390,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@@ -420,11 +420,11 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m['id']) < self.before.id
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
@@ -432,24 +432,24 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries
self.before = Object(id=int(entries[-1]["id"]))
return data.get("users", []), entries
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
)
entries = data.get('audit_log_entries', [])
entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
self.after = Object(id=int(entries[0]["id"]))
return data.get("users", []), entries
async def next(self) -> AuditLogEntry:
if self.entries.empty():
@@ -488,13 +488,13 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
for element in data:
# TODO: remove this if statement later
if element['action_type'] is None:
if element["action_type"] is None:
continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator['Guild']):
class GuildIterator(_AsyncIterator["Guild"]):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
@@ -543,7 +543,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id
self._filter = lambda m: int(m["id"]) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
@@ -595,7 +595,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]['id']))
self.before = Object(id=int(data[-1]["id"]))
return data
async def _retrieve_guilds_after_strategy(self, retrieve):
@@ -605,11 +605,11 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]['id']))
self.after = Object(id=int(data[0]["id"]))
return data
class MemberIterator(_AsyncIterator['Member']):
class MemberIterator(_AsyncIterator["Member"]):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
@@ -652,7 +652,7 @@ class MemberIterator(_AsyncIterator['Member']):
if len(data) < 1000:
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id']))
self.after = Object(id=int(data[-1]["user"]["id"]))
for element in reversed(data):
await self.members.put(self.create_member(element))
@@ -663,7 +663,7 @@ class MemberIterator(_AsyncIterator['Member']):
return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']):
class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def __init__(
self,
channel_id: int,
@@ -681,7 +681,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
self.http = guild._state.http
if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads')
raise ValueError("Cannot iterate over joined public archived threads")
self.before: Optional[str]
if before is None:
@@ -721,11 +721,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp']
return data["thread_metadata"]["archive_timestamp"]
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore
return data["id"] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
@@ -735,11 +735,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', [])
threads: List[ThreadPayload] = data.get("threads", [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False)
self.has_more = data.get("has_more", False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
@@ -750,4 +750,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)

View File

@@ -44,8 +44,8 @@ from .colour import Colour
from .object import Object
__all__ = (
'VoiceState',
'Member',
"VoiceState",
"Member",
)
if TYPE_CHECKING:
@@ -113,52 +113,54 @@ class VoiceState:
"""
__slots__ = (
'session_id',
'deaf',
'mute',
'self_mute',
'self_stream',
'self_video',
'self_deaf',
'afk',
'channel',
'requested_to_speak_at',
'suppress',
"session_id",
"deaf",
"mute",
"self_mute",
"self_stream",
"self_video",
"self_deaf",
"afk",
"channel",
"requested_to_speak_at",
"suppress",
)
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get('session_id')
self.session_id: str = data.get("session_id")
self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
self.self_video: bool = data.get('self_video', False)
self.afk: bool = data.get('suppress', False)
self.mute: bool = data.get('mute', False)
self.deaf: bool = data.get('deaf', False)
self.suppress: bool = data.get('suppress', False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp'))
self.self_mute: bool = data.get("self_mute", False)
self.self_deaf: bool = data.get("self_deaf", False)
self.self_stream: bool = data.get("self_stream", False)
self.self_video: bool = data.get("self_video", False)
self.afk: bool = data.get("suppress", False)
self.mute: bool = data.get("mute", False)
self.deaf: bool = data.get("deaf", False)
self.suppress: bool = data.get("suppress", False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(
data.get("request_to_speak_timestamp")
)
self.channel: Optional[VocalGuildChannel] = channel
def __repr__(self) -> str:
attrs = [
('self_mute', self.self_mute),
('self_deaf', self.self_deaf),
('self_stream', self.self_stream),
('suppress', self.suppress),
('requested_to_speak_at', self.requested_to_speak_at),
('channel', self.channel),
("self_mute", self.self_mute),
("self_deaf", self.self_deaf),
("self_stream", self.self_stream),
("suppress", self.suppress),
("requested_to_speak_at", self.requested_to_speak_at),
("channel", self.channel),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
inner = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods
if attr.startswith('_'):
if attr.startswith("_"):
continue
# don't override what we already have
@@ -167,9 +169,9 @@ def flatten_user(cls):
# if it's a slotted attribute or a property, redirect it
# slotted members are implemented as member_descriptors in Type.__dict__
if not hasattr(value, '__annotations__'):
getter = attrgetter('_user.' + attr)
setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`'))
if not hasattr(value, "__annotations__"):
getter = attrgetter("_user." + attr)
setattr(cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`"))
else:
# Technically, this can also use attrgetter
# However I'm not sure how I feel about "functions" returning properties
@@ -197,7 +199,7 @@ def flatten_user(cls):
return cls
M = TypeVar('M', bound='Member')
M = TypeVar("M", bound="Member")
@flatten_user
@@ -258,17 +260,17 @@ class Member(discord.abc.Messageable, _UserTag):
"""
__slots__ = (
'_roles',
'joined_at',
'premium_since',
'activities',
'guild',
'pending',
'nick',
'_client_status',
'_user',
'_state',
'_avatar',
"_roles",
"joined_at",
"premium_since",
"activities",
"guild",
"pending",
"nick",
"_client_status",
"_user",
"_state",
"_avatar",
)
if TYPE_CHECKING:
@@ -290,16 +292,16 @@ class Member(discord.abc.Messageable, _UserTag):
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state
self._user: User = state.store_user(data['user'])
self._user: User = state.store_user(data["user"])
self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
self._client_status: Dict[Optional[str], str] = {None: 'offline'}
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get("joined_at"))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get("premium_since"))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"]))
self._client_status: Dict[Optional[str], str] = {None: "offline"}
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
self._avatar: Optional[str] = data.get('avatar')
self.nick: Optional[str] = data.get("nick", None)
self.pending: bool = data.get("pending", False)
self._avatar: Optional[str] = data.get("avatar")
def __str__(self) -> str:
return str(self._user)
@@ -309,8 +311,8 @@ class Member(discord.abc.Messageable, _UserTag):
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
f"<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}"
f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>"
)
def __eq__(self, other: Any) -> bool:
@@ -325,25 +327,27 @@ class Member(discord.abc.Messageable, _UserTag):
@classmethod
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
author = message.author
data['user'] = author._to_minimal_user_json() # type: ignore
data["user"] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
self.joined_at = utils.parse_time(data.get("joined_at"))
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self.nick = data.get("nick", None)
self.pending = data.get("pending", False)
@classmethod
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]:
def _try_upgrade(
cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState
) -> Union[User, M]:
# A User object with a 'member' key
try:
member_data = data.pop('member')
member_data = data.pop("member")
except KeyError:
return state.create_user(data)
else:
member_data['user'] = data # type: ignore
member_data["user"] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore
@classmethod
@@ -374,25 +378,25 @@ class Member(discord.abc.Messageable, _UserTag):
# the nickname change is optional,
# if it isn't in the payload then it didn't change
try:
self.nick = data['nick']
self.nick = data["nick"]
except KeyError:
pass
try:
self.pending = data['pending']
self.pending = data["pending"]
except KeyError:
pass
self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles']))
self._avatar = data.get('avatar')
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self._avatar = data.get("avatar")
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data['activities']))
self.activities = tuple(map(create_activity, data["activities"]))
self._client_status = {
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore
sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore
}
self._client_status[None] = sys.intern(data['status'])
self._client_status[None] = sys.intern(data["status"])
if len(user) > 1:
return self._update_inner_user(user)
@@ -402,7 +406,7 @@ class Member(discord.abc.Messageable, _UserTag):
u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags)
# These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
modified = (user["username"], user["avatar"], user["discriminator"], user.get("public_flags", 0))
if original != modified:
to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified
@@ -430,21 +434,21 @@ class Member(discord.abc.Messageable, _UserTag):
@property
def mobile_status(self) -> Status:
""":class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get('mobile', 'offline'))
return try_enum(Status, self._client_status.get("mobile", "offline"))
@property
def desktop_status(self) -> Status:
""":class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get('desktop', 'offline'))
return try_enum(Status, self._client_status.get("desktop", "offline"))
@property
def web_status(self) -> Status:
""":class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get('web', 'offline'))
return try_enum(Status, self._client_status.get("web", "offline"))
def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
return 'mobile' in self._client_status
return "mobile" in self._client_status
@property
def colour(self) -> Colour:
@@ -497,8 +501,8 @@ class Member(discord.abc.Messageable, _UserTag):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the member."""
if self.nick:
return f'<@!{self._user.id}>'
return f'<@{self._user.id}>'
return f"<@!{self._user.id}>"
return f"<@{self._user.id}>"
@property
def display_name(self) -> str:
@@ -720,39 +724,39 @@ class Member(discord.abc.Messageable, _UserTag):
payload: Dict[str, Any] = {}
if nick is not MISSING:
nick = nick or ''
nick = nick or ""
if me:
await http.change_my_nickname(guild_id, nick, reason=reason)
else:
payload['nick'] = nick
payload["nick"] = nick
if deafen is not MISSING:
payload['deaf'] = deafen
payload["deaf"] = deafen
if mute is not MISSING:
payload['mute'] = mute
payload["mute"] = mute
if suppress is not MISSING:
voice_state_payload = {
'channel_id': self.voice.channel.id,
'suppress': suppress,
"channel_id": self.voice.channel.id,
"suppress": suppress,
}
if suppress or self.bot:
voice_state_payload['request_to_speak_timestamp'] = None
voice_state_payload["request_to_speak_timestamp"] = None
if me:
await http.edit_my_voice_state(guild_id, voice_state_payload)
else:
if not suppress:
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat()
voice_state_payload["request_to_speak_timestamp"] = datetime.datetime.utcnow().isoformat()
await http.edit_voice_state(guild_id, self.id, voice_state_payload)
if voice_channel is not MISSING:
payload['channel_id'] = voice_channel and voice_channel.id
payload["channel_id"] = voice_channel and voice_channel.id
if roles is not MISSING:
payload['roles'] = tuple(r.id for r in roles)
payload["roles"] = tuple(r.id for r in roles)
if payload:
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
@@ -780,12 +784,12 @@ class Member(discord.abc.Messageable, _UserTag):
The operation failed.
"""
payload = {
'channel_id': self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),
"channel_id": self.voice.channel.id,
"request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(),
}
if self._state.self_id != self.id:
payload['suppress'] = False
payload["suppress"] = False
await self._state.http.edit_voice_state(self.guild.id, self.id, payload)
else:
await self._state.http.edit_my_voice_state(self.guild.id, payload)

View File

@@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
__all__ = (
'AllowedMentions',
)
__all__ = ("AllowedMentions",)
if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload
@@ -36,7 +34,7 @@ if TYPE_CHECKING:
class _FakeBool:
def __repr__(self):
return 'True'
return "True"
def __eq__(self, other):
return other is True
@@ -47,7 +45,7 @@ class _FakeBool:
default: Any = _FakeBool()
A = TypeVar('A', bound='AllowedMentions')
A = TypeVar("A", bound="AllowedMentions")
class AllowedMentions:
@@ -80,7 +78,7 @@ class AllowedMentions:
.. versionadded:: 1.6
"""
__slots__ = ('everyone', 'users', 'roles', 'replied_user')
__slots__ = ("everyone", "users", "roles", "replied_user")
def __init__(
self,
@@ -116,22 +114,22 @@ class AllowedMentions:
data = {}
if self.everyone:
parse.append('everyone')
parse.append("everyone")
if self.users == True:
parse.append('users')
parse.append("users")
elif self.users != False:
data['users'] = [x.id for x in self.users]
data["users"] = [x.id for x in self.users]
if self.roles == True:
parse.append('roles')
parse.append("roles")
elif self.roles != False:
data['roles'] = [x.id for x in self.roles]
data["roles"] = [x.id for x in self.roles]
if self.replied_user:
data['replied_user'] = True
data["replied_user"] = True
data['parse'] = parse
data["parse"] = parse
return data # type: ignore
def merge(self, other: AllowedMentions) -> AllowedMentions:
@@ -146,6 +144,6 @@ class AllowedMentions:
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(everyone={self.everyone}, '
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})'
f"{self.__class__.__name__}(everyone={self.everyone}, "
f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})"
)

View File

@@ -29,7 +29,21 @@ import datetime
import re
import io
from os import PathLike
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type
from typing import (
Dict,
TYPE_CHECKING,
Union,
List,
Optional,
Any,
Callable,
Tuple,
ClassVar,
Optional,
overload,
TypeVar,
Type,
)
from . import utils
from .reaction import Reaction
@@ -76,15 +90,15 @@ if TYPE_CHECKING:
from .role import Role
from .ui.view import View
MR = TypeVar('MR', bound='MessageReference')
MR = TypeVar("MR", bound="MessageReference")
EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = (
'Attachment',
'Message',
'PartialMessage',
'MessageReference',
'DeletedReferencedMessage',
"Attachment",
"Message",
"PartialMessage",
"MessageReference",
"DeletedReferencedMessage",
)
@@ -93,15 +107,15 @@ def convert_emoji_reaction(emoji):
emoji = emoji.emoji
if isinstance(emoji, Emoji):
return f'{emoji.name}:{emoji.id}'
return f"{emoji.name}:{emoji.id}"
if isinstance(emoji, PartialEmoji):
return emoji._as_reaction()
if isinstance(emoji, str):
# Reactions can be in :name:id format, but not <:name:id>.
# No existing emojis have <> in them, so this should be okay.
return emoji.strip('<>')
return emoji.strip("<>")
raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.')
raise InvalidArgument(f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.")
class Attachment(Hashable):
@@ -155,30 +169,36 @@ class Attachment(Hashable):
The attachment's `media type <https://en.wikipedia.org/wiki/Media_type>`_
.. versionadded:: 1.7
ephemeral: Optional[:class:`bool`]
If the attachment is ephemeral. Ephemeral attachments are temporary and
will automatically be removed after a set period of time.
.. versionadded:: 2.0
"""
__slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type')
__slots__ = ("id", "size", "height", "width", "filename", "url", "proxy_url", "ephemeral", "_http", "content_type")
def __init__(self, *, data: AttachmentPayload, state: ConnectionState):
self.id: int = int(data['id'])
self.size: int = data['size']
self.height: Optional[int] = data.get('height')
self.width: Optional[int] = data.get('width')
self.filename: str = data['filename']
self.url: str = data.get('url')
self.proxy_url: str = data.get('proxy_url')
self.id: int = int(data["id"])
self.size: int = data["size"]
self.height: Optional[int] = data.get("height")
self.width: Optional[int] = data.get("width")
self.filename: str = data["filename"]
self.url: str = data.get("url")
self.proxy_url: str = data.get("proxy_url")
self._http = state.http
self.content_type: Optional[str] = data.get('content_type')
self.content_type: Optional[str] = data.get("content_type")
self.ephemeral: Optional[bool] = data.get("ephemeral")
def is_spoiler(self) -> bool:
""":class:`bool`: Whether this attachment contains a spoiler."""
return self.filename.startswith('SPOILER_')
return self.filename.startswith("SPOILER_")
def __repr__(self) -> str:
return f'<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>'
return f"<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>"
def __str__(self) -> str:
return self.url or ''
return self.url or ""
async def save(
self,
@@ -227,7 +247,7 @@ class Attachment(Hashable):
fp.seek(0)
return written
else:
with open(fp, 'wb') as f:
with open(fp, "wb") as f:
return f.write(data)
async def read(self, *, use_cached: bool = False) -> bytes:
@@ -309,19 +329,19 @@ class Attachment(Hashable):
def to_dict(self) -> AttachmentPayload:
result: AttachmentPayload = {
'filename': self.filename,
'id': self.id,
'proxy_url': self.proxy_url,
'size': self.size,
'url': self.url,
'spoiler': self.is_spoiler(),
"filename": self.filename,
"id": self.id,
"proxy_url": self.proxy_url,
"size": self.size,
"url": self.url,
"spoiler": self.is_spoiler(),
}
if self.height:
result['height'] = self.height
result["height"] = self.height
if self.width:
result['width'] = self.width
result["width"] = self.width
if self.content_type:
result['content_type'] = self.content_type
result["content_type"] = self.content_type
return result
@@ -335,7 +355,7 @@ class DeletedReferencedMessage:
.. versionadded:: 1.6
"""
__slots__ = ('_parent',)
__slots__ = ("_parent",)
def __init__(self, parent: MessageReference):
self._parent: MessageReference = parent
@@ -394,9 +414,11 @@ class MessageReference:
.. versionadded:: 1.6
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state')
__slots__ = ("message_id", "channel_id", "guild_id", "fail_if_not_exists", "resolved", "_state")
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True):
def __init__(
self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True
):
self._state: Optional[ConnectionState] = None
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
self.message_id: Optional[int] = message_id
@@ -407,10 +429,10 @@ class MessageReference:
@classmethod
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
self = cls.__new__(cls)
self.message_id = utils._get_as_snowflake(data, 'message_id')
self.channel_id = int(data.pop('channel_id'))
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.fail_if_not_exists = data.get('fail_if_not_exists', True)
self.message_id = utils._get_as_snowflake(data, "message_id")
self.channel_id = int(data.pop("channel_id"))
self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.fail_if_not_exists = data.get("fail_if_not_exists", True)
self._state = state
self.resolved = None
return self
@@ -439,7 +461,7 @@ class MessageReference:
self = cls(
message_id=message.id,
channel_id=message.channel.id,
guild_id=getattr(message.guild, 'id', None),
guild_id=getattr(message.guild, "id", None),
fail_if_not_exists=fail_if_not_exists,
)
self._state = message._state
@@ -456,36 +478,36 @@ class MessageReference:
.. versionadded:: 1.7
"""
guild_id = self.guild_id if self.guild_id is not None else '@me'
return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}'
guild_id = self.guild_id if self.guild_id is not None else "@me"
return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}"
def __repr__(self) -> str:
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>'
return f"<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>"
def to_dict(self) -> MessageReferencePayload:
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {}
result['channel_id'] = self.channel_id
result: MessageReferencePayload = {"message_id": self.message_id} if self.message_id is not None else {}
result["channel_id"] = self.channel_id
if self.guild_id is not None:
result['guild_id'] = self.guild_id
result["guild_id"] = self.guild_id
if self.fail_if_not_exists is not None:
result['fail_if_not_exists'] = self.fail_if_not_exists
result["fail_if_not_exists"] = self.fail_if_not_exists
return result
to_message_reference_dict = to_dict
def flatten_handlers(cls):
prefix = len('_handle_')
prefix = len("_handle_")
handlers = [
(key[prefix:], value)
for key, value in cls.__dict__.items()
if key.startswith('_handle_') and key != '_handle_member'
if key.startswith("_handle_") and key != "_handle_member"
]
# store _handle_member last
handlers.append(('member', cls._handle_member))
handlers.append(("member", cls._handle_member))
cls._HANDLERS = handlers
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')]
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")]
return cls
@@ -615,36 +637,36 @@ class Message(Hashable):
"""
__slots__ = (
'_state',
'_edited_timestamp',
'_cs_channel_mentions',
'_cs_raw_mentions',
'_cs_clean_content',
'_cs_raw_channel_mentions',
'_cs_raw_role_mentions',
'_cs_system_content',
'tts',
'content',
'channel',
'webhook_id',
'mention_everyone',
'embeds',
'id',
'mentions',
'author',
'attachments',
'nonce',
'pinned',
'role_mentions',
'type',
'flags',
'reactions',
'reference',
'application',
'activity',
'stickers',
'components',
'guild',
"_state",
"_edited_timestamp",
"_cs_channel_mentions",
"_cs_raw_mentions",
"_cs_clean_content",
"_cs_raw_channel_mentions",
"_cs_raw_role_mentions",
"_cs_system_content",
"tts",
"content",
"channel",
"webhook_id",
"mention_everyone",
"embeds",
"id",
"mentions",
"author",
"attachments",
"nonce",
"pinned",
"role_mentions",
"type",
"flags",
"reactions",
"reference",
"application",
"activity",
"stickers",
"components",
"guild",
)
if TYPE_CHECKING:
@@ -664,39 +686,39 @@ class Message(Hashable):
data: MessagePayload,
):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id')
self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])]
self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']]
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']]
self.application: Optional[MessageApplicationPayload] = data.get('application')
self.activity: Optional[MessageActivityPayload] = data.get('activity')
self.id: int = int(data["id"])
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id")
self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])]
self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]]
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]]
self.application: Optional[MessageApplicationPayload] = data.get("application")
self.activity: Optional[MessageActivityPayload] = data.get("activity")
self.channel: MessageableChannel = channel
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp'])
self.type: MessageType = try_enum(MessageType, data['type'])
self.pinned: bool = data['pinned']
self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0))
self.mention_everyone: bool = data['mention_everyone']
self.tts: bool = data['tts']
self.content: str = data['content']
self.nonce: Optional[Union[int, str]] = data.get('nonce')
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])]
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data["edited_timestamp"])
self.type: MessageType = try_enum(MessageType, data["type"])
self.pinned: bool = data["pinned"]
self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0))
self.mention_everyone: bool = data["mention_everyone"]
self.tts: bool = data["tts"]
self.content: str = data["content"]
self.nonce: Optional[Union[int, str]] = data.get("nonce")
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])]
self.components: List[Component] = [_component_factory(d) for d in data.get("components", [])]
try:
# if the channel doesn't have a guild attribute, we handle that
self.guild = channel.guild # type: ignore
except AttributeError:
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id'))
self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id"))
try:
ref = data['message_reference']
ref = data["message_reference"]
except KeyError:
self.reference = None
else:
self.reference = ref = MessageReference.with_state(state, ref)
try:
resolved = data['referenced_message']
resolved = data["referenced_message"]
except KeyError:
pass
else:
@@ -712,18 +734,15 @@ class Message(Hashable):
# the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'):
for handler in ("author", "member", "mentions", "mention_roles"):
try:
getattr(self, f'_handle_{handler}')(data[handler])
getattr(self, f"_handle_{handler}")(data[handler])
except KeyError:
continue
def __repr__(self) -> str:
name = self.__class__.__name__
return (
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
)
return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>"
def __str__(self) -> Optional[str]:
return self.content
@@ -741,7 +760,7 @@ class Message(Hashable):
def _add_reaction(self, data, emoji, user_id) -> Reaction:
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
is_me = data['me'] = user_id == self._state.self_id
is_me = data["me"] = user_id == self._state.self_id
if reaction is None:
reaction = Reaction(message=self, data=data, emoji=emoji)
@@ -758,7 +777,7 @@ class Message(Hashable):
if reaction is None:
# already removed?
raise ValueError('Emoji already removed?')
raise ValueError("Emoji already removed?")
# if reaction isn't in the list, we crash. This means discord
# sent bad data, or we stored improperly
@@ -872,7 +891,7 @@ class Message(Hashable):
return
for mention in filter(None, mentions):
id_search = int(mention['id'])
id_search = int(mention["id"])
member = guild.get_member(id_search)
if member is not None:
r.append(member)
@@ -894,7 +913,7 @@ class Message(Hashable):
self.guild = new_guild
self.channel = new_channel
@utils.cached_slot_property('_cs_raw_mentions')
@utils.cached_slot_property("_cs_raw_mentions")
def raw_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of user IDs matched with
the syntax of ``<@user_id>`` in the message content.
@@ -902,30 +921,30 @@ class Message(Hashable):
This allows you to receive the user IDs of mentioned users
even in a private message context.
"""
return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_channel_mentions')
@utils.cached_slot_property("_cs_raw_channel_mentions")
def raw_channel_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of channel IDs matched with
the syntax of ``<#channel_id>`` in the message content.
"""
return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_role_mentions')
@utils.cached_slot_property("_cs_raw_role_mentions")
def raw_role_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of role IDs matched with
the syntax of ``<@&role_id>`` in the message content.
"""
return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)]
return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_channel_mentions')
@utils.cached_slot_property("_cs_channel_mentions")
def channel_mentions(self) -> List[GuildChannel]:
if self.guild is None:
return []
it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions))
return utils._unique(it)
@utils.cached_slot_property('_cs_clean_content')
@utils.cached_slot_property("_cs_clean_content")
def clean_content(self) -> str:
""":class:`str`: A property that returns the content in a "cleaned up"
manner. This basically means that mentions are transformed
@@ -972,9 +991,9 @@ class Message(Hashable):
# fmt: on
def repl(obj):
return transformations.get(re.escape(obj.group(0)), '')
return transformations.get(re.escape(obj.group(0)), "")
pattern = re.compile('|'.join(transformations.keys()))
pattern = re.compile("|".join(transformations.keys()))
result = pattern.sub(repl, self.content)
return escape_mentions(result)
@@ -991,8 +1010,8 @@ class Message(Hashable):
@property
def jump_url(self) -> str:
""":class:`str`: Returns a URL that allows the client to jump to this message."""
guild_id = getattr(self.guild, 'id', '@me')
return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}'
guild_id = getattr(self.guild, "id", "@me")
return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}"
def is_system(self) -> bool:
""":class:`bool`: Whether the message is a system message.
@@ -1009,7 +1028,7 @@ class Message(Hashable):
MessageType.thread_starter_message,
)
@utils.cached_slot_property('_cs_system_content')
@utils.cached_slot_property("_cs_system_content")
def system_content(self):
r""":class:`str`: A property that returns the content that is rendered
regardless of the :attr:`Message.type`.
@@ -1024,24 +1043,24 @@ class Message(Hashable):
if self.type is MessageType.recipient_add:
if self.channel.type is ChannelType.group:
return f'{self.author.name} added {self.mentions[0].name} to the group.'
return f"{self.author.name} added {self.mentions[0].name} to the group."
else:
return f'{self.author.name} added {self.mentions[0].name} to the thread.'
return f"{self.author.name} added {self.mentions[0].name} to the thread."
if self.type is MessageType.recipient_remove:
if self.channel.type is ChannelType.group:
return f'{self.author.name} removed {self.mentions[0].name} from the group.'
return f"{self.author.name} removed {self.mentions[0].name} from the group."
else:
return f'{self.author.name} removed {self.mentions[0].name} from the thread.'
return f"{self.author.name} removed {self.mentions[0].name} from the thread."
if self.type is MessageType.channel_name_change:
return f'{self.author.name} changed the channel name: **{self.content}**'
return f"{self.author.name} changed the channel name: **{self.content}**"
if self.type is MessageType.channel_icon_change:
return f'{self.author.name} changed the channel icon.'
return f"{self.author.name} changed the channel icon."
if self.type is MessageType.pins_add:
return f'{self.author.name} pinned a message to this channel.'
return f"{self.author.name} pinned a message to this channel."
if self.type is MessageType.new_member:
formats = [
@@ -1065,62 +1084,62 @@ class Message(Hashable):
if self.type is MessageType.premium_guild_subscription:
if not self.content:
return f'{self.author.name} just boosted the server!'
return f"{self.author.name} just boosted the server!"
else:
return f'{self.author.name} just boosted the server **{self.content}** times!'
return f"{self.author.name} just boosted the server **{self.content}** times!"
if self.type is MessageType.premium_guild_tier_1:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**"
if self.type is MessageType.premium_guild_tier_2:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**"
if self.type is MessageType.premium_guild_tier_3:
if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**'
return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**"
else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**'
return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**"
if self.type is MessageType.channel_follow_add:
return f'{self.author.name} has added {self.content} to this channel'
return f"{self.author.name} has added {self.content} to this channel"
if self.type is MessageType.guild_stream:
# the author will be a Member
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore
return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore
if self.type is MessageType.guild_discovery_disqualified:
return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.'
return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details."
if self.type is MessageType.guild_discovery_requalified:
return 'This server is eligible for Server Discovery again and has been automatically relisted!'
return "This server is eligible for Server Discovery again and has been automatically relisted!"
if self.type is MessageType.guild_discovery_grace_period_initial_warning:
return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.'
return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery."
if self.type is MessageType.guild_discovery_grace_period_final_warning:
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.'
return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery."
if self.type is MessageType.thread_created:
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.'
return f"{self.author.name} started a thread: **{self.content}**. See all **threads**."
if self.type is MessageType.reply:
return self.content
if self.type is MessageType.thread_starter_message:
if self.reference is None or self.reference.resolved is None:
return 'Sorry, we couldn\'t load the first message in this thread'
return "Sorry, we couldn't load the first message in this thread"
# the resolved message for the reference will be a Message
return self.reference.resolved.content # type: ignore
if self.type is MessageType.guild_invite_reminder:
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!'
return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!"
async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
"""|coro|
@@ -1271,45 +1290,45 @@ class Message(Hashable):
payload: Dict[str, Any] = {}
if content is not MISSING:
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
else:
payload['content'] = None
payload["content"] = None
if embed is not MISSING and embeds is not MISSING:
raise InvalidArgument('cannot pass both embed and embeds parameter to edit()')
raise InvalidArgument("cannot pass both embed and embeds parameter to edit()")
if embed is not MISSING:
if embed is None:
payload['embeds'] = []
payload["embeds"] = []
else:
payload['embeds'] = [embed.to_dict()]
payload["embeds"] = [embed.to_dict()]
elif embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds]
payload["embeds"] = [e.to_dict() for e in embeds]
if suppress is not MISSING:
flags = MessageFlags._from_value(self.flags.value)
flags.suppress_embeds = suppress
payload['flags'] = flags.value
payload["flags"] = flags.value
if allowed_mentions is MISSING:
if self._state.allowed_mentions is not None and self.author.id == self._state.self_id:
payload['allowed_mentions'] = self._state.allowed_mentions.to_dict()
payload["allowed_mentions"] = self._state.allowed_mentions.to_dict()
else:
if allowed_mentions is not None:
if self._state.allowed_mentions is not None:
payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
payload["allowed_mentions"] = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
else:
payload['allowed_mentions'] = allowed_mentions.to_dict()
payload["allowed_mentions"] = allowed_mentions.to_dict()
if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments]
payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING:
self._state.prevent_view_updates_for(self.id)
if view:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
else:
payload['components'] = []
payload["components"] = []
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data)
@@ -1551,9 +1570,11 @@ class Message(Hashable):
The created thread.
"""
if self.guild is None:
raise InvalidArgument('This message does not have guild info attached.')
raise InvalidArgument("This message does not have guild info attached.")
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440)
default_auto_archive_duration: ThreadArchiveDuration = getattr(
self.channel, "default_auto_archive_duration", 1440
)
data = await self._state.http.start_thread_with_message(
self.channel.id,
self.id,
@@ -1611,12 +1632,12 @@ class Message(Hashable):
def to_message_reference_dict(self) -> MessageReferencePayload:
data: MessageReferencePayload = {
'message_id': self.id,
'channel_id': self.channel.id,
"message_id": self.id,
"channel_id": self.channel.id,
}
if self.guild is not None:
data['guild_id'] = self.guild.id
data["guild_id"] = self.guild.id
return data
@@ -1662,7 +1683,7 @@ class PartialMessage(Hashable):
The message ID.
"""
__slots__ = ('channel', 'id', '_cs_guild', '_state')
__slots__ = ("channel", "id", "_cs_guild", "_state")
jump_url: str = Message.jump_url # type: ignore
delete = Message.delete
@@ -1686,7 +1707,7 @@ class PartialMessage(Hashable):
ChannelType.public_thread,
ChannelType.private_thread,
):
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}')
raise TypeError(f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}")
self.channel: PartialMessageableChannel = channel
self._state: ConnectionState = channel._state
@@ -1702,17 +1723,17 @@ class PartialMessage(Hashable):
pinned = property(None, lambda x, y: None)
def __repr__(self) -> str:
return f'<PartialMessage id={self.id} channel={self.channel!r}>'
return f"<PartialMessage id={self.id} channel={self.channel!r}>"
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: The partial message's creation time in UTC."""
return utils.snowflake_time(self.id)
@utils.cached_slot_property('_cs_guild')
@utils.cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable."""
return getattr(self.channel, 'guild', None)
return getattr(self.channel, "guild", None)
async def fetch(self) -> Message:
"""|coro|
@@ -1794,34 +1815,34 @@ class PartialMessage(Hashable):
"""
try:
content = fields['content']
content = fields["content"]
except KeyError:
pass
else:
if content is not None:
fields['content'] = str(content)
fields["content"] = str(content)
try:
embed = fields['embed']
embed = fields["embed"]
except KeyError:
pass
else:
if embed is not None:
fields['embed'] = embed.to_dict()
fields["embed"] = embed.to_dict()
try:
suppress: bool = fields.pop('suppress')
suppress: bool = fields.pop("suppress")
except KeyError:
pass
else:
flags = MessageFlags._from_value(0)
flags.suppress_embeds = suppress
fields['flags'] = flags.value
fields["flags"] = flags.value
delete_after = fields.pop('delete_after', None)
delete_after = fields.pop("delete_after", None)
try:
allowed_mentions = fields.pop('allowed_mentions')
allowed_mentions = fields.pop("allowed_mentions")
except KeyError:
pass
else:
@@ -1830,19 +1851,19 @@ class PartialMessage(Hashable):
allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
else:
allowed_mentions = allowed_mentions.to_dict()
fields['allowed_mentions'] = allowed_mentions
fields["allowed_mentions"] = allowed_mentions
try:
view = fields.pop('view')
view = fields.pop("view")
except KeyError:
# To check for the view afterwards
view = None
else:
self._state.prevent_view_updates_for(self.id)
if view:
fields['components'] = view.to_components()
fields["components"] = view.to_components()
else:
fields['components'] = []
fields["components"] = []
if fields:
data = await self._state.http.edit_message(self.channel.id, self.id, **fields)

View File

@@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE.
"""
__all__ = (
'EqualityComparable',
'Hashable',
"EqualityComparable",
"Hashable",
)
class EqualityComparable:
__slots__ = ()
@@ -40,6 +41,7 @@ class EqualityComparable:
return other.id != self.id
return True
class Hashable(EqualityComparable):
__slots__ = ()

View File

@@ -35,11 +35,11 @@ from typing import (
if TYPE_CHECKING:
import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
__all__ = (
'Object',
)
__all__ = ("Object",)
class Object(Hashable):
"""Represents a generic Discord object.
@@ -83,12 +83,12 @@ class Object(Hashable):
try:
id = int(id)
except ValueError:
raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None
raise TypeError(f"id parameter must be convertable to int not {id.__class__!r}") from None
else:
self.id = id
def __repr__(self) -> str:
return f'<Object id={self.id!r}>'
return f"<Object id={self.id!r}>"
@property
def created_at(self) -> datetime.datetime:

View File

@@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException
__all__ = (
'OggError',
'OggPage',
'OggStream',
"OggError",
"OggPage",
"OggStream",
)
class OggError(DiscordException):
"""An exception that is thrown for Ogg stream parsing errors."""
pass
# https://tools.ietf.org/html/rfc3533
# https://tools.ietf.org/html/rfc7845
class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB')
_header: ClassVar[struct.Struct] = struct.Struct("<xBQIIIB")
if TYPE_CHECKING:
flag: int
gran_pos: int
@@ -57,14 +61,13 @@ class OggPage:
try:
header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.flag, self.gran_pos, self.serial, self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
bodylen = sum(struct.unpack("B" * self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen)
except Exception:
raise OggError('bad data stream') from None
raise OggError("bad data stream") from None
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
packetlen = offset = 0
@@ -84,18 +87,19 @@ class OggPage:
if partial:
yield self.data[offset:], False
class OggStream:
def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream
def _next_page(self) -> Optional[OggPage]:
head = self.stream.read(4)
if head == b'OggS':
if head == b"OggS":
return OggPage(self.stream)
elif not head:
return None
else:
raise OggError('invalid header magic')
raise OggError("invalid header magic")
def _iter_pages(self) -> Generator[OggPage, None, None]:
page = self._next_page()
@@ -104,10 +108,10 @@ class OggStream:
page = self._next_page()
def iter_packets(self) -> Generator[bytes, None, None]:
partial = b''
partial = b""
for page in self._iter_pages():
for data, complete in page.iter_packets():
partial += data
if complete:
yield partial
partial = b''
partial = b""

View File

@@ -38,9 +38,10 @@ import sys
from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING:
T = TypeVar('T')
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
T = TypeVar("T")
BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"]
SIGNAL_CTL = Literal["auto", "voice", "music"]
class BandCtl(TypedDict):
narrow: int
@@ -49,15 +50,17 @@ class BandCtl(TypedDict):
superwide: int
full: int
class SignalCtl(TypedDict):
auto: int
voice: int
music: int
__all__ = (
'Encoder',
'OpusError',
'OpusNotLoaded',
"Encoder",
"OpusError",
"OpusNotLoaded",
)
_log = logging.getLogger(__name__)
@@ -68,12 +71,15 @@ c_float_ptr = ctypes.POINTER(ctypes.c_float)
_lib = None
class EncoderStruct(ctypes.Structure):
pass
class DecoderStruct(ctypes.Structure):
pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct)
@@ -98,32 +104,35 @@ CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039
band_ctl: BandCtl = {
'narrow': 1101,
'medium': 1102,
'wide': 1103,
'superwide': 1104,
'full': 1105,
"narrow": 1101,
"medium": 1102,
"wide": 1103,
"superwide": 1104,
"full": 1105,
}
signal_ctl: SignalCtl = {
'auto': -1000,
'voice': 3001,
'music': 3002,
"auto": -1000,
"voice": 3001,
"music": 3002,
}
def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK:
_log.info('error has happened in %s', func.__name__)
_log.info("error has happened in %s", func.__name__)
raise OpusError(result)
return result
def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj
if ret.value != OK:
_log.info('error has happened in %s', func.__name__)
_log.info("error has happened in %s", func.__name__)
raise OpusError(ret.value)
return result
# A list of exported functions.
# The first argument is obviously the name.
# The second one are the types of arguments it takes.
@@ -131,54 +140,51 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
# The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [
# Generic
('opus_get_version_string',
None, ctypes.c_char_p, None),
('opus_strerror',
[ctypes.c_int], ctypes.c_char_p, None),
("opus_get_version_string", None, ctypes.c_char_p, None),
("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None),
# Encoder functions
('opus_encoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_encoder_create',
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
('opus_encode',
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encode_float',
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_encoder_destroy',
[EncoderStructPtr], None, None),
("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None),
("opus_encoder_create", [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
(
"opus_encode",
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
(
"opus_encode_float",
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
("opus_encoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_encoder_destroy", [EncoderStructPtr], None, None),
# Decoder functions
('opus_decoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_decoder_create',
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
('opus_decode',
("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None),
("opus_decoder_create", [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
(
"opus_decode",
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt),
('opus_decode_float',
ctypes.c_int,
_err_lt,
),
(
"opus_decode_float",
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt),
('opus_decoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_decoder_destroy',
[DecoderStructPtr], None, None),
('opus_decoder_get_nb_samples',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
ctypes.c_int,
_err_lt,
),
("opus_decoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_decoder_destroy", [DecoderStructPtr], None, None),
("opus_decoder_get_nb_samples", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
# Packet functions
('opus_packet_get_bandwidth',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_channels',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_frames',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_samples_per_frame',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt),
("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt),
("opus_packet_get_nb_frames", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
("opus_packet_get_samples_per_frame", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
]
def libopus_loader(name: str) -> Any:
# create the library...
lib = ctypes.cdll.LoadLibrary(name)
@@ -203,22 +209,24 @@ def libopus_loader(name: str) -> Any:
return lib
def _load_default() -> bool:
global _lib
try:
if sys.platform == 'win32':
if sys.platform == "win32":
_basedir = os.path.dirname(os.path.abspath(__file__))
_bitness = struct.calcsize('P') * 8
_target = 'x64' if _bitness > 32 else 'x86'
_filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll')
_bitness = struct.calcsize("P") * 8
_target = "x64" if _bitness > 32 else "x86"
_filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll")
_lib = libopus_loader(_filename)
else:
_lib = libopus_loader(ctypes.util.find_library('opus'))
_lib = libopus_loader(ctypes.util.find_library("opus"))
except Exception:
_lib = None
return _lib is not None
def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice.
@@ -257,6 +265,7 @@ def load_opus(name: str) -> None:
global _lib
_lib = libopus_loader(name)
def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@@ -271,6 +280,7 @@ def is_loaded() -> bool:
global _lib
return _lib is not None
class OpusError(DiscordException):
"""An exception that is thrown for libopus related errors.
@@ -282,19 +292,22 @@ class OpusError(DiscordException):
def __init__(self, code: int):
self.code: int = code
msg = _lib.opus_strerror(self.code).decode('utf-8')
msg = _lib.opus_strerror(self.code).decode("utf-8")
_log.info('"%s" has happened', msg)
super().__init__(msg)
class OpusNotLoaded(DiscordException):
"""An exception that is thrown for when libopus is not loaded."""
pass
class _OpusStruct:
SAMPLING_RATE = 48000
CHANNELS = 2
FRAME_LENGTH = 20 # in milliseconds
SAMPLE_SIZE = struct.calcsize('h') * CHANNELS
SAMPLE_SIZE = struct.calcsize("h") * CHANNELS
SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH)
FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE
@@ -304,7 +317,8 @@ class _OpusStruct:
if not is_loaded() and not _load_default():
raise OpusNotLoaded()
return _lib.opus_get_version_string().decode('utf-8')
return _lib.opus_get_version_string().decode("utf-8")
class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO):
@@ -315,11 +329,11 @@ class Encoder(_OpusStruct):
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full')
self.set_signal_type('auto')
self.set_bandwidth("full")
self.set_signal_type("auto")
def __del__(self) -> None:
if hasattr(self, '_state'):
if hasattr(self, "_state"):
_lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
@@ -363,7 +377,8 @@ class Encoder(_OpusStruct):
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
return array.array("b", data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct):
def __init__(self):
@@ -372,7 +387,7 @@ class Decoder(_OpusStruct):
self._state: DecoderStruct = self._create_state()
def __del__(self) -> None:
if hasattr(self, '_state'):
if hasattr(self, "_state"):
_lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
@@ -451,4 +466,4 @@ class Decoder(_OpusStruct):
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
return array.array('h', pcm[:ret * channel_count]).tobytes()
return array.array("h", pcm[: ret * channel_count]).tobytes()

View File

@@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin
from .errors import InvalidArgument
from . import utils
__all__ = (
'PartialEmoji',
)
__all__ = ("PartialEmoji",)
if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag:
__slots__ = ()
@@ -49,7 +48,7 @@ class _EmojiTag:
raise NotImplementedError
PE = TypeVar('PE', bound='PartialEmoji')
PE = TypeVar("PE", bound="PartialEmoji")
class PartialEmoji(_EmojiTag, AssetMixin):
@@ -90,9 +89,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
The ID of the custom emoji, if applicable.
"""
__slots__ = ('animated', 'name', 'id', '_state')
__slots__ = ("animated", "name", "id", "_state")
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
_CUSTOM_EMOJI_RE = re.compile(r"<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?")
if TYPE_CHECKING:
id: Optional[int]
@@ -106,9 +105,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
@classmethod
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE:
return cls(
animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'),
name=data.get('name') or '',
animated=data.get("animated", False),
id=utils._get_as_snowflake(data, "id"),
name=data.get("name") or "",
)
@classmethod
@@ -139,19 +138,19 @@ class PartialEmoji(_EmojiTag, AssetMixin):
match = cls._CUSTOM_EMOJI_RE.match(value)
if match is not None:
groups = match.groupdict()
animated = bool(groups['animated'])
emoji_id = int(groups['id'])
name = groups['name']
animated = bool(groups["animated"])
emoji_id = int(groups["id"])
name = groups["name"]
return cls(name=name, animated=animated, id=emoji_id)
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name}
o: Dict[str, Any] = {"name": self.name}
if self.id:
o['id'] = self.id
o["id"] = self.id
if self.animated:
o['animated'] = self.animated
o["animated"] = self.animated
return o
def _to_partial(self) -> PartialEmoji:
@@ -169,11 +168,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
if self.id is None:
return self.name
if self.animated:
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
def __repr__(self):
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>"
def __eq__(self, other: Any) -> bool:
if self.is_unicode_emoji():
@@ -200,7 +199,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
def _as_reaction(self) -> str:
if self.id is None:
return self.name
return f'{self.name}:{self.id}'
return f"{self.name}:{self.id}"
@property
def created_at(self) -> Optional[datetime]:
@@ -220,13 +219,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
If this isn't a custom emoji then an empty string is returned
"""
if self.is_unicode_emoji():
return ''
return ""
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
async def read(self) -> bytes:
if self.is_unicode_emoji():
raise InvalidArgument('PartialEmoji is not a custom emoji')
raise InvalidArgument("PartialEmoji is not a custom emoji")
return await super().read()

View File

@@ -28,8 +28,8 @@ from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING,
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
__all__ = (
'Permissions',
'PermissionOverwrite',
"Permissions",
"PermissionOverwrite",
)
# A permission alias works like a regular flag but is marked
@@ -46,7 +46,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis
return decorator
P = TypeVar('P', bound='Permissions')
P = TypeVar("P", bound="Permissions")
@fill_with_flags()
class Permissions(BaseFlags):
@@ -101,12 +103,12 @@ class Permissions(BaseFlags):
def __init__(self, permissions: int = 0, **kwargs: bool):
if not isinstance(permissions, int):
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.')
raise TypeError(f"Expected int parameter, received {permissions.__class__.__name__} instead.")
self.value = permissions
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid permission name.')
raise TypeError(f"{key!r} is not a valid permission name.")
setattr(self, key, value)
def is_subset(self, other: Permissions) -> bool:
@@ -299,6 +301,13 @@ class Permissions(BaseFlags):
"""
return 1 << 3
@make_permission_alias("administrator")
def admin(self) -> int:
""":class:`bool`: An alias for :attr:`administrator`.
.. versionadded:: 2.0
"""
return 1 << 3
@flag_value
def manage_channels(self) -> int:
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.
@@ -336,7 +345,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
return 1 << 10
@make_permission_alias('read_messages')
@make_permission_alias("read_messages")
def view_channel(self) -> int:
""":class:`bool`: An alias for :attr:`read_messages`.
@@ -389,7 +398,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
return 1 << 18
@make_permission_alias('external_emojis')
@make_permission_alias("external_emojis")
def use_external_emojis(self) -> int:
""":class:`bool`: An alias for :attr:`external_emojis`.
@@ -453,7 +462,7 @@ class Permissions(BaseFlags):
"""
return 1 << 28
@make_permission_alias('manage_roles')
@make_permission_alias("manage_roles")
def manage_permissions(self) -> int:
""":class:`bool`: An alias for :attr:`manage_roles`.
@@ -471,7 +480,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30
@make_permission_alias('manage_emojis')
@make_permission_alias("manage_emojis")
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
@@ -535,7 +544,7 @@ class Permissions(BaseFlags):
"""
return 1 << 37
@make_permission_alias('external_stickers')
@make_permission_alias("external_stickers")
def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`.
@@ -551,7 +560,9 @@ class Permissions(BaseFlags):
"""
return 1 << 38
PO = TypeVar('PO', bound='PermissionOverwrite')
PO = TypeVar("PO", bound="PermissionOverwrite")
def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
@@ -614,7 +625,7 @@ class PermissionOverwrite:
Set the value of permissions by their name.
"""
__slots__ = ('_values',)
__slots__ = ("_values",)
if TYPE_CHECKING:
VALID_NAMES: ClassVar[Set[str]]
@@ -670,7 +681,7 @@ class PermissionOverwrite:
for key, value in kwargs.items():
if key not in self.VALID_NAMES:
raise ValueError(f'no permission called {key}.')
raise ValueError(f"no permission called {key}.")
setattr(self, key, value)
@@ -679,7 +690,7 @@ class PermissionOverwrite:
def _set(self, key: str, value: Optional[bool]) -> None:
if value not in (True, None, False):
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}')
raise TypeError(f"Expected bool or NoneType, received {value.__class__.__name__}")
if value is None:
self._values.pop(key, None)

View File

@@ -47,27 +47,28 @@ if TYPE_CHECKING:
from .voice_client import VoiceClient
AT = TypeVar('AT', bound='AudioSource')
FT = TypeVar('FT', bound='FFmpegOpusAudio')
AT = TypeVar("AT", bound="AudioSource")
FT = TypeVar("FT", bound="FFmpegOpusAudio")
_log = logging.getLogger(__name__)
__all__ = (
'AudioSource',
'PCMAudio',
'FFmpegAudio',
'FFmpegPCMAudio',
'FFmpegOpusAudio',
'PCMVolumeTransformer',
"AudioSource",
"PCMAudio",
"FFmpegAudio",
"FFmpegPCMAudio",
"FFmpegOpusAudio",
"PCMVolumeTransformer",
)
CREATE_NO_WINDOW: int
if sys.platform != 'win32':
if sys.platform != "win32":
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource:
"""Represents an audio stream.
@@ -114,6 +115,7 @@ class AudioSource:
def __del__(self) -> None:
self.cleanup()
class PCMAudio(AudioSource):
"""Represents raw 16-bit 48KHz stereo PCM audio source.
@@ -122,15 +124,17 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM.
"""
def __init__(self, stream: io.BufferedIOBase) -> None:
self.stream: io.BufferedIOBase = stream
def read(self) -> bytes:
ret = self.stream.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b''
return b""
return ret
class FFmpegAudio(AudioSource):
"""Represents an FFmpeg (or AVConv) based AudioSource.
@@ -140,13 +144,15 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3
"""
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE
def __init__(
self, source: Union[str, io.BufferedIOBase], *, executable: str = "ffmpeg", args: Any, **subprocess_kwargs: Any
):
piping = subprocess_kwargs.get("stdin") == subprocess.PIPE
if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE}
kwargs = {"stdout": subprocess.PIPE}
kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
@@ -155,7 +161,7 @@ class FFmpegAudio(AudioSource):
self._pipe_thread: Optional[threading.Thread] = None
if piping:
n = f'popen-stdin-writer:{id(self):#x}'
n = f"popen-stdin-writer:{id(self):#x}"
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start()
@@ -165,10 +171,10 @@ class FFmpegAudio(AudioSource):
try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
except FileNotFoundError:
executable = args.partition(' ')[0] if isinstance(args, str) else args[0]
raise ClientException(executable + ' was not found.') from None
executable = args.partition(" ")[0] if isinstance(args, str) else args[0]
raise ClientException(executable + " was not found.") from None
except subprocess.SubprocessError as exc:
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc
raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc
else:
return process
@@ -177,20 +183,19 @@ class FFmpegAudio(AudioSource):
if proc is MISSING:
return
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
_log.info("Preparing to terminate ffmpeg process %s.", proc.pid)
try:
proc.kill()
except Exception:
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid)
_log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
if proc.poll() is None:
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
_log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid)
proc.communicate()
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
_log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode)
else:
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
_log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode)
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
@@ -202,7 +207,7 @@ class FFmpegAudio(AudioSource):
try:
self._stdin.write(data)
except Exception:
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
_log.debug("Write error for %s, this is probably not a problem", self, exc_info=True)
# at this point the source data is either exhausted or the process is fubar
self._process.terminate()
return
@@ -211,6 +216,7 @@ class FFmpegAudio(AudioSource):
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@@ -250,38 +256,39 @@ class FFmpegPCMAudio(FFmpegAudio):
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = 'ffmpeg',
executable: str = "ffmpeg",
pipe: bool = False,
stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None
options: Optional[str] = None,
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append('-i')
args.append('-' if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
args.append("-i")
args.append("-" if pipe else source)
args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"))
if isinstance(options, str):
args.extend(shlex.split(options))
args.append('pipe:1')
args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
def read(self) -> bytes:
ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b''
return b""
return ret
def is_opus(self) -> bool:
return False
class FFmpegOpusAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@@ -349,7 +356,7 @@ class FFmpegOpusAudio(FFmpegAudio):
*,
bitrate: int = 128,
codec: Optional[str] = None,
executable: str = 'ffmpeg',
executable: str = "ffmpeg",
pipe=False,
stderr=None,
before_options=None,
@@ -357,28 +364,39 @@ class FFmpegOpusAudio(FFmpegAudio):
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append('-i')
args.append('-' if pipe else source)
args.append("-i")
args.append("-" if pipe else source)
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
codec = "copy" if codec in ("opus", "libopus") else "libopus"
args.extend(('-map_metadata', '-1',
'-f', 'opus',
'-c:a', codec,
'-ar', '48000',
'-ac', '2',
'-b:a', f'{bitrate}k',
'-loglevel', 'warning'))
args.extend(
(
"-map_metadata",
"-1",
"-f",
"opus",
"-c:a",
codec,
"-ar",
"48000",
"-ac",
"2",
"-b:a",
f"{bitrate}k",
"-loglevel",
"warning",
)
)
if isinstance(options, str):
args.extend(shlex.split(options))
args.append('pipe:1')
args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
self._packet_iter = OggStream(self._stdout).iter_packets()
@@ -446,7 +464,7 @@ class FFmpegOpusAudio(FFmpegAudio):
An instance of this class.
"""
executable = kwargs.get('executable')
executable = kwargs.get("executable")
codec, bitrate = await cls.probe(source, method=method, executable=executable)
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore
@@ -484,12 +502,12 @@ class FFmpegOpusAudio(FFmpegAudio):
A 2-tuple with the codec and bitrate of the input source.
"""
method = method or 'native'
executable = executable or 'ffmpeg'
method = method or "native"
executable = executable or "ffmpeg"
probefunc = fallback = None
if isinstance(method, str):
probefunc = getattr(cls, '_probe_codec_' + method, None)
probefunc = getattr(cls, "_probe_codec_" + method, None)
if probefunc is None:
raise AttributeError(f"Invalid probe method {method!r}")
@@ -500,8 +518,7 @@ class FFmpegOpusAudio(FFmpegAudio):
probefunc = method
fallback = cls._probe_codec_fallback
else:
raise TypeError("Expected str or callable for parameter 'probe', " \
f"not '{method.__class__.__name__}'")
raise TypeError("Expected str or callable for parameter 'probe', " f"not '{method.__class__.__name__}'")
codec = bitrate = None
loop = asyncio.get_event_loop()
@@ -525,28 +542,28 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate
@staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
def _probe_codec_native(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + "probe" if executable in ("ffmpeg", "avconv") else executable
args = [exe, "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "a:0", source]
output = subprocess.check_output(args, timeout=20)
codec = bitrate = None
if output:
data = json.loads(output)
streamdata = data['streams'][0]
streamdata = data["streams"][0]
codec = streamdata.get('codec_name')
bitrate = int(streamdata.get('bit_rate', 0))
codec = streamdata.get("codec_name")
bitrate = int(streamdata.get("bit_rate", 0))
bitrate = max(round(bitrate / 1000), 512)
return codec, bitrate
@staticmethod
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
args = [executable, '-hide_banner', '-i', source]
def _probe_codec_fallback(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
args = [executable, "-hide_banner", "-i", source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate(timeout=20)
output = out.decode('utf8')
output = out.decode("utf8")
codec = bitrate = None
codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output)
@@ -560,11 +577,12 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate
def read(self) -> bytes:
return next(self._packet_iter, b'')
return next(self._packet_iter, b"")
def is_opus(self) -> bool:
return True
class PCMVolumeTransformer(AudioSource, Generic[AT]):
"""Transforms a previous :class:`AudioSource` to have volume controls.
@@ -589,10 +607,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
def __init__(self, original: AT, volume: float = 1.0):
if not isinstance(original, AudioSource):
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.')
raise TypeError(f"expected AudioSource not {original.__class__.__name__}.")
if original.is_opus():
raise ClientException('AudioSource must not be Opus encoded.')
raise ClientException("AudioSource must not be Opus encoded.")
self.original: AT = original
self.volume = volume
@@ -613,6 +631,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
@@ -685,11 +704,11 @@ class AudioPlayer(threading.Thread):
try:
self.after(error)
except Exception as exc:
_log.exception('Calling the after function failed.')
_log.exception("Calling the after function failed.")
exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error:
msg = f'Exception in voice thread {self.name}'
msg = f"Exception in voice thread {self.name}"
_log.exception(msg, exc_info=error)
print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__)

View File

@@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING:
@@ -34,7 +35,8 @@ if TYPE_CHECKING:
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent
IntegrationDeleteEvent,
TypingEvent,
)
from .message import Message
from .partial_emoji import PartialEmoji
@@ -42,20 +44,21 @@ if TYPE_CHECKING:
__all__ = (
'RawMessageDeleteEvent',
'RawBulkMessageDeleteEvent',
'RawMessageUpdateEvent',
'RawReactionActionEvent',
'RawReactionClearEvent',
'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent',
"RawMessageDeleteEvent",
"RawBulkMessageDeleteEvent",
"RawMessageUpdateEvent",
"RawReactionActionEvent",
"RawReactionClearEvent",
"RawReactionClearEmojiEvent",
"RawIntegrationDeleteEvent",
"RawTypingEvent",
)
class _RawReprMixin:
def __repr__(self) -> str:
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>'
value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__)
return f"<{self.__class__.__name__} {value}>"
class RawMessageDeleteEvent(_RawReprMixin):
@@ -73,14 +76,14 @@ class RawMessageDeleteEvent(_RawReprMixin):
The cached message, if found in the internal message cache.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
__slots__ = ("message_id", "channel_id", "guild_id", "cached_message")
def __init__(self, data: MessageDeleteEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -100,15 +103,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
The cached messages, if found in the internal message cache.
"""
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
__slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages")
def __init__(self, data: BulkMessageDeleteEvent) -> None:
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])}
self.channel_id: int = int(data['channel_id'])
self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])}
self.channel_id: int = int(data["channel_id"])
self.cached_messages: List[Message] = []
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -136,16 +139,16 @@ class RawMessageUpdateEvent(_RawReprMixin):
it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
__slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message")
def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.data: MessageUpdateEvent = data
self.cached_message: Optional[Message] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -179,19 +182,18 @@ class RawReactionActionEvent(_RawReprMixin):
.. versionadded:: 1.3
"""
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
'event_type', 'member')
__slots__ = ("message_id", "user_id", "channel_id", "guild_id", "emoji", "event_type", "member")
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.user_id: int = int(data['user_id'])
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.emoji: PartialEmoji = emoji
self.event_type: str = event_type
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -209,14 +211,14 @@ class RawReactionClearEvent(_RawReprMixin):
The guild ID where the reactions got cleared.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id')
__slots__ = ("message_id", "channel_id", "guild_id")
def __init__(self, data: ReactionClearEvent) -> None:
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -238,15 +240,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
The custom or unicode emoji being removed.
"""
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
__slots__ = ("message_id", "channel_id", "guild_id", "emoji")
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
self.emoji: PartialEmoji = emoji
self.message_id: int = int(data['message_id'])
self.channel_id: int = int(data['channel_id'])
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
try:
self.guild_id: Optional[int] = int(data['guild_id'])
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
@@ -266,13 +268,46 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
The guild ID where the integration got deleted.
"""
__slots__ = ('integration_id', 'application_id', 'guild_id')
__slots__ = ("integration_id", "application_id", "guild_id")
def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data['id'])
self.guild_id: int = int(data['guild_id'])
self.integration_id: int = int(data["id"])
self.guild_id: int = int(data["guild_id"])
try:
self.application_id: Optional[int] = int(data['application_id'])
self.application_id: Optional[int] = int(data["application_id"])
except KeyError:
self.application_id: Optional[int] = None
class RawTypingEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_typing` event.
.. versionadded:: 2.0
Attributes
-----------
channel_id: :class:`int`
The channel ID where the typing originated from.
user_id: :class:`int`
The ID of the user that started typing.
when: :class:`datetime.datetime`
When the typing started as an aware datetime in UTC.
guild_id: Optional[:class:`int`]
The guild ID where the typing originated from, if applicable.
member: Optional[:class:`Member`]
The member who started typing. Only available if the member started typing in a guild.
"""
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
def __init__(self, data: TypingEvent) -> None:
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get("timestamp"), tz=datetime.timezone.utc)
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None

View File

@@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional
from .iterators import ReactionIterator
__all__ = (
'Reaction',
)
__all__ = ("Reaction",)
if TYPE_CHECKING:
from .types.message import Reaction as ReactionPayload
@@ -38,6 +36,7 @@ if TYPE_CHECKING:
from .emoji import Emoji
from .abc import Snowflake
class Reaction:
"""Represents a reaction to a message.
@@ -75,13 +74,16 @@ class Reaction:
message: :class:`Message`
Message this reaction is for.
"""
__slots__ = ('message', 'count', 'emoji', 'me')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None):
__slots__ = ("message", "count", "emoji", "me")
def __init__(
self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None
):
self.message: Message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji'])
self.count: int = data.get('count', 1)
self.me: bool = data.get('me')
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data["emoji"])
self.count: int = data.get("count", 1)
self.me: bool = data.get("me")
# TODO: typeguard
def is_custom_emoji(self) -> bool:
@@ -103,7 +105,7 @@ class Reaction:
return str(self.emoji)
def __repr__(self) -> str:
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>'
return f"<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>"
async def remove(self, user: Snowflake) -> None:
"""|coro|
@@ -201,7 +203,7 @@ class Reaction:
"""
if not isinstance(self.emoji, str):
emoji = f'{self.emoji.name}:{self.emoji.id}'
emoji = f"{self.emoji.name}:{self.emoji.id}"
else:
emoji = self.emoji

View File

@@ -32,8 +32,8 @@ from .mixins import Hashable
from .utils import snowflake_time, _get_as_snowflake, MISSING
__all__ = (
'RoleTags',
'Role',
"RoleTags",
"Role",
)
if TYPE_CHECKING:
@@ -68,19 +68,19 @@ class RoleTags:
"""
__slots__ = (
'bot_id',
'integration_id',
'_premium_subscriber',
"bot_id",
"integration_id",
"_premium_subscriber",
)
def __init__(self, data: RoleTagPayload):
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id')
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id')
self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id")
self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id")
# NOTE: The API returns "null" for this if it's valid, which corresponds to None.
# This is different from other fields where "null" means "not there".
# So in this case, a value of None is the same as True.
# Which means we would need a different sentinel.
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING)
self._premium_subscriber: Optional[Any] = data.get("premium_subscriber", MISSING)
def is_bot_managed(self) -> bool:
""":class:`bool`: Whether the role is associated with a bot."""
@@ -96,12 +96,12 @@ class RoleTags:
def __repr__(self) -> str:
return (
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} '
f'premium_subscriber={self.is_premium_subscriber()}>'
f"<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} "
f"premium_subscriber={self.is_premium_subscriber()}>"
)
R = TypeVar('R', bound='Role')
R = TypeVar("R", bound="Role")
class Role(Hashable):
@@ -181,23 +181,23 @@ class Role(Hashable):
"""
__slots__ = (
'id',
'name',
'_permissions',
'_colour',
'position',
'managed',
'mentionable',
'hoist',
'guild',
'tags',
'_state',
"id",
"name",
"_permissions",
"_colour",
"position",
"managed",
"mentionable",
"hoist",
"guild",
"tags",
"_state",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload):
self.guild: Guild = guild
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.id: int = int(data["id"])
self._update(data)
def __str__(self) -> str:
@@ -207,14 +207,14 @@ class Role(Hashable):
return self.id
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'
return f"<Role id={self.id} name={self.name!r}>"
def __lt__(self: R, other: R) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented
if self.guild != other.guild:
raise RuntimeError('cannot compare roles from two different guilds.')
raise RuntimeError("cannot compare roles from two different guilds.")
# the @everyone role is always the lowest role in hierarchy
guild_id = self.guild.id
@@ -246,17 +246,17 @@ class Role(Hashable):
return not r
def _update(self, data: RolePayload):
self.name: str = data['name']
self._permissions: int = int(data.get('permissions', 0))
self.position: int = data.get('position', 0)
self._colour: int = data.get('color', 0)
self.hoist: bool = data.get('hoist', False)
self.managed: bool = data.get('managed', False)
self.mentionable: bool = data.get('mentionable', False)
self.name: str = data["name"]
self._permissions: int = int(data.get("permissions", 0))
self.position: int = data.get("position", 0)
self._colour: int = data.get("color", 0)
self.hoist: bool = data.get("hoist", False)
self.managed: bool = data.get("managed", False)
self.mentionable: bool = data.get("mentionable", False)
self.tags: Optional[RoleTags]
try:
self.tags = RoleTags(data['tags'])
self.tags = RoleTags(data["tags"])
except KeyError:
self.tags = None
@@ -316,7 +316,7 @@ class Role(Hashable):
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention a role."""
return f'<@&{self.id}>'
return f"<@&{self.id}>"
@property
def members(self) -> List[Member]:
@@ -420,21 +420,21 @@ class Role(Hashable):
if colour is not MISSING:
if isinstance(colour, int):
payload['color'] = colour
payload["color"] = colour
else:
payload['color'] = colour.value
payload["color"] = colour.value
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if permissions is not MISSING:
payload['permissions'] = permissions.value
payload["permissions"] = permissions.value
if hoist is not MISSING:
payload['hoist'] = hoist
payload["hoist"] = hoist
if mentionable is not MISSING:
payload['mentionable'] = mentionable
payload["mentionable"] = mentionable
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
return Role(guild=self.guild, data=data, state=self._state)

View File

@@ -50,11 +50,11 @@ if TYPE_CHECKING:
from .activity import BaseActivity
from .enums import Status
EI = TypeVar('EI', bound='EventItem')
EI = TypeVar("EI", bound="EventItem")
__all__ = (
'AutoShardedClient',
'ShardInfo',
"AutoShardedClient",
"ShardInfo",
)
_log = logging.getLogger(__name__)
@@ -70,11 +70,11 @@ class EventType:
class EventItem:
__slots__ = ('type', 'shard', 'error')
__slots__ = ("type", "shard", "error")
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None:
def __init__(self, etype: int, shard: Optional["Shard"], error: Optional[Exception]) -> None:
self.type: int = etype
self.shard: Optional['Shard'] = shard
self.shard: Optional["Shard"] = shard
self.error: Optional[Exception] = error
def __lt__(self: EI, other: EI) -> bool:
@@ -129,11 +129,11 @@ class Shard:
async def disconnect(self) -> None:
await self.close()
self._dispatch('shard_disconnect', self.id)
self._dispatch("shard_disconnect", self.id)
async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
if not self._reconnect:
self._queue_put(EventItem(EventType.close, self, e))
return
@@ -156,7 +156,7 @@ class Shard:
return
retry = self._backoff.delay()
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
_log.error("Attempting a reconnect for shard ID %s in %.2fs", self.id, retry, exc_info=e)
await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e))
@@ -179,9 +179,9 @@ class Shard:
async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task()
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
_log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
try:
coro = DiscordWebSocket.from_client(
self._client,
@@ -231,7 +231,7 @@ class ShardInfo:
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
"""
__slots__ = ('_parent', 'id', 'shard_count')
__slots__ = ("_parent", "id", "shard_count")
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent: Shard = parent
@@ -321,15 +321,15 @@ class AutoShardedClient(Client):
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None)
kwargs.pop("shard_id", None)
self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
if self.shard_count is None:
raise ClientException('When passing manual shard_ids, you must provide a shard_count.')
raise ClientException("When passing manual shard_ids, you must provide a shard_count.")
elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.')
raise ClientException("shard_ids parameter must be a list or a tuple.")
# instead of a single websocket, we have multiple
# the key is the shard_id
@@ -363,7 +363,7 @@ class AutoShardedClient(Client):
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
"""
if not self.__shards:
return float('nan')
return float("nan")
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
@@ -393,7 +393,7 @@ class AutoShardedClient(Client):
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
_log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
@@ -503,10 +503,10 @@ class AutoShardedClient(Client):
"""
if status is None:
status_value = 'online'
status_value = "online"
status_enum = Status.online
elif status is Status.offline:
status_value = 'invisible'
status_value = "invisible"
status_enum = Status.offline
else:
status_enum = status

View File

@@ -31,9 +31,7 @@ from .mixins import Hashable
from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum
__all__ = (
'StageInstance',
)
__all__ = ("StageInstance",)
if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload
@@ -82,14 +80,14 @@ class StageInstance(Hashable):
"""
__slots__ = (
'_state',
'id',
'guild',
'channel_id',
'topic',
'privacy_level',
'discoverable_disabled',
'_cs_channel',
"_state",
"id",
"guild",
"channel_id",
"topic",
"privacy_level",
"discoverable_disabled",
"_cs_channel",
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
@@ -98,16 +96,16 @@ class StageInstance(Hashable):
self._update(data)
def _update(self, data: StageInstancePayload):
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
self.id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.topic: str = data["topic"]
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data["privacy_level"])
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
return f"<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>"
@cached_slot_property('_cs_channel')
@cached_slot_property("_cs_channel")
def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None
@@ -116,7 +114,9 @@ class StageInstance(Hashable):
def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
async def edit(
self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
) -> None:
"""|coro|
Edits the stage instance.
@@ -146,13 +146,13 @@ class StageInstance(Hashable):
payload = {}
if topic is not MISSING:
payload['topic'] = topic
payload["topic"] = topic
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
raise InvalidArgument("privacy_level field must be of type PrivacyLevel")
payload['privacy_level'] = privacy_level.value
payload["privacy_level"] = privacy_level.value
if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)

File diff suppressed because it is too large Load Diff

View File

@@ -33,11 +33,11 @@ from .errors import InvalidData
from .enums import StickerType, StickerFormatType, try_enum
__all__ = (
'StickerPack',
'StickerItem',
'Sticker',
'StandardSticker',
'GuildSticker',
"StickerPack",
"StickerItem",
"Sticker",
"StandardSticker",
"GuildSticker",
)
if TYPE_CHECKING:
@@ -102,15 +102,15 @@ class StickerPack(Hashable):
"""
__slots__ = (
'_state',
'id',
'stickers',
'name',
'sku_id',
'cover_sticker_id',
'cover_sticker',
'description',
'_banner',
"_state",
"id",
"stickers",
"name",
"sku_id",
"cover_sticker_id",
"cover_sticker",
"description",
"_banner",
)
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
@@ -118,15 +118,17 @@ class StickerPack(Hashable):
self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data['id'])
stickers = data['stickers']
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
self.name: str = data['name']
self.sku_id: int = int(data['sku_id'])
self.cover_sticker_id: int = int(data['cover_sticker_id'])
self.id: int = int(data["id"])
stickers = data["stickers"]
self.stickers: List[StandardSticker] = [
StandardSticker(state=self._state, data=sticker) for sticker in stickers
]
self.name: str = data["name"]
self.sku_id: int = int(data["sku_id"])
self.cover_sticker_id: int = int(data["cover_sticker_id"])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data['description']
self._banner: int = int(data['banner_asset_id'])
self.description: str = data["description"]
self._banner: int = int(data["banner_asset_id"])
@property
def banner(self) -> Asset:
@@ -134,7 +136,7 @@ class StickerPack(Hashable):
return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str:
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
return f"<StickerPack id={self.id} name={self.name!r} description={self.description!r}>"
def __str__(self) -> str:
return self.name
@@ -205,17 +207,17 @@ class StickerItem(_StickerTag):
The URL for the sticker's image.
"""
__slots__ = ('_state', 'name', 'id', 'format', 'url')
__slots__ = ("_state", "name", "id", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state
self.name: str = data['name']
self.id: int = int(data['id'])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
self.name: str = data["name"]
self.id: int = int(data["id"])
self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
return f"<StickerItem id={self.id} name={self.name!r} format={self.format}>"
def __str__(self) -> str:
return self.name
@@ -236,7 +238,7 @@ class StickerItem(_StickerTag):
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore
cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._state, data=data)
@@ -275,21 +277,21 @@ class Sticker(_StickerTag):
The URL for the sticker's image.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
__slots__ = ("_state", "id", "name", "description", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str:
return f'<Sticker id={self.id} name={self.name!r}>'
return f"<Sticker id={self.id} name={self.name!r}>"
def __str__(self) -> str:
return self.name
@@ -337,21 +339,21 @@ class StandardSticker(Sticker):
The sticker's sort order within its pack.
"""
__slots__ = ('sort_value', 'pack_id', 'type', 'tags')
__slots__ = ("sort_value", "pack_id", "type", "tags")
def _from_data(self, data: StandardStickerPayload) -> None:
super()._from_data(data)
self.sort_value: int = data['sort_value']
self.pack_id: int = int(data['pack_id'])
self.sort_value: int = data["sort_value"]
self.pack_id: int = int(data["pack_id"])
self.type: StickerType = StickerType.standard
try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")]
except KeyError:
self.tags = []
def __repr__(self) -> str:
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
return f"<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>"
async def pack(self) -> StickerPack:
"""|coro|
@@ -371,12 +373,12 @@ class StandardSticker(Sticker):
The retrieved sticker pack.
"""
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data['sticker_packs']
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
packs = data["sticker_packs"]
pack = find(lambda d: int(d["id"]) == self.pack_id, packs)
if pack:
return StickerPack(state=self._state, data=pack)
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
raise InvalidData(f"Could not find corresponding sticker pack for {self!r}")
class GuildSticker(Sticker):
@@ -419,21 +421,21 @@ class GuildSticker(Sticker):
The name of a unicode emoji that represents this sticker.
"""
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
__slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild")
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data['available']
self.guild_id: int = int(data['guild_id'])
user = data.get('user')
self.available: bool = data["available"]
self.guild_id: int = int(data["guild_id"])
user = data.get("user")
self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data['tags']
self.emoji: str = data["tags"]
self.type: StickerType = StickerType.guild
def __repr__(self) -> str:
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
return f"<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>"
@cached_slot_property('_cs_guild')
@cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild.
@@ -480,10 +482,10 @@ class GuildSticker(Sticker):
payload: EditGuildSticker = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if description is not MISSING:
payload['description'] = description
payload["description"] = description
if emoji is not MISSING:
try:
@@ -491,9 +493,9 @@ class GuildSticker(Sticker):
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
emoji = emoji.replace(" ", "_")
payload['tags'] = emoji
payload["tags"] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
return GuildSticker(state=self._state, data=data)
@@ -521,7 +523,9 @@ class GuildSticker(Sticker):
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
def _sticker_factory(
sticker_type: Literal[1, 2]
) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type)
if value == StickerType.standard:
return StandardSticker, value

View File

@@ -40,8 +40,8 @@ if TYPE_CHECKING:
)
__all__ = (
'Team',
'TeamMember',
"Team",
"TeamMember",
)
@@ -62,26 +62,26 @@ class Team:
.. versionadded:: 1.3
"""
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members')
__slots__ = ("_state", "id", "name", "_icon", "owner_id", "members")
def __init__(self, state: ConnectionState, data: TeamPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data['icon']
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id')
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']]
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data["icon"]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id")
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]]
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='team')
return Asset._from_icon(self._state, self.id, self._icon, path="team")
@property
def owner(self) -> Optional[TeamMember]:
@@ -130,16 +130,16 @@ class TeamMember(BaseUser):
The membership state of the member (e.g. invited or accepted)
"""
__slots__ = ('team', 'membership_state', 'permissions')
__slots__ = ("team", "membership_state", "permissions")
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload):
self.team: Team = team
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
self.permissions: List[str] = data['permissions']
super().__init__(state=state, data=data['user'])
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data["membership_state"])
self.permissions: List[str] = data["permissions"]
super().__init__(state=state, data=data["user"])
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>"
)

View File

@@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion
from .guild import Guild
__all__ = (
'Template',
)
__all__ = ("Template",)
if TYPE_CHECKING:
import datetime
@@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
raise AttributeError('PartialTemplateState does not support http methods.')
raise AttributeError("PartialTemplateState does not support http methods.")
class _PartialTemplateState:
@@ -84,7 +82,7 @@ class _PartialTemplateState:
return []
def __getattr__(self, attr):
raise AttributeError(f'PartialTemplateState does not support {attr!r}.')
raise AttributeError(f"PartialTemplateState does not support {attr!r}.")
class Template:
@@ -118,16 +116,16 @@ class Template:
"""
__slots__ = (
'code',
'uses',
'name',
'description',
'creator',
'created_at',
'updated_at',
'source_guild',
'is_dirty',
'_state',
"code",
"uses",
"name",
"description",
"creator",
"created_at",
"updated_at",
"source_guild",
"is_dirty",
"_state",
)
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
@@ -135,35 +133,35 @@ class Template:
self._store(data)
def _store(self, data: TemplatePayload) -> None:
self.code: str = data['code']
self.uses: int = data['usage_count']
self.name: str = data['name']
self.description: Optional[str] = data['description']
creator_data = data.get('creator')
self.code: str = data["code"]
self.uses: int = data["usage_count"]
self.name: str = data["name"]
self.description: Optional[str] = data["description"]
creator_data = data.get("creator")
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at'))
self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get("updated_at"))
guild_id = int(data['source_guild_id'])
guild_id = int(data["source_guild_id"])
guild: Optional[Guild] = self._state._get_guild(guild_id)
self.source_guild: Guild
if guild is None:
source_serialised = data['serialized_source_guild']
source_serialised['id'] = guild_id
source_serialised = data["serialized_source_guild"]
source_serialised["id"] = guild_id
state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
self.is_dirty: Optional[bool] = data.get('is_dirty', None)
self.is_dirty: Optional[bool] = data.get("is_dirty", None)
def __repr__(self) -> str:
return (
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}'
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
f"<Template code={self.code!r} uses={self.uses} name={self.name!r}"
f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>"
)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
@@ -279,9 +277,9 @@ class Template:
payload = {}
if name is not MISSING:
payload['name'] = name
payload["name"] = name
if description is not MISSING:
payload['description'] = description
payload["description"] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
return Template(state=self._state, data=data)
@@ -313,4 +311,4 @@ class Template:
.. versionadded:: 2.0
"""
return f'https://discord.new/{self.code}'
return f"https://discord.new/{self.code}"

View File

@@ -35,8 +35,8 @@ from .errors import ClientException
from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = (
'Thread',
'ThreadMember',
"Thread",
"ThreadMember",
)
if TYPE_CHECKING:
@@ -128,25 +128,25 @@ class Thread(Messageable, Hashable):
"""
__slots__ = (
'name',
'id',
'guild',
'_type',
'_state',
'_members',
'owner_id',
'parent_id',
'last_message_id',
'message_count',
'member_count',
'slowmode_delay',
'me',
'locked',
'archived',
'invitable',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
"name",
"id",
"guild",
"_type",
"_state",
"_members",
"owner_id",
"parent_id",
"last_message_id",
"message_count",
"member_count",
"slowmode_delay",
"me",
"locked",
"archived",
"invitable",
"archiver_id",
"auto_archive_duration",
"archive_timestamp",
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
@@ -160,50 +160,50 @@ class Thread(Messageable, Hashable):
def __repr__(self) -> str:
return (
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}'
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>'
f"<Thread id={self.id!r} name={self.name!r} parent={self.parent}"
f" owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>"
)
def __str__(self) -> str:
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self._unroll_metadata(data['thread_metadata'])
self.id = int(data["id"])
self.parent_id = int(data["parent_id"])
self.owner_id = int(data["owner_id"])
self.name = data["name"]
self._type = try_enum(ChannelType, data["type"])
self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self._unroll_metadata(data["thread_metadata"])
try:
member = data['member']
member = data["member"]
except KeyError:
self.me = None
else:
self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
self.archived = data["archived"]
self.archiver_id = _get_as_snowflake(data, "archiver_id")
self.auto_archive_duration = data["auto_archive_duration"]
self.archive_timestamp = parse_time(data["archive_timestamp"])
self.locked = data.get("locked", False)
self.invitable = data.get("invitable", True)
def _update(self, data):
try:
self.name = data['name']
self.name = data["name"]
except KeyError:
pass
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.slowmode_delay = data.get("rate_limit_per_user", 0)
try:
self._unroll_metadata(data['thread_metadata'])
self._unroll_metadata(data["thread_metadata"])
except KeyError:
pass
@@ -225,7 +225,7 @@ class Thread(Messageable, Hashable):
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def members(self) -> List[ThreadMember]:
@@ -275,7 +275,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.category
@property
@@ -295,7 +295,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.category_id
def is_private(self) -> bool:
@@ -352,7 +352,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
raise ClientException("Parent channel not found")
return parent.permissions_for(obj)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@@ -402,7 +402,7 @@ class Thread(Messageable, Hashable):
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
@@ -577,17 +577,17 @@ class Thread(Messageable, Hashable):
"""
payload = {}
if name is not MISSING:
payload['name'] = str(name)
payload["name"] = str(name)
if archived is not MISSING:
payload['archived'] = archived
payload["archived"] = archived
if auto_archive_duration is not MISSING:
payload['auto_archive_duration'] = auto_archive_duration
payload["auto_archive_duration"] = auto_archive_duration
if locked is not MISSING:
payload['locked'] = locked
payload["locked"] = locked
if invitable is not MISSING:
payload['invitable'] = invitable
payload["invitable"] = invitable
if slowmode_delay is not MISSING:
payload['rate_limit_per_user'] = slowmode_delay
payload["rate_limit_per_user"] = slowmode_delay
data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload
@@ -773,12 +773,12 @@ class ThreadMember(Hashable):
"""
__slots__ = (
'id',
'thread_id',
'joined_at',
'flags',
'_state',
'parent',
"id",
"thread_id",
"joined_at",
"flags",
"_state",
"parent",
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
@@ -787,24 +787,60 @@ class ThreadMember(Hashable):
self._from_data(data)
def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
return f"<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>"
def _from_data(self, data: ThreadMemberPayload):
try:
self.id = int(data['user_id'])
self.id = int(data["user_id"])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
try:
self.thread_id = int(data['id'])
self.thread_id = int(data["id"])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
self.joined_at = parse_time(data["join_timestamp"])
self.flags = data["flags"]
@property
def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to."""
return self.parent
async def fetch_member(self) -> Member:
"""|coro|
Retrieves a :class:`Member` from the ThreadMember object.
.. note::
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
Raises
-------
Forbidden
You do not have access to the guild.
HTTPException
Fetching the member failed.
Returns
--------
:class:`Member`
The member.
"""
return await self.thread.guild.fetch_member(self.id)
def get_member(self) -> Optional[Member]:
"""
Get the :class:`Member` from cache for the ThreadMember object.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
return self.thread.guild.get_member(self.id)

View File

@@ -29,7 +29,7 @@ from .user import PartialUser
from .snowflake import Snowflake
StatusType = Literal['idle', 'dnd', 'online', 'offline']
StatusType = Literal["idle", "dnd", "online", "offline"]
class PartialPresenceUpdate(TypedDict):

View File

@@ -30,6 +30,7 @@ from .user import User
from .team import Team
from .snowflake import Snowflake
class BaseAppInfo(TypedDict):
id: Snowflake
name: str
@@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict):
summary: str
description: str
class _AppInfoOptional(TypedDict, total=False):
team: Team
guild_id: Snowflake
@@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False):
hook: bool
max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str]
owner: User
bot_public: bool
bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str]
cover_image: str
@@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
max_participants: int
flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass

View File

@@ -84,31 +84,40 @@ AuditLogEvent = Literal[
class _AuditLogChange_Str(TypedDict):
key: Literal[
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
"name",
"description",
"preferred_locale",
"vanity_url_code",
"topic",
"code",
"allow",
"deny",
"permissions",
"tags",
]
new_value: str
old_value: str
class _AuditLogChange_AssetHash(TypedDict):
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
key: Literal["icon_hash", "splash_hash", "discovery_splash_hash", "banner_hash", "avatar_hash", "asset"]
new_value: str
old_value: str
class _AuditLogChange_Snowflake(TypedDict):
key: Literal[
'id',
'owner_id',
'afk_channel_id',
'rules_channel_id',
'public_updates_channel_id',
'widget_channel_id',
'system_channel_id',
'application_id',
'channel_id',
'inviter_id',
'guild_id',
"id",
"owner_id",
"afk_channel_id",
"rules_channel_id",
"public_updates_channel_id",
"widget_channel_id",
"system_channel_id",
"application_id",
"channel_id",
"inviter_id",
"guild_id",
]
new_value: Snowflake
old_value: Snowflake
@@ -116,20 +125,20 @@ class _AuditLogChange_Snowflake(TypedDict):
class _AuditLogChange_Bool(TypedDict):
key: Literal[
'widget_enabled',
'nsfw',
'hoist',
'mentionable',
'temporary',
'deaf',
'mute',
'nick',
'enabled_emoticons',
'region',
'rtc_region',
'available',
'archived',
'locked',
"widget_enabled",
"nsfw",
"hoist",
"mentionable",
"temporary",
"deaf",
"mute",
"nick",
"enabled_emoticons",
"region",
"rtc_region",
"available",
"archived",
"locked",
]
new_value: bool
old_value: bool
@@ -137,72 +146,72 @@ class _AuditLogChange_Bool(TypedDict):
class _AuditLogChange_Int(TypedDict):
key: Literal[
'afk_timeout',
'prune_delete_days',
'position',
'bitrate',
'rate_limit_per_user',
'color',
'max_uses',
'max_age',
'user_limit',
'auto_archive_duration',
'default_auto_archive_duration',
"afk_timeout",
"prune_delete_days",
"position",
"bitrate",
"rate_limit_per_user",
"color",
"max_uses",
"max_age",
"user_limit",
"auto_archive_duration",
"default_auto_archive_duration",
]
new_value: int
old_value: int
class _AuditLogChange_ListRole(TypedDict):
key: Literal['$add', '$remove']
key: Literal["$add", "$remove"]
new_value: List[Role]
old_value: List[Role]
class _AuditLogChange_MFALevel(TypedDict):
key: Literal['mfa_level']
key: Literal["mfa_level"]
new_value: MFALevel
old_value: MFALevel
class _AuditLogChange_VerificationLevel(TypedDict):
key: Literal['verification_level']
key: Literal["verification_level"]
new_value: VerificationLevel
old_value: VerificationLevel
class _AuditLogChange_ExplicitContentFilter(TypedDict):
key: Literal['explicit_content_filter']
key: Literal["explicit_content_filter"]
new_value: ExplicitContentFilterLevel
old_value: ExplicitContentFilterLevel
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
key: Literal['default_message_notifications']
key: Literal["default_message_notifications"]
new_value: DefaultMessageNotificationLevel
old_value: DefaultMessageNotificationLevel
class _AuditLogChange_ChannelType(TypedDict):
key: Literal['type']
key: Literal["type"]
new_value: ChannelType
old_value: ChannelType
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
key: Literal['expire_behavior']
key: Literal["expire_behavior"]
new_value: IntegrationExpireBehavior
old_value: IntegrationExpireBehavior
class _AuditLogChange_VideoQualityMode(TypedDict):
key: Literal['video_quality_mode']
key: Literal["video_quality_mode"]
new_value: VideoQualityMode
old_value: VideoQualityMode
class _AuditLogChange_Overwrites(TypedDict):
key: Literal['permission_overwrites']
key: Literal["permission_overwrites"]
new_value: List[PermissionOverwrite]
old_value: List[PermissionOverwrite]
@@ -232,7 +241,7 @@ class AuditEntryInfo(TypedDict):
message_id: Snowflake
count: str
id: Snowflake
type: Literal['0', '1']
type: Literal["0", "1"]
role_name: str

View File

@@ -24,49 +24,60 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, TypedDict
class _EmbedFooterOptional(TypedDict, total=False):
icon_url: str
proxy_icon_url: str
class EmbedFooter(_EmbedFooterOptional):
text: str
class _EmbedFieldOptional(TypedDict, total=False):
inline: bool
class EmbedField(_EmbedFieldOptional):
name: str
value: str
class EmbedThumbnail(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedVideo(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedImage(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedProvider(TypedDict, total=False):
name: str
url: str
class EmbedAuthor(TypedDict, total=False):
name: str
url: str
icon_url: str
proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal["rich", "image", "video", "gifv", "article", "link"]
class Embed(TypedDict, total=False):
title: str

View File

@@ -75,28 +75,28 @@ VerificationLevel = Literal[0, 1, 2, 3, 4]
NSFWLevel = Literal[0, 1, 2, 3]
PremiumTier = Literal[0, 1, 2, 3]
GuildFeature = Literal[
'ANIMATED_ICON',
'BANNER',
'COMMERCE',
'COMMUNITY',
'DISCOVERABLE',
'FEATURABLE',
'INVITE_SPLASH',
'MEMBER_VERIFICATION_GATE_ENABLED',
'MONETIZATION_ENABLED',
'MORE_EMOJI',
'MORE_STICKERS',
'NEWS',
'PARTNERED',
'PREVIEW_ENABLED',
'PRIVATE_THREADS',
'SEVEN_DAY_THREAD_ARCHIVE',
'THREE_DAY_THREAD_ARCHIVE',
'TICKETED_EVENTS_ENABLED',
'VANITY_URL',
'VERIFIED',
'VIP_REGIONS',
'WELCOME_SCREEN_ENABLED',
"ANIMATED_ICON",
"BANNER",
"COMMERCE",
"COMMUNITY",
"DISCOVERABLE",
"FEATURABLE",
"INVITE_SPLASH",
"MEMBER_VERIFICATION_GATE_ENABLED",
"MONETIZATION_ENABLED",
"MORE_EMOJI",
"MORE_STICKERS",
"NEWS",
"PARTNERED",
"PREVIEW_ENABLED",
"PRIVATE_THREADS",
"SEVEN_DAY_THREAD_ARCHIVE",
"THREE_DAY_THREAD_ARCHIVE",
"TICKETED_EVENTS_ENABLED",
"VANITY_URL",
"VERIFIED",
"VIP_REGIONS",
"WELCOME_SCREEN_ENABLED",
]

View File

@@ -56,7 +56,7 @@ class PartialIntegration(TypedDict):
account: IntegrationAccount
IntegrationType = Literal['twitch', 'youtube', 'discord']
IntegrationType = Literal["twitch", "youtube", "discord"]
class BaseIntegration(PartialIntegration):

View File

@@ -39,6 +39,7 @@ if TYPE_CHECKING:
ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
type: ApplicationCommandType
@@ -222,15 +223,12 @@ class MessageInteraction(TypedDict):
user: User
class _EditApplicationCommandOptional(TypedDict, total=False):
description: str
options: Optional[List[ApplicationCommandOption]]
type: ApplicationCommandType
default_permission: bool
class EditApplicationCommand(_EditApplicationCommandOptional):
name: str
default_permission: bool

View File

@@ -53,6 +53,7 @@ class _AttachmentOptional(TypedDict, total=False):
height: Optional[int]
width: Optional[int]
content_type: str
ephemeral: bool
spoiler: bool
@@ -128,7 +129,7 @@ class Message(_MessageOptional):
type: MessageType
AllowedMentionType = Literal['roles', 'users', 'everyone']
AllowedMentionType = Literal["roles", "users", "everyone"]
class AllowedMentions(TypedDict):

View File

@@ -85,3 +85,14 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False):
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake
guild_id: Snowflake
class _TypingEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class TypingEvent(_TypingEventOptional):
channel_id: Snowflake
user_id: Snowflake
timestamp: int

View File

@@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional
from .user import PartialUser
from .snowflake import Snowflake
class TeamMember(TypedDict):
user: PartialUser
membership_state: int
permissions: List[str]
team_id: Snowflake
class Team(TypedDict):
id: Snowflake
name: str

View File

@@ -27,7 +27,7 @@ from .snowflake import Snowflake
from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
SupportedModes = Literal["xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"]
class _PartialVoiceStateOptional(TypedDict, total=False):

View File

@@ -35,16 +35,16 @@ from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent
__all__ = (
'Button',
'button',
"Button",
"button",
)
if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
B = TypeVar('B', bound='Button')
V = TypeVar('V', bound='View', covariant=True)
B = TypeVar("B", bound="Button")
V = TypeVar("V", bound="View", covariant=True)
class Button(Item[V]):
@@ -76,12 +76,12 @@ class Button(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
'style',
'url',
'disabled',
'label',
'emoji',
'row',
"style",
"url",
"disabled",
"label",
"emoji",
"row",
)
def __init__(
@@ -97,7 +97,7 @@ class Button(Item[V]):
):
super().__init__()
if custom_id is not None and url is not None:
raise TypeError('cannot mix both url and custom_id with Button')
raise TypeError("cannot mix both url and custom_id with Button")
self._provided_custom_id = custom_id is not None
if url is None and custom_id is None:
@@ -112,7 +112,7 @@ class Button(Item[V]):
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
@@ -145,7 +145,7 @@ class Button(Item[V]):
@custom_id.setter
def custom_id(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str')
raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value
@@ -157,7 +157,7 @@ class Button(Item[V]):
@url.setter
def url(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str')
raise TypeError("url must be None or str")
self._underlying.url = value
@property
@@ -191,7 +191,7 @@ class Button(Item[V]):
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
raise TypeError(f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead")
else:
self._underlying.emoji = None
@@ -273,17 +273,17 @@ def button(
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
raise TypeError("button function must be a coroutine function")
func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = {
'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'row': row,
"style": style,
"custom_id": custom_id,
"url": None,
"disabled": disabled,
"label": label,
"emoji": emoji,
"row": row,
}
return func

View File

@@ -28,17 +28,15 @@ from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECK
from ..interactions import Interaction
__all__ = (
'Item',
)
__all__ = ("Item",)
if TYPE_CHECKING:
from ..enums import ComponentType
from .view import View
from ..components import Component
I = TypeVar('I', bound='Item')
V = TypeVar('V', bound='View', covariant=True)
I = TypeVar("I", bound="Item")
V = TypeVar("V", bound="View", covariant=True)
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
@@ -53,7 +51,7 @@ class Item(Generic[V]):
.. versionadded:: 2.0
"""
__item_repr_attributes__: Tuple[str, ...] = ('row',)
__item_repr_attributes__: Tuple[str, ...] = ("row",)
def __init__(self):
self._view: Optional[V] = None
@@ -91,8 +89,8 @@ class Item(Generic[V]):
return self._provided_custom_id
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
return f'<{self.__class__.__name__} {attrs}>'
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__)
return f"<{self.__class__.__name__} {attrs}>"
@property
def row(self) -> Optional[int]:
@@ -105,7 +103,7 @@ class Item(Generic[V]):
elif 5 > value >= 0:
self._row = value
else:
raise ValueError('row cannot be negative or greater than or equal to 5')
raise ValueError("row cannot be negative or greater than or equal to 5")
@property
def width(self) -> int:

View File

@@ -39,8 +39,8 @@ from ..components import (
)
__all__ = (
'Select',
'select',
"Select",
"select",
)
if TYPE_CHECKING:
@@ -50,8 +50,8 @@ if TYPE_CHECKING:
ComponentInteractionData,
)
S = TypeVar('S', bound='Select')
V = TypeVar('V', bound='View', covariant=True)
S = TypeVar("S", bound="Select")
V = TypeVar("V", bound="View", covariant=True)
class Select(Item[V]):
@@ -89,11 +89,11 @@ class Select(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
)
def __init__(
@@ -131,7 +131,7 @@ class Select(Item[V]):
@custom_id.setter
def custom_id(self, value: str):
if not isinstance(value, str):
raise TypeError('custom_id must be None or str')
raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value
@@ -143,7 +143,7 @@ class Select(Item[V]):
@placeholder.setter
def placeholder(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str')
raise TypeError("placeholder must be None or str")
self._underlying.placeholder = value
@@ -173,9 +173,9 @@ class Select(Item[V]):
@options.setter
def options(self, value: List[SelectOption]):
if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption')
raise TypeError("options must be a list of SelectOption")
if not all(isinstance(obj, SelectOption) for obj in value):
raise TypeError('all list items must subclass SelectOption')
raise TypeError("all list items must subclass SelectOption")
self._underlying.options = value
@@ -224,7 +224,6 @@ class Select(Item[V]):
default=default,
)
self.append_option(option)
def append_option(self, option: SelectOption):
@@ -242,7 +241,7 @@ class Select(Item[V]):
"""
if len(self._underlying.options) > 25:
raise ValueError('maximum number of options already provided')
raise ValueError("maximum number of options already provided")
self._underlying.options.append(option)
@@ -272,7 +271,7 @@ class Select(Item[V]):
def refresh_state(self, interaction: Interaction) -> None:
data: ComponentInteractionData = interaction.data # type: ignore
self._selected_values = data.get('values', [])
self._selected_values = data.get("values", [])
@classmethod
def from_component(cls: Type[S], component: SelectMenu) -> S:
@@ -340,17 +339,17 @@ def select(
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
raise TypeError("select function must be a coroutine function")
func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder,
'custom_id': custom_id,
'row': row,
'min_values': min_values,
'max_values': max_values,
'options': options,
'disabled': disabled,
"placeholder": placeholder,
"custom_id": custom_id,
"row": row,
"min_values": min_values,
"max_values": max_values,
"options": options,
"disabled": disabled,
}
return func

View File

@@ -41,9 +41,7 @@ from ..components import (
SelectMenu as SelectComponent,
)
__all__ = (
'View',
)
__all__ = ("View",)
if TYPE_CHECKING:
@@ -74,9 +72,7 @@ def _component_to_item(component: Component) -> Item:
class _ViewWeights:
__slots__ = (
'weights',
)
__slots__ = ("weights",)
def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0]
@@ -92,13 +88,13 @@ class _ViewWeights:
if weight + item.width <= 5:
return index
raise ValueError('could not find open space for item')
raise ValueError("could not find open space for item")
def add_item(self, item: Item) -> None:
if item.row is not None:
total = self.weights[item.row] + item.width
if total > 5:
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)")
self.weights[item.row] = total
item._rendered_row = item.row
else:
@@ -144,11 +140,11 @@ class View:
children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
if hasattr(member, '__discord_ui_model_type__'):
if hasattr(member, "__discord_ui_model_type__"):
children.append(member)
if len(children) > 25:
raise TypeError('View cannot have more than 25 children')
raise TypeError("View cannot have more than 25 children")
cls.__view_children_items__ = children
@@ -171,7 +167,7 @@ class View:
self.__stopped: asyncio.Future[bool] = loop.create_future()
def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>"
async def __timeout_task_impl(self) -> None:
while True:
@@ -203,8 +199,8 @@ class View:
components.append(
{
'type': 1,
'components': children,
"type": 1,
"components": children,
}
)
@@ -261,10 +257,10 @@ class View:
"""
if len(self.children) > 25:
raise ValueError('maximum number of children exceeded')
raise ValueError("maximum number of children exceeded")
if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}')
raise TypeError(f"expected Item not {item.__class__!r}")
self.__weights.add_item(item)
@@ -344,7 +340,7 @@ class View:
interaction: :class:`~discord.Interaction`
The interaction that led to the failure.
"""
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
async def _scheduled_task(self, item: Item, interaction: Interaction):
@@ -357,7 +353,7 @@ class View:
return
await item.callback(interaction)
if not interaction.response._responded:
if not interaction.response.is_done():
await interaction.response.defer()
except Exception as e:
return await self.on_error(e, item, interaction)
@@ -377,13 +373,13 @@ class View:
return
self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
asyncio.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}")
def _dispatch_item(self, item: Item, interaction: Interaction):
if self.__stopped.done():
return
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
asyncio.create_task(self._scheduled_task(item, interaction), name=f"discord-ui-view-dispatch-{self.id}")
def refresh(self, components: List[Component]):
# This is pretty hacky at the moment

View File

@@ -45,11 +45,11 @@ if TYPE_CHECKING:
__all__ = (
'User',
'ClientUser',
"User",
"ClientUser",
)
BU = TypeVar('BU', bound='BaseUser')
BU = TypeVar("BU", bound="BaseUser")
class _UserTag:
@@ -59,16 +59,16 @@ class _UserTag:
class BaseUser(_UserTag):
__slots__ = (
'name',
'id',
'discriminator',
'_avatar',
'_banner',
'_accent_colour',
'bot',
'system',
'_public_flags',
'_state',
"name",
"id",
"discriminator",
"_avatar",
"_banner",
"_accent_colour",
"bot",
"system",
"_public_flags",
"_state",
)
if TYPE_CHECKING:
@@ -80,7 +80,7 @@ class BaseUser(_UserTag):
_state: ConnectionState
_avatar: Optional[str]
_banner: Optional[str]
_accent_colour: Optional[str]
_accent_colour: Optional[int]
_public_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
@@ -94,7 +94,7 @@ class BaseUser(_UserTag):
)
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
return f"{self.name}#{self.discriminator}"
def __int__(self) -> int:
return self.id
@@ -109,15 +109,15 @@ class BaseUser(_UserTag):
return self.id >> 22
def _update(self, data: UserPayload) -> None:
self.name = data['username']
self.id = int(data['id'])
self.discriminator = data['discriminator']
self._avatar = data['avatar']
self._banner = data.get('banner', None)
self._accent_colour = data.get('accent_color', None)
self._public_flags = data.get('public_flags', 0)
self.bot = data.get('bot', False)
self.system = data.get('system', False)
self.name = data["username"]
self.id = int(data["id"])
self.discriminator = data["discriminator"]
self._avatar = data["avatar"]
self._banner = data.get("banner", None)
self._accent_colour = data.get("accent_color", None)
self._public_flags = data.get("public_flags", 0)
self.bot = data.get("bot", False)
self.system = data.get("system", False)
@classmethod
def _copy(cls: Type[BU], user: BU) -> BU:
@@ -137,11 +137,11 @@ class BaseUser(_UserTag):
def _to_minimal_user_json(self) -> Dict[str, Any]:
return {
'username': self.name,
'id': self.id,
'avatar': self._avatar,
'discriminator': self.discriminator,
'bot': self.bot,
"username": self.name,
"id": self.id,
"avatar": self._avatar,
"discriminator": self.discriminator,
"bot": self.bot,
}
@property
@@ -240,7 +240,7 @@ class BaseUser(_UserTag):
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>'
return f"<@{self.id}>"
@property
def created_at(self) -> datetime:
@@ -324,7 +324,7 @@ class ClientUser(BaseUser):
Specifies if the user has MFA turned on and working.
"""
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
__slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__")
if TYPE_CHECKING:
verified: bool
@@ -337,17 +337,17 @@ class ClientUser(BaseUser):
def __repr__(self) -> str:
return (
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
f"<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>"
)
def _update(self, data: UserPayload) -> None:
super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False)
self.locale = data.get('locale')
self._flags = data.get('flags', 0)
self.mfa_enabled = data.get('mfa_enabled', False)
self.verified = data.get("verified", False)
self.locale = data.get("locale")
self._flags = data.get("flags", 0)
self.mfa_enabled = data.get("mfa_enabled", False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
"""|coro|
@@ -388,10 +388,10 @@ class ClientUser(BaseUser):
"""
payload: Dict[str, Any] = {}
if username is not MISSING:
payload['username'] = username
payload["username"] = username
if avatar is not MISSING:
payload['avatar'] = _bytes_to_base64_data(avatar)
payload["avatar"] = _bytes_to_base64_data(avatar)
data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data)
@@ -436,14 +436,14 @@ class User(BaseUser, discord.abc.Messageable):
Specifies if the user is a system user (i.e. represents Discord officially).
"""
__slots__ = ('_stored',)
__slots__ = ("_stored",)
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
self._stored: bool = False
def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
return f"<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>"
def __del__(self) -> None:
try:

View File

@@ -72,18 +72,18 @@ else:
__all__ = (
'oauth_url',
'snowflake_time',
'time_snowflake',
'find',
'get',
'sleep_until',
'utcnow',
'remove_markdown',
'escape_markdown',
'escape_mentions',
'as_chunks',
'format_dt',
"oauth_url",
"snowflake_time",
"time_snowflake",
"find",
"get",
"sleep_until",
"utcnow",
"remove_markdown",
"escape_markdown",
"escape_mentions",
"as_chunks",
"format_dt",
)
DISCORD_EPOCH = 1420070400000
@@ -97,7 +97,7 @@ class _MissingSentinel:
return False
def __repr__(self):
return '...'
return "..."
MISSING: Any = _MissingSentinel()
@@ -106,7 +106,7 @@ MISSING: Any = _MissingSentinel()
class _cached_property:
def __init__(self, function):
self.function = function
self.__doc__ = getattr(function, '__doc__')
self.__doc__ = getattr(function, "__doc__")
def __get__(self, instance, owner):
if instance is None:
@@ -131,15 +131,14 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Mapping[str, Any]
P = ParamSpec('P')
P = ParamSpec("P")
else:
cached_property = _cached_property
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
_Iter = Union[Iterator[T], AsyncIterator[T]]
@@ -147,7 +146,7 @@ class CachedSlotProperty(Generic[T, T_co]):
def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.name = name
self.function = function
self.__doc__ = getattr(function, '__doc__')
self.__doc__ = getattr(function, "__doc__")
@overload
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
@@ -177,7 +176,7 @@ class classproperty(Generic[T_co]):
return self.fget(owner)
def __set__(self, instance, value) -> None:
raise AttributeError('cannot set attribute')
raise AttributeError("cannot set attribute")
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
@@ -249,14 +248,14 @@ def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Call
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.simplefilter("always", DeprecationWarning) # turn off filter
if instead:
fmt = "{0.__name__} is deprecated, use {1} instead."
else:
fmt = '{0.__name__} is deprecated.'
fmt = "{0.__name__} is deprecated."
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
warnings.simplefilter('default', DeprecationWarning) # reset filter
warnings.simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs)
return decorated
@@ -301,18 +300,18 @@ def oauth_url(
:class:`str`
The OAuth2 URL for inviting the bot into guilds.
"""
url = f'https://discord.com/oauth2/authorize?client_id={client_id}'
url += '&scope=' + '+'.join(scopes or ('bot',))
url = f"https://discord.com/oauth2/authorize?client_id={client_id}"
url += "&scope=" + "+".join(scopes or ("bot",))
if permissions is not MISSING:
url += f'&permissions={permissions.value}'
url += f"&permissions={permissions.value}"
if guild is not MISSING:
url += f'&guild_id={guild.id}'
url += f"&guild_id={guild.id}"
if redirect_uri is not MISSING:
from urllib.parse import urlencode
url += '&response_type=code&' + urlencode({'redirect_uri': redirect_uri})
url += "&response_type=code&" + urlencode({"redirect_uri": redirect_uri})
if disable_guild_select:
url += '&disable_guild_select=true'
url += "&disable_guild_select=true"
return url
@@ -435,13 +434,13 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace('__', '.'))
pred = attrget(k.replace("__", "."))
for elem in iterable:
if pred(elem) == v:
return elem
return None
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()]
for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
@@ -463,46 +462,46 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
def _get_mime_type_for_image(data: bytes):
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
return 'image/png'
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
return 'image/jpeg'
elif data.startswith((b'\x47\x49\x46\x38\x37\x61', b'\x47\x49\x46\x38\x39\x61')):
return 'image/gif'
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP':
return 'image/webp'
if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"):
return "image/png"
elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
return "image/jpeg"
elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")):
return "image/gif"
elif data.startswith(b"RIFF") and data[8:12] == b"WEBP":
return "image/webp"
else:
raise InvalidArgument('Unsupported image type given')
raise InvalidArgument("Unsupported image type given")
def _bytes_to_base64_data(data: bytes) -> str:
fmt = 'data:{mime};base64,{data}'
fmt = "data:{mime};base64,{data}"
mime = _get_mime_type_for_image(data)
b64 = b64encode(data).decode('ascii')
b64 = b64encode(data).decode("ascii")
return fmt.format(mime=mime, data=b64)
if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode('utf-8')
return orjson.dumps(obj).decode("utf-8")
_from_json = orjson.loads # type: ignore
else:
def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
_from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
reset_after: Optional[str] = request.headers.get("X-Ratelimit-Reset-After")
if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc)
return (reset - now).total_seconds()
else:
return float(reset_after)
@@ -612,7 +611,7 @@ class SnowflakeList(array.array):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
i = bisect_left(self, element)
@@ -627,7 +626,7 @@ class SnowflakeList(array.array):
return i != len(self) and self[i] == element
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
_IS_ASCII = re.compile(r"^[\x00-\x7f]+$")
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
@@ -636,7 +635,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
if match:
return match.endpos
UNICODE_WIDE_CHAR_TYPE = 'WFA'
UNICODE_WIDE_CHAR_TYPE = "WFA"
func = unicodedata.east_asian_width
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
@@ -660,7 +659,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite):
return invite.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)"
m = re.match(rx, invite)
if m:
return m.group(1)
@@ -688,22 +687,24 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template):
return code.code
else:
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)"
m = re.match(rx, code)
if m:
return m.group(1)
return code
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|'))
_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?<!\{0})\{0})))".format(c) for c in ("*", "`", "_", "~", "|"))
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
_MARKDOWN_ESCAPE_COMMON = r"^>(?:>>)?\s|\[.+\]\(.+\)"
_MARKDOWN_ESCAPE_REGEX = re.compile(fr'(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})', re.MULTILINE)
_MARKDOWN_ESCAPE_REGEX = re.compile(
fr"(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", re.MULTILINE
)
_URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])'
_URL_REGEX = r"(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])"
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
_MARKDOWN_STOCK_REGEX = fr"(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})"
def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
@@ -732,11 +733,11 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
def replacement(match):
groupdict = match.groupdict()
return groupdict.get('url', '')
return groupdict.get("url", "")
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})'
regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE)
@@ -769,18 +770,18 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
def replacement(match):
groupdict = match.groupdict()
is_url = groupdict.get('url')
is_url = groupdict.get("url")
if is_url:
return is_url
return '\\' + groupdict['markdown']
return "\\" + groupdict["markdown"]
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})'
regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE)
else:
text = re.sub(r'\\', r'\\\\', text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
text = re.sub(r"\\", r"\\\\", text)
return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text)
def escape_mentions(text: str) -> str:
@@ -806,7 +807,7 @@ def escape_mentions(text: str) -> str:
:class:`str`
The text with the mentions removed.
"""
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text)
return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text)
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
@@ -870,7 +871,7 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
A new iterator which yields chunks of a given size.
"""
if max_size <= 0:
raise ValueError('Chunk sizes must be greater than 0.')
raise ValueError("Chunk sizes must be greater than 0.")
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
@@ -916,11 +917,11 @@ def evaluate_annotation(
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
if hasattr(tp, "__args__"):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
@@ -938,10 +939,12 @@ def evaluate_annotation(
implicit_str = False
is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
evaluated_args = tuple(
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.')
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")
if evaluated_args == args:
return tp
@@ -971,7 +974,7 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache)
TimestampStyle = Literal['f', 'F', 'd', 'D', 't', 'T', 'R']
TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
@@ -1015,5 +1018,5 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
The formatted string.
"""
if style is None:
return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>'
return f"<t:{int(dt.timestamp())}>"
return f"<t:{int(dt.timestamp())}:{style}>"

View File

@@ -72,20 +72,20 @@ has_nacl: bool
try:
import nacl.secret # type: ignore
has_nacl = True
except ImportError:
has_nacl = False
__all__ = (
'VoiceProtocol',
'VoiceClient',
"VoiceProtocol",
"VoiceClient",
)
_log = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@@ -195,6 +195,7 @@ class VoiceProtocol:
key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
@@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
endpoint_ip: str
voice_port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
@@ -255,18 +256,20 @@ class VoiceClient(VoiceProtocol):
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
"xsalsa20_poly1305_lite",
"xsalsa20_poly1305_suffix",
"xsalsa20_poly1305",
)
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, 'guild', None)
return getattr(self.channel, "guild", None)
@property
def user(self) -> ClientUser:
@@ -283,8 +286,8 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
channel_id = data['channel_id']
self.session_id = data["session_id"]
channel_id = data["channel_id"]
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
@@ -301,20 +304,22 @@ class VoiceClient(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
_log.info('Ignoring extraneous voice server update.')
_log.info("Ignoring extraneous voice server update.")
return
self.token = data.get('token')
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
self.token = data.get("token")
self.server_id = int(data["guild_id"])
endpoint = data.get("endpoint")
if endpoint is None or self.token is None:
_log.warning('Awaiting endpoint... This requires waiting. ' \
'If timeout occurred considering raising the timeout and reconnecting.')
_log.warning(
"Awaiting endpoint... This requires waiting. "
"If timeout occurred considering raising the timeout and reconnecting."
)
return
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
self.endpoint, _, _ = endpoint.rpartition(":")
if self.endpoint.startswith("wss://"):
# Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:]
@@ -335,18 +340,20 @@ class VoiceClient(VoiceProtocol):
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self) -> None:
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
_log.info(
"The voice handshake is being terminated for Channel ID %s (Guild ID %s)", self.channel.id, self.guild.id
)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
_log.info("Starting voice handshake... (connection attempt %d)", self._connections + 1)
self._connections += 1
def finish_handshake(self) -> None:
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
_log.info("Voice handshake complete. Endpoint found %s", self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
@@ -360,7 +367,7 @@ class VoiceClient(VoiceProtocol):
return ws
async def connect(self, *, reconnect: bool, timeout: float) -> None:
_log.info('Connecting to voice...')
_log.info("Connecting to voice...")
self.timeout = timeout
for i in range(5):
@@ -388,7 +395,7 @@ class VoiceClient(VoiceProtocol):
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
_log.exception('Failed to connect to voice... Retrying...')
_log.exception("Failed to connect to voice... Retrying...")
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
@@ -453,14 +460,14 @@ class VoiceClient(VoiceProtocol):
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
_log.info("Disconnecting from voice normally, close code %d.", exc.code)
await self.disconnect()
break
if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.')
_log.info("Disconnected from voice by force... potentially reconnecting.")
successful = await self.potential_reconnect()
if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
_log.info("Reconnect was unsuccessful, disconnecting from voice normally...")
await self.disconnect()
break
else:
@@ -471,7 +478,7 @@ class VoiceClient(VoiceProtocol):
raise
retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
_log.exception("Disconnected from voice... Reconnecting in %.2fs.", retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
@@ -479,7 +486,7 @@ class VoiceClient(VoiceProtocol):
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...')
_log.warning("Could not connect to voice... Retrying...")
continue
async def disconnect(self, *, force: bool = False) -> None:
@@ -527,11 +534,11 @@ class VoiceClient(VoiceProtocol):
# Formulate rtp header
header[0] = 0x80
header[1] = 0x78
struct.pack_into('>H', header, 2, self.sequence)
struct.pack_into('>I', header, 4, self.timestamp)
struct.pack_into('>I', header, 8, self.ssrc)
struct.pack_into(">H", header, 2, self.sequence)
struct.pack_into(">I", header, 4, self.timestamp)
struct.pack_into(">I", header, 8, self.ssrc)
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
encrypt_packet = getattr(self, "_encrypt_" + self.mode)
return encrypt_packet(header, data)
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
@@ -551,8 +558,8 @@ class VoiceClient(VoiceProtocol):
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:4] = struct.pack('>I', self._lite_nonce)
self.checked_add('_lite_nonce', 1, 4294967295)
nonce[:4] = struct.pack(">I", self._lite_nonce)
self.checked_add("_lite_nonce", 1, 4294967295)
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
@@ -586,13 +593,13 @@ class VoiceClient(VoiceProtocol):
"""
if not self.is_connected():
raise ClientException('Not connected to voice.')
raise ClientException("Not connected to voice.")
if self.is_playing():
raise ClientException('Already playing audio.')
raise ClientException("Already playing audio.")
if not isinstance(source, AudioSource):
raise TypeError(f'source must be an AudioSource not {source.__class__.__name__}')
raise TypeError(f"source must be an AudioSource not {source.__class__.__name__}")
if not self.encoder and not source.is_opus():
self.encoder = opus.Encoder()
@@ -635,10 +642,10 @@ class VoiceClient(VoiceProtocol):
@source.setter
def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource):
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
raise TypeError(f"expected AudioSource not {value.__class__.__name__}.")
if self._player is None:
raise ValueError('Not playing anything.')
raise ValueError("Not playing anything.")
self._player._set_source(value)
@@ -662,7 +669,7 @@ class VoiceClient(VoiceProtocol):
Encoding the data failed.
"""
self.checked_add('sequence', 1, 65535)
self.checked_add("sequence", 1, 65535)
if encode:
encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
else:
@@ -671,6 +678,6 @@ class VoiceClient(VoiceProtocol):
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
_log.warning("A packet has been dropped (seq: %s, timestamp: %s)", self.sequence, self.timestamp)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

View File

@@ -46,10 +46,10 @@ from ..mixins import Hashable
from ..channel import PartialMessageable
__all__ = (
'Webhook',
'WebhookMessage',
'PartialWebhookChannel',
'PartialWebhookGuild',
"Webhook",
"WebhookMessage",
"PartialWebhookChannel",
"PartialWebhookGuild",
)
_log = logging.getLogger(__name__)
@@ -120,14 +120,14 @@ class AsyncWebhookAdapter:
self._locks[bucket] = lock = asyncio.Lock()
if payload is not None:
headers['Content-Type'] = 'application/json'
headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
headers["Authorization"] = f"Bot {auth_token}"
if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[aiohttp.ClientResponse] = None
data: Optional[Union[Dict[str, Any], str]] = None
@@ -149,21 +149,23 @@ class AsyncWebhookAdapter:
try:
async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
"Webhook ID %s with %s %s has returned status code %s",
webhook_id,
method,
url,
response.status,
)
data = (await response.text(encoding='utf-8')) or None
if data and response.headers['Content-Type'] == 'application/json':
data = (await response.text(encoding="utf-8")) or None
if data and response.headers["Content-Type"] == "application/json":
data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status != 429:
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status != 429:
delta = utils._parse_ratelimit_header(response)
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
)
lock.delay_by(delta)
@@ -171,11 +173,13 @@ class AsyncWebhookAdapter:
return data
if response.status == 429:
if not response.headers.get('Via'):
if not response.headers.get("Via"):
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
retry_after: float = data["retry_after"] # type: ignore
_log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds", webhook_id, retry_after
)
await asyncio.sleep(retry_after)
continue
@@ -201,7 +205,7 @@ class AsyncWebhookAdapter:
raise DiscordServerError(response, data)
raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.')
raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook(
self,
@@ -211,7 +215,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token(
@@ -222,7 +226,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("DELETE", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
def edit_webhook(
@@ -234,7 +238,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
def edit_webhook_with_token(
@@ -246,7 +250,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
reason: Optional[str] = None,
) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("PATCH", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
def execute_webhook(
@@ -261,10 +265,10 @@ class AsyncWebhookAdapter:
thread_id: Optional[int] = None,
wait: bool = False,
) -> Response[Optional[MessagePayload]]:
params = {'wait': int(wait)}
params = {"wait": int(wait)}
if thread_id:
params['thread_id'] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
params["thread_id"] = thread_id
route = Route("POST", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def get_webhook_message(
@@ -276,8 +280,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[MessagePayload]:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -296,8 +300,8 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None,
) -> Response[Message]:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -313,8 +317,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[None]:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -328,7 +332,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token(
@@ -338,7 +342,7 @@ class AsyncWebhookAdapter:
*,
session: aiohttp.ClientSession,
) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("GET", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
def create_interaction_response(
@@ -351,15 +355,15 @@ class AsyncWebhookAdapter:
data: Optional[Dict[str, Any]] = None,
) -> Response[None]:
payload: Dict[str, Any] = {
'type': type,
"type": type,
}
if data is not None:
payload['data'] = data
payload["data"] = data
route = Route(
'POST',
'/interactions/{webhook_id}/{webhook_token}/callback',
"POST",
"/interactions/{webhook_id}/{webhook_token}/callback",
webhook_id=interaction_id,
webhook_token=token,
)
@@ -374,8 +378,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[MessagePayload]:
r = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id,
webhook_token=token,
)
@@ -392,8 +396,8 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None,
) -> Response[MessagePayload]:
r = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/@original',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id,
webhook_token=token,
)
@@ -407,8 +411,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
) -> Response[None]:
r = Route(
'DELETE',
'/webhooks/{webhook_id}/{wehook_token}/messages/@original',
"DELETE",
"/webhooks/{webhook_id}/{wehook_token}/messages/@original",
webhook_id=application_id,
wehook_token=token,
)
@@ -437,82 +441,82 @@ def handle_message_parameters(
previous_allowed_mentions: Optional[AllowedMentions] = None,
) -> ExecuteWebhookParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
raise TypeError("Cannot mix file and files keyword arguments.")
if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.')
raise TypeError("Cannot mix embed and embeds keyword arguments.")
payload = {}
if embeds is not MISSING:
if len(embeds) > 10:
raise InvalidArgument('embeds has a maximum of 10 elements.')
payload['embeds'] = [e.to_dict() for e in embeds]
raise InvalidArgument("embeds has a maximum of 10 elements.")
payload["embeds"] = [e.to_dict() for e in embeds]
if embed is not MISSING:
if embed is None:
payload['embeds'] = []
payload["embeds"] = []
else:
payload['embeds'] = [embed.to_dict()]
payload["embeds"] = [embed.to_dict()]
if content is not MISSING:
if content is not None:
payload['content'] = str(content)
payload["content"] = str(content)
else:
payload['content'] = None
payload["content"] = None
if view is not MISSING:
if view is not None:
payload['components'] = view.to_components()
payload["components"] = view.to_components()
else:
payload['components'] = []
payload["components"] = []
payload['tts'] = tts
payload["tts"] = tts
if avatar_url:
payload['avatar_url'] = str(avatar_url)
payload["avatar_url"] = str(avatar_url)
if username:
payload['username'] = username
payload["username"] = username
if ephemeral:
payload['flags'] = 64
payload["flags"] = 64
if allowed_mentions:
if previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
payload["allowed_mentions"] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
else:
payload['allowed_mentions'] = allowed_mentions.to_dict()
payload["allowed_mentions"] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
payload["allowed_mentions"] = previous_allowed_mentions.to_dict()
multipart = []
if file is not MISSING:
files = [file]
if files:
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
multipart.append({"name": "payload_json", "value": utils._to_json(payload)})
payload = None
if len(files) == 1:
file = files[0]
multipart.append(
{
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
"name": "file",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
else:
for index, file in enumerate(files):
multipart.append(
{
'name': f'file{index}',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
"name": f"file{index}",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar("async_webhook_context", default=AsyncWebhookAdapter())
class PartialWebhookChannel(Hashable):
@@ -530,14 +534,14 @@ class PartialWebhookChannel(Hashable):
The partial channel's name.
"""
__slots__ = ('id', 'name')
__slots__ = ("id", "name")
def __init__(self, *, data):
self.id = int(data['id'])
self.name = data['name']
self.id = int(data["id"])
self.name = data["name"]
def __repr__(self):
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
return f"<PartialWebhookChannel name={self.name!r} id={self.id}>"
class PartialWebhookGuild(Hashable):
@@ -555,16 +559,16 @@ class PartialWebhookGuild(Hashable):
The partial guild's name.
"""
__slots__ = ('id', 'name', '_icon', '_state')
__slots__ = ("id", "name", "_icon", "_state")
def __init__(self, *, data, state):
self._state = state
self.id = int(data['id'])
self.name = data['name']
self._icon = data['icon']
self.id = int(data["id"])
self.name = data["name"]
self._icon = data["icon"]
def __repr__(self):
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
return f"<PartialWebhookGuild name={self.name!r} id={self.id}>"
@property
def icon(self) -> Optional[Asset]:
@@ -578,11 +582,11 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
raise AttributeError('PartialWebhookState does not support http methods.')
raise AttributeError("PartialWebhookState does not support http methods.")
class _WebhookState:
__slots__ = ('_parent', '_webhook')
__slots__ = ("_parent", "_webhook")
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
self._webhook: Any = webhook
@@ -621,7 +625,7 @@ class _WebhookState:
if self._parent is not None:
return getattr(self._parent, attr)
raise AttributeError(f'PartialWebhookState does not support {attr!r}.')
raise AttributeError(f"PartialWebhookState does not support {attr!r}.")
class WebhookMessage(Message):
@@ -750,18 +754,18 @@ class WebhookMessage(Message):
class BaseWebhook(Hashable):
__slots__: Tuple[str, ...] = (
'id',
'type',
'guild_id',
'channel_id',
'token',
'auth_token',
'user',
'name',
'_avatar',
'source_channel',
'source_guild',
'_state',
"id",
"type",
"guild_id",
"channel_id",
"token",
"auth_token",
"user",
"name",
"_avatar",
"source_channel",
"source_guild",
"_state",
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
@@ -770,27 +774,27 @@ class BaseWebhook(Hashable):
self._update(data)
def _update(self, data: WebhookPayload):
self.id = int(data['id'])
self.type = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name')
self._avatar = data.get('avatar')
self.token = data.get('token')
self.id = int(data["id"])
self.type = try_enum(WebhookType, int(data["type"]))
self.channel_id = utils._get_as_snowflake(data, "channel_id")
self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.name = data.get("name")
self._avatar = data.get("avatar")
self.token = data.get("token")
user = data.get('user')
user = data.get("user")
self.user: Optional[Union[BaseUser, User]] = None
if user is not None:
# state parameter may be _WebhookState
self.user = User(state=self._state, data=user) # type: ignore
source_channel = data.get('source_channel')
source_channel = data.get("source_channel")
if source_channel:
source_channel = PartialWebhookChannel(data=source_channel)
self.source_channel: Optional[PartialWebhookChannel] = source_channel
source_guild = data.get('source_guild')
source_guild = data.get("source_guild")
if source_guild:
source_guild = PartialWebhookGuild(data=source_guild, state=self._state)
@@ -927,22 +931,24 @@ class Webhook(BaseWebhook):
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = ('session',)
__slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
super().__init__(data, token, state)
self.session = session
def __repr__(self):
return f'<Webhook id={self.id!r}>'
return f"<Webhook id={self.id!r}>"
@property
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def partial(
cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None
) -> Webhook:
"""Creates a partial :class:`Webhook`.
Parameters
@@ -970,9 +976,9 @@ class Webhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token.
"""
data: WebhookPayload = {
'id': id,
'type': 1,
'token': token,
"id": id,
"type": 1,
"token": token,
}
return cls(data, session, token=bot_token)
@@ -1008,24 +1014,24 @@ class Webhook(BaseWebhook):
A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token.
"""
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
m = re.search(r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})", url)
if m is None:
raise InvalidArgument('Invalid webhook URL given.')
raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict()
data['type'] = 1
data["type"] = 1
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
'type': 2,
'name': name,
'channel_id': channel.id,
'guild_id': channel.guild.id,
'user': {'username': user.name, 'discriminator': user.discriminator, 'id': user.id, 'avatar': user._avatar},
"id": data["webhook_id"],
"type": 2,
"name": name,
"channel_id": channel.id,
"guild_id": channel.guild.id,
"user": {"username": user.name, "discriminator": user.discriminator, "id": user.id, "avatar": user._avatar},
}
state = channel._state
@@ -1079,7 +1085,7 @@ class Webhook(BaseWebhook):
elif self.token:
data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
else:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
return Webhook(data, self.session, token=self.auth_token, state=self._state)
@@ -1112,7 +1118,7 @@ class Webhook(BaseWebhook):
This webhook does not have a token associated with it.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get()
@@ -1165,14 +1171,14 @@ class Webhook(BaseWebhook):
or it tried editing a channel without authentication.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
payload = {}
if name is not MISSING:
payload['name'] = str(name) if name is not None else None
payload["name"] = str(name) if name is not None else None
if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter = async_context.get()
@@ -1180,27 +1186,31 @@ class Webhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook')
raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
payload["channel_id"] = channel.id
data = await adapter.edit_webhook(
self.id, self.auth_token, payload=payload, session=self.session, reason=reason
)
if prefer_auth and self.auth_token:
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
data = await adapter.edit_webhook(
self.id, self.auth_token, payload=payload, session=self.session, reason=reason
)
elif self.token:
data = await adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
raise RuntimeError("Unreachable code hit: data was not assigned")
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
@@ -1350,22 +1360,22 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
if content is None:
content = MISSING
application_webhook = self.type is WebhookType.application
if ephemeral and not application_webhook:
raise InvalidArgument('ephemeral messages can only be sent from application webhooks')
raise InvalidArgument("ephemeral messages can only be sent from application webhooks")
if application_webhook:
wait = True
if view is not MISSING:
if isinstance(self._state, _WebhookState):
raise InvalidArgument('Webhook views require an associated state with the webhook')
raise InvalidArgument("Webhook views require an associated state with the webhook")
if ephemeral is True and view.timeout is None:
view.timeout = 15 * 60.0
@@ -1439,7 +1449,7 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get()
data = await adapter.get_webhook_message(
@@ -1525,15 +1535,15 @@ class Webhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
if view is not MISSING:
if isinstance(self._state, _WebhookState):
raise InvalidArgument('This webhook does not have state associated with it')
raise InvalidArgument("This webhook does not have state associated with it")
self._state.prevent_view_updates_for(message_id)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
params = handle_message_parameters(
content=content,
file=file,
@@ -1583,7 +1593,7 @@ class Webhook(BaseWebhook):
Deleted a message that is not yours.
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get()
await adapter.delete_webhook_message(

View File

@@ -48,8 +48,8 @@ from ..channel import PartialMessageable
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
__all__ = (
'SyncWebhook',
'SyncWebhookMessage',
"SyncWebhook",
"SyncWebhookMessage",
)
_log = logging.getLogger(__name__)
@@ -116,14 +116,14 @@ class WebhookAdapter:
self._locks[bucket] = lock = threading.Lock()
if payload is not None:
headers['Content-Type'] = 'application/json'
headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload)
if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}'
headers["Authorization"] = f"Bot {auth_token}"
if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[Response] = None
data: Optional[Union[Dict[str, Any], str]] = None
@@ -140,36 +140,38 @@ class WebhookAdapter:
if multipart:
file_data = {}
for p in multipart:
name = p['name']
if name == 'payload_json':
to_send = {'payload_json': p['value']}
name = p["name"]
if name == "payload_json":
to_send = {"payload_json": p["value"]}
else:
file_data[name] = (p['filename'], p['value'], p['content_type'])
file_data[name] = (p["filename"], p["value"], p["content_type"])
try:
with session.request(
method, url, data=to_send, files=file_data, headers=headers, params=params
) as response:
_log.debug(
'Webhook ID %s with %s %s has returned status code %s',
"Webhook ID %s with %s %s has returned status code %s",
webhook_id,
method,
url,
response.status_code,
)
response.encoding = 'utf-8'
response.encoding = "utf-8"
# Compatibility with aiohttp
response.status = response.status_code # type: ignore
data = response.text or None
if data and response.headers['Content-Type'] == 'application/json':
if data and response.headers["Content-Type"] == "application/json":
data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status_code != 429:
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status_code != 429:
delta = utils._parse_ratelimit_header(response)
_log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
)
lock.delay_by(delta)
@@ -177,11 +179,13 @@ class WebhookAdapter:
return data
if response.status_code == 429:
if not response.headers.get('Via'):
if not response.headers.get("Via"):
raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
retry_after: float = data["retry_after"] # type: ignore
_log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds", webhook_id, retry_after
)
time.sleep(retry_after)
continue
@@ -207,7 +211,7 @@ class WebhookAdapter:
raise DiscordServerError(response, data)
raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.')
raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook(
self,
@@ -217,7 +221,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token(
@@ -228,7 +232,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("DELETE", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
def edit_webhook(
@@ -240,7 +244,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
def edit_webhook_with_token(
@@ -252,7 +256,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("PATCH", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
def execute_webhook(
@@ -267,10 +271,10 @@ class WebhookAdapter:
thread_id: Optional[int] = None,
wait: bool = False,
):
params = {'wait': int(wait)}
params = {"wait": int(wait)}
if thread_id:
params['thread_id'] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
params["thread_id"] = thread_id
route = Route("POST", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def get_webhook_message(
@@ -282,8 +286,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -302,8 +306,8 @@ class WebhookAdapter:
files: Optional[List[File]] = None,
):
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -319,8 +323,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@@ -334,7 +338,7 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token(
@@ -344,7 +348,7 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
route = Route("GET", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
@@ -516,22 +520,24 @@ class SyncWebhook(BaseWebhook):
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = ('session',)
__slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
super().__init__(data, token, state)
self.session = session
def __repr__(self):
return f'<Webhook id={self.id!r}>'
return f"<Webhook id={self.id!r}>"
@property
def url(self) -> str:
""":class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
def partial(
cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None
) -> SyncWebhook:
"""Creates a partial :class:`Webhook`.
Parameters
@@ -556,15 +562,15 @@ class SyncWebhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token.
"""
data: WebhookPayload = {
'id': id,
'type': 1,
'token': token,
"id": id,
"type": 1,
"token": token,
}
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}')
raise TypeError(f"expected requests.Session not {session.__class__!r}")
else:
session = requests # type: ignore
return cls(data, session, token=bot_token)
@@ -597,17 +603,17 @@ class SyncWebhook(BaseWebhook):
A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token.
"""
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
m = re.search(r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})", url)
if m is None:
raise InvalidArgument('Invalid webhook URL given.')
raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict()
data['type'] = 1
data["type"] = 1
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}')
raise TypeError(f"expected requests.Session not {session.__class__!r}")
else:
session = requests # type: ignore
return cls(data, session, token=bot_token) # type: ignore
@@ -650,7 +656,7 @@ class SyncWebhook(BaseWebhook):
elif self.token:
data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
else:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
@@ -679,7 +685,7 @@ class SyncWebhook(BaseWebhook):
This webhook does not have a token associated with it.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter()
@@ -731,14 +737,14 @@ class SyncWebhook(BaseWebhook):
The newly edited webhook.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
payload = {}
if name is not MISSING:
payload['name'] = str(name) if name is not None else None
payload["name"] = str(name) if name is not None else None
if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter: WebhookAdapter = _get_webhook_adapter()
@@ -746,25 +752,27 @@ class SyncWebhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook')
raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id
payload["channel_id"] = channel.id
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
if prefer_auth and self.auth_token:
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
elif self.token:
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
data = adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
)
if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned')
raise RuntimeError("Unreachable code hit: data was not assigned")
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
@@ -887,9 +895,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
if content is None:
content = MISSING
@@ -951,7 +959,7 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message(
@@ -1015,9 +1023,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
params = handle_message_parameters(
content=content,
file=file,
@@ -1060,7 +1068,7 @@ class SyncWebhook(BaseWebhook):
Deleted a message that is not yours.
"""
if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it')
raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message(

View File

@@ -41,11 +41,12 @@ if TYPE_CHECKING:
)
__all__ = (
'WidgetChannel',
'WidgetMember',
'Widget',
"WidgetChannel",
"WidgetMember",
"Widget",
)
class WidgetChannel:
"""Represents a "partial" widget channel.
@@ -76,7 +77,8 @@ class WidgetChannel:
position: :class:`int`
The channel's position
"""
__slots__ = ('id', 'name', 'position')
__slots__ = ("id", "name", "position")
def __init__(self, id: int, name: str, position: int) -> None:
self.id: int = id
@@ -87,18 +89,19 @@ class WidgetChannel:
return self.name
def __repr__(self) -> str:
return f'<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>'
return f"<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>"
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
return f"<#{self.id}>"
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id)
class WidgetMember(BaseUser):
"""Represents a "partial" member of the widget's guild.
@@ -147,29 +150,37 @@ class WidgetMember(BaseUser):
connected_channel: Optional[:class:`WidgetChannel`]
Which channel the member is connected to.
"""
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted',
'connected_channel')
__slots__ = (
"name",
"status",
"nick",
"avatar",
"discriminator",
"id",
"bot",
"activity",
"deafened",
"suppress",
"muted",
"connected_channel",
)
if TYPE_CHECKING:
activity: Optional[Union[BaseActivity, Spotify]]
def __init__(
self,
*,
state: ConnectionState,
data: WidgetMemberPayload,
connected_channel: Optional[WidgetChannel] = None
self, *, state: ConnectionState, data: WidgetMemberPayload, connected_channel: Optional[WidgetChannel] = None
) -> None:
super().__init__(state=state, data=data)
self.nick: Optional[str] = data.get('nick')
self.status: Status = try_enum(Status, data.get('status'))
self.deafened: Optional[bool] = data.get('deaf', False) or data.get('self_deaf', False)
self.muted: Optional[bool] = data.get('mute', False) or data.get('self_mute', False)
self.suppress: Optional[bool] = data.get('suppress', False)
self.nick: Optional[str] = data.get("nick")
self.status: Status = try_enum(Status, data.get("status"))
self.deafened: Optional[bool] = data.get("deaf", False) or data.get("self_deaf", False)
self.muted: Optional[bool] = data.get("mute", False) or data.get("self_mute", False)
self.suppress: Optional[bool] = data.get("suppress", False)
try:
game = data['game']
game = data["game"]
except KeyError:
activity = None
else:
@@ -190,6 +201,7 @@ class WidgetMember(BaseUser):
""":class:`str`: Returns the member's display name."""
return self.nick or self.name
class Widget:
"""Represents a :class:`Guild` widget.
@@ -227,27 +239,28 @@ class Widget:
retrieved is capped.
"""
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
__slots__ = ("_state", "channels", "_invite", "id", "members", "name")
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state
self._invite = data['instant_invite']
self.name: str = data['name']
self.id: int = int(data['id'])
self._invite = data["instant_invite"]
self.name: str = data["name"]
self.id: int = int(data["id"])
self.channels: List[WidgetChannel] = []
for channel in data.get('channels', []):
_id = int(channel['id'])
self.channels.append(WidgetChannel(id=_id, name=channel['name'], position=channel['position']))
for channel in data.get("channels", []):
_id = int(channel["id"])
self.channels.append(WidgetChannel(id=_id, name=channel["name"], position=channel["position"]))
self.members: List[WidgetMember] = []
channels = {channel.id: channel for channel in self.channels}
for member in data.get('members', []):
connected_channel = _get_as_snowflake(member, 'channel_id')
for member in data.get("members", []):
connected_channel = _get_as_snowflake(member, "channel_id")
if connected_channel in channels:
connected_channel = channels[connected_channel] # type: ignore
elif connected_channel:
connected_channel = WidgetChannel(id=connected_channel, name='', position=0)
connected_channel = WidgetChannel(id=connected_channel, name="", position=0)
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore
@@ -260,7 +273,7 @@ class Widget:
return False
def __repr__(self) -> str:
return f'<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>'
return f"<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>"
@property
def created_at(self) -> datetime.datetime:

View File

@@ -369,6 +369,17 @@ to handle it, which defaults to print a traceback and ignoring the exception.
:param when: When the typing started as an aware datetime in UTC.
:type when: :class:`datetime.datetime`
.. function:: on_raw_typing(payload)
Called when someone begins typing a message. Unlike :func:`on_typing`, this is
called regardless if the user can be found or not. This most often happens
when a user types in DMs.
This requires :attr:`Intents.typing` to be enabled.
:param payload: The raw typing payload.
:type payload: :class:`RawTypingEvent`
.. function:: on_message(message)
Called when a :class:`Message` is created and sent.
@@ -3846,6 +3857,14 @@ GuildSticker
.. autoclass:: GuildSticker()
:members:
RawTypingEvent
~~~~~~~~~~~~~~~~~~~~~~~
.. attributetable:: RawTypingEvent
.. autoclass:: RawTypingEvent()
:members:
RawMessageDeleteEvent
~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -18,8 +18,8 @@ import re
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))
sys.path.append(os.path.abspath('extensions'))
sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath("extensions"))
# -- General configuration ------------------------------------------------
@@ -30,33 +30,33 @@ sys.path.append(os.path.abspath('extensions'))
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'builder',
'sphinx.ext.autodoc',
'sphinx.ext.extlinks',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinxcontrib_trio',
'details',
'exception_hierarchy',
'attributetable',
'resourcelinks',
'nitpick_file_ignorer',
"builder",
"sphinx.ext.autodoc",
"sphinx.ext.extlinks",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinxcontrib_trio",
"details",
"exception_hierarchy",
"attributetable",
"resourcelinks",
"nitpick_file_ignorer",
]
autodoc_member_order = 'bysource'
autodoc_typehints = 'none'
autodoc_member_order = "bysource"
autodoc_typehints = "none"
# maybe consider this?
# napoleon_attr_annotations = False
extlinks = {
'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'GH-'),
"issue": ("https://github.com/Rapptz/discord.py/issues/%s", "GH-"),
}
# Links used for cross-referencing stuff in other documentation
intersphinx_mapping = {
'py': ('https://docs.python.org/3', None),
'aio': ('https://docs.aiohttp.org/en/stable/', None),
'req': ('https://docs.python-requests.org/en/latest/', None)
"py": ("https://docs.python.org/3", None),
"aio": ("https://docs.aiohttp.org/en/stable/", None),
"req": ("https://docs.python-requests.org/en/latest/", None),
}
rst_prolog = """
@@ -67,20 +67,20 @@ rst_prolog = """
"""
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix of source filenames.
source_suffix = '.rst'
source_suffix = ".rst"
# The encoding of source files.
# source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'discord.py'
copyright = '2015-present, Rapptz'
project = "discord.py"
copyright = "2015-present, Rapptz"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -88,15 +88,15 @@ copyright = '2015-present, Rapptz'
#
# The short X.Y version.
version = ''
with open('../discord/__init__.py') as f:
version = ""
with open("../discord/__init__.py") as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
# The full version, including alpha/beta/rc tags.
release = version
# This assumes a tag is available for final releases
branch = 'master' if version.endswith('a') else 'v' + version
branch = "master" if version.endswith("a") else "v" + version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -105,7 +105,7 @@ branch = 'master' if version.endswith('a') else 'v' + version
# Usually you set "language" from the command line for these cases.
language = None
locale_dirs = ['locale/']
locale_dirs = ["locale/"]
gettext_compact = False
# There are two options for replacing |today|: either, you set today to some
@@ -116,7 +116,7 @@ gettext_compact = False
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all
# documents.
@@ -134,7 +134,7 @@ exclude_patterns = ['_build']
# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'friendly'
pygments_style = "friendly"
# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
@@ -156,21 +156,21 @@ html_experimental_html5_writer = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'basic'
html_theme = "basic"
html_context = {
'discord_invite': 'https://discord.gg/r3sSKJJ',
'discord_extensions': [
('discord.ext.commands', 'ext/commands'),
('discord.ext.tasks', 'ext/tasks'),
"discord_invite": "https://discord.gg/TvqYBrGXEm",
"discord_extensions": [
("discord.ext.commands", "ext/commands"),
("discord.ext.tasks", "ext/tasks"),
],
}
resource_links = {
'discord': 'https://discord.gg/r3sSKJJ',
'issues': 'https://github.com/Rapptz/discord.py/issues',
'discussions': 'https://github.com/Rapptz/discord.py/discussions',
'examples': f'https://github.com/Rapptz/discord.py/tree/{branch}/examples',
"discord": "https://discord.gg/TvqYBrGXEm",
"issues": "https://github.com/Rapptz/discord.py/issues",
"discussions": "https://github.com/Rapptz/discord.py/discussions",
"examples": f"https://github.com/Rapptz/discord.py/tree/{branch}/examples",
}
# Theme options are theme-specific and customize the look and feel of a theme
@@ -196,12 +196,12 @@ resource_links = {
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
html_favicon = './images/discord_py_logo.ico'
html_favicon = "./images/discord_py_logo.ico"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
@@ -261,30 +261,22 @@ html_static_path = ['_static']
# The name of a javascript file (relative to the configuration directory) that
# implements a search results scorer. If empty, the default will be used.
html_search_scorer = '_static/scorer.js'
html_search_scorer = "_static/scorer.js"
html_js_files = [
'custom.js',
'settings.js',
'copy.js',
'sidebar.js'
]
html_js_files = ["custom.js", "settings.js", "copy.js", "sidebar.js"]
# Output file base name for HTML help builder.
htmlhelp_basename = 'discord.pydoc'
htmlhelp_basename = "discord.pydoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
@@ -293,8 +285,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
('index', 'discord.py.tex', 'discord.py Documentation',
'Rapptz', 'manual'),
("index", "discord.py.tex", "discord.py Documentation", "Rapptz", "manual"),
]
# The name of an image file (relative to this directory) to place at the top of
@@ -322,10 +313,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'discord.py', 'discord.py Documentation',
['Rapptz'], 1)
]
man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
# If true, show URL addresses after external links.
# man_show_urls = False
@@ -337,9 +325,15 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'discord.py', 'discord.py Documentation',
'Rapptz', 'discord.py', 'One line description of project.',
'Miscellaneous'),
(
"index",
"discord.py",
"discord.py Documentation",
"Rapptz",
"discord.py",
"One line description of project.",
"Miscellaneous",
),
]
# Documents to append as an appendix to all manuals.
@@ -354,8 +348,9 @@ texinfo_documents = [
# If true, do not generate a @detailmenu in the "Top" node's menu.
# texinfo_no_detailmenu = False
def setup(app):
if app.config.language == 'ja':
app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None)
app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg'
app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg'
if app.config.language == "ja":
app.config.intersphinx_mapping["py"] = ("https://docs.python.org/ja/3", None)
app.config.html_context["discord_invite"] = "https://discord.gg/TvqYBrGXEm"
app.config.resource_links["discord"] = "https://discord.gg/TvqYBrGXEm"

View File

@@ -429,6 +429,12 @@ Converters
.. autofunction:: discord.ext.commands.run_converters
Option
~~~~~~
.. autoclass:: discord.ext.commands.Option
:members:
Flag Converter
~~~~~~~~~~~~~~~

View File

@@ -61,6 +61,7 @@ the name to something other than the function would be as simple as doing this:
async def _list(ctx, arg):
pass
Parameters
------------
@@ -133,6 +134,11 @@ at all:
Since the ``args`` variable is a :class:`py:tuple`,
you can do anything you would usually do with one.
.. admonition:: Slash Command Only
This functionally is currently not supported by the slash command API, so is turned into
a single ``STRING`` parameter on discord's end which we do our own parsing on.
Keyword-Only Arguments
++++++++++++++++++++++++
@@ -179,6 +185,12 @@ know how the command was executed. It contains a lot of useful information:
The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you
can do on the :class:`~ext.commands.Context`.
.. admonition:: Slash Command Only
:attr:`.Context.message` will be fake if in a slash command, it is not
recommended to access if :attr:`.Context.interaction` is not None as most
methods will error due to the message not actually existing.
Converters
------------
@@ -400,47 +412,55 @@ specify.
Under the hood, these are implemented by the :ref:`ext_commands_adv_converters` interface. A table of the equivalent
converter is given below:
+--------------------------+-------------------------------------------------+
| Discord Class | Converter |
+--------------------------+-------------------------------------------------+
| :class:`Object` | :class:`~ext.commands.ObjectConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Member` | :class:`~ext.commands.MemberConverter` |
+--------------------------+-------------------------------------------------+
| :class:`User` | :class:`~ext.commands.UserConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Message` | :class:`~ext.commands.MessageConverter` |
+--------------------------+-------------------------------------------------+
| :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` |
+--------------------------+-------------------------------------------------+
| :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Invite` | :class:`~ext.commands.InviteConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Guild` | :class:`~ext.commands.GuildConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Role` | :class:`~ext.commands.RoleConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Game` | :class:`~ext.commands.GameConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Colour` | :class:`~ext.commands.ColourConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Emoji` | :class:`~ext.commands.EmojiConverter` |
+--------------------------+-------------------------------------------------+
| :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` |
+--------------------------+-------------------------------------------------+
+--------------------------+-------------------------------------------------+-----------------------------+
| Discord Class | Converter | Supported By Slash Commands |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Object` | :class:`~ext.commands.ObjectConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Member` | :class:`~ext.commands.MemberConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`User` | :class:`~ext.commands.UserConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Message` | :class:`~ext.commands.MessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Invite` | :class:`~ext.commands.InviteConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Guild` | :class:`~ext.commands.GuildConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Role` | :class:`~ext.commands.RoleConverter` | Yes, as type 8 (ROLE) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Game` | :class:`~ext.commands.GameConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Colour` | :class:`~ext.commands.ColourConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Emoji` | :class:`~ext.commands.EmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
.. admonition:: Slash Command Only
If a slash command is not marked on the table above as supported, it will be sent as type 3 (STRING)
and parsed by normal content parsing, see
`the discord documentation <https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-type>`_
for all supported types by the API.
By providing the converter it allows us to use them as building blocks for another converter:
@@ -487,6 +507,10 @@ then a special error is raised, :exc:`~ext.commands.BadUnionArgument`.
Note that any valid converter discussed above can be passed in to the argument list of a :data:`typing.Union`.
.. admonition:: Slash Command Only
These are not currently supported by the Discord API and will be sent as type 3 (STRING)
typing.Optional
^^^^^^^^^^^^^^^^^
@@ -680,6 +704,11 @@ In order to customise the flag syntax we also have a few options that can be pas
a command line parser. The syntax is mainly inspired by Discord's search bar input and as a result
all flags need a corresponding value.
.. admonition:: Slash Command Only
As these are built very similar to slash command options, they are converted into options and parsed
back into flags when the slash command is executed.
The flag converter is similar to regular commands and allows you to use most types of converters
(with the exception of :class:`~ext.commands.Greedy`) as the type annotation. Some extra support is added for specific
annotations as described below.

View File

@@ -15,4 +15,5 @@ extension library that handles this for you.
commands
cogs
extensions
slash-commands
api

View File

@@ -0,0 +1,23 @@
.. currentmodule:: discord
.. _ext_commands_slash_commands:
Slash Commands
==============
Slash Commands are currently supported in enhanced-discord.py using a system on top of ext.commands.
This system is very simple to use, and can be enabled via :attr:`.Bot.slash_commands` globally,
or only for specific commands via :attr:`.Command.slash_command`.
There is also the parameter ``slash_command_guilds`` which can be passed to either :class:`.Bot` or the command
decorator in order to only upload the commands as guild commands to these specific guild IDs, however this
should only be used for testing or small (<10 guilds) bots.
If you want to add option descriptions to your commands, you should use :class:`.Option`
For troubleshooting, see the :ref:`FAQ <ext_commands_slash_command_troubleshooting>`
.. admonition:: Slash Command Only
For parts of the docs specific to slash commands, look for this box!

View File

@@ -9,60 +9,78 @@ import inspect
import os
import re
class attributetable(nodes.General, nodes.Element):
pass
class attributetablecolumn(nodes.General, nodes.Element):
pass
class attributetabletitle(nodes.TextElement):
pass
class attributetableplaceholder(nodes.General, nodes.Element):
pass
class attributetablebadge(nodes.TextElement):
pass
class attributetable_item(nodes.Part, nodes.Element):
pass
def visit_attributetable_node(self, node):
class_ = node["python-class"]
self.body.append(f'<div class="py-attribute-table" data-move-to-id="{class_}">')
def visit_attributetablecolumn_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='py-attribute-table-column'))
self.body.append(self.starttag(node, "div", CLASS="py-attribute-table-column"))
def visit_attributetabletitle_node(self, node):
self.body.append(self.starttag(node, 'span'))
self.body.append(self.starttag(node, "span"))
def visit_attributetablebadge_node(self, node):
attributes = {
'class': 'py-attribute-table-badge',
'title': node['badge-type'],
"class": "py-attribute-table-badge",
"title": node["badge-type"],
}
self.body.append(self.starttag(node, 'span', **attributes))
self.body.append(self.starttag(node, "span", **attributes))
def visit_attributetable_item_node(self, node):
self.body.append(self.starttag(node, 'li', CLASS='py-attribute-table-entry'))
self.body.append(self.starttag(node, "li", CLASS="py-attribute-table-entry"))
def depart_attributetable_node(self, node):
self.body.append('</div>')
self.body.append("</div>")
def depart_attributetablecolumn_node(self, node):
self.body.append('</div>')
self.body.append("</div>")
def depart_attributetabletitle_node(self, node):
self.body.append('</span>')
self.body.append("</span>")
def depart_attributetablebadge_node(self, node):
self.body.append('</span>')
self.body.append("</span>")
def depart_attributetable_item_node(self, node):
self.body.append('</li>')
self.body.append("</li>")
_name_parser_regex = re.compile(r"(?P<module>[\w.]+\.)?(?P<name>\w+)")
_name_parser_regex = re.compile(r'(?P<module>[\w.]+\.)?(?P<name>\w+)')
class PyAttributeTable(SphinxDirective):
has_content = False
@@ -74,13 +92,13 @@ class PyAttributeTable(SphinxDirective):
def parse_name(self, content):
path, name = _name_parser_regex.match(content).groups()
if path:
modulename = path.rstrip('.')
modulename = path.rstrip(".")
else:
modulename = self.env.temp_data.get('autodoc:module')
modulename = self.env.temp_data.get("autodoc:module")
if not modulename:
modulename = self.env.ref_context.get('py:module')
modulename = self.env.ref_context.get("py:module")
if modulename is None:
raise RuntimeError('modulename somehow None for %s in %s.' % (content, self.env.docname))
raise RuntimeError("modulename somehow None for %s in %s." % (content, self.env.docname))
return modulename, name
@@ -112,29 +130,33 @@ class PyAttributeTable(SphinxDirective):
replaced.
"""
content = self.arguments[0].strip()
node = attributetableplaceholder('')
node = attributetableplaceholder("")
modulename, name = self.parse_name(content)
node['python-doc'] = self.env.docname
node['python-module'] = modulename
node['python-class'] = name
node['python-full-name'] = f'{modulename}.{name}'
node["python-doc"] = self.env.docname
node["python-module"] = modulename
node["python-class"] = name
node["python-full-name"] = f"{modulename}.{name}"
return [node]
def build_lookup_table(env):
# Given an environment, load up a lookup table of
# full-class-name: objects
result = {}
domain = env.domains['py']
domain = env.domains["py"]
ignored = {
'data', 'exception', 'module', 'class',
"data",
"exception",
"module",
"class",
}
for (fullname, _, objtype, docname, _, _) in domain.get_objects():
if objtype in ignored:
continue
classname, _, child = fullname.rpartition('.')
classname, _, child = fullname.rpartition(".")
try:
result[classname].append(child)
except KeyError:
@@ -143,36 +165,40 @@ def build_lookup_table(env):
return result
TableElement = namedtuple('TableElement', 'fullname label badge')
TableElement = namedtuple("TableElement", "fullname label badge")
def process_attributetable(app, doctree, fromdocname):
env = app.builder.env
lookup = build_lookup_table(env)
for node in doctree.traverse(attributetableplaceholder):
modulename, classname, fullname = node['python-module'], node['python-class'], node['python-full-name']
modulename, classname, fullname = node["python-module"], node["python-class"], node["python-full-name"]
groups = get_class_results(lookup, modulename, classname, fullname)
table = attributetable('')
table = attributetable("")
for label, subitems in groups.items():
if not subitems:
continue
table.append(class_results_to_node(label, sorted(subitems, key=lambda c: c.label)))
table['python-class'] = fullname
table["python-class"] = fullname
if not table:
node.replace_self([])
else:
node.replace_self([table])
def get_class_results(lookup, modulename, name, fullname):
module = importlib.import_module(modulename)
cls = getattr(module, name)
groups = OrderedDict([
(_('Attributes'), []),
(_('Methods'), []),
])
groups = OrderedDict(
[
(_("Attributes"), []),
(_("Methods"), []),
]
)
try:
members = lookup[fullname]
@@ -180,8 +206,8 @@ def get_class_results(lookup, modulename, name, fullname):
return groups
for attr in members:
attrlookup = f'{fullname}.{attr}'
key = _('Attributes')
attrlookup = f"{fullname}.{attr}"
key = _("Attributes")
badge = None
label = attr
value = None
@@ -192,53 +218,54 @@ def get_class_results(lookup, modulename, name, fullname):
break
if value is not None:
doc = value.__doc__ or ''
if inspect.iscoroutinefunction(value) or doc.startswith('|coro|'):
key = _('Methods')
badge = attributetablebadge('async', 'async')
badge['badge-type'] = _('coroutine')
doc = value.__doc__ or ""
if inspect.iscoroutinefunction(value) or doc.startswith("|coro|"):
key = _("Methods")
badge = attributetablebadge("async", "async")
badge["badge-type"] = _("coroutine")
elif isinstance(value, classmethod):
key = _('Methods')
label = f'{name}.{attr}'
badge = attributetablebadge('cls', 'cls')
badge['badge-type'] = _('classmethod')
key = _("Methods")
label = f"{name}.{attr}"
badge = attributetablebadge("cls", "cls")
badge["badge-type"] = _("classmethod")
elif inspect.isfunction(value):
if doc.startswith(('A decorator', 'A shortcut decorator')):
if doc.startswith(("A decorator", "A shortcut decorator")):
# finicky but surprisingly consistent
badge = attributetablebadge('@', '@')
badge['badge-type'] = _('decorator')
key = _('Methods')
badge = attributetablebadge("@", "@")
badge["badge-type"] = _("decorator")
key = _("Methods")
else:
key = _('Methods')
badge = attributetablebadge('def', 'def')
badge['badge-type'] = _('method')
key = _("Methods")
badge = attributetablebadge("def", "def")
badge["badge-type"] = _("method")
groups[key].append(TableElement(fullname=attrlookup, label=label, badge=badge))
return groups
def class_results_to_node(key, elements):
title = attributetabletitle(key, key)
ul = nodes.bullet_list('')
ul = nodes.bullet_list("")
for element in elements:
ref = nodes.reference('', '', internal=True,
refuri='#' + element.fullname,
anchorname='',
*[nodes.Text(element.label)])
para = addnodes.compact_paragraph('', '', ref)
ref = nodes.reference(
"", "", internal=True, refuri="#" + element.fullname, anchorname="", *[nodes.Text(element.label)]
)
para = addnodes.compact_paragraph("", "", ref)
if element.badge is not None:
ul.append(attributetable_item('', element.badge, para))
ul.append(attributetable_item("", element.badge, para))
else:
ul.append(attributetable_item('', para))
ul.append(attributetable_item("", para))
return attributetablecolumn("", title, ul)
return attributetablecolumn('', title, ul)
def setup(app):
app.add_directive('attributetable', PyAttributeTable)
app.add_directive("attributetable", PyAttributeTable)
app.add_node(attributetable, html=(visit_attributetable_node, depart_attributetable_node))
app.add_node(attributetablecolumn, html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node))
app.add_node(attributetabletitle, html=(visit_attributetabletitle_node, depart_attributetabletitle_node))
app.add_node(attributetablebadge, html=(visit_attributetablebadge_node, depart_attributetablebadge_node))
app.add_node(attributetable_item, html=(visit_attributetable_item_node, depart_attributetable_item_node))
app.add_node(attributetableplaceholder)
app.connect('doctree-resolved', process_attributetable)
app.connect("doctree-resolved", process_attributetable)

View File

@@ -2,15 +2,15 @@ from sphinx.builders.html import StandaloneHTMLBuilder
from sphinx.environment.adapters.indexentries import IndexEntries
from sphinx.writers.html5 import HTML5Translator
class DPYHTML5Translator(HTML5Translator):
def visit_section(self, node):
self.section_level += 1
self.body.append(
self.starttag(node, 'section'))
self.body.append(self.starttag(node, "section"))
def depart_section(self, node):
self.section_level -= 1
self.body.append('</section>\n')
self.body.append("</section>\n")
def visit_table(self, node):
self.body.append('<div class="table-wrapper">')
@@ -18,7 +18,8 @@ class DPYHTML5Translator(HTML5Translator):
def depart_table(self, node):
super().depart_table(node)
self.body.append('</div>')
self.body.append("</div>")
class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
# This is mostly copy pasted from Sphinx.
@@ -28,50 +29,48 @@ class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
genindex = IndexEntries(self.env).create_index(self, group_entries=False)
indexcounts = []
for _k, entries in genindex:
indexcounts.append(sum(1 + len(subitems)
for _, (_, subitems, _) in entries))
indexcounts.append(sum(1 + len(subitems) for _, (_, subitems, _) in entries))
genindexcontext = {
'genindexentries': genindex,
'genindexcounts': indexcounts,
'split_index': self.config.html_split_index,
"genindexentries": genindex,
"genindexcounts": indexcounts,
"split_index": self.config.html_split_index,
}
if self.config.html_split_index:
self.handle_page('genindex', genindexcontext,
'genindex-split.html')
self.handle_page('genindex-all', genindexcontext,
'genindex.html')
self.handle_page("genindex", genindexcontext, "genindex-split.html")
self.handle_page("genindex-all", genindexcontext, "genindex.html")
for (key, entries), count in zip(genindex, indexcounts):
ctx = {'key': key, 'entries': entries, 'count': count,
'genindexentries': genindex}
self.handle_page('genindex-' + key, ctx,
'genindex-single.html')
ctx = {"key": key, "entries": entries, "count": count, "genindexentries": genindex}
self.handle_page("genindex-" + key, ctx, "genindex-single.html")
else:
self.handle_page('genindex', genindexcontext, 'genindex.html')
self.handle_page("genindex", genindexcontext, "genindex.html")
def add_custom_jinja2(app):
env = app.builder.templates.environment
env.tests['prefixedwith'] = str.startswith
env.tests['suffixedwith'] = str.endswith
env.tests["prefixedwith"] = str.startswith
env.tests["suffixedwith"] = str.endswith
def add_builders(app):
"""This is necessary because RTD injects their own for some reason."""
app.set_translator('html', DPYHTML5Translator, override=True)
app.set_translator("html", DPYHTML5Translator, override=True)
app.add_builder(DPYStandaloneHTMLBuilder, override=True)
try:
original = app.registry.builders['readthedocs']
original = app.registry.builders["readthedocs"]
except KeyError:
pass
else:
injected_mro = tuple(base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder
for base in original.mro()[1:])
new_builder = type(original.__name__, injected_mro, {'name': 'readthedocs'})
app.set_translator('readthedocs', DPYHTML5Translator, override=True)
injected_mro = tuple(
base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder for base in original.mro()[1:]
)
new_builder = type(original.__name__, injected_mro, {"name": "readthedocs"})
app.set_translator("readthedocs", DPYHTML5Translator, override=True)
app.add_builder(new_builder, override=True)
def setup(app):
add_builders(app)
app.connect('builder-inited', add_custom_jinja2)
app.connect("builder-inited", add_custom_jinja2)

View File

@@ -3,32 +3,39 @@ from docutils.parsers.rst import states, directives
from docutils.parsers.rst.roles import set_classes
from docutils import nodes
class details(nodes.General, nodes.Element):
pass
class summary(nodes.General, nodes.Element):
pass
def visit_details_node(self, node):
self.body.append(self.starttag(node, 'details', CLASS=node.attributes.get('class', '')))
self.body.append(self.starttag(node, "details", CLASS=node.attributes.get("class", "")))
def visit_summary_node(self, node):
self.body.append(self.starttag(node, 'summary', CLASS=node.attributes.get('summary-class', '')))
self.body.append(self.starttag(node, "summary", CLASS=node.attributes.get("summary-class", "")))
self.body.append(node.rawsource)
def depart_details_node(self, node):
self.body.append('</details>\n')
self.body.append("</details>\n")
def depart_summary_node(self, node):
self.body.append('</summary>')
self.body.append("</summary>")
class DetailsDirective(Directive):
final_argument_whitespace = True
optional_arguments = 1
option_spec = {
'class': directives.class_option,
'summary-class': directives.class_option,
"class": directives.class_option,
"summary-class": directives.class_option,
}
has_content = True
@@ -37,7 +44,7 @@ class DetailsDirective(Directive):
set_classes(self.options)
self.assert_has_content()
text = '\n'.join(self.content)
text = "\n".join(self.content)
node = details(text, **self.options)
if self.arguments:
@@ -48,8 +55,8 @@ class DetailsDirective(Directive):
self.state.nested_parse(self.content, self.content_offset, node)
return [node]
def setup(app):
app.add_node(details, html=(visit_details_node, depart_details_node))
app.add_node(summary, html=(visit_summary_node, depart_summary_node))
app.add_directive('details', DetailsDirective)
app.add_directive("details", DetailsDirective)

View File

@@ -4,24 +4,29 @@ from docutils.parsers.rst.roles import set_classes
from docutils import nodes
from sphinx.locale import _
class exception_hierarchy(nodes.General, nodes.Element):
pass
def visit_exception_hierarchy_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='exception-hierarchy-content'))
self.body.append(self.starttag(node, "div", CLASS="exception-hierarchy-content"))
def depart_exception_hierarchy_node(self, node):
self.body.append('</div>\n')
self.body.append("</div>\n")
class ExceptionHierarchyDirective(Directive):
has_content = True
def run(self):
self.assert_has_content()
node = exception_hierarchy('\n'.join(self.content))
node = exception_hierarchy("\n".join(self.content))
self.state.nested_parse(self.content, self.content_offset, node)
return [node]
def setup(app):
app.add_node(exception_hierarchy, html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node))
app.add_directive('exception_hierarchy', ExceptionHierarchyDirective)
app.add_directive("exception_hierarchy", ExceptionHierarchyDirective)

View File

@@ -5,18 +5,17 @@ from sphinx.util import logging as sphinx_logging
class NitpickFileIgnorer(logging.Filter):
def __init__(self, app: Sphinx) -> None:
self.app = app
super().__init__()
def filter(self, record: sphinx_logging.SphinxLogRecord) -> bool:
if getattr(record, 'type', None) == 'ref':
return record.location.get('refdoc') not in self.app.config.nitpick_ignore_files
if getattr(record, "type", None) == "ref":
return record.location.get("refdoc") not in self.app.config.nitpick_ignore_files
return True
def setup(app: Sphinx):
app.add_config_value('nitpick_ignore_files', [], '')
app.add_config_value("nitpick_ignore_files", [], "")
f = NitpickFileIgnorer(app)
sphinx_logging.getLogger('sphinx.transforms.post_transforms').logger.addFilter(f)
sphinx_logging.getLogger("sphinx.transforms.post_transforms").logger.addFilter(f)

View File

@@ -16,13 +16,7 @@ from sphinx.util.typing import RoleFunction
def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
def role(
typ: str,
rawtext: str,
text: str,
lineno: int,
inliner: Inliner,
options: Dict = {},
content: List[str] = []
typ: str, rawtext: str, text: str, lineno: int, inliner: Inliner, options: Dict = {}, content: List[str] = []
) -> Tuple[List[Node], List[system_message]]:
text = utils.unescape(text)
@@ -32,13 +26,15 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
title = full_url
pnode = nodes.reference(title, title, internal=False, refuri=full_url)
return [pnode], []
return role
def add_link_role(app: Sphinx) -> None:
app.add_role('resource', make_link_role(app.config.resource_links))
app.add_role("resource", make_link_role(app.config.resource_links))
def setup(app: Sphinx) -> Dict[str, Any]:
app.add_config_value('resource_links', {}, 'env')
app.connect('builder-inited', add_link_role)
return {'version': sphinx.__display_version__, 'parallel_read_safe': True}
app.add_config_value("resource_links", {}, "env")
app.connect("builder-inited", add_link_role)
return {"version": sphinx.__display_version__, "parallel_read_safe": True}

View File

@@ -410,3 +410,34 @@ Example: ::
await ctx.send(f'Pushing to {remote} {branch}')
This could then be used as ``?git push origin master``.
How do I make slash commands?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :doc:`/ext/commands/slash-commands`
My slash commands aren't showing up!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _ext_commands_slash_command_troubleshooting:
You need to invite your bot with the ``application.commands`` scope on each guild and
you need the :attr:`Permissions.use_slash_commands` permission in order to see slash commands.
.. image:: /images/discord_oauth2_slash_scope.png
:alt: The scopes checkbox with "bot" and "applications.commands" ticked.
Global slash commands (created by not specifying :attr:`~ext.commands.Bot.slash_command_guilds`) will also take up an
hour to refresh on discord's end, so it is recommended to set :attr:`~ext.commands.Bot.slash_command_guilds` for development.
If none of this works, make sure you are actually running enhanced-discord.py by doing ``print(bot.slash_commands)``
My bot won't start after enabling slash commands!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This means some of your command metadata is invalid for slash commands.
Make sure your command names and option names are lowercase, and they have to match the regex ``^[\w-]{1,32}$``
If you cannot figure out the problem, you should disable slash commands globally (:attr:`~ext.commands.Bot.slash_commands`\=False)
then go through commands, enabling them specifically with :attr:`~.commands.Command.slash_command`\=True until it
errors, then you can debug the problem with that command specifically.

View File

@@ -2,6 +2,7 @@ from discord.ext import tasks
import discord
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -13,8 +14,8 @@ class MyClient(discord.Client):
self.my_background_task.start()
async def on_ready(self):
print(f'Logged in as {self.user} (ID: {self.user.id})')
print('------')
print(f"Logged in as {self.user} (ID: {self.user.id})")
print("------")
@tasks.loop(seconds=60) # task runs every 60 seconds
async def my_background_task(self):
@@ -26,5 +27,6 @@ class MyClient(discord.Client):
async def before_my_task(self):
await self.wait_until_ready() # wait until the bot logs in
client = MyClient(intents=discord.Intents(guilds=True))
client.run('token')
client.run("token")

Some files were not shown because too many files have changed in this diff Show More